Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions docs/configuration.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,28 @@ paths:
| `graph` | str | `${run_dir}/graph` | Trained models directory |
| `samples` | str | `${run_dir}/samples` | Output samples directory |

### `sample`

Configure sampling for each sample type:

```yaml
sample:
prior:
n: 64
posterior:
n: 100
resimulate: false # if true, also forward-simulate from each posterior sample
exclude_keys: [] # node names to drop from saved NPZ files
add_keys: [] # node names to add beyond the default set
```

| Key | Type | Default | Description |
|-----|------|---------|-------------|
| `n` | int | — | Number of samples to generate |
| `resimulate` | bool | `false` | **(posterior only)** After drawing posterior samples, run each through the forward simulator to produce predicted observations. Saves all node values (latent + simulated). Useful for posterior predictive checks. |
| `exclude_keys` | list/str | `[]` | Node names to exclude from saved NPZ files |
| `add_keys` | list/str | `[]` | Node names to add beyond the default set |

### `buffer`

Configure the rolling sample buffer that feeds training. Falcon continuously simulates new samples in the background while training runs concurrently.
Expand Down
20 changes: 15 additions & 5 deletions falcon/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,11 +414,15 @@ def _save_samples(samples, sample_cfg, sample_type, graph, cfg, info_fn=print):
# Default: save all .value keys
default_keys = set(node_keys)
elif sample_type == "posterior":
# Default: save only posterior nodes (nodes with evidence)
default_keys = {
f"{k}.value" for k, node in graph.node_dict.items()
if node.evidence and f"{k}.value" in samples
}
if sample_cfg.get("resimulate", False):
# resimulate=True: save all nodes (latent + forward-simulated)
default_keys = set(node_keys)
else:
# Default: save only posterior nodes (nodes with evidence)
default_keys = {
f"{k}.value" for k, node in graph.node_dict.items()
if node.evidence and f"{k}.value" in samples
}
else:
default_keys = set(node_keys)

Expand Down Expand Up @@ -688,6 +692,9 @@ def stop_check():
info(f"Generating {num_posterior_samples} posterior samples...")

sample_refs = deployed_graph.sample_posterior(num_posterior_samples, observations)
if sample_cfg.get("resimulate", False):
info("Resimulating forward model from posterior samples...")
sample_refs = deployed_graph.resimulate_posterior(sample_refs)
samples = deployed_graph._refs_to_arrays(sample_refs)

# Save posterior samples
Expand Down Expand Up @@ -813,6 +820,9 @@ def sample_mode(cfg, sample_type: str) -> None:
elif sample_type == "posterior":
deployed_graph.load(Path(cfg.paths.graph))
sample_refs = deployed_graph.sample_posterior(num_samples, observations)
if sample_cfg.get("resimulate", False):
info("Resimulating forward model from posterior samples...")
sample_refs = deployed_graph.resimulate_posterior(sample_refs)

elif sample_type == "proposal":
deployed_graph.load(Path(cfg.paths.graph))
Expand Down
33 changes: 33 additions & 0 deletions falcon/core/deployed_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -621,6 +621,39 @@ def _extract_value_refs(self, sample_refs):
result[node_name] = [d[key] for d in sample_refs]
return result

def resimulate_posterior(self, posterior_refs):
"""Forward-simulate non-latent nodes conditioned on posterior samples.

Mirrors the adaptive training loop pattern (see launch() adaptive loop).

Args:
posterior_refs: List[Dict[str, ObjectRef]] from sample_posterior()

Returns:
List[Dict[str, ObjectRef]] with all node values:
latent nodes (from posterior) + non-latent nodes (freshly simulated)
"""
num_samples = len(posterior_refs)

# Extract latent node value refs (nodes that have an estimator)
condition_refs = self._extract_value_refs(posterior_refs)
latent_nodes = {n.name for n in self.graph.node_list
if n.estimator_cls is not None}
condition_refs = {k: v for k, v in condition_refs.items()
if k in latent_nodes}

# Re-simulate non-latent nodes (forward model) conditioned on posterior values
fwd_refs = self._execute_graph(
num_samples, self.graph.forward_order, condition_refs, "sample"
)

# _execute_graph skips pre-conditioned nodes (FIXME in _execute_graph),
# so merge latent node outputs back from the original posterior refs.
for i, post_ref in enumerate(posterior_refs):
fwd_refs[i].update({k: v for k, v in post_ref.items()
if k.split(".")[0] in latent_nodes})
return fwd_refs

def _execute_graph(self, num_samples, node_order, condition_refs, sample_method):
"""Execute graph traversal with specified sampling method.

Expand Down