diff --git a/src/kuzualchemy/kuzu_session.py b/src/kuzualchemy/kuzu_session.py index 3caf91b..32c0ccd 100755 --- a/src/kuzualchemy/kuzu_session.py +++ b/src/kuzualchemy/kuzu_session.py @@ -11,7 +11,7 @@ """ from __future__ import annotations -from typing import Any, Dict, List, Optional, Type, TypeVar, Union, Iterator, cast +from typing import Any, Dict, List, Optional, Type, TypeVar, Union, Iterator, Sequence, cast from contextlib import contextmanager from threading import RLock from pathlib import Path @@ -33,11 +33,18 @@ KuzuDataType, LoggingConstants, DatabaseConstants, + KuzuDefaultFunction, ) from .connection_pool import get_shared_connection_pool from .constants import PerformanceConstants from .constants import DDLConstants -from .kuzu_orm import get_node_by_name, KuzuRelationshipBase, get_registered_nodes +from .kuzu_orm import ( + get_node_by_name, + KuzuRelationshipBase, + get_registered_nodes, + KuzuFieldMetadata, + ArrayTypeSpecification, +) # TODO: TESTS ARE FALING SINCE I MADE THE BATCH SIZES DYNAMIC ACCORDING TO KUZU, EVEN THOUGH MY IMPLEM ENTATION FIXED A SHIT LOAD OF ISSUES IN OTHER PRODUCTION PROJECTS. FIX THE ISSUES RIGHT NOW HERE AND INT HE CONNECTION POOL OR WHEREEVER. @@ -682,6 +689,287 @@ def is_buffer_exhaustion_error(err: BaseException) -> bool: gc.collect() # Do not advance 'start'; retry with smaller batch + def _maybe_generate_default_function(self, value: Any) -> Any: + """Generate values for default function markers when present.""" + if isinstance(value, KuzuDefaultFunction): + return self._generate_default_function_value(value) + if hasattr(value, '__class__') and 'KuzuDefaultFunction' in str(value.__class__): + return self._generate_default_function_value(value) + return value + + def _normalize_value_for_type( + self, + value: Any, + kuzu_type: Union[KuzuDataType, ArrayTypeSpecification, str], + ) -> Any: + """Normalize Python values according to the declared Kùzu type.""" + if value is None: + return None + + value = self._maybe_generate_default_function(value) + + # Array/list types recurse on the element type + if isinstance(kuzu_type, ArrayTypeSpecification): + if value is None: + return None + if isinstance(value, (list, tuple, set)): + iterable = value + else: + iterable = [value] + return [ + self._normalize_value_for_type(elem, kuzu_type.element_type) + for elem in iterable + ] + + # String-based metadata: attempt to resolve to canonical enum + if isinstance(kuzu_type, str): + try: + canonical = KuzuDataType[kuzu_type.upper()] + except KeyError: + canonical = None + if canonical is not None: + return self._normalize_value_for_type(value, canonical) + # Fallback to generic handling for unknown strings + if isinstance(value, Enum): + return getattr(value, 'value', value) + if isinstance(value, uuid.UUID): + return str(value) + if hasattr(value, 'isoformat') and not isinstance(value, (str, bytes)): + return value.isoformat() + return value + + if isinstance(value, Enum): + value = getattr(value, 'value', value) + + if isinstance(kuzu_type, KuzuDataType): + if kuzu_type == KuzuDataType.UUID: + if isinstance(value, uuid.UUID): + return str(value) + return str(value) + + if kuzu_type in {KuzuDataType.STRING}: + if isinstance(value, str): + return value + return str(value) + + if kuzu_type in {KuzuDataType.BOOL, KuzuDataType.BOOLEAN}: + if isinstance(value, bool): + return value + if isinstance(value, str): + normalized = value.strip().lower() + if normalized in {"true", "t", "1", "yes", "y"}: + return True + if normalized in {"false", "f", "0", "no", "n"}: + return False + return bool(value) + + numeric_types = { + KuzuDataType.INT8, + KuzuDataType.INT16, + KuzuDataType.INT32, + KuzuDataType.INT64, + KuzuDataType.UINT8, + KuzuDataType.UINT16, + KuzuDataType.UINT32, + KuzuDataType.UINT64, + KuzuDataType.SERIAL, + } + + if kuzu_type in numeric_types: + return value + + if kuzu_type in {KuzuDataType.FLOAT, KuzuDataType.DOUBLE}: + return value + + if kuzu_type in {KuzuDataType.INT128, KuzuDataType.DECIMAL}: + return value + + if kuzu_type in { + KuzuDataType.DATE, + KuzuDataType.TIMESTAMP, + KuzuDataType.TIMESTAMP_NS, + KuzuDataType.TIMESTAMP_MS, + KuzuDataType.TIMESTAMP_SEC, + KuzuDataType.TIMESTAMP_TZ, + KuzuDataType.INTERVAL, + }: + if isinstance(value, str): + return value + if hasattr(value, 'isoformat'): + return value.isoformat() + return str(value) + + if kuzu_type == KuzuDataType.BLOB: + if isinstance(value, (bytes, bytearray, memoryview)): + return bytes(value) + if isinstance(value, str): + return value.encode('utf-8') + return bytes(value) + + # Default fallback for other types + if isinstance(value, uuid.UUID): + return str(value) + if hasattr(value, 'isoformat') and not isinstance(value, (str, bytes)): + return value.isoformat() + return value + + # Unknown metadata type fallback + if isinstance(value, uuid.UUID): + return str(value) + if isinstance(value, Enum): + return getattr(value, 'value', value) + if hasattr(value, 'isoformat') and not isinstance(value, (str, bytes)): + return value.isoformat() + return value + + def _normalize_bulk_value(self, value: Any, field_meta: Optional[KuzuFieldMetadata]) -> Any: + """Normalize a bulk insert value using available metadata.""" + if value is None: + return None + + if field_meta is not None: + return self._normalize_value_for_type(value, field_meta.kuzu_type) + + value = self._maybe_generate_default_function(value) + + if isinstance(value, Enum): + value = getattr(value, 'value', value) + + if isinstance(value, uuid.UUID): + return str(value) + + if isinstance(value, (list, tuple, set)): + return [self._normalize_bulk_value(v, None) for v in value] + + if hasattr(value, 'isoformat') and not isinstance(value, (str, bytes)): + return value.isoformat() + + return value + + def _get_pyarrow_type_for_field(self, field_meta: Optional[KuzuFieldMetadata]) -> Optional[pa.DataType]: + """Return the PyArrow type for the provided field metadata, if known.""" + if field_meta is None: + return None + return self._map_kuzu_type_to_arrow(field_meta.kuzu_type) + + def _map_kuzu_type_to_arrow( + self, kuzu_type: Union[KuzuDataType, ArrayTypeSpecification, str] + ) -> Optional[pa.DataType]: + """Map a Kùzu data type to the corresponding PyArrow type.""" + if isinstance(kuzu_type, ArrayTypeSpecification): + element_type = self._map_kuzu_type_to_arrow(kuzu_type.element_type) + if element_type is None: + return None + return pa.list_(element_type) + + if isinstance(kuzu_type, KuzuDataType): + if kuzu_type == KuzuDataType.INT8: + return pa.int8() + if kuzu_type == KuzuDataType.INT16: + return pa.int16() + if kuzu_type == KuzuDataType.INT32: + return pa.int32() + if kuzu_type == KuzuDataType.INT64: + return pa.int64() + if kuzu_type == KuzuDataType.UINT8: + return pa.uint8() + if kuzu_type == KuzuDataType.UINT16: + return pa.uint16() + if kuzu_type == KuzuDataType.UINT32: + return pa.uint32() + if kuzu_type == KuzuDataType.UINT64: + return pa.uint64() + if kuzu_type in {KuzuDataType.INT128, KuzuDataType.DECIMAL}: + return pa.decimal128(38, 0) + if kuzu_type == KuzuDataType.FLOAT: + return pa.float32() + if kuzu_type == KuzuDataType.DOUBLE: + return pa.float64() + if kuzu_type == KuzuDataType.SERIAL: + return pa.int64() + if kuzu_type in {KuzuDataType.STRING, KuzuDataType.UUID}: + return pa.string() + if kuzu_type in { + KuzuDataType.DATE, + KuzuDataType.TIMESTAMP, + KuzuDataType.TIMESTAMP_NS, + KuzuDataType.TIMESTAMP_MS, + KuzuDataType.TIMESTAMP_SEC, + KuzuDataType.TIMESTAMP_TZ, + KuzuDataType.INTERVAL, + }: + return pa.string() + if kuzu_type in {KuzuDataType.BOOL, KuzuDataType.BOOLEAN}: + return pa.bool_() + if kuzu_type == KuzuDataType.BLOB: + return pa.binary() + + if isinstance(kuzu_type, str): + try: + canonical = KuzuDataType[kuzu_type.upper()] + except KeyError: + return None + return self._map_kuzu_type_to_arrow(canonical) + + return None + + def _prepare_arrow_array( + self, + field_name: str, + values: Sequence[Any], + field_meta: Optional[KuzuFieldMetadata], + ) -> pa.Array: + """Create a PyArrow array for the given values with graceful fallback.""" + arrow_type = self._get_pyarrow_type_for_field(field_meta) + try: + if arrow_type is not None: + return pa.array(values, type=arrow_type) + return pa.array(values) + except (pa.ArrowInvalid, pa.ArrowTypeError): + coerced: List[Any] = [] + for item in values: + if item is None: + coerced.append(None) + elif isinstance(item, uuid.UUID): + coerced.append(str(item)) + elif isinstance(item, list): + coerced.append([ + str(elem) if isinstance(elem, uuid.UUID) else elem for elem in item + ]) + else: + coerced.append(str(item)) + return pa.array(coerced, type=pa.string()) + + def _build_arrow_table( + self, + data_dict: Dict[str, List[Any]], + schema_fields: List[pa.Field], + ordered_field_names: Sequence[str], + field_metadata: Dict[str, KuzuFieldMetadata], + ) -> pa.Table: + """Construct a PyArrow table honoring explicit schemas and field order.""" + arrays: List[pa.Array] = [] + names: List[str] = [] + schema_field_names = {field.name for field in schema_fields} + + for field in schema_fields: + try: + arrays.append(pa.array(data_dict[field.name], type=field.type)) + except (pa.ArrowInvalid, pa.ArrowTypeError): + arrays.append( + self._prepare_arrow_array(field.name, data_dict[field.name], field_metadata.get(field.name)) + ) + names.append(field.name) + + for field_name in ordered_field_names: + if field_name in schema_field_names: + continue + values = data_dict.get(field_name, []) + arrays.append(self._prepare_arrow_array(field_name, values, field_metadata.get(field_name))) + names.append(field_name) + + return pa.table(arrays, names=names) + def _process_batch_with_pyarrow(self, model_class: Type[Any], instances: List[Any]) -> None: """ Process a batch of instances using PyArrow table with explicit UUID schema. @@ -690,8 +978,9 @@ def _process_batch_with_pyarrow(self, model_class: Type[Any], instances: List[An return # Convert instances to dictionary format - data_dict = {} - schema_fields = [] + data_dict: Dict[str, List[Any]] = {} + schema_fields: List[pa.Field] = [] + field_metadata = model_class.get_all_kuzu_metadata() from .kuzu_orm import KuzuRelationshipBase is_relationship = issubclass(model_class, KuzuRelationshipBase) @@ -712,9 +1001,6 @@ def _process_batch_with_pyarrow(self, model_class: Type[Any], instances: List[An } field_names = [f for f in all_field_names if f not in internal_fields] - # Cached metadata for type conversions - auto_increment_metadata = model_class.get_auto_increment_metadata() - # Helper to resolve node label from node instance/reference def _resolve_node_label(node_obj: Any) -> str: # Prefer explicit kuzu node name on class or instance @@ -767,9 +1053,10 @@ def _resolve_node_label(node_obj: Any) -> str: # Initialize property columns for fname in field_names: g_data[fname] = [] - fmeta = auto_increment_metadata.get(fname) - if fmeta and fmeta.kuzu_type == KuzuDataType.UUID: - g_schema.append(pa.field(fname, pa.string())) + fmeta = field_metadata.get(fname) + arrow_type = self._get_pyarrow_type_for_field(fmeta) + if arrow_type is not None: + g_schema.append(pa.field(fname, arrow_type)) # Fill rows for inst in group_instances: @@ -780,58 +1067,12 @@ def _resolve_node_label(node_obj: Any) -> str: inst_dict = inst.model_dump() for fname in field_names: - val = inst_dict.get(fname) - fmeta = auto_increment_metadata.get(fname) - if fmeta and fmeta.kuzu_type == KuzuDataType.UUID and isinstance(val, uuid.UUID): - val = str(val) - elif isinstance(val, uuid.UUID): - val = str(val) - elif hasattr(val, 'isoformat'): - val = val.isoformat() - elif (hasattr(val, '__class__') and 'KuzuDefaultFunction' in str(val.__class__)): - val = self._generate_default_function_value(val) - g_data[fname].append(val) - - # Build PyArrow table - arrays = [] - names = [] - schema_names = {f.name for f in g_schema} - - # from_node_pk/to_node_pk remain strings - for f in g_schema: - arrays.append(pa.array(g_data[f.name], type=f.type)) - names.append(f.name) - - # For property fields without explicit schema, coerce to string arrays. - # This avoids the prepared-statement INT128 binder path for numeric types - # while keeping the COPY in-memory via $dataframe. - for fname in field_names: - if fname not in schema_names: - col = g_data[fname] - conv: list[str | list[str] | None] = [] - for v in col: - if isinstance(v, uuid.UUID): - conv.append(str(v)) - elif isinstance(v, list) and v and isinstance(v[0], uuid.UUID): - conv.append([str(u) for u in v]) - else: - # Coerce numerics, enums, bools to strings; leave None as-is - if v is None: - conv.append(None) - else: - try: - # Enum -> underlying value, then to str - from enum import Enum as _E - if isinstance(v, _E): - conv.append(str(getattr(v, 'value', v))) - else: - conv.append(str(v)) - except Exception: - conv.append(str(v)) - arrays.append(pa.array(conv, type=pa.string())) - names.append(fname) - - g_df = pa.table(arrays, names=names) + raw_value = inst_dict.get(fname) + normalized = self._normalize_bulk_value(raw_value, field_metadata.get(fname)) + g_data[fname].append(normalized) + + ordered_fields = ['from_node_pk', 'to_node_pk'] + field_names + g_df = self._build_arrow_table(g_data, g_schema, ordered_fields, field_metadata) try: self._execute_with_connection_reuse( @@ -867,15 +1108,11 @@ def _resolve_node_label(node_obj: Any) -> str: ]) # Initialize property columns with proper types - auto_increment_metadata = model_class.get_auto_increment_metadata() for field_name in field_names: data_dict[field_name] = [] - - # Determine PyArrow type for this field - field_meta = auto_increment_metadata.get(field_name) - if field_meta and field_meta.kuzu_type == KuzuDataType.UUID: - # Use string type for UUID fields to avoid BLOB conversion issues - schema_fields.append(pa.field(field_name, pa.string())) + arrow_type = self._get_pyarrow_type_for_field(field_metadata.get(field_name)) + if arrow_type is not None: + schema_fields.append(pa.field(field_name, arrow_type)) # Extract data from relationship instances for instance in instances: @@ -889,19 +1126,9 @@ def _resolve_node_label(node_obj: Any) -> str: instance_data = instance.model_dump() for field_name in field_names: - value = instance_data.get(field_name) - - # Convert based on field type - field_meta = auto_increment_metadata.get(field_name) - if field_meta and field_meta.kuzu_type == KuzuDataType.UUID and isinstance(value, uuid.UUID): - # Convert UUID objects to strings for KuzuDB UUID parsing - value = str(value) - elif hasattr(value, 'isoformat'): - value = value.isoformat() - elif (hasattr(value, '__class__') and 'KuzuDefaultFunction' in str(value.__class__)): - value = self._generate_default_function_value(value) - - data_dict[field_name].append(value) + raw_value = instance_data.get(field_name) + normalized = self._normalize_bulk_value(raw_value, field_metadata.get(field_name)) + data_dict[field_name].append(normalized) else: # Handle nodes @@ -919,82 +1146,26 @@ def _resolve_node_label(node_obj: Any) -> str: field_names.append(field_name) # Initialize columns with proper schema - auto_increment_metadata = model_class.get_auto_increment_metadata() for field_name in field_names: data_dict[field_name] = [] - - # Determine PyArrow type for this field - field_meta = auto_increment_metadata.get(field_name) - if field_meta and field_meta.kuzu_type == KuzuDataType.UUID: - # Use string type for UUID fields to avoid BLOB conversion issues - schema_fields.append(pa.field(field_name, pa.string())) - # Note: For other types, we'll let PyArrow infer from the data later + arrow_type = self._get_pyarrow_type_for_field(field_metadata.get(field_name)) + if arrow_type is not None: + schema_fields.append(pa.field(field_name, arrow_type)) # Extract data from instances for instance in instances: instance_data = instance.model_dump() for field_name in field_names: - value = instance_data.get(field_name) - - # Convert based on field type - field_meta = auto_increment_metadata.get(field_name) - if field_meta and field_meta.kuzu_type == KuzuDataType.UUID and isinstance(value, uuid.UUID): - # Convert UUID objects to strings for KuzuDB UUID parsing - value = str(value) - elif isinstance(value, uuid.UUID): - # Convert any UUID objects to strings for KuzuDB UUID parsing - value = str(value) - elif hasattr(value, 'isoformat'): - value = value.isoformat() - elif (hasattr(value, '__class__') and 'KuzuDefaultFunction' in str(value.__class__)): - value = self._generate_default_function_value(value) - - data_dict[field_name].append(value) - - # Create PyArrow table with mixed explicit and inferred schema - if schema_fields: - # Build arrays with proper types for explicit fields, let PyArrow infer others - arrays = [] - field_names_ordered = [] - schema_field_names = {f.name for f in schema_fields} - - # First, add arrays for fields with explicit types - for field in schema_fields: - arrays.append(pa.array(data_dict[field.name], type=field.type)) - field_names_ordered.append(field.name) - - # Then, add arrays for fields without explicit types (let PyArrow infer) - for field_name in field_names: - if field_name not in schema_field_names: - # Convert any remaining UUID objects to strings before PyArrow inference - field_data = data_dict[field_name] - converted_data = [] - for value in field_data: - if isinstance(value, uuid.UUID): - converted_data.append(str(value)) - elif isinstance(value, list) and value and isinstance(value[0], uuid.UUID): - converted_data.append([str(uuid_val) for uuid_val in value]) - else: - converted_data.append(value) - arrays.append(pa.array(converted_data)) - field_names_ordered.append(field_name) + raw_value = instance_data.get(field_name) + normalized = self._normalize_bulk_value(raw_value, field_metadata.get(field_name)) + data_dict[field_name].append(normalized) - # Create table with field names only, let PyArrow infer types for non-explicit fields - df = pa.table(arrays, names=field_names_ordered) + if is_relationship: + ordered_fields = ['from_node_pk', 'to_node_pk'] + field_names else: - # Fallback to full inference - convert all UUID objects to strings first - converted_dict = {} - for field_name, field_data in data_dict.items(): - converted_data = [] - for value in field_data: - if isinstance(value, uuid.UUID): - converted_data.append(str(value)) - elif isinstance(value, list) and value and isinstance(value[0], uuid.UUID): - converted_data.append([str(uuid_val) for uuid_val in value]) - else: - converted_data.append(value) - converted_dict[field_name] = converted_data - df = pa.table(converted_dict) + ordered_fields = field_names + + df = self._build_arrow_table(data_dict, schema_fields, ordered_fields, field_metadata) data_dict.clear() del data_dict