|
1 | | -import dash |
2 | | -from dash import dcc, html |
3 | | -from dash.dependencies import Input, Output |
4 | | -import plotly.graph_objs as go |
5 | | -import pandas as pd |
| 1 | +#!/usr/bin/env python |
| 2 | +""" |
| 3 | +Real-time Training Monitor Entrypoint for SpatialTranscriptFormer. |
| 4 | +""" |
6 | 5 | import argparse |
7 | | -import os |
| 6 | +import logging |
| 7 | +from spatial_transcript_former.dashboard.app import init_app, app |
| 8 | + |
| 9 | + |
| 10 | +def parse_args(): |
| 11 | + parser = argparse.ArgumentParser(description="Real-time Training Monitor") |
| 12 | + parser.add_argument( |
| 13 | + "--run-dir", |
| 14 | + type=str, |
| 15 | + default=None, |
| 16 | + help="Path to the experiment run directory containing training_logs.sqlite", |
| 17 | + ) |
| 18 | + parser.add_argument( |
| 19 | + "--runs-dir", |
| 20 | + type=str, |
| 21 | + default=None, |
| 22 | + help="Path to a directory containing MULTIPLE experiment run directories for comparison", |
| 23 | + ) |
| 24 | + parser.add_argument( |
| 25 | + "--port", type=int, default=8050, help="Port to host the dashboard on" |
| 26 | + ) |
| 27 | + parser.add_argument( |
| 28 | + "--interval", |
| 29 | + type=int, |
| 30 | + default=5000, |
| 31 | + help="Log polling interval in milliseconds", |
| 32 | + ) |
| 33 | + args = parser.parse_args() |
| 34 | + if not args.run_dir and not args.runs_dir: |
| 35 | + parser.error("Must provide either --run-dir or --runs-dir") |
| 36 | + return args |
8 | 37 |
|
9 | | -import flask |
10 | | -import glob |
11 | | - |
12 | | -# Set up argument parsing for the target run directory |
13 | | -parser = argparse.ArgumentParser(description="Real-time Training Monitor") |
14 | | -parser.add_argument( |
15 | | - "--run-dir", |
16 | | - type=str, |
17 | | - required=True, |
18 | | - help="Path to the experiment run directory containing training_log.csv", |
19 | | -) |
20 | | -parser.add_argument( |
21 | | - "--port", type=int, default=8050, help="Port to host the dashboard on" |
22 | | -) |
23 | | -parser.add_argument( |
24 | | - "--interval", type=int, default=5000, help="Log polling interval in milliseconds" |
25 | | -) |
26 | | -args = parser.parse_args() |
27 | | - |
28 | | -log_path = os.path.join(args.run_dir, "training_log.csv") |
29 | | - |
30 | | -# Configure Flask to serve images from the run directory |
31 | | -server = flask.Flask(__name__) |
32 | | - |
33 | | - |
34 | | -@server.route("/images/<path:filename>") |
35 | | -def serve_image(filename): |
36 | | - return flask.send_from_directory(os.path.abspath(args.run_dir), filename) |
37 | | - |
38 | | - |
39 | | -# Initialize Dash application |
40 | | -app = dash.Dash( |
41 | | - __name__, |
42 | | - server=server, |
43 | | - title=f"Training Monitor - {os.path.basename(args.run_dir)}", |
44 | | -) |
45 | | - |
46 | | -# Define application layout |
47 | | -app.layout = html.Div( |
48 | | - style={"fontFamily": "sans-serif", "padding": "20px"}, |
49 | | - children=[ |
50 | | - html.H1(f"Real-time Training Monitor: {os.path.basename(args.run_dir)}"), |
51 | | - html.Div( |
52 | | - id="last-updated", |
53 | | - style={"color": "gray", "fontStyle": "italic", "marginBottom": "10px"}, |
54 | | - ), |
55 | | - html.Button( |
56 | | - "Pause Updates", |
57 | | - id="pause-button", |
58 | | - n_clicks=0, |
59 | | - style={ |
60 | | - "marginBottom": "20px", |
61 | | - "padding": "10px", |
62 | | - "fontSize": "16px", |
63 | | - "cursor": "pointer", |
64 | | - "backgroundColor": "#f0f0f0", |
65 | | - "border": "1px solid #ccc", |
66 | | - "borderRadius": "5px", |
67 | | - }, |
68 | | - ), |
69 | | - html.Div( |
70 | | - [ |
71 | | - html.Label( |
72 | | - "Smoothing Window (Epochs):", |
73 | | - style={"fontWeight": "bold", "marginRight": "10px"}, |
74 | | - ), |
75 | | - dcc.Slider( |
76 | | - id="smoothing-slider", |
77 | | - min=1, |
78 | | - max=50, |
79 | | - step=1, |
80 | | - value=1, |
81 | | - marks={1: "1 (None)", 10: "10", 25: "25", 50: "50"}, |
82 | | - ), |
83 | | - ], |
84 | | - style={ |
85 | | - "marginBottom": "20px", |
86 | | - "padding": "10px", |
87 | | - "backgroundColor": "#f9f9f9", |
88 | | - "borderRadius": "5px", |
89 | | - }, |
90 | | - ), |
91 | | - html.Div( |
92 | | - [ |
93 | | - # Row 1: Losses + Correlation |
94 | | - html.Div( |
95 | | - [dcc.Graph(id="live-loss-graph", animate=False)], |
96 | | - style={ |
97 | | - "width": "48%", |
98 | | - "display": "inline-block", |
99 | | - "verticalAlign": "top", |
100 | | - }, |
101 | | - ), |
102 | | - html.Div( |
103 | | - [dcc.Graph(id="live-pcc-graph", animate=False)], |
104 | | - style={ |
105 | | - "width": "48%", |
106 | | - "display": "inline-block", |
107 | | - "verticalAlign": "top", |
108 | | - }, |
109 | | - ), |
110 | | - # Row 2: Variance + Learning Rate |
111 | | - html.Div( |
112 | | - [dcc.Graph(id="live-variance-graph", animate=False)], |
113 | | - style={ |
114 | | - "width": "48%", |
115 | | - "display": "inline-block", |
116 | | - "verticalAlign": "top", |
117 | | - }, |
118 | | - ), |
119 | | - html.Div( |
120 | | - [dcc.Graph(id="live-lr-graph", animate=False)], |
121 | | - style={ |
122 | | - "width": "48%", |
123 | | - "display": "inline-block", |
124 | | - "verticalAlign": "top", |
125 | | - }, |
126 | | - ), |
127 | | - ] |
128 | | - ), |
129 | | - html.Div( |
130 | | - [ |
131 | | - html.H3("Latest Inference Plot (Truth vs Pred)"), |
132 | | - html.Div(id="image-container"), |
133 | | - ], |
134 | | - style={ |
135 | | - "marginTop": "40px", |
136 | | - "textAlign": "center", |
137 | | - "backgroundColor": "#1a1a2e", |
138 | | - "padding": "20px", |
139 | | - "borderRadius": "10px", |
140 | | - }, |
141 | | - ), |
142 | | - # Hidden interval component for polling |
143 | | - dcc.Interval( |
144 | | - id="interval-component", |
145 | | - interval=args.interval, # in milliseconds |
146 | | - n_intervals=0, |
147 | | - disabled=False, |
148 | | - ), |
149 | | - ], |
150 | | -) |
151 | | - |
152 | | - |
153 | | -@app.callback( |
154 | | - Output("image-container", "children"), [Input("interval-component", "n_intervals")] |
155 | | -) |
156 | | -def update_image(n): |
157 | | - search_pattern = os.path.join(args.run_dir, "*.png") |
158 | | - list_of_files = glob.glob(search_pattern) |
159 | | - if not list_of_files: |
160 | | - return html.P( |
161 | | - "No inference plots found yet. Make sure to run training with --plot-pathways.", |
162 | | - style={"color": "red"}, |
163 | | - ) |
164 | | - |
165 | | - # Get the newest file |
166 | | - latest_file = max(list_of_files, key=os.path.getmtime) |
167 | | - filename = os.path.basename(latest_file) |
168 | | - |
169 | | - # Force reload by appending modifying timestamp query |
170 | | - mtime = os.path.getmtime(latest_file) |
171 | | - url = f"/images/{filename}?t={mtime}" |
172 | | - |
173 | | - return html.Img(src=url, style={"maxWidth": "100%", "height": "auto"}) |
174 | | - |
175 | | - |
176 | | -def _make_traces(df, cols, smoothing_window): |
177 | | - """Create Plotly traces for the given columns with optional smoothing.""" |
178 | | - traces = [] |
179 | | - for col in cols: |
180 | | - if col not in df.columns: |
181 | | - continue |
182 | | - y_data = df[col].dropna() |
183 | | - epochs = df.loc[y_data.index, "epoch"] |
184 | | - if smoothing_window and smoothing_window > 1: |
185 | | - y_data = y_data.rolling(window=smoothing_window, min_periods=1).mean() |
186 | | - traces.append(go.Scatter(x=epochs, y=y_data, mode="lines", name=col)) |
187 | | - return traces |
188 | | - |
189 | | - |
190 | | -@app.callback( |
191 | | - [ |
192 | | - Output("live-loss-graph", "figure"), |
193 | | - Output("live-pcc-graph", "figure"), |
194 | | - Output("live-variance-graph", "figure"), |
195 | | - Output("live-lr-graph", "figure"), |
196 | | - Output("last-updated", "children"), |
197 | | - ], |
198 | | - [Input("interval-component", "n_intervals"), Input("smoothing-slider", "value")], |
199 | | -) |
200 | | -def update_graphs(n, smoothing_window): |
201 | | - empty = dash.no_update |
202 | | - if not os.path.exists(log_path): |
203 | | - return empty, empty, empty, empty, "Waiting for training_log.csv..." |
204 | | - |
205 | | - try: |
206 | | - df = pd.read_csv(log_path) |
207 | | - except Exception as e: |
208 | | - return empty, empty, empty, empty, f"Error reading log: {e}" |
209 | | - |
210 | | - if df.empty or "epoch" not in df.columns: |
211 | | - return empty, empty, empty, empty, "Log empty or missing 'epoch'." |
212 | | - |
213 | | - margin = dict(l=40, r=40, t=40, b=40) |
214 | | - |
215 | | - # Chart 1: Losses (log scale) |
216 | | - loss_cols = [c for c in df.columns if "loss" in c.lower()] |
217 | | - loss_fig = { |
218 | | - "data": _make_traces(df, loss_cols, smoothing_window), |
219 | | - "layout": go.Layout( |
220 | | - title="Loss", |
221 | | - xaxis=dict(title="Epoch"), |
222 | | - yaxis=dict(title="Loss", type="log"), |
223 | | - margin=margin, |
224 | | - ), |
225 | | - } |
226 | | - |
227 | | - # Chart 2: Correlation (PCC, MAE) |
228 | | - corr_cols = [c for c in ["val_pcc", "val_mae"] if c in df.columns] |
229 | | - pcc_fig = { |
230 | | - "data": _make_traces(df, corr_cols, smoothing_window), |
231 | | - "layout": go.Layout( |
232 | | - title="Correlation & Error", |
233 | | - xaxis=dict(title="Epoch"), |
234 | | - yaxis=dict(title="Score"), |
235 | | - margin=margin, |
236 | | - ), |
237 | | - } |
238 | | - |
239 | | - # Chart 3: Prediction Variance |
240 | | - var_cols = [c for c in ["pred_variance"] if c in df.columns] |
241 | | - var_fig = { |
242 | | - "data": _make_traces(df, var_cols, smoothing_window), |
243 | | - "layout": go.Layout( |
244 | | - title="Prediction Variance (collapse detector)", |
245 | | - xaxis=dict(title="Epoch"), |
246 | | - yaxis=dict(title="Variance", type="log"), |
247 | | - margin=margin, |
248 | | - ), |
249 | | - } |
250 | | - |
251 | | - # Chart 4: Learning Rate |
252 | | - lr_cols = [c for c in ["lr"] if c in df.columns] |
253 | | - lr_fig = { |
254 | | - "data": _make_traces(df, lr_cols, smoothing_window), |
255 | | - "layout": go.Layout( |
256 | | - title="Learning Rate Schedule", |
257 | | - xaxis=dict(title="Epoch"), |
258 | | - yaxis=dict(title="LR", type="log"), |
259 | | - margin=margin, |
260 | | - ), |
261 | | - } |
262 | | - |
263 | | - last_epoch = df["epoch"].iloc[-1] |
264 | | - update_text = f"Last updated: Epoch {last_epoch} (Polled automatically)" |
265 | | - |
266 | | - return loss_fig, pcc_fig, var_fig, lr_fig, update_text |
267 | 38 |
|
| 39 | +if __name__ == "__main__": |
| 40 | + args = parse_args() |
268 | 41 |
|
269 | | -@app.callback( |
270 | | - [ |
271 | | - Output("interval-component", "disabled"), |
272 | | - Output("pause-button", "children"), |
273 | | - Output("pause-button", "style"), |
274 | | - ], |
275 | | - [Input("pause-button", "n_clicks")], |
276 | | -) |
277 | | -def toggle_pause(n_clicks): |
278 | | - base_style = { |
279 | | - "marginBottom": "20px", |
280 | | - "padding": "10px", |
281 | | - "fontSize": "16px", |
282 | | - "cursor": "pointer", |
283 | | - "borderRadius": "5px", |
284 | | - "border": "1px solid #ccc", |
285 | | - } |
286 | | - if n_clicks % 2 == 1: |
287 | | - # Paused state |
288 | | - active_style = { |
289 | | - **base_style, |
290 | | - "backgroundColor": "#ffcccc", |
291 | | - "borderColor": "#ff0000", |
292 | | - } |
293 | | - return True, "Resume Updates", active_style |
294 | | - # Active state |
295 | | - active_style = {**base_style, "backgroundColor": "#f0f0f0"} |
296 | | - return False, "Pause Updates", active_style |
| 42 | + # Initialize the dash app |
| 43 | + init_app(args) |
297 | 44 |
|
| 45 | + if getattr(args, "runs_dir", None): |
| 46 | + print(f"Tracking multiple runs in: {args.runs_dir}") |
| 47 | + else: |
| 48 | + print(f"Tracking single run at: {args.run_dir}") |
298 | 49 |
|
299 | | -if __name__ == "__main__": |
300 | | - print(f"Tracking log at: {log_path}") |
301 | 50 | print(f"Starting dashboard on http://127.0.0.1:{args.port}/") |
| 51 | + |
| 52 | + # Run the server |
302 | 53 | # Turn off debug to prevent double-reloading the data parser during polling |
303 | 54 | app.run(debug=False, port=args.port) |
0 commit comments