diff --git a/adapta/dataclass_validation/validation/validation_abstract.py b/adapta/dataclass_validation/validation/validation_abstract.py index 537ac850..2719e6fe 100644 --- a/adapta/dataclass_validation/validation/validation_abstract.py +++ b/adapta/dataclass_validation/validation/validation_abstract.py @@ -72,6 +72,8 @@ def _get_expected_dtypes(self, dtype: type) -> Any: origin_dtype = get_origin(dtype) if origin_dtype in self._dtype_recursive_dtypes and origin_dtype == list: return self._dtype_recursive_dtypes[origin_dtype](self._get_expected_dtypes(dtype=get_args(dtype)[0])) + if origin_dtype == dict: + return self._dtype_mapping[origin_dtype] raise TypeError( f"Unsupported data type: {dtype}. Supported types are: " diff --git a/adapta/dataclass_validation/validation/validation_polars.py b/adapta/dataclass_validation/validation/validation_polars.py index f19c5088..cad0c365 100644 --- a/adapta/dataclass_validation/validation/validation_polars.py +++ b/adapta/dataclass_validation/validation/validation_polars.py @@ -23,6 +23,7 @@ def _dtype_mapping(self): bool: pl.Boolean, datetime.date: pl.Date, datetime.datetime: pl.Datetime, + dict: pl.Struct, } @property @@ -39,6 +40,7 @@ def _allowed_casts(self): pl.Float32: [pl.Float64, pl.String], pl.Float64: [pl.String], pl.Boolean: [pl.Int64, pl.String], + pl.List(pl.Null): [pl.List(pl.String), pl.List(pl.Int64), pl.List(pl.Float64), pl.List(pl.Boolean)], } def _validate_primary_keys(self, **kwargs) -> None: diff --git a/tests/dataclass_validation/validation/abstract_or_validation_class/test_coerce_and_select_columns.py b/tests/dataclass_validation/validation/abstract_or_validation_class/test_coerce_and_select_columns.py index e83728fd..5cbb1ada 100644 --- a/tests/dataclass_validation/validation/abstract_or_validation_class/test_coerce_and_select_columns.py +++ b/tests/dataclass_validation/validation/abstract_or_validation_class/test_coerce_and_select_columns.py @@ -161,6 +161,39 @@ class TestOutput: TestOutput(expected_dtype=pl.String, expected_values=["true"]), id="Boolean to String", ), + # --- Empty Lists (Target: List(String), List(Int64), List(Float64)) --- + pytest.param( + TestInput( + target_field=Field(display_name="v", description="d", dtype=list[str]), + dataframe=pl.DataFrame({"v": [[]]}, schema={"v": pl.List(pl.Null)}), + ), + TestOutput(expected_dtype=pl.List(pl.String), expected_values=[[]]), + id="List(Null) to List(String)", + ), + pytest.param( + TestInput( + target_field=Field(display_name="v", description="d", dtype=list[int]), + dataframe=pl.DataFrame({"v": [[]]}, schema={"v": pl.List(pl.Null)}), + ), + TestOutput(expected_dtype=pl.List(pl.Int64), expected_values=[[]]), + id="List(Null) to List(Int64)", + ), + pytest.param( + TestInput( + target_field=Field(display_name="v", description="d", dtype=list[float]), + dataframe=pl.DataFrame({"v": [[]]}, schema={"v": pl.List(pl.Null)}), + ), + TestOutput(expected_dtype=pl.List(pl.Float64), expected_values=[[]]), + id="List(Null) to List(Float64)", + ), + pytest.param( + TestInput( + target_field=Field(display_name="v", description="d", dtype=list[bool]), + dataframe=pl.DataFrame({"v": [[]]}, schema={"v": pl.List(pl.Null)}), + ), + TestOutput(expected_dtype=pl.List(pl.Boolean), expected_values=[[]]), + id="List(Null) to List(Boolean)", + ), ], ) def test__coerce_and_select_columns__casting_rules__unit_test(inputs: TestInput, expected: TestOutput): @@ -208,7 +241,7 @@ def test__allowed_casts__convention_coverage(): s_base = source_dtype if isinstance(source_dtype, type) else source_dtype.__class__ target_py_type = test_input.target_field.dtype - target_dtype = validator._dtype_mapping.get(target_py_type, target_py_type) + target_dtype = validator._get_expected_dtypes(target_py_type) # Normalize target t_base = target_dtype if isinstance(target_dtype, type) else target_dtype.__class__ diff --git a/tests/dataclass_validation/validation/abstract_or_validation_class/test_coerce_data_types.py b/tests/dataclass_validation/validation/abstract_or_validation_class/test_coerce_data_types.py index 5b56c445..ba457aed 100644 --- a/tests/dataclass_validation/validation/abstract_or_validation_class/test_coerce_data_types.py +++ b/tests/dataclass_validation/validation/abstract_or_validation_class/test_coerce_data_types.py @@ -82,6 +82,23 @@ def create_schema(fields_dict: dict): TestOutput(expect_failure=True), id="Log failure instead of raising when should_raise is False", ), + pytest.param( + TestInput( + target_schema=create_schema( + { + "c1": Field(display_name="v1", description="d", dtype=list[str], coerce=True), + "c2": Field(display_name="v2", description="d", dtype=list[int], coerce=True), + "c3": Field(display_name="v3", description="d", dtype=list[float], coerce=True), + "c4": Field(display_name="v4", description="d", dtype=list[bool], coerce=True), + } + ), + dataframe=pl.DataFrame({"c1": [[]], "c2": [[]], "c3": [[]], "c4": [[]]}), + ), + TestOutput( + expected_dtypes=[pl.List(pl.String), pl.List(pl.Int64), pl.List(pl.Float64), pl.List(pl.Boolean)] + ), + id="Coerces empty List(Null) columns to List(String|Int64|Float64|Boolean)", + ), ], ) def test__coerce_data_types__unit_test(inputs: TestInput, expected: TestOutput): diff --git a/tests/dataclass_validation/validation/abstract_or_validation_class/test_validate_data_types.py b/tests/dataclass_validation/validation/abstract_or_validation_class/test_validate_data_types.py index 8a1b7379..b64a1838 100644 --- a/tests/dataclass_validation/validation/abstract_or_validation_class/test_validate_data_types.py +++ b/tests/dataclass_validation/validation/abstract_or_validation_class/test_validate_data_types.py @@ -27,6 +27,16 @@ class TestDataClass(AbstractDataClass): description="Description for column 4.", dtype=bool, ) + column_5 = Field( + display_name="Column 5", + description="Description for column 5.", + dtype=dict, + ) + column_6 = Field( + display_name="Column 6", + description="Description for column 6.", + dtype=dict[str, str | int], + ) TEST_SCHEMA = TestDataClass() @@ -37,6 +47,14 @@ class TestDataClass(AbstractDataClass): TEST_SCHEMA.column_2: pl.Series([1, 2], dtype=pl.Int64), TEST_SCHEMA.column_3: pl.Series([1.0, 2.0], dtype=pl.Float64), TEST_SCHEMA.column_4: [True, False], + TEST_SCHEMA.column_5: [ + {"name": "alice", "age": "30"}, + {"city": "Seattle"}, + ], + TEST_SCHEMA.column_6: [ + {"key1": "value1", "key2": 1}, + {"key1": "value1", "key2": 1}, + ], } ), schema=TEST_SCHEMA, @@ -72,6 +90,11 @@ class TestDataClass(AbstractDataClass): description="Description for column 4.", dtype=bool, ) + column_5 = Field( + display_name="Column 5", + description="Description for column 5.", + dtype=dict, + ) TEST_SCHEMA = TestDataClass() @@ -82,6 +105,7 @@ class TestDataClass(AbstractDataClass): TEST_SCHEMA.column_2: ["value"], TEST_SCHEMA.column_3: ["value"], TEST_SCHEMA.column_4: ["value"], + TEST_SCHEMA.column_5: ["value"], } ), schema=TEST_SCHEMA, @@ -95,4 +119,5 @@ class TestDataClass(AbstractDataClass): "Column 'column_2' has incorrect type. Expected Int64, got String", "Column 'column_3' has incorrect type. Expected Float64, got String", "Column 'column_4' has incorrect type. Expected Boolean, got String", + "Column 'column_5' has incorrect type. Expected Struct, got String", ] diff --git a/tests/dataclass_validation/validation/test_get_expected_dtypes.py b/tests/dataclass_validation/validation/test_get_expected_dtypes.py index b19557dc..9a794f5f 100644 --- a/tests/dataclass_validation/validation/test_get_expected_dtypes.py +++ b/tests/dataclass_validation/validation/test_get_expected_dtypes.py @@ -13,6 +13,7 @@ (int, pl.Int64), (float, pl.Float64), (bool, pl.Boolean), + (dict, pl.Struct), (list[str], pl.List(pl.String)), (list[list[float]], pl.List(pl.List(pl.Float64))), (datetime.date, pl.Date), @@ -45,7 +46,7 @@ class DummyDataClass(AbstractDataClass): @pytest.mark.parametrize( "dtype", [ - dict, + tuple, ], ) def test__polars_get_expected_types__expected_errors(