|
9 | 9 | from __future__ import annotations |
10 | 10 |
|
11 | 11 | import pytest |
| 12 | +from pathlib import Path |
12 | 13 | from enum import Enum, IntEnum, StrEnum |
13 | 14 | from typing import Optional, Union, List, Tuple |
14 | | -from pydantic import ValidationError |
15 | 15 |
|
16 | | -from kuzualchemy import BaseModel, kuzu_node, kuzu_field, KuzuDataType |
| 16 | +from kuzualchemy import ( |
| 17 | + BaseModel, |
| 18 | + KuzuDataType, |
| 19 | + KuzuSession, |
| 20 | + KuzuRelationshipBase, |
| 21 | + kuzu_field, |
| 22 | + kuzu_int8enum, |
| 23 | + kuzu_node, |
| 24 | + kuzu_relationship, |
| 25 | +) |
| 26 | +from kuzualchemy.test_utilities import initialize_schema |
17 | 27 |
|
18 | 28 |
|
19 | 29 | class StatusEnum(Enum): |
@@ -52,6 +62,15 @@ class StatusStrEnum(StrEnum): |
52 | 62 | PENDING = "pending" |
53 | 63 |
|
54 | 64 |
|
| 65 | +class CanonicalStageEnum(IntEnum): |
| 66 | + BASELINE = 1 |
| 67 | + |
| 68 | + |
| 69 | +@kuzu_int8enum(base_enum=CanonicalStageEnum) |
| 70 | +class ExtendedStageEnum: |
| 71 | + EXPERIMENTAL = 9 |
| 72 | + |
| 73 | + |
55 | 74 | class PrimaryUnionEnum(StrEnum): |
56 | 75 | """First enum candidate in a multi-enum Union field.""" |
57 | 76 | FIRST = "first" |
@@ -411,6 +430,56 @@ def test_validator_error_message_formatting(self): |
411 | 430 | assert "Valid names:" in error_message |
412 | 431 | assert "valid values:" in error_message |
413 | 432 |
|
| 433 | + def test_node_and_relationship_query_round_trip_rehydrates_extended_enums(self, tmp_path: Path): |
| 434 | + @kuzu_node("EnumRoundTripNode") |
| 435 | + class EnumRoundTripNode(BaseModel): |
| 436 | + id: int = kuzu_field(kuzu_type=KuzuDataType.INT32, primary_key=True) |
| 437 | + stage: ExtendedStageEnum = kuzu_field(kuzu_type=KuzuDataType.INT8) |
| 438 | + |
| 439 | + @kuzu_relationship("ENUM_ROUND_TRIP", pairs=[(EnumRoundTripNode, EnumRoundTripNode)]) |
| 440 | + class EnumRoundTripRelationship(KuzuRelationshipBase): |
| 441 | + stage: ExtendedStageEnum = kuzu_field(kuzu_type=KuzuDataType.INT8) |
| 442 | + |
| 443 | + db_path = tmp_path / "enum_round_trip.kuzu" |
| 444 | + session = KuzuSession(db_path=db_path) |
| 445 | + |
| 446 | + from kuzualchemy.kuzu_orm import get_ddl_for_node, get_ddl_for_relationship |
| 447 | + |
| 448 | + ddl = "\n".join( |
| 449 | + [ |
| 450 | + get_ddl_for_node(EnumRoundTripNode), |
| 451 | + get_ddl_for_relationship(EnumRoundTripRelationship), |
| 452 | + ] |
| 453 | + ) |
| 454 | + initialize_schema(session, ddl=ddl) |
| 455 | + |
| 456 | + source = EnumRoundTripNode(id=1, stage=CanonicalStageEnum.BASELINE) |
| 457 | + target = EnumRoundTripNode(id=2, stage=9) |
| 458 | + session.add(source) |
| 459 | + session.add(target) |
| 460 | + session.commit() |
| 461 | + |
| 462 | + source_db = session.query(EnumRoundTripNode).filter_by(id=1).first() |
| 463 | + target_db = session.query(EnumRoundTripNode).filter_by(id=2).first() |
| 464 | + |
| 465 | + assert source_db is not None |
| 466 | + assert target_db is not None |
| 467 | + assert source_db.stage is ExtendedStageEnum.BASELINE |
| 468 | + assert target_db.stage is ExtendedStageEnum.EXPERIMENTAL |
| 469 | + |
| 470 | + session.create_relationship( |
| 471 | + EnumRoundTripRelationship, |
| 472 | + source_db.id, |
| 473 | + target_db.id, |
| 474 | + stage=9, |
| 475 | + ) |
| 476 | + session.commit() |
| 477 | + |
| 478 | + relationships = session.query(EnumRoundTripRelationship).all() |
| 479 | + |
| 480 | + assert len(relationships) == 1 |
| 481 | + assert relationships[0].stage is ExtendedStageEnum.EXPERIMENTAL |
| 482 | + |
414 | 483 |
|
415 | 484 | # Additional tests for list/tuple of Enum conversions |
416 | 485 | from typing import List, Tuple |
|
0 commit comments