Skip to content

Commit 5f45f61

Browse files
Query result enum normalization: apply convert_input_enums_for_model to node and relationship properties during rehydration, test coverage for extended enum round-trip persistence
1 parent 784cb87 commit 5f45f61

3 files changed

Lines changed: 75 additions & 3 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
44

55
[project]
66
name = "kuzualchemy"
7-
version = "0.3.27"
7+
version = "0.3.28"
88
description = "SQLAlchemy-like ORM for Kuzu graph database"
99
readme = "README.md"
1010
license = { file = "LICENSE" }

src/kuzualchemy/kuzu_query.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
)
2828
from .kuzu_query_builder import QueryState, JoinClause, CypherQueryBuilder
2929
from .kuzu_query_fields import QueryField, ModelFieldAccessor
30+
from .enum_normalization import convert_input_enums_for_model
3031
from .uuid_normalization import _NULL_UUID
3132

3233
logger = logging.getLogger(__name__)
@@ -1102,6 +1103,7 @@ def _assert_uuid_seq(seq_value: Any, field_name: str) -> None:
11021103
_assert_uuid_seq(v, k)
11031104

11041105
clean_props[k] = v
1106+
clean_props = convert_input_enums_for_model(model_class=node_cls, values=clean_props)
11051107

11061108
# 6) Use model_construct for fast instantiation (skip validation - data from DB is valid)
11071109
inst = node_cls.model_construct(**clean_props)
@@ -1178,6 +1180,7 @@ def _assert_uuid_seq(seq_value: Any, field_name: str) -> None:
11781180
rel_props.pop(DDLConstants.REL_FROM_NODE_FIELD)
11791181
if DDLConstants.REL_TO_NODE_FIELD in rel_props:
11801182
rel_props.pop(DDLConstants.REL_TO_NODE_FIELD)
1183+
rel_props = convert_input_enums_for_model(model_class=result_model_class, values=rel_props)
11811184

11821185
# Construct relationship instance using model_construct for speed (data from DB is valid)
11831186
instance = result_model_class.model_construct(from_node=from_node, to_node=to_node, **rel_props)

tests/test_basemodel_enum_conversion.py

Lines changed: 71 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,21 @@
99
from __future__ import annotations
1010

1111
import pytest
12+
from pathlib import Path
1213
from enum import Enum, IntEnum, StrEnum
1314
from typing import Optional, Union, List, Tuple
14-
from pydantic import ValidationError
1515

16-
from kuzualchemy import BaseModel, kuzu_node, kuzu_field, KuzuDataType
16+
from kuzualchemy import (
17+
BaseModel,
18+
KuzuDataType,
19+
KuzuSession,
20+
KuzuRelationshipBase,
21+
kuzu_field,
22+
kuzu_int8enum,
23+
kuzu_node,
24+
kuzu_relationship,
25+
)
26+
from kuzualchemy.test_utilities import initialize_schema
1727

1828

1929
class StatusEnum(Enum):
@@ -52,6 +62,15 @@ class StatusStrEnum(StrEnum):
5262
PENDING = "pending"
5363

5464

65+
class CanonicalStageEnum(IntEnum):
66+
BASELINE = 1
67+
68+
69+
@kuzu_int8enum(base_enum=CanonicalStageEnum)
70+
class ExtendedStageEnum:
71+
EXPERIMENTAL = 9
72+
73+
5574
class PrimaryUnionEnum(StrEnum):
5675
"""First enum candidate in a multi-enum Union field."""
5776
FIRST = "first"
@@ -411,6 +430,56 @@ def test_validator_error_message_formatting(self):
411430
assert "Valid names:" in error_message
412431
assert "valid values:" in error_message
413432

433+
def test_node_and_relationship_query_round_trip_rehydrates_extended_enums(self, tmp_path: Path):
434+
@kuzu_node("EnumRoundTripNode")
435+
class EnumRoundTripNode(BaseModel):
436+
id: int = kuzu_field(kuzu_type=KuzuDataType.INT32, primary_key=True)
437+
stage: ExtendedStageEnum = kuzu_field(kuzu_type=KuzuDataType.INT8)
438+
439+
@kuzu_relationship("ENUM_ROUND_TRIP", pairs=[(EnumRoundTripNode, EnumRoundTripNode)])
440+
class EnumRoundTripRelationship(KuzuRelationshipBase):
441+
stage: ExtendedStageEnum = kuzu_field(kuzu_type=KuzuDataType.INT8)
442+
443+
db_path = tmp_path / "enum_round_trip.kuzu"
444+
session = KuzuSession(db_path=db_path)
445+
446+
from kuzualchemy.kuzu_orm import get_ddl_for_node, get_ddl_for_relationship
447+
448+
ddl = "\n".join(
449+
[
450+
get_ddl_for_node(EnumRoundTripNode),
451+
get_ddl_for_relationship(EnumRoundTripRelationship),
452+
]
453+
)
454+
initialize_schema(session, ddl=ddl)
455+
456+
source = EnumRoundTripNode(id=1, stage=CanonicalStageEnum.BASELINE)
457+
target = EnumRoundTripNode(id=2, stage=9)
458+
session.add(source)
459+
session.add(target)
460+
session.commit()
461+
462+
source_db = session.query(EnumRoundTripNode).filter_by(id=1).first()
463+
target_db = session.query(EnumRoundTripNode).filter_by(id=2).first()
464+
465+
assert source_db is not None
466+
assert target_db is not None
467+
assert source_db.stage is ExtendedStageEnum.BASELINE
468+
assert target_db.stage is ExtendedStageEnum.EXPERIMENTAL
469+
470+
session.create_relationship(
471+
EnumRoundTripRelationship,
472+
source_db.id,
473+
target_db.id,
474+
stage=9,
475+
)
476+
session.commit()
477+
478+
relationships = session.query(EnumRoundTripRelationship).all()
479+
480+
assert len(relationships) == 1
481+
assert relationships[0].stage is ExtendedStageEnum.EXPERIMENTAL
482+
414483

415484
# Additional tests for list/tuple of Enum conversions
416485
from typing import List, Tuple

0 commit comments

Comments
 (0)