Motivation
Model misspecification is a critical blind spot in SBI: if the simulator is wrong, the posterior is wrong. The MIST paper proposes training neural distortion detectors alongside inference to test for model misspecification. More generally, there's a class of diagnostics that need to consume simulation data during training without producing output for other nodes (coverage checks, calibration diagnostics, anomaly detectors, etc.).
The current graph architecture is ~95% of the way there. This issue proposes a small extension to support "observer nodes" — nodes that train on buffered data but don't participate in forward simulation or inference.
Proposed Design
Config: observe key in the graph: section
graph:
theta:
evidence: [x]
simulator: ...
estimator: ...
x:
parents: [theta]
simulator: model.Simulate
observed: "./data/obs.npz['x']"
mist: # Observer node
observe: [x] # Declares inputs, marks as observer
estimator:
_target_: model.MistDetector # User-provided estimator
# ... user config ...
ray:
num_gpus: 0
A node with observe is an observer: no simulator, no parents, no evidence. It reads data from the buffer and trains alongside other nodes.
Changes (~50 lines across 3 files)
falcon/core/graph.py:
- Add
observe to Node and _VALID_NODE_KEYS
- Validation: observer must NOT have
simulator/parents/evidence, MUST have estimator
- Exclude observers from both topological sorts (forward + inference)
- Add
Graph.observer_dict for tracking
falcon/core/deployed_graph.py:
NodeWrapper.__init__: guard for simulator_cls=None, set condition_keys from observe, pass theta_key=None for observers
- Add
NodeWrapper.evaluate() method (calls estimator.evaluate() if it exists)
DeployedGraph._launch: evaluate observers after training completes, log results
falcon/estimators/base.py:
- One-line guard:
self.param_dim = simulator_instance.param_dim if simulator_instance is not None else None
No new base classes
Users subclass StepwiseEstimator directly. The only thing they override is train() to pass condition-only keys to the existing _train():
class MistDetector(StepwiseEstimator):
async def train(self, buffer):
keys = [f"{k}.value" for k in self.condition_keys]
await self._train(buffer, self.loop_config, keys)
def train_step(self, batch):
x = batch[f"{self.condition_keys[0]}.value"]
# ... generate distortions, compute loss ...
return {"loss": loss.item()}
def evaluate(self, observations):
# ... run best model on x_obs ...
return {"snr_0": ..., "t_sum": ...}
Everything else (epoch loop, early stopping, pause/resume, monitoring, save/load) comes for free from StepwiseEstimator.
Example
A working MIST example (examples/07_mist/) with config, MistDetector subclass, distortion class, and mock data.
Key Properties
- Minimal core changes: 3 files, ~50 lines — no new modules or base classes
- Falcon provides plumbing, users provide implementations: same pattern as simulators/estimators
- Generalizes beyond MIST: any train-time diagnostic fits the observer pattern
- Observers get monitoring, save/load, pause/resume for free via existing
StepwiseEstimator infrastructure
References
Motivation
Model misspecification is a critical blind spot in SBI: if the simulator is wrong, the posterior is wrong. The MIST paper proposes training neural distortion detectors alongside inference to test for model misspecification. More generally, there's a class of diagnostics that need to consume simulation data during training without producing output for other nodes (coverage checks, calibration diagnostics, anomaly detectors, etc.).
The current graph architecture is ~95% of the way there. This issue proposes a small extension to support "observer nodes" — nodes that train on buffered data but don't participate in forward simulation or inference.
Proposed Design
Config:
observekey in thegraph:sectionA node with
observeis an observer: nosimulator, noparents, noevidence. It reads data from the buffer and trains alongside other nodes.Changes (~50 lines across 3 files)
falcon/core/graph.py:observetoNodeand_VALID_NODE_KEYSsimulator/parents/evidence, MUST haveestimatorGraph.observer_dictfor trackingfalcon/core/deployed_graph.py:NodeWrapper.__init__: guard forsimulator_cls=None, setcondition_keysfromobserve, passtheta_key=Nonefor observersNodeWrapper.evaluate()method (callsestimator.evaluate()if it exists)DeployedGraph._launch: evaluate observers after training completes, log resultsfalcon/estimators/base.py:self.param_dim = simulator_instance.param_dim if simulator_instance is not None else NoneNo new base classes
Users subclass
StepwiseEstimatordirectly. The only thing they override istrain()to pass condition-only keys to the existing_train():Everything else (epoch loop, early stopping, pause/resume, monitoring, save/load) comes for free from
StepwiseEstimator.Example
A working MIST example (
examples/07_mist/) with config,MistDetectorsubclass, distortion class, and mock data.Key Properties
StepwiseEstimatorinfrastructureReferences