Skip to content

Add observer nodes for training-time diagnostics (e.g. MIST misspecification tests) #40

@cweniger

Description

@cweniger

Motivation

Model misspecification is a critical blind spot in SBI: if the simulator is wrong, the posterior is wrong. The MIST paper proposes training neural distortion detectors alongside inference to test for model misspecification. More generally, there's a class of diagnostics that need to consume simulation data during training without producing output for other nodes (coverage checks, calibration diagnostics, anomaly detectors, etc.).

The current graph architecture is ~95% of the way there. This issue proposes a small extension to support "observer nodes" — nodes that train on buffered data but don't participate in forward simulation or inference.

Proposed Design

Config: observe key in the graph: section

graph:
  theta:
    evidence: [x]
    simulator: ...
    estimator: ...

  x:
    parents: [theta]
    simulator: model.Simulate
    observed: "./data/obs.npz['x']"

  mist:                              # Observer node
    observe: [x]                     # Declares inputs, marks as observer
    estimator:
      _target_: model.MistDetector   # User-provided estimator
      # ... user config ...
    ray:
      num_gpus: 0

A node with observe is an observer: no simulator, no parents, no evidence. It reads data from the buffer and trains alongside other nodes.

Changes (~50 lines across 3 files)

falcon/core/graph.py:

  • Add observe to Node and _VALID_NODE_KEYS
  • Validation: observer must NOT have simulator/parents/evidence, MUST have estimator
  • Exclude observers from both topological sorts (forward + inference)
  • Add Graph.observer_dict for tracking

falcon/core/deployed_graph.py:

  • NodeWrapper.__init__: guard for simulator_cls=None, set condition_keys from observe, pass theta_key=None for observers
  • Add NodeWrapper.evaluate() method (calls estimator.evaluate() if it exists)
  • DeployedGraph._launch: evaluate observers after training completes, log results

falcon/estimators/base.py:

  • One-line guard: self.param_dim = simulator_instance.param_dim if simulator_instance is not None else None

No new base classes

Users subclass StepwiseEstimator directly. The only thing they override is train() to pass condition-only keys to the existing _train():

class MistDetector(StepwiseEstimator):
    async def train(self, buffer):
        keys = [f"{k}.value" for k in self.condition_keys]
        await self._train(buffer, self.loop_config, keys)

    def train_step(self, batch):
        x = batch[f"{self.condition_keys[0]}.value"]
        # ... generate distortions, compute loss ...
        return {"loss": loss.item()}

    def evaluate(self, observations):
        # ... run best model on x_obs ...
        return {"snr_0": ..., "t_sum": ...}

Everything else (epoch loop, early stopping, pause/resume, monitoring, save/load) comes for free from StepwiseEstimator.

Example

A working MIST example (examples/07_mist/) with config, MistDetector subclass, distortion class, and mock data.

Key Properties

  • Minimal core changes: 3 files, ~50 lines — no new modules or base classes
  • Falcon provides plumbing, users provide implementations: same pattern as simulators/estimators
  • Generalizes beyond MIST: any train-time diagnostic fits the observer pattern
  • Observers get monitoring, save/load, pause/resume for free via existing StepwiseEstimator infrastructure

References

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions