Skip to content

Commit 63d0912

Browse files
- Updated the monitoring dashboard with a more modern look and refactored for better maintainbility.
1 parent dd03c61 commit 63d0912

7 files changed

Lines changed: 1074 additions & 308 deletions

File tree

scripts/monitor.py

Lines changed: 45 additions & 294 deletions
Original file line numberDiff line numberDiff line change
@@ -1,303 +1,54 @@
1-
import dash
2-
from dash import dcc, html
3-
from dash.dependencies import Input, Output
4-
import plotly.graph_objs as go
5-
import pandas as pd
1+
#!/usr/bin/env python
2+
"""
3+
Real-time Training Monitor Entrypoint for SpatialTranscriptFormer.
4+
"""
65
import argparse
7-
import os
6+
import logging
7+
from spatial_transcript_former.dashboard.app import init_app, app
8+
9+
10+
def parse_args():
11+
parser = argparse.ArgumentParser(description="Real-time Training Monitor")
12+
parser.add_argument(
13+
"--run-dir",
14+
type=str,
15+
default=None,
16+
help="Path to the experiment run directory containing training_logs.sqlite",
17+
)
18+
parser.add_argument(
19+
"--runs-dir",
20+
type=str,
21+
default=None,
22+
help="Path to a directory containing MULTIPLE experiment run directories for comparison",
23+
)
24+
parser.add_argument(
25+
"--port", type=int, default=8050, help="Port to host the dashboard on"
26+
)
27+
parser.add_argument(
28+
"--interval",
29+
type=int,
30+
default=5000,
31+
help="Log polling interval in milliseconds",
32+
)
33+
args = parser.parse_args()
34+
if not args.run_dir and not args.runs_dir:
35+
parser.error("Must provide either --run-dir or --runs-dir")
36+
return args
837

9-
import flask
10-
import glob
11-
12-
# Set up argument parsing for the target run directory
13-
parser = argparse.ArgumentParser(description="Real-time Training Monitor")
14-
parser.add_argument(
15-
"--run-dir",
16-
type=str,
17-
required=True,
18-
help="Path to the experiment run directory containing training_log.csv",
19-
)
20-
parser.add_argument(
21-
"--port", type=int, default=8050, help="Port to host the dashboard on"
22-
)
23-
parser.add_argument(
24-
"--interval", type=int, default=5000, help="Log polling interval in milliseconds"
25-
)
26-
args = parser.parse_args()
27-
28-
log_path = os.path.join(args.run_dir, "training_log.csv")
29-
30-
# Configure Flask to serve images from the run directory
31-
server = flask.Flask(__name__)
32-
33-
34-
@server.route("/images/<path:filename>")
35-
def serve_image(filename):
36-
return flask.send_from_directory(os.path.abspath(args.run_dir), filename)
37-
38-
39-
# Initialize Dash application
40-
app = dash.Dash(
41-
__name__,
42-
server=server,
43-
title=f"Training Monitor - {os.path.basename(args.run_dir)}",
44-
)
45-
46-
# Define application layout
47-
app.layout = html.Div(
48-
style={"fontFamily": "sans-serif", "padding": "20px"},
49-
children=[
50-
html.H1(f"Real-time Training Monitor: {os.path.basename(args.run_dir)}"),
51-
html.Div(
52-
id="last-updated",
53-
style={"color": "gray", "fontStyle": "italic", "marginBottom": "10px"},
54-
),
55-
html.Button(
56-
"Pause Updates",
57-
id="pause-button",
58-
n_clicks=0,
59-
style={
60-
"marginBottom": "20px",
61-
"padding": "10px",
62-
"fontSize": "16px",
63-
"cursor": "pointer",
64-
"backgroundColor": "#f0f0f0",
65-
"border": "1px solid #ccc",
66-
"borderRadius": "5px",
67-
},
68-
),
69-
html.Div(
70-
[
71-
html.Label(
72-
"Smoothing Window (Epochs):",
73-
style={"fontWeight": "bold", "marginRight": "10px"},
74-
),
75-
dcc.Slider(
76-
id="smoothing-slider",
77-
min=1,
78-
max=50,
79-
step=1,
80-
value=1,
81-
marks={1: "1 (None)", 10: "10", 25: "25", 50: "50"},
82-
),
83-
],
84-
style={
85-
"marginBottom": "20px",
86-
"padding": "10px",
87-
"backgroundColor": "#f9f9f9",
88-
"borderRadius": "5px",
89-
},
90-
),
91-
html.Div(
92-
[
93-
# Row 1: Losses + Correlation
94-
html.Div(
95-
[dcc.Graph(id="live-loss-graph", animate=False)],
96-
style={
97-
"width": "48%",
98-
"display": "inline-block",
99-
"verticalAlign": "top",
100-
},
101-
),
102-
html.Div(
103-
[dcc.Graph(id="live-pcc-graph", animate=False)],
104-
style={
105-
"width": "48%",
106-
"display": "inline-block",
107-
"verticalAlign": "top",
108-
},
109-
),
110-
# Row 2: Variance + Learning Rate
111-
html.Div(
112-
[dcc.Graph(id="live-variance-graph", animate=False)],
113-
style={
114-
"width": "48%",
115-
"display": "inline-block",
116-
"verticalAlign": "top",
117-
},
118-
),
119-
html.Div(
120-
[dcc.Graph(id="live-lr-graph", animate=False)],
121-
style={
122-
"width": "48%",
123-
"display": "inline-block",
124-
"verticalAlign": "top",
125-
},
126-
),
127-
]
128-
),
129-
html.Div(
130-
[
131-
html.H3("Latest Inference Plot (Truth vs Pred)"),
132-
html.Div(id="image-container"),
133-
],
134-
style={
135-
"marginTop": "40px",
136-
"textAlign": "center",
137-
"backgroundColor": "#1a1a2e",
138-
"padding": "20px",
139-
"borderRadius": "10px",
140-
},
141-
),
142-
# Hidden interval component for polling
143-
dcc.Interval(
144-
id="interval-component",
145-
interval=args.interval, # in milliseconds
146-
n_intervals=0,
147-
disabled=False,
148-
),
149-
],
150-
)
151-
152-
153-
@app.callback(
154-
Output("image-container", "children"), [Input("interval-component", "n_intervals")]
155-
)
156-
def update_image(n):
157-
search_pattern = os.path.join(args.run_dir, "*.png")
158-
list_of_files = glob.glob(search_pattern)
159-
if not list_of_files:
160-
return html.P(
161-
"No inference plots found yet. Make sure to run training with --plot-pathways.",
162-
style={"color": "red"},
163-
)
164-
165-
# Get the newest file
166-
latest_file = max(list_of_files, key=os.path.getmtime)
167-
filename = os.path.basename(latest_file)
168-
169-
# Force reload by appending modifying timestamp query
170-
mtime = os.path.getmtime(latest_file)
171-
url = f"/images/{filename}?t={mtime}"
172-
173-
return html.Img(src=url, style={"maxWidth": "100%", "height": "auto"})
174-
175-
176-
def _make_traces(df, cols, smoothing_window):
177-
"""Create Plotly traces for the given columns with optional smoothing."""
178-
traces = []
179-
for col in cols:
180-
if col not in df.columns:
181-
continue
182-
y_data = df[col].dropna()
183-
epochs = df.loc[y_data.index, "epoch"]
184-
if smoothing_window and smoothing_window > 1:
185-
y_data = y_data.rolling(window=smoothing_window, min_periods=1).mean()
186-
traces.append(go.Scatter(x=epochs, y=y_data, mode="lines", name=col))
187-
return traces
188-
189-
190-
@app.callback(
191-
[
192-
Output("live-loss-graph", "figure"),
193-
Output("live-pcc-graph", "figure"),
194-
Output("live-variance-graph", "figure"),
195-
Output("live-lr-graph", "figure"),
196-
Output("last-updated", "children"),
197-
],
198-
[Input("interval-component", "n_intervals"), Input("smoothing-slider", "value")],
199-
)
200-
def update_graphs(n, smoothing_window):
201-
empty = dash.no_update
202-
if not os.path.exists(log_path):
203-
return empty, empty, empty, empty, "Waiting for training_log.csv..."
204-
205-
try:
206-
df = pd.read_csv(log_path)
207-
except Exception as e:
208-
return empty, empty, empty, empty, f"Error reading log: {e}"
209-
210-
if df.empty or "epoch" not in df.columns:
211-
return empty, empty, empty, empty, "Log empty or missing 'epoch'."
212-
213-
margin = dict(l=40, r=40, t=40, b=40)
214-
215-
# Chart 1: Losses (log scale)
216-
loss_cols = [c for c in df.columns if "loss" in c.lower()]
217-
loss_fig = {
218-
"data": _make_traces(df, loss_cols, smoothing_window),
219-
"layout": go.Layout(
220-
title="Loss",
221-
xaxis=dict(title="Epoch"),
222-
yaxis=dict(title="Loss", type="log"),
223-
margin=margin,
224-
),
225-
}
226-
227-
# Chart 2: Correlation (PCC, MAE)
228-
corr_cols = [c for c in ["val_pcc", "val_mae"] if c in df.columns]
229-
pcc_fig = {
230-
"data": _make_traces(df, corr_cols, smoothing_window),
231-
"layout": go.Layout(
232-
title="Correlation & Error",
233-
xaxis=dict(title="Epoch"),
234-
yaxis=dict(title="Score"),
235-
margin=margin,
236-
),
237-
}
238-
239-
# Chart 3: Prediction Variance
240-
var_cols = [c for c in ["pred_variance"] if c in df.columns]
241-
var_fig = {
242-
"data": _make_traces(df, var_cols, smoothing_window),
243-
"layout": go.Layout(
244-
title="Prediction Variance (collapse detector)",
245-
xaxis=dict(title="Epoch"),
246-
yaxis=dict(title="Variance", type="log"),
247-
margin=margin,
248-
),
249-
}
250-
251-
# Chart 4: Learning Rate
252-
lr_cols = [c for c in ["lr"] if c in df.columns]
253-
lr_fig = {
254-
"data": _make_traces(df, lr_cols, smoothing_window),
255-
"layout": go.Layout(
256-
title="Learning Rate Schedule",
257-
xaxis=dict(title="Epoch"),
258-
yaxis=dict(title="LR", type="log"),
259-
margin=margin,
260-
),
261-
}
262-
263-
last_epoch = df["epoch"].iloc[-1]
264-
update_text = f"Last updated: Epoch {last_epoch} (Polled automatically)"
265-
266-
return loss_fig, pcc_fig, var_fig, lr_fig, update_text
26738

39+
if __name__ == "__main__":
40+
args = parse_args()
26841

269-
@app.callback(
270-
[
271-
Output("interval-component", "disabled"),
272-
Output("pause-button", "children"),
273-
Output("pause-button", "style"),
274-
],
275-
[Input("pause-button", "n_clicks")],
276-
)
277-
def toggle_pause(n_clicks):
278-
base_style = {
279-
"marginBottom": "20px",
280-
"padding": "10px",
281-
"fontSize": "16px",
282-
"cursor": "pointer",
283-
"borderRadius": "5px",
284-
"border": "1px solid #ccc",
285-
}
286-
if n_clicks % 2 == 1:
287-
# Paused state
288-
active_style = {
289-
**base_style,
290-
"backgroundColor": "#ffcccc",
291-
"borderColor": "#ff0000",
292-
}
293-
return True, "Resume Updates", active_style
294-
# Active state
295-
active_style = {**base_style, "backgroundColor": "#f0f0f0"}
296-
return False, "Pause Updates", active_style
42+
# Initialize the dash app
43+
init_app(args)
29744

45+
if getattr(args, "runs_dir", None):
46+
print(f"Tracking multiple runs in: {args.runs_dir}")
47+
else:
48+
print(f"Tracking single run at: {args.run_dir}")
29849

299-
if __name__ == "__main__":
300-
print(f"Tracking log at: {log_path}")
30150
print(f"Starting dashboard on http://127.0.0.1:{args.port}/")
51+
52+
# Run the server
30253
# Turn off debug to prevent double-reloading the data parser during polling
30354
app.run(debug=False, port=args.port)
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
"""
2+
Dashboard package for SpatialTranscriptFormer model monitoring.
3+
"""
4+
5+
from .app import app, server
6+
7+
__all__ = ["app", "server"]

0 commit comments

Comments
 (0)