Skip to content
Merged
2 changes: 2 additions & 0 deletions adapta/dataclass_validation/validation/validation_abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ def _get_expected_dtypes(self, dtype: type) -> Any:
origin_dtype = get_origin(dtype)
if origin_dtype in self._dtype_recursive_dtypes and origin_dtype == list:
return self._dtype_recursive_dtypes[origin_dtype](self._get_expected_dtypes(dtype=get_args(dtype)[0]))
if origin_dtype == dict:
return self._dtype_mapping[origin_dtype]

raise TypeError(
f"Unsupported data type: {dtype}. Supported types are: "
Expand Down
2 changes: 2 additions & 0 deletions adapta/dataclass_validation/validation/validation_polars.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def _dtype_mapping(self):
bool: pl.Boolean,
datetime.date: pl.Date,
datetime.datetime: pl.Datetime,
dict: pl.Struct,
}

@property
Expand All @@ -39,6 +40,7 @@ def _allowed_casts(self):
pl.Float32: [pl.Float64, pl.String],
pl.Float64: [pl.String],
pl.Boolean: [pl.Int64, pl.String],
pl.List(pl.Null): [pl.List(pl.String), pl.List(pl.Int64), pl.List(pl.Float64), pl.List(pl.Boolean)],
}

def _validate_primary_keys(self, **kwargs) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,39 @@ class TestOutput:
TestOutput(expected_dtype=pl.String, expected_values=["true"]),
id="Boolean to String",
),
# --- Empty Lists (Target: List(String), List(Int64), List(Float64)) ---
pytest.param(
TestInput(
target_field=Field(display_name="v", description="d", dtype=list[str]),
dataframe=pl.DataFrame({"v": [[]]}, schema={"v": pl.List(pl.Null)}),
),
TestOutput(expected_dtype=pl.List(pl.String), expected_values=[[]]),
id="List(Null) to List(String)",
),
pytest.param(
TestInput(
target_field=Field(display_name="v", description="d", dtype=list[int]),
dataframe=pl.DataFrame({"v": [[]]}, schema={"v": pl.List(pl.Null)}),
),
TestOutput(expected_dtype=pl.List(pl.Int64), expected_values=[[]]),
id="List(Null) to List(Int64)",
),
pytest.param(
TestInput(
target_field=Field(display_name="v", description="d", dtype=list[float]),
dataframe=pl.DataFrame({"v": [[]]}, schema={"v": pl.List(pl.Null)}),
),
TestOutput(expected_dtype=pl.List(pl.Float64), expected_values=[[]]),
id="List(Null) to List(Float64)",
),
pytest.param(
TestInput(
target_field=Field(display_name="v", description="d", dtype=list[bool]),
dataframe=pl.DataFrame({"v": [[]]}, schema={"v": pl.List(pl.Null)}),
),
TestOutput(expected_dtype=pl.List(pl.Boolean), expected_values=[[]]),
id="List(Null) to List(Boolean)",
),
],
)
def test__coerce_and_select_columns__casting_rules__unit_test(inputs: TestInput, expected: TestOutput):
Expand Down Expand Up @@ -208,7 +241,7 @@ def test__allowed_casts__convention_coverage():
s_base = source_dtype if isinstance(source_dtype, type) else source_dtype.__class__

target_py_type = test_input.target_field.dtype
target_dtype = validator._dtype_mapping.get(target_py_type, target_py_type)
target_dtype = validator._get_expected_dtypes(target_py_type)
# Normalize target
t_base = target_dtype if isinstance(target_dtype, type) else target_dtype.__class__

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,23 @@ def create_schema(fields_dict: dict):
TestOutput(expect_failure=True),
id="Log failure instead of raising when should_raise is False",
),
pytest.param(
TestInput(
target_schema=create_schema(
{
"c1": Field(display_name="v1", description="d", dtype=list[str], coerce=True),
"c2": Field(display_name="v2", description="d", dtype=list[int], coerce=True),
"c3": Field(display_name="v3", description="d", dtype=list[float], coerce=True),
"c4": Field(display_name="v4", description="d", dtype=list[bool], coerce=True),
}
),
dataframe=pl.DataFrame({"c1": [[]], "c2": [[]], "c3": [[]], "c4": [[]]}),
),
TestOutput(
expected_dtypes=[pl.List(pl.String), pl.List(pl.Int64), pl.List(pl.Float64), pl.List(pl.Boolean)]
),
id="Coerces empty List(Null) columns to List(String|Int64|Float64|Boolean)",
),
],
)
def test__coerce_data_types__unit_test(inputs: TestInput, expected: TestOutput):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,16 @@ class TestDataClass(AbstractDataClass):
description="Description for column 4.",
dtype=bool,
)
column_5 = Field(
display_name="Column 5",
description="Description for column 5.",
dtype=dict,
)
column_6 = Field(
display_name="Column 6",
description="Description for column 6.",
dtype=dict[str, str | int],
)

TEST_SCHEMA = TestDataClass()

Expand All @@ -37,6 +47,14 @@ class TestDataClass(AbstractDataClass):
TEST_SCHEMA.column_2: pl.Series([1, 2], dtype=pl.Int64),
TEST_SCHEMA.column_3: pl.Series([1.0, 2.0], dtype=pl.Float64),
TEST_SCHEMA.column_4: [True, False],
TEST_SCHEMA.column_5: [
{"name": "alice", "age": "30"},
{"city": "Seattle"},
],
TEST_SCHEMA.column_6: [
{"key1": "value1", "key2": 1},
{"key1": "value1", "key2": 1},
],
}
),
schema=TEST_SCHEMA,
Expand Down Expand Up @@ -72,6 +90,11 @@ class TestDataClass(AbstractDataClass):
description="Description for column 4.",
dtype=bool,
)
column_5 = Field(
display_name="Column 5",
description="Description for column 5.",
dtype=dict,
Comment thread
henrikfoss marked this conversation as resolved.
)

TEST_SCHEMA = TestDataClass()

Expand All @@ -82,6 +105,7 @@ class TestDataClass(AbstractDataClass):
TEST_SCHEMA.column_2: ["value"],
TEST_SCHEMA.column_3: ["value"],
TEST_SCHEMA.column_4: ["value"],
TEST_SCHEMA.column_5: ["value"],
}
),
schema=TEST_SCHEMA,
Expand All @@ -95,4 +119,5 @@ class TestDataClass(AbstractDataClass):
"Column 'column_2' has incorrect type. Expected Int64, got String",
"Column 'column_3' has incorrect type. Expected Float64, got String",
"Column 'column_4' has incorrect type. Expected Boolean, got String",
"Column 'column_5' has incorrect type. Expected Struct, got String",
]
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
(int, pl.Int64),
(float, pl.Float64),
(bool, pl.Boolean),
(dict, pl.Struct),
(list[str], pl.List(pl.String)),
(list[list[float]], pl.List(pl.List(pl.Float64))),
(datetime.date, pl.Date),
Expand Down Expand Up @@ -45,7 +46,7 @@ class DummyDataClass(AbstractDataClass):
@pytest.mark.parametrize(
"dtype",
[
dict,
tuple,
],
)
def test__polars_get_expected_types__expected_errors(
Expand Down
Loading