diff --git a/backend/main/views.py b/backend/main/views.py index fc9f0ec20..a7138fcad 100644 --- a/backend/main/views.py +++ b/backend/main/views.py @@ -4,9 +4,7 @@ from zipfile import ZipFile import re import logging -from typing import Any -import numpy as np from plotly.io import to_json import pandas as pd @@ -42,6 +40,7 @@ get_displayed_steps, parameters_from_post, sanitize_name, + _dataframe_as_datagrid_rows, ) from backend.protzilla.all_steps import get_all_possible_steps @@ -683,51 +682,6 @@ def get_step_plots(request): ) -# TODO: Move somewhere else -def _step_output_as_serialised_table( - label: str, _data: pd.DataFrame | Any, index_delims: tuple[int, int] = (None, None) -) -> list[dict]: - """ - Returns the output data of a step as a list of dicts in "records" orientaion, like this: - [{'col1': 1, 'col2': 0.5}, {'col1': 2, 'col2': 0.75}] - Also delimits the return according to index_delims. - If the output could not be serialised, None is returned - - :param label: The label of the step output to serialise - :param _data: The data associated with the output - :param index_delims: tuple used as slice begin and end indices to delimit the output - """ - start_index = index_delims[0] - end_index = index_delims[1] - - # Note: using [None:None] as a slice returns the entire collection - if isinstance(_data, pd.DataFrame): - data = _data.iloc[start_index:end_index].copy() - - # Safer than just adding the new column. We assume __id_col is not - # a column name anyone would use - if "id" in data.columns: - data.rename(columns={"id": "__id_col"}, inplace=True) - - data["id"] = data.index - cleaned_data = data.replace(np.nan, None) - return cleaned_data.to_dict(orient="records") - - # Serialise compatible lists - # TODO #49 this should be refactored to be stored somewhere and not be calculated on every call (can take a few seconds) - # Potential fix: Just do not use lists bro??? - elif ( - ("_df" not in label) and (label not in hidden_outputs) and (type(_data) == list) - ): - data = pd.DataFrame({label: _data[start_index:end_index]}) - data["id"] = data.index - cleaned_data = data.replace(np.nan, None) - return cleaned_data.to_dict(orient="records") - - else: - return None - - def get_png_from_step(request: HttpRequest): """ API call. Returns a base64-encoded PNG of a step output to the front-end @@ -760,8 +714,8 @@ def get_png_from_step(request: HttpRequest): def get_current_step_table_data(request): """ - API call. Returns a specific delimited slice of data from a specified table - of the current step's outputs. + API call. Returns a specific delimited and optionally filtered and/or sorted slice of data + from a specified table of the current step's outputs. """ if request.method != "POST": return JsonResponse( @@ -774,9 +728,17 @@ def get_current_step_table_data(request): table_label = data.get("table_label") start_index = data.get("start_index") end_index = data.get("end_index") - index_delims = (start_index, end_index) + sort_field = data.get("sort_field") + sort_direction = data.get("sort_direction", "asc") + filters_raw = data.get("filters", "[]") + filters = json.loads(filters_raw) - response = {"success": False, "message": None, "rows": None, "total_row_count": 0} + response: dict[str, object | None] = { + "success": False, + "message": None, + "rows": None, + "total_row_count": 0, + } run = Run(run_name) @@ -789,9 +751,55 @@ def get_current_step_table_data(request): response["message"] = "Requested step output not found" return JsonResponse(response, status=404) - serialised_output = _step_output_as_serialised_table( - table_label, step_output, index_delims - ) + # Serialise compatible lists + # TODO #49 this should be refactored to be stored somewhere and not be calculated on every call (can take a few seconds) + # Potential fix: Do not use lists? + if ( + ("_df" not in table_label) + and (table_label not in hidden_outputs) + and (type(step_output) == list) + ): + step_output = pd.DataFrame({table_label: step_output}) + + if isinstance(step_output, pd.DataFrame): + for f in filters: + field = f.get("field") + operator = f.get("operator") + value = f.get("value") + + if not field or value is None: + continue + + col = step_output[field] + if operator == "contains": + step_output = step_output[ + col.astype(str).str.contains(str(value), case=False, na=False) + ] + elif operator == "equals": + step_output = step_output[ + col.astype(str).str.lower() == str(value).lower() + ] + elif operator == "=": + step_output = step_output[col == float(value)] + elif operator == ">": + step_output = step_output[col > float(value)] + elif operator == "<": + step_output = step_output[col < float(value)] + + if sort_field: + step_output = step_output.sort_values( + by=sort_field, + ascending=(sort_direction == "asc"), + na_position="last", + ) + + response["total_row_count"] = len(step_output) + + paginated_output = step_output.iloc[start_index:end_index] + + serialised_output = _dataframe_as_datagrid_rows(paginated_output) + else: + serialised_output = None if serialised_output is None: response["rows"] = [ @@ -803,8 +811,6 @@ def get_current_step_table_data(request): response["success"] = True response["rows"] = serialised_output - response["total_row_count"] = len(step_output) - return JsonResponse(response) diff --git a/backend/main/views_helper.py b/backend/main/views_helper.py index d8fa417c1..d33202cb3 100644 --- a/backend/main/views_helper.py +++ b/backend/main/views_helper.py @@ -1,7 +1,9 @@ import re from pathlib import Path +from typing import Any import numpy as np +import pandas as pd from backend.protzilla.constants.paths import SETTINGS_PATH from backend.protzilla.disk_operator import YamlOperator @@ -176,3 +178,28 @@ def load_yaml_from_file(path: Path) -> str: raise FileNotFoundError(f"File {path} does not exist.") with path.open("r") as f: return f.read() + + +def _dataframe_as_datagrid_rows(_data: pd.DataFrame) -> list[dict] | None: + """ + Converts dataframes from step outputs into a DataGrid-compatible row format for the frontend. + Returns the output data of a step as a list of dicts in "records" orientaion, like this: + [{'col1': 1, 'col2': 0.5}, {'col1': 2, 'col2': 0.75}] + If the output could not be serialised, None is returned. + An id column based on index will be added. + + :param _data: The data associated with the output + """ + if isinstance(_data, pd.DataFrame): + data = _data.copy() + + # Safer than just adding the new column. We assume __id_col is not + # a column name anyone would use + if "id" in data.columns: + data.rename(columns={"id": "__id_col"}, inplace=True) + + data["id"] = data.index + cleaned_data = data.replace(np.nan, None) + return cleaned_data.to_dict(orient="records") + else: + return None diff --git a/frontend/src/components/app/run-screen/run-screen.tsx b/frontend/src/components/app/run-screen/run-screen.tsx index 3d747c1c9..649f947d8 100644 --- a/frontend/src/components/app/run-screen/run-screen.tsx +++ b/frontend/src/components/app/run-screen/run-screen.tsx @@ -1,6 +1,5 @@ import { Navbar, NodeEditor, PlotDownloadSettings } from "@protzilla/app"; import { - CSVButton, DataTable, FlexColumn, FlexRow, @@ -67,12 +66,6 @@ const StyledContentDiv = styled.div` flex-direction: column; `; -const StyledCSVButton = styled(CSVButton)` - width: auto; - align-self: flex-end; - margin-top: ${spacing("buttonGap")}; -`; - const FooterText = styled.div` text-align: center; padding: ${spacing("small")}; @@ -281,7 +274,6 @@ export const RunScreen: React.FC = () => { const singleTableComponent = (tableLabel: string) => ( - ); diff --git a/frontend/src/components/core/data-table/data-table.tsx b/frontend/src/components/core/data-table/data-table.tsx index 4e70d523f..619188862 100644 --- a/frontend/src/components/core/data-table/data-table.tsx +++ b/frontend/src/components/core/data-table/data-table.tsx @@ -2,18 +2,30 @@ import { Box } from "@mui/material"; import { ThemeProvider } from "@mui/material/styles"; import { DataGrid, + getGridNumericOperators, + getGridStringOperators, GridColDef, GridColumnVisibilityModel, + GridFilterModel, GridFooterContainer, GridFooterContainerProps, GridPagination, GridPaginationModel, + GridSortModel, } from "@mui/x-data-grid"; -import { baseTheme, getMuiTheme } from "@protzilla/theme"; +import { baseTheme, getMuiTheme, spacing } from "@protzilla/theme"; import { callApiWithParameters, TableRecord } from "@protzilla/utils"; -import React, { useEffect, useMemo, useState } from "react"; +import React, { useEffect, useMemo, useRef, useState } from "react"; +import { styled } from "styled-components"; import { DataTableProps } from "./data-table.props"; +import { CSVButton } from "../shared"; + +const StyledCSVButton = styled(CSVButton)` + width: auto; + align-self: flex-end; + margin-top: ${spacing("buttonGap")}; +`; export const CustomFooter: React.FC = () => { return ( @@ -37,6 +49,14 @@ const FALLBACK_TOO_MANY_COLUMNS = [ ]; const MAX_COLUMNS = 25; +const stringOperators = getGridStringOperators().filter( + (op) => op.value === "contains" || op.value === "equals", +); + +const numericOperators = getGridNumericOperators().filter( + (op) => op.value === "=" || op.value === ">" || op.value === "<", +); + export const DataTable: React.FC = ({ runName, tableLabel, @@ -53,6 +73,20 @@ export const DataTable: React.FC = ({ const [currentRows, setCurrentRows] = useState([]); const [totalRowCount, setTotalRowCount] = useState(0); const [isLoading, setLoading] = useState(false); + const [sortModel, setSortModel] = useState([]); + const [filterModel, setFilterModel] = useState({ + items: [], + }); + const [columns, setColumns] = useState([]); + const columnsInitializedRef = useRef(false); + + // necessary for updating which columns exist when switching between tables + useEffect(() => { + setColumns([]); + setFilterModel({ items: [] }); + setSortModel([]); + columnsInitializedRef.current = false; + }, [tableLabel]); // Fetch data when pagination changes useEffect(() => { @@ -68,8 +102,34 @@ export const DataTable: React.FC = ({ table_label: tableLabel, start_index: startIndex, end_index: endIndex, + sort_field: sortModel[0]?.field, + sort_direction: sortModel[0]?.sort ?? "asc", + filters: JSON.stringify(filterModel.items), }); + if (response.rows.length > 0 && !columnsInitializedRef.current) { + const generatedColumns = Object.keys(response.rows[0]).map((key) => { + const isNumeric = response.rows.every( + (row: TableRecord) => typeof row[key] === "number" || row[key] === null, + ); + + return { + field: key, + headerName: key, + flex: 1, + type: isNumeric ? "number" : "string", + align: "left", + headerAlign: "left", + filterable: true, + filterOperators: isNumeric ? numericOperators : stringOperators, + valueFormatter: (value: unknown) => value ?? "NaN", + } as GridColDef; + }); + + setColumns(generatedColumns); + columnsInitializedRef.current = true; + } + if (response.rows.length > 0 && Object.keys(response.rows[0]).length > MAX_COLUMNS) { setCurrentRows(FALLBACK_TOO_MANY_COLUMNS); setTotalRowCount(FALLBACK_TOO_MANY_COLUMNS.length); @@ -85,27 +145,7 @@ export const DataTable: React.FC = ({ }; void fetchData(); - }, [paginationModel, tableLabel, runName]); - - const columns = useMemo(() => { - if (currentRows.length === 0) return []; - - return Object.keys(currentRows[0]).map((key) => { - const isNumeric = currentRows.every( - (row) => typeof row[key] === "number" || row[key] === null, - ); - return { - field: key, - headerName: key, - flex: 1, - type: isNumeric ? "number" : "string", - align: "left", - headerAlign: "left", - // eslint-disable-next-line @typescript-eslint/no-unnecessary-condition - valueFormatter: (value) => value ?? "NaN", - } as GridColDef; - }); - }, [currentRows]); + }, [paginationModel, sortModel, filterModel, tableLabel, runName]); const theme = useMemo(() => getMuiTheme(), []); const height = parseInt(baseTheme.sizes.tableRow, 10); @@ -123,6 +163,12 @@ export const DataTable: React.FC = ({ paginationModel={paginationModel} onPaginationModelChange={setPaginationModel} pageSizeOptions={pageSizeOptions} + sortingMode="server" + sortModel={sortModel} + onSortModelChange={setSortModel} + filterMode="server" + filterModel={filterModel} + onFilterModelChange={setFilterModel} sx={{ width: "100%", height: "100%", @@ -132,8 +178,13 @@ export const DataTable: React.FC = ({ slots={{ footer: CustomFooter, }} - disableColumnSorting - disableColumnFilter + /> + ); diff --git a/frontend/src/components/core/shared/button/button.props.ts b/frontend/src/components/core/shared/button/button.props.ts index 33876bd95..b343f9f6c 100644 --- a/frontend/src/components/core/shared/button/button.props.ts +++ b/frontend/src/components/core/shared/button/button.props.ts @@ -1,3 +1,4 @@ +import { GridFilterModel, GridSortModel } from "@mui/x-data-grid"; import type { Color } from "@protzilla/theme"; import type { UIStateProps } from "@protzilla/utils"; import type React from "react"; @@ -74,4 +75,14 @@ export interface CSVButtonProps extends ButtonProps { runName: string; tableLabel: string; fileName?: string; + sortModel: GridSortModel; + filterModel: GridFilterModel; +} + +type TableValue = string | number | boolean | null | undefined | object; + +type TableRow = Record; + +export interface TableDataResponse { + rows: TableRow[]; } diff --git a/frontend/src/components/core/shared/button/button.tsx b/frontend/src/components/core/shared/button/button.tsx index 1442fd77a..3cd8945c4 100644 --- a/frontend/src/components/core/shared/button/button.tsx +++ b/frontend/src/components/core/shared/button/button.tsx @@ -14,6 +14,7 @@ import { ButtonRef, CSVButtonProps, StatusButtonProps, + TableDataResponse, ToggleableButtonProps, } from "./button.props"; @@ -557,19 +558,27 @@ export const CSVButton: React.FC = ({ runName, tableLabel, fileName = "data.csv", + sortModel, + filterModel, ...params }) => { const [isLoading, setLoading] = useState(false); const downloadCSV = async () => { setLoading(true); - let fetchedRows: any[] = []; + let fetchedRows: TableDataResponse["rows"] = []; try { - const response = await callApiWithParameters("get_current_step_table_data/", { - run_name: runName, - table_label: tableLabel, - }); + const response: TableDataResponse = await callApiWithParameters( + "get_current_step_table_data/", + { + run_name: runName, + table_label: tableLabel, + sort_field: sortModel[0]?.field, + sort_direction: sortModel[0]?.sort ?? "asc", + filters: JSON.stringify(filterModel.items), + }, + ); fetchedRows = response.rows; } catch (error) { console.error("Failed to fetch table data:", error); @@ -584,7 +593,6 @@ export const CSVButton: React.FC = ({ header .map((key) => { const value = row[key]; - // eslint-disable-next-line @typescript-eslint/no-unnecessary-condition if (value == null) return "NaN"; // Value will be explicitly converted via String() const stringified = typeof value === "object" ? JSON.stringify(value) : String(value);