Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions server/forms/upload_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,17 @@ class UploadBaseForm(BaseModel):
title="Use bibcode",
description="If enabled, provide a NASA ADS bibcode; otherwise provide manual source metadata.",
)
bibcode: str = Field(default="", title="Bibcode")
pub_name: str = Field(default="", title="Source name")
bibcode: str = Field(default="", title="Bibcode", json_schema_extra={"visible_when": {"has_bibcode": True}})
pub_name: str = Field(default="", title="Source name", json_schema_extra={"visible_when": {"has_bibcode": False}})
pub_authors: list[str] = Field(
default_factory=list,
title="Authors",
description="One author per entry when not using bibcode.",
json_schema_extra={"visible_when": {"has_bibcode": False}},
)
pub_year: int = Field(
default=0, title="Publication year", json_schema_extra={"visible_when": {"has_bibcode": False}}
)
pub_year: int = Field(default=0, title="Publication year")
table_type: str = Field(
default="regular",
title="Table type",
Expand Down
2 changes: 2 additions & 0 deletions server/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from pydantic import ValidationError

from server.history import load_history
from server.schema import process_schema
from server.task_registry import register_all_tasks
from server.tasks import TASKS, get_run, start_task

Expand Down Expand Up @@ -50,6 +51,7 @@ def task_schema(task_id: str) -> dict[str, object]:
task = TASKS[task_id]
schema = task.form_model.model_json_schema()
schema.pop("title", None)
schema = process_schema(schema)
return {"title": task.title, "schema": schema}


Expand Down
104 changes: 104 additions & 0 deletions server/schema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
import copy
from collections import OrderedDict
from typing import Any


def _condition_key(condition: dict[str, Any]) -> tuple[tuple[str, Any], ...]:
return tuple(sorted(condition.items()))


def _extract_const_condition(if_schema: Any) -> dict[str, Any] | None:
if not isinstance(if_schema, dict):
return None
properties = if_schema.get("properties")
if not isinstance(properties, dict) or not properties:
return None

condition: dict[str, Any] = {}
for field_name, field_schema in properties.items():
if not isinstance(field_schema, dict) or "const" not in field_schema:
return None
condition[field_name] = field_schema["const"]
return condition


def _extract_required(schema: Any) -> set[str]:
if not isinstance(schema, dict):
return set()
required = schema.get("required")
if not isinstance(required, list):
return set()
return {item for item in required if isinstance(item, str)}


def process_schema(schema: dict[str, Any]) -> dict[str, Any]:
processed = copy.deepcopy(schema)
properties = processed.get("properties")
if not isinstance(properties, dict):
return processed

required = processed.get("required")
if not isinstance(required, list):
required = []
processed["required"] = required

grouped_fields: OrderedDict[tuple[tuple[str, Any], ...], dict[str, Any]] = OrderedDict()
grouped_conditions: dict[tuple[tuple[str, Any], ...], dict[str, Any]] = {}

for field_name in list(properties.keys()):
field_schema = properties[field_name]
if not isinstance(field_schema, dict):
continue
visible_when = field_schema.pop("visible_when", None)
if not isinstance(visible_when, dict) or not visible_when:
continue

key = _condition_key(visible_when)
grouped_fields.setdefault(key, {})[field_name] = field_schema
grouped_conditions[key] = visible_when
properties.pop(field_name, None)
required[:] = [item for item in required if item != field_name]

if not grouped_fields:
if not required:
processed.pop("required", None)
return processed

conditional_required: dict[tuple[tuple[str, Any], ...], set[str]] = {}
root_if = processed.get("if")
root_then = processed.get("then")
root_else = processed.get("else")
root_condition = _extract_const_condition(root_if)

if root_condition:
conditional_required[_condition_key(root_condition)] = _extract_required(root_then)
if len(root_condition) == 1:
((field_name, value),) = root_condition.items()
if isinstance(value, bool):
else_condition = {field_name: not value}
conditional_required[_condition_key(else_condition)] = _extract_required(root_else)

all_of = processed.get("allOf")
all_of_entries: list[dict[str, Any]] = list(all_of) if isinstance(all_of, list) else []

for key, branch_properties in grouped_fields.items():
condition = grouped_conditions[key]
then_schema: dict[str, Any] = {"properties": branch_properties}
branch_required = conditional_required.get(key, set())
required_in_branch = [name for name in branch_properties if name in branch_required]
if required_in_branch:
then_schema["required"] = required_in_branch

if_schema = {
"properties": {name: {"const": value} for name, value in condition.items()},
"required": list(condition.keys()),
}
all_of_entries.append({"if": if_schema, "then": then_schema})

processed["allOf"] = all_of_entries
processed.pop("if", None)
processed.pop("then", None)
processed.pop("else", None)
if not required:
processed.pop("required", None)
return processed
67 changes: 67 additions & 0 deletions tests/test_schema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
from typing import Any

from server.schema import process_schema


def test_process_schema_moves_visible_when_fields_to_all_of() -> None:
schema: dict[str, Any] = {
"type": "object",
"properties": {
"has_bibcode": {"type": "boolean", "default": True},
"table_name": {"type": "string"},
"bibcode": {"type": "string", "visible_when": {"has_bibcode": True}},
"pub_name": {"type": "string", "visible_when": {"has_bibcode": False}},
"pub_authors": {
"type": "array",
"items": {"type": "string"},
"visible_when": {"has_bibcode": False},
},
"pub_year": {"type": "integer", "visible_when": {"has_bibcode": False}},
},
"required": ["table_name"],
"if": {"properties": {"has_bibcode": {"const": True}}},
"then": {"required": ["bibcode"]},
"else": {"required": ["pub_name", "pub_authors", "pub_year"]},
}

processed = process_schema(schema)

assert "bibcode" not in processed["properties"]
assert "pub_name" not in processed["properties"]
assert "pub_authors" not in processed["properties"]
assert "pub_year" not in processed["properties"]
assert processed["required"] == ["table_name"]
assert "if" not in processed
assert "then" not in processed
assert "else" not in processed

all_of = processed["allOf"]
assert len(all_of) == 2

true_branch = next(branch for branch in all_of if branch["if"]["properties"]["has_bibcode"]["const"] is True)
false_branch = next(branch for branch in all_of if branch["if"]["properties"]["has_bibcode"]["const"] is False)

assert set(true_branch["then"]["properties"].keys()) == {"bibcode"}
assert true_branch["then"]["required"] == ["bibcode"]

assert set(false_branch["then"]["properties"].keys()) == {"pub_name", "pub_authors", "pub_year"}
assert false_branch["then"]["required"] == ["pub_name", "pub_authors", "pub_year"]

for branch in all_of:
for field_schema in branch["then"]["properties"].values():
assert "visible_when" not in field_schema


def test_process_schema_preserves_existing_all_of_and_noop_without_visible_when() -> None:
schema: dict[str, Any] = {
"type": "object",
"properties": {
"flag": {"type": "boolean"},
"name": {"type": "string"},
},
"allOf": [{"if": {"properties": {"flag": {"const": True}}}, "then": {"required": ["name"]}}],
}

processed = process_schema(schema)

assert processed == schema
Loading