diff --git a/ml_peg/app/build_app.py b/ml_peg/app/build_app.py index f63c4e4d6..77d671231 100644 --- a/ml_peg/app/build_app.py +++ b/ml_peg/app/build_app.py @@ -5,10 +5,11 @@ from importlib import import_module import warnings -from dash import Dash, Input, Output, callback +from dash import Dash, Input, Output, callback, ctx, no_update from dash.dash_table import DataTable -from dash.dcc import Store, Tab, Tabs -from dash.html import H1, H3, Div +from dash.dcc import Checklist, Store, Tab, Tabs +from dash.exceptions import PreventUpdate +from dash.html import H1, H3, Button, Details, Div, Summary from yaml import safe_load from ml_peg.analysis.utils.utils import calc_table_scores, get_table_style @@ -276,17 +277,112 @@ def build_tabs( all_tabs = [Tab(label="Summary", value="summary-tab", id="summary-tab")] + [ Tab(label=category_name, value=category_name) for category_name in layouts ] + model_options = [{"label": model, "value": model} for model in MODELS] tabs_layout = [ Div( [ H1("ML-PEG"), Tabs(id="all-tabs", value="summary-tab", children=all_tabs), + Details( + [ + Summary("Visible models"), + Div( + [ + H3("Visible models"), + Checklist( + id="model-filter-checklist", + options=model_options, + value=MODELS, + inline=True, + style={ + "display": "flex", + "flexWrap": "wrap", + "gap": "6px", + }, + labelStyle={ + "display": "flex", + "alignItems": "center", + "gap": "6px", + "border": "1px solid #ced4da", + "borderRadius": "6px", + "padding": "4px 10px", + "margin": "2px", + }, + ), + Div( + [ + Button( + "Select all", + id="model-filter-select-all", + n_clicks=0, + style={ + "fontSize": "12px", + "padding": "4px 10px", + "backgroundColor": "#0d6efd", + "color": "#fff", + "border": "none", + "borderRadius": "4px", + "cursor": "pointer", + }, + ), + Button( + "Clear", + id="model-filter-clear-all", + n_clicks=0, + style={ + "fontSize": "12px", + "padding": "4px 10px", + "backgroundColor": "#6c757d", + "color": "#fff", + "border": "none", + "borderRadius": "4px", + "cursor": "pointer", + }, + ), + ], + style={ + "display": "flex", + "gap": "8px", + "flexWrap": "wrap", + "marginTop": "8px", + }, + ), + ], + style={ + "padding": "12px", + "border": "1px solid #dee2e6", + "borderRadius": "6px", + "background": "#f8f9fa", + }, + ), + ], + id="model-filter-details", + open=True, + style={ + "marginTop": "16px", + "marginBottom": "16px", + "padding": "0 8px 8px 8px", + "border": "1px solid #dee2e6", + "borderRadius": "6px", + "background": "#fff", + }, + ), Div(id="tabs-content"), ], style={"flex": "1", "marginBottom": "40px"}, ), build_footer(), + Store( + id="selected-models-store", + storage_type="session", + data=MODELS, + ), + Store( + id="summary-table-computed-store", + storage_type="session", + data=summary_table.data, + ), ] full_app.layout = Div( @@ -323,6 +419,79 @@ def select_tab(tab) -> Div: ) return Div([layouts[tab]]) + @callback( + Output("model-filter-checklist", "value"), + Output("selected-models-store", "data"), + Input("model-filter-checklist", "value"), + Input("model-filter-select-all", "n_clicks"), + Input("model-filter-clear-all", "n_clicks"), + Input("selected-models-store", "data"), + prevent_initial_call=False, + ) + def sync_model_filter( + checklist_value: list[str] | None, + select_all_clicks: int, + clear_clicks: int, + stored_selection: list[str] | None, + ) -> tuple[list[str], list[str] | object]: + """ + Keep the model selector and backing store synchronised. + + Parameters + ---------- + checklist_value + Current selection from the checklist. + select_all_clicks + Number of clicks on the Select all button. + clear_clicks + Number of clicks on the Clear button. + stored_selection + Previously persisted models pulled from ``selected-models-store``. + + Returns + ------- + tuple[list[str], list[str] | object] + Updated checklist value and store contents. + """ + trigger_id = ctx.triggered_id + stored_value = stored_selection if stored_selection is not None else MODELS + + if trigger_id in (None, "selected-models-store"): + return stored_value, no_update + + if trigger_id == "model-filter-select-all": + return MODELS, MODELS + + if trigger_id == "model-filter-clear-all": + return [], [] + + if trigger_id == "model-filter-checklist": + selected = checklist_value or [] + return selected, selected + + raise PreventUpdate + + @callback( + Output("model-filter-details", "open"), + Input("all-tabs", "value"), + prevent_initial_call=False, + ) + def toggle_filter_panel(tab: str) -> bool: + """ + Keep the visible-models panel expanded on the summary tab only. + + Parameters + ---------- + tab + Currently selected tab identifier. + + Returns + ------- + bool + ``True`` when the summary tab is selected, otherwise ``False``. + """ + return tab == "summary-tab" + def build_full_app(full_app: Dash, category: str = "*") -> None: """ diff --git a/ml_peg/app/utils/register_callbacks.py b/ml_peg/app/utils/register_callbacks.py index 2c8cbbda4..b2b63a9b6 100644 --- a/ml_peg/app/utils/register_callbacks.py +++ b/ml_peg/app/utils/register_callbacks.py @@ -17,6 +17,7 @@ from ml_peg.app.utils.utils import ( Thresholds, clean_thresholds, + filter_rows_by_models, format_metric_columns, format_tooltip_headers, get_scores, @@ -29,18 +30,21 @@ def register_summary_table_callbacks() -> None: @callback( Output("summary-table", "data"), Output("summary-table", "style_data_conditional"), + Output("summary-table-computed-store", "data"), Input("all-tabs", "value"), Input("summary-table-weight-store", "data"), + Input("selected-models-store", "data"), State("summary-table-scores-store", "data"), - State("summary-table", "data"), + State("summary-table-computed-store", "data"), prevent_initial_call=False, ) def update_summary_table( tabs_value: str, - stored_weights: dict[str, float], - stored_scores: dict[str, dict[str, float]], - summary_data: list[dict], - ) -> list[dict]: + stored_weights: dict[str, float] | None, + selected_models: list[str] | None, + stored_scores: dict[str, dict[str, float]] | None, + summary_data: list[dict] | None, + ) -> tuple[list[dict], list[dict], list[dict]]: """ Update summary table when scores/weights change, and sync on tab change. @@ -50,24 +54,37 @@ def update_summary_table( Value of selected tab. Parameter unused, but required to register Input. stored_weights Stored summary weights dictionary. + selected_models + Currently selected MLIPs from the global model filter. stored_scores Stored scores for table scores. summary_data - Data from summary table to be updated. + Latest computed summary table rows. Returns ------- - list[dict] - Updated summary table data. + tuple[list[dict], list[dict], list[dict]] + Filtered table data, style, and cached unfiltered rows. """ + summary_rows = deepcopy(summary_data) if summary_data else [] + if not summary_rows: + raise PreventUpdate + # Update table from stored scores if stored_scores: - for row in summary_data: + for row in summary_rows: for tab, values in stored_scores.items(): - row[tab] = values[row["MLIP"]] + row[tab] = values.get(row["MLIP"]) + + full_rows, _ = update_score_style(summary_rows, stored_weights) + filtered_rows = filter_rows_by_models(full_rows, selected_models) + if filtered_rows: + filtered_scores = calc_metric_scores(filtered_rows) + style = get_table_style(filtered_rows, scored_data=filtered_scores) + else: + style = [] - # Update table contents - return update_score_style(summary_data, stored_weights) + return filtered_rows, style, full_rows def register_category_table_callbacks( @@ -98,6 +115,7 @@ def register_category_table_callbacks( Input(f"{table_id}-weight-store", "data"), Input(f"{table_id}-thresholds-store", "data"), Input("all-tabs", "value"), + Input("selected-models-store", "data"), Input(f"{table_id}-normalized-toggle", "value"), State(f"{table_id}-raw-data-store", "data"), State(f"{table_id}-computed-store", "data"), @@ -109,6 +127,7 @@ def update_benchmark_table_scores( stored_weights: dict[str, float] | None, stored_threshold: dict | None, _tabs_value: str, + selected_models: list[str] | None, toggle_value: list[str] | None, stored_raw_data: list[dict] | None, stored_computed_data: list[dict] | None, @@ -149,6 +168,7 @@ def update_benchmark_table_scores( # Tab switches and toggle flips reuse the cached scored rows rather than # recalculating scores, we only re-score when weights/thresholds change. + raw_rows_output = stored_raw_data if ( trigger_id in ("all-tabs", f"{table_id}-normalized-toggle") and stored_computed_data @@ -157,38 +177,50 @@ def update_benchmark_table_scores( stored_raw_data, stored_computed_data, thresholds, toggle_value ) scored_rows = calc_metric_scores(stored_raw_data, thresholds=thresholds) - style = get_table_style(display_rows, scored_data=scored_rows) columns = format_metric_columns( current_columns, thresholds, show_normalized ) tooltips = format_tooltip_headers( raw_tooltips, thresholds, show_normalized ) - return ( - display_rows, - style, - columns, - tooltips, - stored_computed_data, - stored_raw_data, + else: + # Update overall table score for new weights and thresholds + metrics_data = calc_table_scores( + stored_raw_data, stored_weights, thresholds + ) + raw_rows_output = metrics_data + # Update stored scores per metric + scored_rows = calc_metric_scores(stored_raw_data, thresholds) + # Select between unitful and unitless data + display_rows = get_scores( + metrics_data, scored_rows, thresholds, toggle_value + ) + columns = format_metric_columns( + current_columns, thresholds, show_normalized + ) + tooltips = format_tooltip_headers( + raw_tooltips, thresholds, show_normalized ) - # Update overall table score for new weights and thresholds - metrics_data = calc_table_scores( - stored_raw_data, stored_weights, thresholds + filtered_rows = filter_rows_by_models(display_rows, selected_models) + filtered_scores = ( + filter_rows_by_models(scored_rows, selected_models) + if scored_rows + else [] ) - # Update stored scores per metric - scored_rows = calc_metric_scores(stored_raw_data, thresholds) - # Select between unitful and unitless data - display_rows = get_scores( - metrics_data, scored_rows, thresholds, toggle_value + style = ( + get_table_style(filtered_rows, scored_data=filtered_scores) + if filtered_rows + else [] ) - style = get_table_style(display_rows, scored_data=scored_rows) - columns = format_metric_columns( - current_columns, thresholds, show_normalized + return ( + filtered_rows, + style, + columns, + tooltips, + scored_rows, + raw_rows_output, ) - tooltips = format_tooltip_headers(raw_tooltips, thresholds, show_normalized) - return display_rows, style, columns, tooltips, scored_rows, metrics_data else: @@ -198,6 +230,7 @@ def update_benchmark_table_scores( Output(f"{table_id}-computed-store", "data", allow_duplicate=True), Input(f"{table_id}-weight-store", "data"), Input("all-tabs", "value"), + Input("selected-models-store", "data"), State(table_id, "data"), State(f"{table_id}-computed-store", "data"), prevent_initial_call="initial_duplicate", @@ -205,31 +238,46 @@ def update_benchmark_table_scores( def update_table_scores( stored_weights: dict[str, float] | None, _tabs_value: str, + selected_models: list[str] | None, table_data: list[dict] | None, computed_store: list[dict] | None, ) -> tuple[list[dict], list[dict], list[dict]]: trigger_id = ctx.triggered_id - if trigger_id == "all-tabs" and computed_store: - # When returning to the tab we show the last scored rows instantly. - style = get_table_style(computed_store) - return computed_store, style, computed_store - - if not table_data: + source_rows = computed_store or table_data + if not source_rows: raise PreventUpdate - scored_rows, style = update_score_style(table_data, stored_weights) - return scored_rows, style, scored_rows + if trigger_id == "all-tabs" and computed_store: + # When returning to the tab we show the last scored rows instantly. + filtered_rows = filter_rows_by_models(computed_store, selected_models) + if filtered_rows: + filtered_scores = calc_metric_scores(filtered_rows) + style = get_table_style(filtered_rows, scored_data=filtered_scores) + else: + style = [] + return filtered_rows, style, computed_store + + scored_rows, _ = update_score_style(source_rows, stored_weights) + filtered_rows = filter_rows_by_models(scored_rows, selected_models) + if filtered_rows: + filtered_scores = calc_metric_scores(filtered_rows) + style = get_table_style(filtered_rows, scored_data=filtered_scores) + else: + style = [] + return filtered_rows, style, scored_rows @callback( Output("summary-table-scores-store", "data", allow_duplicate=True), Input(table_id, "data"), State("summary-table-scores-store", "data"), + State(f"{table_id}-computed-store", "data"), prevent_initial_call="initial_duplicate", ) def update_scores_store( table_data: list[dict], scores_data: dict[str, dict[str, float]], + computed_rows: list[dict] | None, ) -> dict[str, dict[str, float]]: """ Update stored scores values when weights update. @@ -240,6 +288,8 @@ def update_scores_store( Data from `table_id` to be updated. scores_data Dictionary of scores for each tab. + computed_rows + Cached unfiltered rows for the category summary. Returns ------- @@ -250,12 +300,16 @@ def update_scores_store( if not table_id.endswith("-summary-table"): return scores_data + source_rows = computed_rows or table_data + if not source_rows: + return scores_data + if not scores_data: scores_data = {} # Update scores store. Category table IDs are of form "[category]-summary-table" # Table headings are of the form "[category] Score" scores_data[table_id.removesuffix("-summary-table") + " Score"] = { - row["MLIP"]: row["Score"] for row in table_data + row["MLIP"]: row["Score"] for row in source_rows if row.get("MLIP") } return scores_data @@ -289,6 +343,7 @@ def register_benchmark_to_category_callback( Output(f"{category_table_id}-computed-store", "data", allow_duplicate=True), Input(f"{benchmark_table_id}-computed-store", "data"), Input("all-tabs", "value"), + Input("selected-models-store", "data"), State(category_table_id, "data"), State(f"{category_table_id}-weight-store", "data"), State(f"{category_table_id}-computed-store", "data"), @@ -297,6 +352,7 @@ def register_benchmark_to_category_callback( def update_category_from_benchmark( benchmark_computed_store: list[dict] | None, _tabs_value: str, + selected_models: list[str] | None, category_data: list[dict] | None, category_weights: dict[str, float] | None, category_computed_store: list[dict] | None, @@ -310,6 +366,8 @@ def update_category_from_benchmark( Latest scored benchmark rows emitted by the benchmark table. _tabs_value Current tab identifier (unused, required to trigger on tab change). + selected_models + Currently selected MLIPs from the global model filter. category_data Existing category table rows shown to the user. category_weights @@ -338,8 +396,14 @@ def update_category_from_benchmark( if mlip in benchmark_scores: row[benchmark_column] = benchmark_scores[mlip] - category_rows, style = update_score_style(category_rows, category_weights) - return category_rows, style, category_rows + category_rows, _ = update_score_style(category_rows, category_weights) + filtered_rows = filter_rows_by_models(category_rows, selected_models) + if filtered_rows: + filtered_scores = calc_metric_scores(filtered_rows) + style = get_table_style(filtered_rows, scored_data=filtered_scores) + else: + style = [] + return filtered_rows, style, category_rows def register_weight_callbacks( diff --git a/ml_peg/app/utils/utils.py b/ml_peg/app/utils/utils.py index 795faa5aa..fa0d4f643 100644 --- a/ml_peg/app/utils/utils.py +++ b/ml_peg/app/utils/utils.py @@ -173,6 +173,37 @@ def clean_weights(raw_weights: dict[str, float] | None) -> dict[str, float]: return weights +def filter_rows_by_models( + rows: list[dict] | None, selected_models: Sequence[str] | None +) -> list[dict]: + """ + Filter table rows by the selected MLIP identifiers. + + Parameters + ---------- + rows + Table rows that include an ``MLIP`` entry. + selected_models + Iterable of model identifiers to keep. ``None`` returns the original rows. + + Returns + ------- + list[dict] + Filtered rows that match ``selected_models`` while preserving order. + """ + if not rows: + return [] + + if not selected_models: + return [] + + selected = {model for model in selected_models if model} + if not selected: + return [] + + return [row for row in rows if row.get("MLIP") in selected] + + def get_scores( raw_rows: list[dict], scored_rows: list[dict],