Skip to content

Refactor graph sampling: DAG class, _run primitive, _resimulate composition #45

@cweniger

Description

@cweniger

Context

The proposal re-simulation path in deployed_graph.py had a data corruption bug (fixed in 020fa4f): observation-based intermediate values for deterministic nodes (e.g. tokens) overwrote freshly re-simulated ones. The fix works but the underlying design is hacky — manual filtering by estimator_cls is not None, string-splitting keys, post-hoc stitching.

This issue tracks a refactoring to clean up the abstractions.

Design decisions

  • Use graph structure (n.evidence) to identify inference nodes, not implementation detail (estimator_cls)
  • _execute_graph renamed to _run — a simple primitive (num_samples, order, conditions, method)
  • _resimulate composes two explicit _run calls: backward (proposal) → forward (re-simulation)
  • No keep parameter on _run — the caller (_resimulate) handles result composition explicitly
  • Extract DAG class from Graph for deps + order + subgraph(), used as graph.forward and graph.backward

Changes

1. falcon/core/graph.py — Extract DAG class, refactor Graph

DAG class: General-purpose directed acyclic graph owning one dependency structure.

  • __init__(names, deps) — stores deps dict, computes topo-sorted order
  • subgraph(targets, conditions) — returns minimal topo-sorted order to compute targets given conditions

Graph refactored to use two DAG instances:

  • self.forward = DAG(names, {n.name: n.parents ...}) — simulation direction
  • self.backward = DAG(backward_names, backward_deps) — inference direction
  • Add inference_nodes property: {n.name for n in self.node_list if n.evidence}
  • Backward construction logic moved to _build_backward() method (same algorithm)
  • Compatibility aliases: forward_order, backward_order, forward_deps as properties

2. falcon/core/deployed_graph.py_run + _resimulate

Rename _execute_graph_run: Same body, cleaner name. Conditioned nodes are skipped (not in output).

Add _resimulate(num_samples, observations):

def _resimulate(self, num_samples, observations):
    inference_nodes = self.graph.inference_nodes
    obs_refs = self._arrays_to_condition_refs(observations, num_samples)

    # Backward: run inference graph
    backward_refs = self._run(
        num_samples, self.graph.backward.order, obs_refs, "sample_proposal"
    )

    # Extract inference node values as conditions for forward pass
    inference_refs = {
        name: [d[f'{name}.value'] for d in backward_refs]
        for name in inference_nodes
    }

    # Forward: re-simulate everything conditioned on inference values
    sample_refs = self._run(
        num_samples, self.graph.forward.order, inference_refs, "sample"
    )

    # Add inference node results (value + log_prob) from backward pass
    for i in range(len(sample_refs)):
        for name in inference_nodes:
            for suffix in ('.value', '.log_prob'):
                key = f'{name}{suffix}'
                if key in backward_refs[i]:
                    sample_refs[i][key] = backward_refs[i][key]

    return sample_refs

Simplify _launch re-simulation block (lines ~786-807) to:

sample_refs = self._resimulate(num_new_samples, observations)
ray.get(dataset_manager.append_refs.remote(sample_refs))

Remove _extract_value_refs — no longer needed.

Update callers of _execute_graph_run in sample(), sample_posterior(), sample_proposal().

Verification

  1. falcon launch --config-name config2.yml in example 06 — training converges after proposal re-simulation
  2. falcon launch in example 06 — config.yml no regression
  3. falcon sample prior / falcon sample posterior — correct keys in saved NPZ

🤖 Generated with Claude Code

Metadata

Metadata

Assignees

No one assigned

    Labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions