diff --git a/src/py_avro_schema/_schemas.py b/src/py_avro_schema/_schemas.py index ad21090..abec206 100644 --- a/src/py_avro_schema/_schemas.py +++ b/src/py_avro_schema/_schemas.py @@ -214,6 +214,7 @@ def schema( namespace: Optional[str] = None, names: Optional[NamesType] = None, options: Option = Option(0), + processing: set[type] | None = None, ) -> JSONType: """ Generate and return an Avro schema for a given Python type @@ -228,12 +229,30 @@ def schema( """ if names is None: names = [] - schema_obj = _schema_obj(py_type, namespace=namespace, options=options) + schema_obj = _schema_obj(py_type, namespace=namespace, options=options, processing=processing) schema_data = schema_obj.data(names=names) return schema_data -def _schema_obj(py_type: Type, namespace: Optional[str] = None, options: Option = Option(0)) -> "Schema": +def _fullname_for_forward_ref(py_type: Type, namespace: Optional[str], options: Option) -> str: + """Computes the fully-qualified name to be used in a ForwardRef ot break cycles.""" + name = py_type.__name__ + if namespace is None and Option.NO_AUTO_NAMESPACE not in options: + module = inspect.getmodule(py_type) + if module and module.__name__ != "builtin": + if Option.AUTO_NAMESPACE_MODULE in options: + namespace = module.__name__ + else: + namespace = module.__name__.split(".", 1)[0] + return f"{namespace}.{name}" if namespace else name + + +def _schema_obj( + py_type: Type, + namespace: Optional[str] = None, + options: Option = Option(0), + processing: set[type] | None = None, +) -> "Schema": """ Dispatch to relevant schema classes @@ -241,10 +260,15 @@ def _schema_obj(py_type: Type, namespace: Optional[str] = None, options: Option :param namespace: The Avro namespace to add to schemas. :param options: Schema generation options. """ + processing = processing or set() + # If py_type is currently being processed further up the stack, emit a ForwardRef to break the cycle + unwrapped = _type_from_annotated(py_type) + if unwrapped in processing and hasattr(unwrapped, "__name__"): + py_type = ForwardRef(_fullname_for_forward_ref(unwrapped, namespace, options)) # type: ignore # Find concrete Schema subclasses defined in the current module for schema_class in sorted(_SCHEMA_CLASSES, key=lambda c: getattr(c, "__py_avro_priority", 0)): # Find the first schema class that handles py_type - schema_obj = schema_class(py_type, namespace=namespace, options=options) # type: ignore + schema_obj = schema_class(py_type, namespace=namespace, options=options, processing=processing) # type: ignore if schema_obj: return schema_obj raise TypeNotSupportedError(f"Cannot generate Avro schema for Python type {py_type}") @@ -274,7 +298,13 @@ def validate_name(value: str) -> str: class Schema(abc.ABC): """Schema base""" - def __new__(cls, py_type: Type, namespace: Optional[str] = None, options: Option = Option(0)): + def __new__( + cls, + py_type: Type, + namespace: Optional[str] = None, + options: Option = Option(0), + processing: set[type] | None = None, + ): """ Create an instance of this schema class if it handles py_type @@ -287,17 +317,25 @@ def __new__(cls, py_type: Type, namespace: Optional[str] = None, options: Option else: return None - def __init__(self, py_type: Type, namespace: Optional[str] = None, options: Option = Option(0)): + def __init__( + self, + py_type: Type, + namespace: Optional[str] = None, + options: Option = Option(0), + processing: set[type] | None = None, + ): """ A schema base :param py_type: The Python class to generate a schema for. :param namespace: The Avro namespace to add to schemas. :param options: Schema generation options. + :param processing: Internal parameter to track types currently being processed (for circular references). """ self.py_type = py_type self.options = options self._namespace = namespace # Namespace override + self.processing = processing or set() @property def namespace_override(self) -> Optional[str]: @@ -334,22 +372,27 @@ def make_default(self, py_default: Any) -> Any: """ return py_default - def _wrap_as_record(self, inner_schema: JSONObj, names: NamesType) -> JSONType: + def _wrap_as_record( + self, + names: NamesType, + build_inner: collections.abc.Callable[[NamesType], JSONObj], + ) -> JSONType: """ - Wrap a container schema (array or map) into an Avro record with ``__id`` and ``__data`` fields. - Handles deduplication via ``names``. + Wrap a container schema into an Avro record with ``__id`` and ``__data`` fields. The wrapper's + fullname is reserved in ``names`` before internal data is computed. This is to avoid a recursive inner type + to be expanded again (as the wrapper is). """ record_name = _avro_name_for_type(_type_from_annotated(self.py_type)) fullname = f"{self.namespace}.{record_name}" if self.namespace else record_name if fullname in names: return fullname names.append(fullname) - record_schema = { + record_schema: JSONObj = { "type": "record", "name": record_name, "fields": [ {"name": REF_ID_KEY, "type": ["null", "long"], "default": None}, - {"name": REF_DATA_KEY, "type": inner_schema}, + {"name": REF_DATA_KEY, "type": build_inner(names)}, ], } if self.namespace: @@ -428,7 +471,13 @@ def data(self, names: NamesType) -> JSONObj: class LiteralSchema(Schema): """An Avro schema of any type for a Python Literal type, e.g. ``Literal[""]``""" - def __init__(self, py_type: Type[Any], namespace: Optional[str] = None, options: Option = Option(0)): + def __init__( + self, + py_type: Type[Any], + namespace: Optional[str] = None, + options: Option = Option(0), + **kwargs, + ): """ An Avro schema of any type for a Python Literal type, e.g. ``Literal[""]`` @@ -462,7 +511,13 @@ def data(self, names: NamesType) -> JSONType: class FinalSchema(Schema): """An Avro schema for Python ``typing.Final``""" - def __init__(self, py_type: Type, namespace: Optional[str] = None, options: Option = Option(0)): + def __init__( + self, + py_type: Type, + namespace: Optional[str] = None, + options: Option = Option(0), + **kwargs, + ): """An Avro schema for Python ``typing.Final``""" super().__init__(py_type, namespace, options) py_type = _type_from_annotated(py_type) @@ -757,6 +812,7 @@ def __init__( py_type: Type[collections.abc.MutableSequence], namespace: Optional[str] = None, options: Option = Option(0), + processing: set[type] | None = None, ): """ An Avro array schema for a given Python sequence @@ -765,17 +821,19 @@ def __init__( :param namespace: The Avro namespace to add to schemas. :param options: Schema generation options. """ - super().__init__(py_type, namespace=namespace, options=options) + super().__init__(py_type, namespace=namespace, options=options, processing=processing) py_type = _type_from_annotated(py_type) args = get_args(py_type) # TODO: validate if args has exactly 1 item? - self.items_schema = _schema_obj(args[0], namespace=namespace, options=options) + self.items_schema = _schema_obj(args[0], namespace=namespace, options=options, processing=self.processing) def data(self, names: NamesType) -> JSONType: """Return the schema data""" - array_schema = {"type": "array", "items": self.items_schema.data(names=names)} if Option.WRAP_INTO_RECORDS not in self.options: - return array_schema - return self._wrap_as_record(array_schema, names) + return {"type": "array", "items": self.items_schema.data(names=names)} + return self._wrap_as_record( + names, + lambda n: {"type": "array", "items": self.items_schema.data(names=n)}, + ) def make_default(self, py_default: collections.abc.Sequence) -> JSONType: """Return an Avro schema compliant default value for a given Python Sequence @@ -804,6 +862,7 @@ def __init__( py_type: type[collections.abc.MutableSet], namespace: str | None = None, options: Option = Option(0), + processing: set[type] | None = None, ): """ An Avro array schema for a given Python sequence @@ -812,7 +871,7 @@ def __init__( :param namespace: The Avro namespace to add to schemas. :param options: Schema generation options. """ - super().__init__(py_type, namespace=namespace, options=options) # type: ignore + super().__init__(py_type, namespace=namespace, options=options, processing=processing) # type: ignore @register_schema @@ -833,6 +892,7 @@ def __init__( py_type: Type[collections.abc.MutableMapping], namespace: Optional[str] = None, options: Option = Option(0), + processing: set[type] | None = None, ): """ An Avro map schema for a given Python mapping @@ -840,20 +900,23 @@ def __init__( :param py_type: The Python class to generate a schema for. :param namespace: The Avro namespace to add to schemas. :param options: Schema generation options. + :param processing: Internal parameter to track types currently being processed (for circular references). """ - super().__init__(py_type, namespace=namespace, options=options) + super().__init__(py_type, namespace=namespace, options=options, processing=processing) py_type = _type_from_annotated(py_type) args = get_args(py_type) if args[0] != str and not issubclass(args[0], StrEnum): raise TypeError(f"Cannot generate Avro mapping schema for Python dictionary {py_type} with non-string keys") - self.values_schema = _schema_obj(args[1], namespace=namespace, options=options) + self.values_schema = _schema_obj(args[1], namespace=namespace, options=options, processing=self.processing) def data(self, names: NamesType) -> JSONType: """Return the schema data""" - map_schema = {"type": "map", "values": self.values_schema.data(names=names)} if Option.WRAP_INTO_RECORDS not in self.options: - return map_schema - return self._wrap_as_record(map_schema, names) + return {"type": "map", "values": self.values_schema.data(names=names)} + return self._wrap_as_record( + names, + lambda n: {"type": "map", "values": self.values_schema.data(names=n)}, + ) def make_default(self, py_default: Any) -> JSONType: """Return an Avro schema compliant default value for a given Python value""" @@ -879,7 +942,13 @@ def handles_type(cls, py_type: Type) -> bool: return origin == Union or origin == union_type return origin == Union - def __init__(self, py_type: Type[Union[Any]], namespace: Optional[str] = None, options: Option = Option(0)): + def __init__( + self, + py_type: Type[Union[Any]], + namespace: Optional[str] = None, + options: Option = Option(0), + processing: set[type] | None = None, + ): """ An Avro union schema for a given Python union type @@ -887,11 +956,13 @@ def __init__(self, py_type: Type[Union[Any]], namespace: Optional[str] = None, o :param namespace: The Avro namespace to add to schemas. :param options: Schema generation options. """ - super().__init__(py_type, namespace=namespace, options=options) + super().__init__(py_type, namespace=namespace, options=options, processing=processing) py_type = _type_from_annotated(py_type) args = get_args(py_type) self._validate_union(args) - self.item_schemas = [_schema_obj(arg, namespace=namespace, options=options) for arg in args] + self.item_schemas = [ + _schema_obj(arg, namespace=namespace, options=options, processing=self.processing) for arg in args + ] @staticmethod def _validate_union(args: tuple[Any, ...]) -> None: @@ -976,7 +1047,13 @@ def make_default(self, py_default: Any) -> JSONType: class NamedSchema(Schema): """A named Avro schema base class""" - def __init__(self, py_type: Type, namespace: Optional[str] = None, options: Option = Option(0)): + def __init__( + self, + py_type: Type, + namespace: Optional[str] = None, + options: Option = Option(0), + processing: set[type] | None = None, + ): """ A named Avro schema base class @@ -984,7 +1061,7 @@ def __init__(self, py_type: Type, namespace: Optional[str] = None, options: Opti :param namespace: The Avro namespace to add to schemas. :param options: Schema generation options. """ - super().__init__(py_type, namespace=namespace, options=options) + super().__init__(py_type, namespace=namespace, options=options, processing=processing) py_type = _type_from_annotated(py_type) self.name = py_type.__name__ @@ -1032,7 +1109,13 @@ def handles_type(cls, py_type: Type) -> bool: """Whether this schema class can represent a given Python class""" return _is_class(py_type, enum.Enum) - def __init__(self, py_type: Type[enum.Enum], namespace: Optional[str] = None, options: Option = Option(0)): + def __init__( + self, + py_type: Type[enum.Enum], + namespace: Optional[str] = None, + options: Option = Option(0), + **kwargs, + ): """ An Avro enum schema for a Python enum with string values @@ -1098,15 +1181,24 @@ def data_before_deduplication(self, names: NamesType) -> JSONObj: class RecordSchema(NamedSchema): """An Avro record schema base class""" - def __init__(self, py_type: Type, namespace: Optional[str] = None, options: Option = Option(0)): + def __init__( + self, + py_type: Type, + namespace: Optional[str] = None, + options: Option = Option(0), + processing: set[type] | None = None, + ): """ An Avro record schema base class :param py_type: The Python class to generate a schema for. :param namespace: The Avro namespace to add to schemas. :param options: Schema generation options. + :param processing: Internal parameter to track types currently being processed (for circular references). """ - super().__init__(py_type, namespace=namespace, options=options) + super().__init__(py_type, namespace=namespace, options=options, processing=processing) + # Per each record we copy the set, to separete executions between siblings. + self.processing = self.processing | {py_type} self.record_fields: collections.abc.Sequence[RecordField] = [] def data_before_deduplication(self, names: NamesType) -> JSONObj: @@ -1142,6 +1234,7 @@ def __init__( default: Any = dataclasses.MISSING, docs: str = "", options: Option = Option(0), + processing: set[type] | None = None, ): """ An Avro record field @@ -1154,6 +1247,8 @@ def __init__( :param docs: Field documentation or description :param options: Schema generation options """ + if processing is None: + processing = set() if aliases is None: aliases = [] self.py_type = py_type @@ -1163,7 +1258,8 @@ def __init__( self.default = default self.docs = docs self.options = options - self.schema = _schema_obj(self.py_type, namespace=self._namespace, options=options) + + self.schema = _schema_obj(self.py_type, namespace=self._namespace, options=options, processing=processing) if self.default != dataclasses.MISSING: if isinstance(self.schema, UnionSchema): @@ -1214,7 +1310,13 @@ def handles_type(cls, py_type: Type) -> bool: py_type = _type_from_annotated(py_type) return dataclasses.is_dataclass(py_type) - def __init__(self, py_type: Type, namespace: Optional[str] = None, options: Option = Option(0)): + def __init__( + self, + py_type: Type, + namespace: Optional[str] = None, + options: Option = Option(0), + processing: set[type] | None = None, + ): """ An Avro record schema for a given Python dataclass @@ -1222,7 +1324,7 @@ def __init__(self, py_type: Type, namespace: Optional[str] = None, options: Opti :param namespace: The Avro namespace to add to schemas. :param options: Schema generation options. """ - super().__init__(py_type, namespace=namespace, options=options) + super().__init__(py_type, namespace=namespace, options=options, processing=processing) py_type = _type_from_annotated(py_type) self.py_fields = dataclasses.fields(py_type) self.record_fields = [self._record_field(field) for field in self.py_fields] @@ -1240,6 +1342,7 @@ def _record_field(self, py_field: dataclasses.Field) -> RecordField: default=default, aliases=aliases, options=self.options, + processing=self.processing, ) return field_obj @@ -1262,7 +1365,13 @@ def handles_type(cls, py_type: Type) -> bool: py_type = _type_from_annotated(py_type) return hasattr(py_type, "__pydantic_private__") - def __init__(self, py_type: Type[pydantic.BaseModel], namespace: Optional[str] = None, options: Option = Option(0)): + def __init__( + self, + py_type: Type[pydantic.BaseModel], + namespace: Optional[str] = None, + options: Option = Option(0), + processing: set[type] | None = None, + ): """ An Avro record schema for a given Pydantic model class @@ -1270,7 +1379,7 @@ def __init__(self, py_type: Type[pydantic.BaseModel], namespace: Optional[str] = :param namespace: The Avro namespace to add to schemas. :param options: Schema generation options. """ - super().__init__(py_type, namespace=namespace, options=options) + super().__init__(py_type, namespace=namespace, options=options, processing=processing) if Option.USE_CLASS_ALIAS in self.options: self.name = py_type.model_config.get("title") or self.name self.py_fields = py_type.model_fields @@ -1290,6 +1399,7 @@ def _record_field(self, name: str, py_field: pydantic.fields.FieldInfo) -> Recor aliases=aliases, docs=py_field.description or "", options=self.options, + processing=self.processing, ) return field_obj @@ -1336,10 +1446,16 @@ def handles_type(cls, py_type: Type) -> bool: # If we are subclassing a string, used the "named string" approach and (inspect.isclass(py_type) and not issubclass(py_type, str)) # and any other class with typed annotations - and bool(get_type_hints(py_type)) + and has_annotations(py_type) ) - def __init__(self, py_type: Type, namespace: Optional[str] = None, options: Option = Option(0)): + def __init__( + self, + py_type: Type, + namespace: Optional[str] = None, + options: Option = Option(0), + processing: set[type] | None = None, + ): """ An Avro record schema for a plain Python class with type hints @@ -1347,7 +1463,7 @@ def __init__(self, py_type: Type, namespace: Optional[str] = None, options: Opti :param namespace: The Avro namespace to add to schemas. :param options: Schema generation options. """ - super().__init__(py_type, namespace=namespace, options=options) + super().__init__(py_type, namespace=namespace, options=options, processing=processing) py_type = _type_from_annotated(py_type) # Try to get resolved type hints, but fall back to raw annotations if there are unresolved forward refs @@ -1372,6 +1488,7 @@ def _record_field(self, py_field: tuple[str, Type]) -> RecordField: default=default, aliases=aliases, options=self.options, + processing=self.processing, ) return field_obj @@ -1392,7 +1509,13 @@ def handles_type(cls, py_type: Type) -> bool: """Whether this schema can represent a TypedDict""" return is_typeddict(py_type) - def __init__(self, py_type: Type, namespace: str | None = None, options: Option = Option(0)): + def __init__( + self, + py_type: Type, + namespace: str | None = None, + options: Option = Option(0), + processing: set[type] | None = None, + ): """ An Avro record schema for a given Python TypedDict @@ -1400,7 +1523,7 @@ def __init__(self, py_type: Type, namespace: str | None = None, options: Option :param namespace: The Avro namespace to add to schemas. :param options: Schema generation options. """ - super().__init__(py_type, namespace=namespace, options=options) + super().__init__(py_type, namespace=namespace, options=options, processing=processing) py_type = _type_from_annotated(py_type) self.is_total = py_type.__dict__.get("__total__", True) self.py_fields: dict[str, Type] = get_type_hints(py_type, include_extras=True) @@ -1437,6 +1560,7 @@ def _record_field(self, py_field: tuple[str, Type]) -> RecordField: aliases=aliases, default=default, options=self.options, + processing=self.processing, ) return field_obj @@ -1505,7 +1629,11 @@ def is_logically_json(py_type: Type) -> bool: return _is_list_any(py_type) or _is_list_dict_str_any(py_type) or _is_dict_str_any(py_type) -def _is_class(py_type: Any, of_types: Union[Type, Tuple[Type, ...]], include_subclasses: bool = True) -> bool: +def _is_class( + py_type: Any, + of_types: Union[Type, Tuple[Type, ...]], + include_subclasses: bool = True, +) -> bool: """Return whether the given type is a (sub) class of a type or types""" py_type = _type_from_annotated(py_type) if include_subclasses: @@ -1530,6 +1658,16 @@ def _type_from_annotated(py_type: Type) -> Type: return py_type +def has_annotations(py_type: Type) -> bool: + """Checks if a type has annotations""" + py_type = _type_from_annotated(py_type) + try: + return bool(get_type_hints(py_type)) + except Exception: + pass + return hasattr(py_type, "__annotations__") + + def _avro_name_for_type(py_type: Type) -> str: """ Generate an Avro-compatible name for a given Python type. It is used when wrapping container types (mostly lists diff --git a/tests/test_plain_class.py b/tests/test_plain_class.py index b814547..20ca8e8 100644 --- a/tests/test_plain_class.py +++ b/tests/test_plain_class.py @@ -10,7 +10,7 @@ # specific language governing permissions and limitations under the License. import re -from typing import Annotated, Final +from typing import Annotated, Final, ForwardRef import pytest @@ -201,3 +201,30 @@ class PyType: ], } assert_schema(PyType, expected, options=Option.ADD_REFERENCE_ID) + + +class PyType: + backend: ForwardRef("Backend") + value: str + + +class Backend: + py_type: PyType + + +def test_circular_references(): + expected = { + "fields": [ + { + "name": "py_type", + "type": { + "fields": [{"name": "backend", "type": "Backend"}, {"name": "value", "type": "string"}], + "name": "PyType", + "type": "record", + }, + } + ], + "name": "Backend", + "type": "record", + } + assert_schema(Backend, expected) diff --git a/tests/test_typed_dict.py b/tests/test_typed_dict.py index 3ce9b73..8c78493 100644 --- a/tests/test_typed_dict.py +++ b/tests/test_typed_dict.py @@ -211,3 +211,204 @@ class PyType2(TypedDict): py_type = Union[PyType, PyType2] with pytest.raises(TypeError): py_avro_schema._schemas.schema(py_type) + + +class SiblingInner(TypedDict): + x: str + + +class SiblingOuter(TypedDict): + a: SiblingInner + b: list[SiblingInner] + + +def test_sibling_fields_references(): + """Check sibling attributes in a record won't all get a bare reference.""" + expected = { + "type": "record", + "name": "SiblingOuter", + "namespace": "test_typed_dict", + "fields": [ + { + "name": "a", + "type": { + "type": "record", + "name": "SiblingInner", + "namespace": "test_typed_dict", + "fields": [{"name": "x", "type": "string"}], + }, + }, + { + "name": "b", + "type": { + "type": "record", + "name": "TestTypedDictSiblingInnerList", + "namespace": "builtins", + "fields": [ + {"name": "__id", "type": ["null", "long"], "default": None}, + { + "name": "__data", + "type": { + "type": "array", + "items": "test_typed_dict.SiblingInner", + }, + }, + ], + }, + }, + ], + } + assert_schema( + SiblingOuter, + expected, + options=pas.Option.WRAP_INTO_RECORDS, + do_auto_namespace=True, + ) + + +ConfigurationList = list["Configuration"] + + +class Configuration(TypedDict): + Configurations: ConfigurationList | None + + +def test_recursive_reference(): + """Test simple recursive reference with no wrapped records.""" + + class PyType(TypedDict): + Configurations: ConfigurationList | None + + expected = { + "type": "record", + "name": "PyType", + "fields": [ + { + "name": "Configurations", + "type": [ + { + "type": "array", + "items": { + "type": "record", + "name": "Configuration", + "fields": [ + { + "name": "Configurations", + "type": [ + {"type": "array", "items": "Configuration"}, + "null", + ], + }, + ], + }, + }, + "null", + ], + }, + ], + } + assert_schema(PyType, expected) + + +def test_recursive_reference_with_wrap_into_records(): + """Checks that a self-referential record combined with ``WRAP_INTO_RECORDS`` must define the wrapper once.""" + + class PyType(TypedDict): + Configurations: ConfigurationList | None + + expected = { + "type": "record", + "name": "PyType", + "fields": [ + { + "name": "Configurations", + "type": [ + { + "type": "record", + "name": "TestTypedDictConfigurationList", + "fields": [ + {"name": "__id", "type": ["null", "long"], "default": None}, + { + "name": "__data", + "type": { + "type": "array", + "items": { + "type": "record", + "name": "Configuration", + "fields": [ + { + "name": "Configurations", + "type": [ + "TestTypedDictConfigurationList", + "null", + ], + }, + ], + }, + }, + }, + ], + }, + "null", + ], + }, + ], + } + assert_schema(PyType, expected, options=pas.Option.WRAP_INTO_RECORDS) + + +RecExpressions = list["RecExpression"] + + +class RecExpression(TypedDict, total=False): + Or: RecExpressions | None + And: RecExpressions | None + Not: "RecExpression | None" + + +def test_recursive_reference_with_wrap_into_records_and_namespaces(): + """Checks that with WRAP_INTO_RECORDS and AUTO_NAMESPACE_MODULE a self-recursive record is referenced by + its fully-qualified name from inside the list wrapper. + """ + expected = { + "type": "record", + "name": "RecExpression", + "namespace": "test_typed_dict", + "fields": [ + { + "name": "Or", + "type": [ + { + "type": "record", + "name": "TestTypedDictRecExpressionList", + "namespace": "builtins", + "fields": [ + {"name": "__id", "type": ["null", "long"], "default": None}, + { + "name": "__data", + "type": { + "type": "array", + "items": "test_typed_dict.RecExpression", + }, + }, + ], + }, + "null", + ], + }, + { + "name": "And", + "type": ["builtins.TestTypedDictRecExpressionList", "null"], + }, + { + "name": "Not", + "type": ["test_typed_dict.RecExpression", "null"], + }, + ], + } + assert_schema( + RecExpression, + expected, + options=pas.Option.WRAP_INTO_RECORDS, + do_auto_namespace=True, + )