Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions docs/configuration.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,16 +31,18 @@ Configure file paths:

```yaml
paths:
import: "."
graph: ${run_dir}/graph
imports: ["."]
graph: ${run_dir}/graph
samples: ${run_dir}/samples
buffer: ${run_dir}/buffer # optional; redirect to a separate volume (e.g. scratch)
```

| Key | Type | Default | Description |
|-----|------|---------|-------------|
| `import` | str | `"."` | Path to import custom modules |
| `graph` | str | `${run_dir}/graph` | Trained models directory |
| `imports` | list[str] | `null` | Directories prepended to `sys.path` in Ray workers so custom modules (e.g. `model.Simulator`) can be imported |
| `graph` | str | `${run_dir}/graph` | Trained model checkpoints directory |
| `samples` | str | `${run_dir}/samples` | Output samples directory |
| `buffer` | str | `${run_dir}/buffer` | Buffer snapshots directory (`snapshots/` is appended); useful for routing large temporary simulation data to a separate scratch volume while keeping `run_dir` on persistent storage |

### `buffer`

Expand All @@ -65,7 +67,7 @@ buffer:
| `simulate_count` | int | `64` | Number of new samples generated per simulation round. For simulators taking >1s per sample, keep this small (4–16) to avoid long delays between buffer updates; for fast simulators, increase to reduce Ray overhead. |
| `simulate_interval` | float | `1` | Seconds between simulation rounds |
| `simulate_when_full` | bool | `true` | If `true`, simulation continues after `max_samples` is reached and old samples are replaced; if `false`, simulation stops once the buffer is full |
| `snapshot_every` | int | `0` | Save every Nth sample to `buffer/snapshots/` for inspection (0 = disabled, 1 = all, 10 = every 10th sample) |
| `snapshot_every` | int | `0` | Save every Nth sample to `{paths.buffer}/snapshots/` for inspection (0 = disabled, 1 = all, 10 = every 10th sample) |

### `graph`

Expand Down
2 changes: 1 addition & 1 deletion examples/01_minimal/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ logging:
# Directory configuration
# -----------------------------------------------------------------------------
paths:
import: "./src" # Local folder(s) with user-defined code (e.g. model.py)
imports: ["./src"] # Local folder(s) with user-defined code (e.g. model.py)
graph: ${run_dir}/graph # Directory for serialized graph and trained networks
samples: ${run_dir}/samples # Directory for generated samples (posterior, prior, etc.)

Expand Down
2 changes: 1 addition & 1 deletion examples/02_bimodal/config_amortized.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ logging:

# Directory configuration
paths:
import: "./src"
imports: ["./src"]
graph: ${run_dir}/graph
samples: ${run_dir}/samples

Expand Down
2 changes: 1 addition & 1 deletion examples/02_bimodal/config_regular.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ logging:

# Directory configuration
paths:
import: "./src"
imports: ["./src"]
graph: ${run_dir}/graph
samples: ${run_dir}/samples

Expand Down
2 changes: 1 addition & 1 deletion examples/02_bimodal/config_rounds_fill.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ logging:

# Directory configuration
paths:
import: "./src"
imports: ["./src"]
graph: ${run_dir}/graph
samples: ${run_dir}/samples

Expand Down
2 changes: 1 addition & 1 deletion examples/02_bimodal/config_rounds_renew.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ logging:

# Directory configuration
paths:
import: "./src"
imports: ["./src"]
graph: ${run_dir}/graph
samples: ${run_dir}/samples

Expand Down
2 changes: 1 addition & 1 deletion examples/03_composite/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ logging:

# Directory configuration
paths:
import: "./src"
imports: ["./src"]
graph: ${run_dir}/graph
samples: ${run_dir}/samples

Expand Down
2 changes: 1 addition & 1 deletion examples/04_gaussian/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ logging:
# Directory configuration
# -----------------------------------------------------------------------------
paths:
import: "./src"
imports: ["./src"]
graph: ${run_dir}/graph
samples: ${run_dir}/samples

Expand Down
2 changes: 1 addition & 1 deletion examples/05_linear_regression/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ logging:
# Directory configuration
# -----------------------------------------------------------------------------
paths:
import: "./src"
imports: ["./src"]
graph: ${run_dir}/graph
samples: ${run_dir}/samples

Expand Down
2 changes: 2 additions & 0 deletions falcon/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
"Node", "Graph", "CompositeNode",
"DeployedGraph",
"get_ray_dataset_manager",
"PathConfig",
"BufferConfig",
"LazyLoader",
"Logger", "get_logger", "set_logger", "log", "debug", "info", "warning", "error",
Expand All @@ -26,6 +27,7 @@
"CompositeNode": ".core.graph",
"DeployedGraph": ".core.deployed_graph",
"get_ray_dataset_manager": ".core.raystore",
"PathConfig": ".core.raystore",
"BufferConfig": ".core.raystore",
"LazyLoader": ".core.utils",
"Logger": ".core.logger",
Expand Down
41 changes: 27 additions & 14 deletions falcon/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,16 @@ def load_config(config_name: str = "config.yml", run_dir: str = None, overrides:
return cfg


def _resolve_paths(cfg):
"""Merge cfg.paths against PathConfig and return a plain dict."""
from omegaconf import OmegaConf
from falcon.core.raystore import PathConfig as _PathConfig
return OmegaConf.to_container(
OmegaConf.merge(OmegaConf.structured(_PathConfig), cfg.paths),
resolve=True,
)


class TeeOutput:
"""Write to both terminal and log file."""
def __init__(self, log_file, terminal):
Expand Down Expand Up @@ -318,9 +328,9 @@ def _build_run_summary(status, output_dir, cfg, deployed_graph, start_time=None,
lines.append("=" * 60)
lines.append(f"falcon launch {status}")
lines.append(f"Output: {output_dir}")
samples_path = cfg.paths.get("samples", f"{cfg.run_dir}/samples")
lines.append(f"Samples: {samples_path}")
graph_path = Path(cfg.paths.graph)
paths = _resolve_paths(cfg)
lines.append(f"Samples: {paths['samples'] or f'{cfg.run_dir}/samples'}")
graph_path = Path(paths['graph'])
lines.append(f"Logs: {graph_path / 'driver' / 'output.log'} (driver)")
try:
node_names = list(cfg.graph.keys())
Expand Down Expand Up @@ -453,7 +463,7 @@ def _save_samples(samples, sample_cfg, sample_type, graph, cfg, info_fn=print):
info_fn(f" {key}: {value.shape}")

# Determine output directory (flat structure)
samples_dir = cfg.paths.get("samples", f"{cfg.run_dir}/samples")
samples_dir = _resolve_paths(cfg)["samples"] or f"{cfg.run_dir}/samples"
output_dir = Path(samples_dir) / sample_type
output_dir.mkdir(parents=True, exist_ok=True)

Expand Down Expand Up @@ -505,6 +515,7 @@ def launch_mode(cfg, interactive: bool = False, log_lines: int = 16, auto_sample

# Get output directory from config
output_dir = Path(cfg.run_dir)
path_cfg = _resolve_paths(cfg)

# Generate wandb group if not set - use run-dir folder name
logging_cfg = OmegaConf.to_container(cfg.get("logging", {}), resolve=True)
Expand All @@ -514,7 +525,7 @@ def launch_mode(cfg, interactive: bool = False, log_lines: int = 16, auto_sample
logging_cfg.setdefault("wandb", {})["group"] = output_dir.name

# Ensure local dir is set to graph path
logging_cfg.setdefault("local", {})["dir"] = str(cfg.paths.graph)
logging_cfg.setdefault("local", {})["dir"] = path_cfg["graph"]

# Create driver logger and set as module-level logger
# This enables falcon.info(), falcon.log() etc. for DeployedGraph and other components
Expand Down Expand Up @@ -588,7 +599,7 @@ def launch_mode(cfg, interactive: bool = False, log_lines: int = 16, auto_sample

# Start status polling thread for interactive mode
status_thread = None
graph_path = Path(cfg.paths.graph)
graph_path = Path(path_cfg["graph"])
if display:
# Set log directory so display can read node output.log files
display.set_log_dir(str(graph_path))
Expand Down Expand Up @@ -660,17 +671,18 @@ def stop_check():
# 1) Deploy graph (pass logging config)
deployed_graph = falcon.DeployedGraph(
graph,
model_path=cfg.paths.get("import"),
import_dirs=path_cfg["imports"],
log_config=logging_cfg,
)

# 2) Prepare dataset manager for deployed graph and store initial samples
from omegaconf import OmegaConf as _OmegaConf
from falcon.core.raystore import BufferConfig as _BufferConfig
buffer_cfg = _OmegaConf.merge(_OmegaConf.structured(_BufferConfig), cfg.buffer)
buffer_base = path_cfg["buffer"] or str(Path(cfg.run_dir) / "buffer")
dataset_manager = falcon.get_ray_dataset_manager(
buffer_cfg,
snapshots_path=str(Path(cfg.run_dir) / "buffer" / "snapshots"),
snapshots_path=str(Path(buffer_base) / "snapshots"),
log_config=logging_cfg,
)

Expand Down Expand Up @@ -780,8 +792,9 @@ def sample_mode(cfg, sample_type: str) -> None:
from falcon.core.logger import Logger, set_logger, info

# Setup logging config
path_cfg = _resolve_paths(cfg)
logging_cfg = OmegaConf.to_container(cfg.get("logging", {}), resolve=True)
logging_cfg.setdefault("local", {})["dir"] = str(cfg.paths.graph)
logging_cfg.setdefault("local", {})["dir"] = path_cfg["graph"]

# Create driver logger and set as module-level logger
driver_logger = Logger("driver", logging_cfg, capture_exceptions=True)
Expand Down Expand Up @@ -823,7 +836,7 @@ def sample_mode(cfg, sample_type: str) -> None:
# Deploy graph for sampling
deployed_graph = falcon.DeployedGraph(
graph,
model_path=cfg.paths.get("import"),
import_dirs=path_cfg["imports"],
log_config=logging_cfg,
)

Expand All @@ -832,15 +845,15 @@ def sample_mode(cfg, sample_type: str) -> None:
sample_refs = deployed_graph.sample(num_samples)

elif sample_type == "posterior":
deployed_graph.load(Path(cfg.paths.graph))
deployed_graph.load(Path(path_cfg["graph"]))
sample_refs = deployed_graph.sample_posterior(num_samples, observations)

elif sample_type == "proposal":
deployed_graph.load(Path(cfg.paths.graph))
deployed_graph.load(Path(path_cfg["graph"]))
sample_refs = deployed_graph.sample_proposal(num_samples, observations)

elif sample_type == "ppd":
deployed_graph.load(Path(cfg.paths.graph))
deployed_graph.load(Path(path_cfg["graph"]))
sample_refs = deployed_graph.sample_ppd(num_samples, observations)

else:
Expand Down Expand Up @@ -924,7 +937,7 @@ def parse_args():
elif arg.startswith("--refresh="):
refresh = float(arg.split("=", 1)[1])
i += 1
return mode, None, None, None, None, False, 16, address, refresh
return mode, None, None, None, None, False, 16, True, None, address, refresh

sample_type = None
if mode == "sample":
Expand Down
25 changes: 12 additions & 13 deletions falcon/core/deployed_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,10 @@ def _ray_options(actor_config):

@ray.remote
class MultiplexNodeWrapper:
def __init__(self, actor_config, node, graph, num_actors, model_path=None, log_config=None):
def __init__(self, actor_config, node, graph, num_actors, import_dirs=None, log_config=None):
self.num_actors = num_actors
self.wrapped_node_list = [
NodeWrapper.options(**_ray_options(actor_config)).remote(node, graph, model_path, log_config)
NodeWrapper.options(**_ray_options(actor_config)).remote(node, graph, import_dirs, log_config)
for _ in range(self.num_actors)
]

Expand Down Expand Up @@ -120,7 +120,7 @@ def get_output_log_tail(self, num_lines: int = 50) -> list:
# actors — sampling reads best_model which is independent of training state.
@ray.remote
class NodeWrapper:
def __init__(self, node, graph, model_path=None, log_config=None):
def __init__(self, node, graph, import_dirs=None, log_config=None):
# Suppress Ray warning about blocking ray.get in async actor.
# Ray emits this once per actor via a global flag. We set the flag
# to True before any ray.get calls to prevent the warning.
Expand All @@ -134,11 +134,10 @@ def __init__(self, node, graph, model_path=None, log_config=None):
except (ImportError, AttributeError):
pass # Ray internals changed, warning will appear

# Add model_path to sys.path if provided
if model_path:
model_path = Path(model_path).resolve()
if str(model_path) not in sys.path:
sys.path.insert(0, str(model_path))
for p in (import_dirs or []):
resolved = str(Path(p).resolve())
if resolved not in sys.path:
sys.path.insert(0, resolved)

self.node = node
self.name = node.name
Expand Down Expand Up @@ -446,14 +445,14 @@ def shutdown(self):


class DeployedGraph:
def __init__(self, graph, model_path=None, log_config=None):
def __init__(self, graph, import_dirs=None, log_config=None):
"""Initialize a DeployedGraph with the given conceptual graph of nodes.

Note: This class uses falcon.info(), falcon.warning() etc. for logging.
These functions use the module-level logger set by cli.py via set_logger().
"""
self.graph = graph
self.model_path = model_path
self.import_dirs = import_dirs or []
self.log_config = log_config or {}
self.wrapped_nodes_dict = {}
self.monitor_bridge = None
Expand All @@ -464,7 +463,7 @@ def __init__(self, graph, model_path=None, log_config=None):
def _create_monitor_bridge(self):
"""Create the MonitorBridge actor for falcon monitor TUI."""
from falcon.core.monitor_bridge import MonitorBridge
run_dir = str(self.model_path) if self.model_path else "unknown"
run_dir = str(self.import_dirs[0]) if self.import_dirs else "unknown"
try:
# Name the actor so falcon monitor can discover it
self.monitor_bridge = MonitorBridge.options(
Expand Down Expand Up @@ -532,13 +531,13 @@ def deploy_nodes(self):
node,
self.graph,
node.num_actors,
self.model_path,
self.import_dirs,
self.log_config,
)
else:
self.wrapped_nodes_dict[node.name] = NodeWrapper.options(
**_ray_options(node.actor_config)
).remote(node, self.graph, self.model_path, self.log_config)
).remote(node, self.graph, self.import_dirs, self.log_config)

# Wait for all actors to initialize and register with monitor bridge
for name, actor in self.wrapped_nodes_dict.items():
Expand Down
12 changes: 11 additions & 1 deletion falcon/core/raystore.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,23 @@
from dataclasses import dataclass
from enum import IntEnum
from pathlib import Path
from typing import Optional
from typing import List, Optional

import ray
from omegaconf import MISSING
from falcon.core.logger import Logger, set_logger, log, info, warning, error


@dataclass
class PathConfig:
"""Configuration for file-system paths."""

graph: str = MISSING
samples: Optional[str] = None
buffer: Optional[str] = None
imports: Optional[List[str]] = None # directories prepended to sys.path in Ray workers


@dataclass
class BufferConfig:
"""Configuration for the rolling sample buffer."""
Expand Down
Loading