diff --git a/docs/configuration.md b/docs/configuration.md index 309c85e..834ee13 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -42,6 +42,28 @@ paths: | `graph` | str | `${run_dir}/graph` | Trained models directory | | `samples` | str | `${run_dir}/samples` | Output samples directory | +### `sample` + +Configure sampling for each sample type: + +```yaml +sample: + prior: + n: 64 + posterior: + n: 100 + resimulate: false # if true, also forward-simulate from each posterior sample + exclude_keys: [] # node names to drop from saved NPZ files + add_keys: [] # node names to add beyond the default set +``` + +| Key | Type | Default | Description | +|-----|------|---------|-------------| +| `n` | int | — | Number of samples to generate | +| `resimulate` | bool | `false` | **(posterior only)** After drawing posterior samples, run each through the forward simulator to produce predicted observations. Saves all node values (latent + simulated). Useful for posterior predictive checks. | +| `exclude_keys` | list/str | `[]` | Node names to exclude from saved NPZ files | +| `add_keys` | list/str | `[]` | Node names to add beyond the default set | + ### `buffer` Configure the rolling sample buffer that feeds training. Falcon continuously simulates new samples in the background while training runs concurrently. diff --git a/falcon/cli.py b/falcon/cli.py index b161301..4ee703b 100644 --- a/falcon/cli.py +++ b/falcon/cli.py @@ -414,11 +414,15 @@ def _save_samples(samples, sample_cfg, sample_type, graph, cfg, info_fn=print): # Default: save all .value keys default_keys = set(node_keys) elif sample_type == "posterior": - # Default: save only posterior nodes (nodes with evidence) - default_keys = { - f"{k}.value" for k, node in graph.node_dict.items() - if node.evidence and f"{k}.value" in samples - } + if sample_cfg.get("resimulate", False): + # resimulate=True: save all nodes (latent + forward-simulated) + default_keys = set(node_keys) + else: + # Default: save only posterior nodes (nodes with evidence) + default_keys = { + f"{k}.value" for k, node in graph.node_dict.items() + if node.evidence and f"{k}.value" in samples + } else: default_keys = set(node_keys) @@ -688,6 +692,9 @@ def stop_check(): info(f"Generating {num_posterior_samples} posterior samples...") sample_refs = deployed_graph.sample_posterior(num_posterior_samples, observations) + if sample_cfg.get("resimulate", False): + info("Resimulating forward model from posterior samples...") + sample_refs = deployed_graph.resimulate_posterior(sample_refs) samples = deployed_graph._refs_to_arrays(sample_refs) # Save posterior samples @@ -813,6 +820,9 @@ def sample_mode(cfg, sample_type: str) -> None: elif sample_type == "posterior": deployed_graph.load(Path(cfg.paths.graph)) sample_refs = deployed_graph.sample_posterior(num_samples, observations) + if sample_cfg.get("resimulate", False): + info("Resimulating forward model from posterior samples...") + sample_refs = deployed_graph.resimulate_posterior(sample_refs) elif sample_type == "proposal": deployed_graph.load(Path(cfg.paths.graph)) diff --git a/falcon/core/deployed_graph.py b/falcon/core/deployed_graph.py index d4d7726..6afdc12 100644 --- a/falcon/core/deployed_graph.py +++ b/falcon/core/deployed_graph.py @@ -621,6 +621,39 @@ def _extract_value_refs(self, sample_refs): result[node_name] = [d[key] for d in sample_refs] return result + def resimulate_posterior(self, posterior_refs): + """Forward-simulate non-latent nodes conditioned on posterior samples. + + Mirrors the adaptive training loop pattern (see launch() adaptive loop). + + Args: + posterior_refs: List[Dict[str, ObjectRef]] from sample_posterior() + + Returns: + List[Dict[str, ObjectRef]] with all node values: + latent nodes (from posterior) + non-latent nodes (freshly simulated) + """ + num_samples = len(posterior_refs) + + # Extract latent node value refs (nodes that have an estimator) + condition_refs = self._extract_value_refs(posterior_refs) + latent_nodes = {n.name for n in self.graph.node_list + if n.estimator_cls is not None} + condition_refs = {k: v for k, v in condition_refs.items() + if k in latent_nodes} + + # Re-simulate non-latent nodes (forward model) conditioned on posterior values + fwd_refs = self._execute_graph( + num_samples, self.graph.forward_order, condition_refs, "sample" + ) + + # _execute_graph skips pre-conditioned nodes (FIXME in _execute_graph), + # so merge latent node outputs back from the original posterior refs. + for i, post_ref in enumerate(posterior_refs): + fwd_refs[i].update({k: v for k, v in post_ref.items() + if k.split(".")[0] in latent_nodes}) + return fwd_refs + def _execute_graph(self, num_samples, node_order, condition_refs, sample_method): """Execute graph traversal with specified sampling method.