diff --git a/backend/main/views.py b/backend/main/views.py
index fc9f0ec20..a7138fcad 100644
--- a/backend/main/views.py
+++ b/backend/main/views.py
@@ -4,9 +4,7 @@
from zipfile import ZipFile
import re
import logging
-from typing import Any
-import numpy as np
from plotly.io import to_json
import pandas as pd
@@ -42,6 +40,7 @@
get_displayed_steps,
parameters_from_post,
sanitize_name,
+ _dataframe_as_datagrid_rows,
)
from backend.protzilla.all_steps import get_all_possible_steps
@@ -683,51 +682,6 @@ def get_step_plots(request):
)
-# TODO: Move somewhere else
-def _step_output_as_serialised_table(
- label: str, _data: pd.DataFrame | Any, index_delims: tuple[int, int] = (None, None)
-) -> list[dict]:
- """
- Returns the output data of a step as a list of dicts in "records" orientaion, like this:
- [{'col1': 1, 'col2': 0.5}, {'col1': 2, 'col2': 0.75}]
- Also delimits the return according to index_delims.
- If the output could not be serialised, None is returned
-
- :param label: The label of the step output to serialise
- :param _data: The data associated with the output
- :param index_delims: tuple used as slice begin and end indices to delimit the output
- """
- start_index = index_delims[0]
- end_index = index_delims[1]
-
- # Note: using [None:None] as a slice returns the entire collection
- if isinstance(_data, pd.DataFrame):
- data = _data.iloc[start_index:end_index].copy()
-
- # Safer than just adding the new column. We assume __id_col is not
- # a column name anyone would use
- if "id" in data.columns:
- data.rename(columns={"id": "__id_col"}, inplace=True)
-
- data["id"] = data.index
- cleaned_data = data.replace(np.nan, None)
- return cleaned_data.to_dict(orient="records")
-
- # Serialise compatible lists
- # TODO #49 this should be refactored to be stored somewhere and not be calculated on every call (can take a few seconds)
- # Potential fix: Just do not use lists bro???
- elif (
- ("_df" not in label) and (label not in hidden_outputs) and (type(_data) == list)
- ):
- data = pd.DataFrame({label: _data[start_index:end_index]})
- data["id"] = data.index
- cleaned_data = data.replace(np.nan, None)
- return cleaned_data.to_dict(orient="records")
-
- else:
- return None
-
-
def get_png_from_step(request: HttpRequest):
"""
API call. Returns a base64-encoded PNG of a step output to the front-end
@@ -760,8 +714,8 @@ def get_png_from_step(request: HttpRequest):
def get_current_step_table_data(request):
"""
- API call. Returns a specific delimited slice of data from a specified table
- of the current step's outputs.
+ API call. Returns a specific delimited and optionally filtered and/or sorted slice of data
+ from a specified table of the current step's outputs.
"""
if request.method != "POST":
return JsonResponse(
@@ -774,9 +728,17 @@ def get_current_step_table_data(request):
table_label = data.get("table_label")
start_index = data.get("start_index")
end_index = data.get("end_index")
- index_delims = (start_index, end_index)
+ sort_field = data.get("sort_field")
+ sort_direction = data.get("sort_direction", "asc")
+ filters_raw = data.get("filters", "[]")
+ filters = json.loads(filters_raw)
- response = {"success": False, "message": None, "rows": None, "total_row_count": 0}
+ response: dict[str, object | None] = {
+ "success": False,
+ "message": None,
+ "rows": None,
+ "total_row_count": 0,
+ }
run = Run(run_name)
@@ -789,9 +751,55 @@ def get_current_step_table_data(request):
response["message"] = "Requested step output not found"
return JsonResponse(response, status=404)
- serialised_output = _step_output_as_serialised_table(
- table_label, step_output, index_delims
- )
+ # Serialise compatible lists
+ # TODO #49 this should be refactored to be stored somewhere and not be calculated on every call (can take a few seconds)
+ # Potential fix: Do not use lists?
+ if (
+ ("_df" not in table_label)
+ and (table_label not in hidden_outputs)
+ and (type(step_output) == list)
+ ):
+ step_output = pd.DataFrame({table_label: step_output})
+
+ if isinstance(step_output, pd.DataFrame):
+ for f in filters:
+ field = f.get("field")
+ operator = f.get("operator")
+ value = f.get("value")
+
+ if not field or value is None:
+ continue
+
+ col = step_output[field]
+ if operator == "contains":
+ step_output = step_output[
+ col.astype(str).str.contains(str(value), case=False, na=False)
+ ]
+ elif operator == "equals":
+ step_output = step_output[
+ col.astype(str).str.lower() == str(value).lower()
+ ]
+ elif operator == "=":
+ step_output = step_output[col == float(value)]
+ elif operator == ">":
+ step_output = step_output[col > float(value)]
+ elif operator == "<":
+ step_output = step_output[col < float(value)]
+
+ if sort_field:
+ step_output = step_output.sort_values(
+ by=sort_field,
+ ascending=(sort_direction == "asc"),
+ na_position="last",
+ )
+
+ response["total_row_count"] = len(step_output)
+
+ paginated_output = step_output.iloc[start_index:end_index]
+
+ serialised_output = _dataframe_as_datagrid_rows(paginated_output)
+ else:
+ serialised_output = None
if serialised_output is None:
response["rows"] = [
@@ -803,8 +811,6 @@ def get_current_step_table_data(request):
response["success"] = True
response["rows"] = serialised_output
- response["total_row_count"] = len(step_output)
-
return JsonResponse(response)
diff --git a/backend/main/views_helper.py b/backend/main/views_helper.py
index d8fa417c1..d33202cb3 100644
--- a/backend/main/views_helper.py
+++ b/backend/main/views_helper.py
@@ -1,7 +1,9 @@
import re
from pathlib import Path
+from typing import Any
import numpy as np
+import pandas as pd
from backend.protzilla.constants.paths import SETTINGS_PATH
from backend.protzilla.disk_operator import YamlOperator
@@ -176,3 +178,28 @@ def load_yaml_from_file(path: Path) -> str:
raise FileNotFoundError(f"File {path} does not exist.")
with path.open("r") as f:
return f.read()
+
+
+def _dataframe_as_datagrid_rows(_data: pd.DataFrame) -> list[dict] | None:
+ """
+ Converts dataframes from step outputs into a DataGrid-compatible row format for the frontend.
+ Returns the output data of a step as a list of dicts in "records" orientaion, like this:
+ [{'col1': 1, 'col2': 0.5}, {'col1': 2, 'col2': 0.75}]
+ If the output could not be serialised, None is returned.
+ An id column based on index will be added.
+
+ :param _data: The data associated with the output
+ """
+ if isinstance(_data, pd.DataFrame):
+ data = _data.copy()
+
+ # Safer than just adding the new column. We assume __id_col is not
+ # a column name anyone would use
+ if "id" in data.columns:
+ data.rename(columns={"id": "__id_col"}, inplace=True)
+
+ data["id"] = data.index
+ cleaned_data = data.replace(np.nan, None)
+ return cleaned_data.to_dict(orient="records")
+ else:
+ return None
diff --git a/frontend/src/components/app/run-screen/run-screen.tsx b/frontend/src/components/app/run-screen/run-screen.tsx
index 3d747c1c9..649f947d8 100644
--- a/frontend/src/components/app/run-screen/run-screen.tsx
+++ b/frontend/src/components/app/run-screen/run-screen.tsx
@@ -1,6 +1,5 @@
import { Navbar, NodeEditor, PlotDownloadSettings } from "@protzilla/app";
import {
- CSVButton,
DataTable,
FlexColumn,
FlexRow,
@@ -67,12 +66,6 @@ const StyledContentDiv = styled.div`
flex-direction: column;
`;
-const StyledCSVButton = styled(CSVButton)`
- width: auto;
- align-self: flex-end;
- margin-top: ${spacing("buttonGap")};
-`;
-
const FooterText = styled.div`
text-align: center;
padding: ${spacing("small")};
@@ -281,7 +274,6 @@ export const RunScreen: React.FC = () => {
const singleTableComponent = (tableLabel: string) => (
-
);
diff --git a/frontend/src/components/core/data-table/data-table.tsx b/frontend/src/components/core/data-table/data-table.tsx
index 4e70d523f..619188862 100644
--- a/frontend/src/components/core/data-table/data-table.tsx
+++ b/frontend/src/components/core/data-table/data-table.tsx
@@ -2,18 +2,30 @@ import { Box } from "@mui/material";
import { ThemeProvider } from "@mui/material/styles";
import {
DataGrid,
+ getGridNumericOperators,
+ getGridStringOperators,
GridColDef,
GridColumnVisibilityModel,
+ GridFilterModel,
GridFooterContainer,
GridFooterContainerProps,
GridPagination,
GridPaginationModel,
+ GridSortModel,
} from "@mui/x-data-grid";
-import { baseTheme, getMuiTheme } from "@protzilla/theme";
+import { baseTheme, getMuiTheme, spacing } from "@protzilla/theme";
import { callApiWithParameters, TableRecord } from "@protzilla/utils";
-import React, { useEffect, useMemo, useState } from "react";
+import React, { useEffect, useMemo, useRef, useState } from "react";
+import { styled } from "styled-components";
import { DataTableProps } from "./data-table.props";
+import { CSVButton } from "../shared";
+
+const StyledCSVButton = styled(CSVButton)`
+ width: auto;
+ align-self: flex-end;
+ margin-top: ${spacing("buttonGap")};
+`;
export const CustomFooter: React.FC = () => {
return (
@@ -37,6 +49,14 @@ const FALLBACK_TOO_MANY_COLUMNS = [
];
const MAX_COLUMNS = 25;
+const stringOperators = getGridStringOperators().filter(
+ (op) => op.value === "contains" || op.value === "equals",
+);
+
+const numericOperators = getGridNumericOperators().filter(
+ (op) => op.value === "=" || op.value === ">" || op.value === "<",
+);
+
export const DataTable: React.FC = ({
runName,
tableLabel,
@@ -53,6 +73,20 @@ export const DataTable: React.FC = ({
const [currentRows, setCurrentRows] = useState([]);
const [totalRowCount, setTotalRowCount] = useState(0);
const [isLoading, setLoading] = useState(false);
+ const [sortModel, setSortModel] = useState([]);
+ const [filterModel, setFilterModel] = useState({
+ items: [],
+ });
+ const [columns, setColumns] = useState([]);
+ const columnsInitializedRef = useRef(false);
+
+ // necessary for updating which columns exist when switching between tables
+ useEffect(() => {
+ setColumns([]);
+ setFilterModel({ items: [] });
+ setSortModel([]);
+ columnsInitializedRef.current = false;
+ }, [tableLabel]);
// Fetch data when pagination changes
useEffect(() => {
@@ -68,8 +102,34 @@ export const DataTable: React.FC = ({
table_label: tableLabel,
start_index: startIndex,
end_index: endIndex,
+ sort_field: sortModel[0]?.field,
+ sort_direction: sortModel[0]?.sort ?? "asc",
+ filters: JSON.stringify(filterModel.items),
});
+ if (response.rows.length > 0 && !columnsInitializedRef.current) {
+ const generatedColumns = Object.keys(response.rows[0]).map((key) => {
+ const isNumeric = response.rows.every(
+ (row: TableRecord) => typeof row[key] === "number" || row[key] === null,
+ );
+
+ return {
+ field: key,
+ headerName: key,
+ flex: 1,
+ type: isNumeric ? "number" : "string",
+ align: "left",
+ headerAlign: "left",
+ filterable: true,
+ filterOperators: isNumeric ? numericOperators : stringOperators,
+ valueFormatter: (value: unknown) => value ?? "NaN",
+ } as GridColDef;
+ });
+
+ setColumns(generatedColumns);
+ columnsInitializedRef.current = true;
+ }
+
if (response.rows.length > 0 && Object.keys(response.rows[0]).length > MAX_COLUMNS) {
setCurrentRows(FALLBACK_TOO_MANY_COLUMNS);
setTotalRowCount(FALLBACK_TOO_MANY_COLUMNS.length);
@@ -85,27 +145,7 @@ export const DataTable: React.FC = ({
};
void fetchData();
- }, [paginationModel, tableLabel, runName]);
-
- const columns = useMemo(() => {
- if (currentRows.length === 0) return [];
-
- return Object.keys(currentRows[0]).map((key) => {
- const isNumeric = currentRows.every(
- (row) => typeof row[key] === "number" || row[key] === null,
- );
- return {
- field: key,
- headerName: key,
- flex: 1,
- type: isNumeric ? "number" : "string",
- align: "left",
- headerAlign: "left",
- // eslint-disable-next-line @typescript-eslint/no-unnecessary-condition
- valueFormatter: (value) => value ?? "NaN",
- } as GridColDef;
- });
- }, [currentRows]);
+ }, [paginationModel, sortModel, filterModel, tableLabel, runName]);
const theme = useMemo(() => getMuiTheme(), []);
const height = parseInt(baseTheme.sizes.tableRow, 10);
@@ -123,6 +163,12 @@ export const DataTable: React.FC = ({
paginationModel={paginationModel}
onPaginationModelChange={setPaginationModel}
pageSizeOptions={pageSizeOptions}
+ sortingMode="server"
+ sortModel={sortModel}
+ onSortModelChange={setSortModel}
+ filterMode="server"
+ filterModel={filterModel}
+ onFilterModelChange={setFilterModel}
sx={{
width: "100%",
height: "100%",
@@ -132,8 +178,13 @@ export const DataTable: React.FC = ({
slots={{
footer: CustomFooter,
}}
- disableColumnSorting
- disableColumnFilter
+ />
+
);
diff --git a/frontend/src/components/core/shared/button/button.props.ts b/frontend/src/components/core/shared/button/button.props.ts
index 33876bd95..b343f9f6c 100644
--- a/frontend/src/components/core/shared/button/button.props.ts
+++ b/frontend/src/components/core/shared/button/button.props.ts
@@ -1,3 +1,4 @@
+import { GridFilterModel, GridSortModel } from "@mui/x-data-grid";
import type { Color } from "@protzilla/theme";
import type { UIStateProps } from "@protzilla/utils";
import type React from "react";
@@ -74,4 +75,14 @@ export interface CSVButtonProps extends ButtonProps {
runName: string;
tableLabel: string;
fileName?: string;
+ sortModel: GridSortModel;
+ filterModel: GridFilterModel;
+}
+
+type TableValue = string | number | boolean | null | undefined | object;
+
+type TableRow = Record;
+
+export interface TableDataResponse {
+ rows: TableRow[];
}
diff --git a/frontend/src/components/core/shared/button/button.tsx b/frontend/src/components/core/shared/button/button.tsx
index 1442fd77a..3cd8945c4 100644
--- a/frontend/src/components/core/shared/button/button.tsx
+++ b/frontend/src/components/core/shared/button/button.tsx
@@ -14,6 +14,7 @@ import {
ButtonRef,
CSVButtonProps,
StatusButtonProps,
+ TableDataResponse,
ToggleableButtonProps,
} from "./button.props";
@@ -557,19 +558,27 @@ export const CSVButton: React.FC = ({
runName,
tableLabel,
fileName = "data.csv",
+ sortModel,
+ filterModel,
...params
}) => {
const [isLoading, setLoading] = useState(false);
const downloadCSV = async () => {
setLoading(true);
- let fetchedRows: any[] = [];
+ let fetchedRows: TableDataResponse["rows"] = [];
try {
- const response = await callApiWithParameters("get_current_step_table_data/", {
- run_name: runName,
- table_label: tableLabel,
- });
+ const response: TableDataResponse = await callApiWithParameters(
+ "get_current_step_table_data/",
+ {
+ run_name: runName,
+ table_label: tableLabel,
+ sort_field: sortModel[0]?.field,
+ sort_direction: sortModel[0]?.sort ?? "asc",
+ filters: JSON.stringify(filterModel.items),
+ },
+ );
fetchedRows = response.rows;
} catch (error) {
console.error("Failed to fetch table data:", error);
@@ -584,7 +593,6 @@ export const CSVButton: React.FC = ({
header
.map((key) => {
const value = row[key];
- // eslint-disable-next-line @typescript-eslint/no-unnecessary-condition
if (value == null) return "NaN";
// Value will be explicitly converted via String()
const stringified = typeof value === "object" ? JSON.stringify(value) : String(value);