Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 62 additions & 56 deletions backend/main/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@
from zipfile import ZipFile
import re
import logging
from typing import Any

import numpy as np
from plotly.io import to_json

import pandas as pd
Expand Down Expand Up @@ -42,6 +40,7 @@
get_displayed_steps,
parameters_from_post,
sanitize_name,
_dataframe_as_datagrid_rows,
)
from backend.protzilla.all_steps import get_all_possible_steps

Expand Down Expand Up @@ -683,51 +682,6 @@ def get_step_plots(request):
)


# TODO: Move somewhere else
def _step_output_as_serialised_table(
label: str, _data: pd.DataFrame | Any, index_delims: tuple[int, int] = (None, None)
) -> list[dict]:
"""
Returns the output data of a step as a list of dicts in "records" orientaion, like this:
[{'col1': 1, 'col2': 0.5}, {'col1': 2, 'col2': 0.75}]
Also delimits the return according to index_delims.
If the output could not be serialised, None is returned

:param label: The label of the step output to serialise
:param _data: The data associated with the output
:param index_delims: tuple used as slice begin and end indices to delimit the output
"""
start_index = index_delims[0]
end_index = index_delims[1]

# Note: using [None:None] as a slice returns the entire collection
if isinstance(_data, pd.DataFrame):
data = _data.iloc[start_index:end_index].copy()

# Safer than just adding the new column. We assume __id_col is not
# a column name anyone would use
if "id" in data.columns:
data.rename(columns={"id": "__id_col"}, inplace=True)

data["id"] = data.index
cleaned_data = data.replace(np.nan, None)
return cleaned_data.to_dict(orient="records")

# Serialise compatible lists
# TODO #49 this should be refactored to be stored somewhere and not be calculated on every call (can take a few seconds)
# Potential fix: Just do not use lists bro???
elif (
("_df" not in label) and (label not in hidden_outputs) and (type(_data) == list)
):
data = pd.DataFrame({label: _data[start_index:end_index]})
data["id"] = data.index
cleaned_data = data.replace(np.nan, None)
return cleaned_data.to_dict(orient="records")

else:
return None


def get_png_from_step(request: HttpRequest):
"""
API call. Returns a base64-encoded PNG of a step output to the front-end
Expand Down Expand Up @@ -760,8 +714,8 @@ def get_png_from_step(request: HttpRequest):

def get_current_step_table_data(request):
"""
API call. Returns a specific delimited slice of data from a specified table
of the current step's outputs.
API call. Returns a specific delimited and optionally filtered and/or sorted slice of data
from a specified table of the current step's outputs.
"""
if request.method != "POST":
return JsonResponse(
Expand All @@ -774,9 +728,17 @@ def get_current_step_table_data(request):
table_label = data.get("table_label")
start_index = data.get("start_index")
end_index = data.get("end_index")
index_delims = (start_index, end_index)
sort_field = data.get("sort_field")
sort_direction = data.get("sort_direction", "asc")
filters_raw = data.get("filters", "[]")
filters = json.loads(filters_raw)

response = {"success": False, "message": None, "rows": None, "total_row_count": 0}
response: dict[str, object | None] = {
"success": False,
"message": None,
"rows": None,
"total_row_count": 0,
}

run = Run(run_name)

Expand All @@ -789,9 +751,55 @@ def get_current_step_table_data(request):
response["message"] = "Requested step output not found"
return JsonResponse(response, status=404)

serialised_output = _step_output_as_serialised_table(
table_label, step_output, index_delims
)
# Serialise compatible lists
# TODO #49 this should be refactored to be stored somewhere and not be calculated on every call (can take a few seconds)
# Potential fix: Do not use lists?
if (
("_df" not in table_label)
and (table_label not in hidden_outputs)
and (type(step_output) == list)
):
step_output = pd.DataFrame({table_label: step_output})

if isinstance(step_output, pd.DataFrame):
for f in filters:
field = f.get("field")
operator = f.get("operator")
value = f.get("value")

if not field or value is None:
continue

col = step_output[field]
if operator == "contains":
step_output = step_output[
col.astype(str).str.contains(str(value), case=False, na=False)
]
elif operator == "equals":
step_output = step_output[
col.astype(str).str.lower() == str(value).lower()
]
elif operator == "=":
step_output = step_output[col == float(value)]
elif operator == ">":
step_output = step_output[col > float(value)]
elif operator == "<":
step_output = step_output[col < float(value)]

if sort_field:
step_output = step_output.sort_values(
by=sort_field,
ascending=(sort_direction == "asc"),
na_position="last",
)

response["total_row_count"] = len(step_output)

paginated_output = step_output.iloc[start_index:end_index]

serialised_output = _dataframe_as_datagrid_rows(paginated_output)
else:
serialised_output = None

if serialised_output is None:
response["rows"] = [
Expand All @@ -803,8 +811,6 @@ def get_current_step_table_data(request):
response["success"] = True
response["rows"] = serialised_output

response["total_row_count"] = len(step_output)

return JsonResponse(response)


Expand Down
27 changes: 27 additions & 0 deletions backend/main/views_helper.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import re
from pathlib import Path
from typing import Any

import numpy as np
import pandas as pd

from backend.protzilla.constants.paths import SETTINGS_PATH
from backend.protzilla.disk_operator import YamlOperator
Expand Down Expand Up @@ -176,3 +178,28 @@ def load_yaml_from_file(path: Path) -> str:
raise FileNotFoundError(f"File {path} does not exist.")
with path.open("r") as f:
return f.read()


def _dataframe_as_datagrid_rows(_data: pd.DataFrame) -> list[dict] | None:
"""
Converts dataframes from step outputs into a DataGrid-compatible row format for the frontend.
Returns the output data of a step as a list of dicts in "records" orientaion, like this:
[{'col1': 1, 'col2': 0.5}, {'col1': 2, 'col2': 0.75}]
If the output could not be serialised, None is returned.
An id column based on index will be added.

:param _data: The data associated with the output
"""
if isinstance(_data, pd.DataFrame):
data = _data.copy()

# Safer than just adding the new column. We assume __id_col is not
# a column name anyone would use
if "id" in data.columns:
data.rename(columns={"id": "__id_col"}, inplace=True)

data["id"] = data.index
cleaned_data = data.replace(np.nan, None)
return cleaned_data.to_dict(orient="records")
else:
return None
8 changes: 0 additions & 8 deletions frontend/src/components/app/run-screen/run-screen.tsx
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import { Navbar, NodeEditor, PlotDownloadSettings } from "@protzilla/app";
import {
CSVButton,
DataTable,
FlexColumn,
FlexRow,
Expand Down Expand Up @@ -67,12 +66,6 @@ const StyledContentDiv = styled.div`
flex-direction: column;
`;

const StyledCSVButton = styled(CSVButton)`
width: auto;
align-self: flex-end;
margin-top: ${spacing("buttonGap")};
`;

const FooterText = styled.div`
text-align: center;
padding: ${spacing("small")};
Expand Down Expand Up @@ -281,7 +274,6 @@ export const RunScreen: React.FC = () => {
const singleTableComponent = (tableLabel: string) => (
<StyledContentDiv>
<DataTable runName={runName} tableLabel={tableLabel} />
<StyledCSVButton runName={runName} tableLabel={tableLabel} fileName={tableLabel} />
</StyledContentDiv>
);

Expand Down
101 changes: 76 additions & 25 deletions frontend/src/components/core/data-table/data-table.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,30 @@ import { Box } from "@mui/material";
import { ThemeProvider } from "@mui/material/styles";
import {
DataGrid,
getGridNumericOperators,
getGridStringOperators,
GridColDef,
GridColumnVisibilityModel,
GridFilterModel,
GridFooterContainer,
GridFooterContainerProps,
GridPagination,
GridPaginationModel,
GridSortModel,
} from "@mui/x-data-grid";
import { baseTheme, getMuiTheme } from "@protzilla/theme";
import { baseTheme, getMuiTheme, spacing } from "@protzilla/theme";
import { callApiWithParameters, TableRecord } from "@protzilla/utils";
import React, { useEffect, useMemo, useState } from "react";
import React, { useEffect, useMemo, useRef, useState } from "react";
import { styled } from "styled-components";

import { DataTableProps } from "./data-table.props";
import { CSVButton } from "../shared";

const StyledCSVButton = styled(CSVButton)`
width: auto;
align-self: flex-end;
margin-top: ${spacing("buttonGap")};
`;

export const CustomFooter: React.FC<GridFooterContainerProps> = () => {
return (
Expand All @@ -37,6 +49,14 @@ const FALLBACK_TOO_MANY_COLUMNS = [
];
const MAX_COLUMNS = 25;

const stringOperators = getGridStringOperators().filter(
(op) => op.value === "contains" || op.value === "equals",
);

const numericOperators = getGridNumericOperators().filter(
(op) => op.value === "=" || op.value === ">" || op.value === "<",
);

export const DataTable: React.FC<DataTableProps> = ({
runName,
tableLabel,
Expand All @@ -53,6 +73,20 @@ export const DataTable: React.FC<DataTableProps> = ({
const [currentRows, setCurrentRows] = useState<TableRecord[]>([]);
const [totalRowCount, setTotalRowCount] = useState(0);
const [isLoading, setLoading] = useState(false);
const [sortModel, setSortModel] = useState<GridSortModel>([]);
const [filterModel, setFilterModel] = useState<GridFilterModel>({
items: [],
});
const [columns, setColumns] = useState<GridColDef[]>([]);
const columnsInitializedRef = useRef(false);

// necessary for updating which columns exist when switching between tables
useEffect(() => {
setColumns([]);
setFilterModel({ items: [] });
setSortModel([]);
columnsInitializedRef.current = false;
}, [tableLabel]);

// Fetch data when pagination changes
useEffect(() => {
Expand All @@ -68,8 +102,34 @@ export const DataTable: React.FC<DataTableProps> = ({
table_label: tableLabel,
start_index: startIndex,
end_index: endIndex,
sort_field: sortModel[0]?.field,
sort_direction: sortModel[0]?.sort ?? "asc",
filters: JSON.stringify(filterModel.items),
});

if (response.rows.length > 0 && !columnsInitializedRef.current) {
const generatedColumns = Object.keys(response.rows[0]).map((key) => {
const isNumeric = response.rows.every(
(row: TableRecord) => typeof row[key] === "number" || row[key] === null,
);

return {
field: key,
headerName: key,
flex: 1,
type: isNumeric ? "number" : "string",
align: "left",
headerAlign: "left",
filterable: true,
filterOperators: isNumeric ? numericOperators : stringOperators,
valueFormatter: (value: unknown) => value ?? "NaN",
} as GridColDef;
});

setColumns(generatedColumns);
columnsInitializedRef.current = true;
}

if (response.rows.length > 0 && Object.keys(response.rows[0]).length > MAX_COLUMNS) {
setCurrentRows(FALLBACK_TOO_MANY_COLUMNS);
setTotalRowCount(FALLBACK_TOO_MANY_COLUMNS.length);
Expand All @@ -85,27 +145,7 @@ export const DataTable: React.FC<DataTableProps> = ({
};

void fetchData();
}, [paginationModel, tableLabel, runName]);

const columns = useMemo(() => {
if (currentRows.length === 0) return [];

return Object.keys(currentRows[0]).map((key) => {
const isNumeric = currentRows.every(
(row) => typeof row[key] === "number" || row[key] === null,
);
return {
field: key,
headerName: key,
flex: 1,
type: isNumeric ? "number" : "string",
align: "left",
headerAlign: "left",
// eslint-disable-next-line @typescript-eslint/no-unnecessary-condition
valueFormatter: (value) => value ?? "NaN",
} as GridColDef;
});
}, [currentRows]);
}, [paginationModel, sortModel, filterModel, tableLabel, runName]);

const theme = useMemo(() => getMuiTheme(), []);
const height = parseInt(baseTheme.sizes.tableRow, 10);
Expand All @@ -123,6 +163,12 @@ export const DataTable: React.FC<DataTableProps> = ({
paginationModel={paginationModel}
onPaginationModelChange={setPaginationModel}
pageSizeOptions={pageSizeOptions}
sortingMode="server"
sortModel={sortModel}
onSortModelChange={setSortModel}
filterMode="server"
filterModel={filterModel}
onFilterModelChange={setFilterModel}
sx={{
width: "100%",
height: "100%",
Expand All @@ -132,8 +178,13 @@ export const DataTable: React.FC<DataTableProps> = ({
slots={{
footer: CustomFooter,
}}
disableColumnSorting
disableColumnFilter
/>
<StyledCSVButton
runName={runName}
tableLabel={tableLabel}
fileName={tableLabel}
sortModel={sortModel}
filterModel={filterModel}
/>
</ThemeProvider>
);
Expand Down
Loading
Loading