From 5ca1d260d3871399d9b88134162b71c0a234ba11 Mon Sep 17 00:00:00 2001 From: Simon Edwards Date: Tue, 26 Nov 2019 15:33:54 +0000 Subject: [PATCH 01/12] Make fields type hint compatible --- jsonmodels/errors.py | 8 +-- jsonmodels/fields.py | 115 ++++++++++++++++++++++++---------------- jsonmodels/utilities.py | 4 +- 3 files changed, 75 insertions(+), 52 deletions(-) diff --git a/jsonmodels/errors.py b/jsonmodels/errors.py index 200584b..bcd8d3e 100644 --- a/jsonmodels/errors.py +++ b/jsonmodels/errors.py @@ -1,4 +1,4 @@ -from typing import List, Tuple, Type +from typing import Any, List, Tuple, Type class ValidationError(RuntimeError): @@ -38,7 +38,7 @@ class FieldValidationError(ValidationError): Enriches a validator error with the name of the field that caused it. """ def __init__(self, model_name: str, field_name: str, - given_value: any, error: ValidatorError): + given_value: Any, error: ValidatorError): """ :param model_name: The name of the model. :param field_name: The name of the field. @@ -78,7 +78,7 @@ class BadTypeError(ValidatorError): expected one """ - def __init__(self, value: any, types: Tuple, is_list: bool): + def __init__(self, value: Any, types: Tuple, is_list: bool): """ :param value: The given value. :param types: The accepted types. @@ -186,7 +186,7 @@ def __init__(self, value, maximum_value, exclusive: bool): class EnumError(ValidatorError): """ Error raised by the Enum validator """ - def __init__(self, value: any, choices: List[any]): + def __init__(self, value: Any, choices: List[Any]): """ :param value: The given value. :param choices: The allowed choices. diff --git a/jsonmodels/fields.py b/jsonmodels/fields.py index 169830f..31f327f 100755 --- a/jsonmodels/fields.py +++ b/jsonmodels/fields.py @@ -5,11 +5,15 @@ import re import six from dateutil.parser import parse -from typing import List, Optional, Dict, Set, Union, Pattern +from typing import Any, List, Generic, Optional, Dict, Sequence, Set, Tuple, TypeVar, Union, Pattern from .collections import ModelCollection from .errors import RequiredFieldError, BadTypeError, AmbiguousTypeError +MYPY = False +if MYPY: + from .models import Base + # unique marker for "no default value specified". None is not good enough since # it is a completely valid default value. NotSet = object() @@ -21,18 +25,24 @@ ] -class BaseField(object): +T = TypeVar("T") + + +class BaseField(Generic[T]): """Base class for all fields.""" - types = None + types: Tuple[Any, ...] = () + + validators: List[Any] = [] + memory: WeakKeyDictionary def __init__( self, required=False, nullable=False, help_text=None, - validators=None, + validators: Optional[List[Any]]=None, default=NotSet, name=None): self.memory = WeakKeyDictionary() @@ -47,31 +57,31 @@ def __init__( self._default = default @property - def has_default(self): + def has_default(self) -> bool: return self._default is not NotSet - def _assign_validators(self, validators): + def _assign_validators(self, validators) -> None: if validators and not isinstance(validators, list): validators = [validators] self.validators = validators or [] - def __set__(self, instance, value): + def __set__(self, instance: "Base", value: Optional[T]) -> None: self._finish_initialization(type(instance)) value = self.parse_value(value) self.validate(value) self.memory[instance._cache_key] = value - def __get__(self, instance, owner=None): + def __get__(self, instance: "Base", owner=None) -> T: if instance is None: self._finish_initialization(owner) - return self + return self # type: ignore self._finish_initialization(type(instance)) self._check_value(instance) return self.memory[instance._cache_key] - def _finish_initialization(self, owner): + def _finish_initialization(self, owner) -> None: pass def _check_value(self, obj): @@ -82,21 +92,21 @@ def validate_for_object(self, obj): value = self.__get__(obj) self.validate(value) - def validate(self, value): + def validate(self, value: Optional[T]) -> None: self._check_types() self._validate_against_types(value) self._check_against_required(value) self._validate_with_custom_validators(value) - def _check_against_required(self, value): + def _check_against_required(self, value) -> None: if value is None and self.required: raise RequiredFieldError() - def _validate_against_types(self, value): + def _validate_against_types(self, value) -> None: if value is not None and not isinstance(value, self.types): raise BadTypeError(value, self.types, is_list=False) - def _check_types(self): + def _check_types(self) -> None: if self.types is None: tpl = 'Field "{type}" is not usable, try different field type.' raise ValueError(tpl.format(type=type(self).__name__)) @@ -131,7 +141,7 @@ def _get_embed_type(value, models): return matching_models[0] return models[0] - def toBsonEncodable(self, value: types) -> BsonEncodable: + def toBsonEncodable(self, value) -> BsonEncodable: """Optionally return a bson encodable python object. Returned object should be BSON compatible. By default uses the @@ -153,7 +163,7 @@ def to_struct(self, value): """Cast value to Python dict.""" return value - def parse_value(self, value): + def parse_value(self, value: Optional[Any]) -> Optional[T]: """Parse value from primitive to desired format. Each field can parse value to form it wants it to be (like string or @@ -195,18 +205,18 @@ def structue_name(self, default): return self.structure_name(default) -class StringField(BaseField): +class StringField(BaseField[str]): """String field.""" - types = six.string_types + types: Tuple[Any, ...] = six.string_types -class IntField(BaseField): +class IntField(BaseField[int]): """Integer field.""" - types = (int,) + types: Tuple[Any, ...] = (int,) def parse_value(self, value): """Cast value to `int`, e.g. from string or long""" @@ -219,18 +229,18 @@ def parse_value(self, value): raise BadTypeError(value, types=(int,), is_list=False) -class FloatField(BaseField): +class FloatField(BaseField[float]): """Float field.""" - types = (float, int) + types: Tuple[Any, ...] = (float, int) -class BoolField(BaseField): +class BoolField(BaseField[bool]): """Bool field.""" - types = (bool,) + types: Tuple[Any, ...] = (bool,) def parse_value(self, value): """Cast value to `bool`.""" @@ -238,13 +248,18 @@ def parse_value(self, value): return bool(parsed) if parsed is not None else None -class ListField(BaseField): +I = TypeVar("I") + + +class ListField(BaseField[List[I]]): """List field.""" - types = (list, tuple) + types: Tuple[Any, ...] = (list, tuple) + items_types: Tuple[Any, ...] + item_validators: List[Any] - def __init__(self, items_types=None, item_validators=(), omit_empty=False, + def __init__(self, items_types: Optional[List[Any]]=None, item_validators: Union[Any, List[Any]]=[], omit_empty=False, *args, **kwargs): """Init. @@ -359,13 +374,15 @@ def __init__(self, field: BaseField, *args, **kwargs): :param validators: The validators for the list field. """ self._field = field + + fixed_kwargs = kwargs.copy() + fixed_kwargs["items_types"] = field.types + fixed_kwargs["item_validators"] = field.validators super(DerivedListField, self).__init__( - items_types=field.types, - item_validators=field.validators, - *args, **kwargs, + *args, **fixed_kwargs, ) - def to_struct(self, values: List[any]) -> List[any]: + def to_struct(self, values: List[Any]) -> Optional[List[Any]]: """ Converts the list to its output format. :param values: The values in the list. @@ -374,18 +391,22 @@ def to_struct(self, values: List[any]) -> List[any]: return [self._field.to_struct(value) for value in values] \ if values or not self._omit_empty else None - def parse_value(self, values: List[any]) -> List[any]: + def parse_value(self, values: Optional[Any]) -> Optional[List[Any]]: """ Converts the list to its internal format. :param values: The values in the list. :return: The converted values. """ + if values is None: + return None + try: return [self._field.parse_value(value) for value in values] except TypeError: raise BadTypeError(values, self._field.types, is_list=True) + return None - def validate_single_value(self, value: any) -> None: + def validate_single_value(self, value: Any) -> None: """ Validates a single value in the list. :param value: One of the values in the list. @@ -443,7 +464,7 @@ def to_struct(self, value): return value.to_struct() -class MapField(BaseField): +class MapField(BaseField[Dict[Any, Any]]): """ Model field that keeps a mapping between two other fields. It is basically a dictionary with key and values being separate fields. @@ -453,7 +474,7 @@ class MapField(BaseField): included in the to_struct method. """ - types = (dict,) + types: Tuple[Any, ...] = (dict,) def __init__(self, key_field: BaseField, value_field: BaseField, **kwargs): @@ -476,18 +497,18 @@ def _finish_initialization(self, owner): self._key_field._finish_initialization(owner) self._value_field._finish_initialization(owner) - def get_default_value(self) -> any: + def get_default_value(self) -> Any: """ Gets the default value for this field """ default = super(MapField, self).get_default_value() if default is None and self.required: return dict() return default - def parse_value(self, values: Optional[dict]) -> Optional[dict]: + def parse_value(self, values: Optional[Any]) -> Optional[Dict[Any, Any]]: """ Parses the given values into a new dict. """ values = super().parse_value(values) if values is None: - return + return None items = [ (self._key_field.parse_value(key), self._value_field.parse_value(value)) @@ -495,7 +516,7 @@ def parse_value(self, values: Optional[dict]) -> Optional[dict]: ] return type(values)(items) # Preserves OrderedDict - def to_struct(self, values: Optional[dict]) -> Optional[dict]: + def to_struct(self, values: Dict[Any, Any]) -> Dict[Any, Any]: """ Casts the field values into a dict. """ items = [ (self._key_field.to_struct(key), @@ -504,7 +525,7 @@ def to_struct(self, values: Optional[dict]) -> Optional[dict]: ] return type(values)(items) # Preserves OrderedDict - def validate(self, values: Optional[dict]) -> Optional[dict]: + def validate(self, values: Optional[Dict[Any, Any]]) -> None: """ Validates all keys and values in the map field. :param values: The values in the mapping. @@ -567,7 +588,7 @@ class TimeField(StringField): """Time field.""" - types = (datetime.time,) + types: Tuple[Any, ...] = (datetime.time,) def __init__(self, str_format=None, *args, **kwargs): """Init. @@ -598,7 +619,7 @@ class DateField(StringField): """Date field.""" - types = (datetime.date,) + types: Tuple[Any, ...] = (datetime.date,) default_format = '%Y-%m-%d' def __init__(self, str_format=None, *args, **kwargs): @@ -630,7 +651,7 @@ class DateTimeField(StringField): """Datetime field.""" - types = (datetime.datetime,) + types: Tuple[Any, ...] = (datetime.datetime,) def __init__(self, str_format=None, *args, **kwargs): """Init. @@ -648,7 +669,7 @@ def to_struct(self, value): return value.strftime(self.str_format) return value.isoformat() - def toBsonEncodable(self, value: datetime) -> datetime: + def toBsonEncodable(self, value: datetime.datetime) -> datetime.datetime: """ Keep datetime object a datetime object, since pymongo supports that. """ @@ -666,17 +687,17 @@ def parse_value(self, value): return None -class GenericField(BaseField): +class GenericField(BaseField[Any]): """ Field that supports any kind of value, converting models to their correct struct, keeping ordered dictionaries in their original order. """ - types = (any,) + types: Tuple[Any, ...] = (any,) def _validate_against_types(self, value) -> None: pass - def to_struct(self, values: any) -> any: + def to_struct(self, values: Any) -> Any: """ Casts value to Python structure. """ from .models import Base if isinstance(values, Base): diff --git a/jsonmodels/utilities.py b/jsonmodels/utilities.py index 5e864ed..6c9102b 100644 --- a/jsonmodels/utilities.py +++ b/jsonmodels/utilities.py @@ -5,8 +5,10 @@ import six import re from collections import namedtuple +from typing import cast, Any, List, Tuple -SCALAR_TYPES = tuple(list(six.string_types) + [int, float, bool]) +six_string_types: List[Any] = list(six.string_types) +SCALAR_TYPES = cast(Tuple[Any], tuple(six_string_types + [int, float, bool])) ECMA_TO_PYTHON_FLAGS = { 'i': re.I, From 65b4ae32939fa10583140e56647f800946dc8f08 Mon Sep 17 00:00:00 2001 From: Simon Edwards Date: Thu, 16 Apr 2020 08:56:13 +0000 Subject: [PATCH 02/12] Add script to run Mypy over the code and some tests --- requirements.txt | 1 + run_mypy.sh | 4 ++++ tests/__init__.py | 4 +++- tests/test_fields.py | 27 ++++++++++++++------------- 4 files changed, 22 insertions(+), 14 deletions(-) create mode 100755 run_mypy.sh diff --git a/requirements.txt b/requirements.txt index 9925baf..d455999 100644 --- a/requirements.txt +++ b/requirements.txt @@ -15,5 +15,6 @@ pytest pytest-cov sphinxcontrib-spelling tox +typing virtualenv wheel diff --git a/run_mypy.sh b/run_mypy.sh new file mode 100755 index 0000000..da90413 --- /dev/null +++ b/run_mypy.sh @@ -0,0 +1,4 @@ +#!/bin/bash + +mypy -p jsonmodels +mypy tests/test_fields.py diff --git a/tests/__init__.py b/tests/__init__.py index 14df325..398ac48 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1,6 +1,8 @@ def _have_flake8(): try: - import flake8 # noqa: F401 + MYPY = False + if not MYPY: + import flake8 # noqa: F401 return True except ImportError: return False diff --git a/tests/test_fields.py b/tests/test_fields.py index dc740d0..c5bdc02 100755 --- a/tests/test_fields.py +++ b/tests/test_fields.py @@ -1,8 +1,9 @@ from collections import OrderedDict import datetime from datetime import timezone +from typing import Dict, Union -import pytest +import pytest # type: ignore from jsonmodels import models, fields, validators, errors @@ -14,7 +15,7 @@ def test_deprecated_structue_name(): assert field.structue_name('default') == 'default' -def test_bool_field(): +def test_bool_field() -> None: field = fields.BoolField() @@ -43,7 +44,7 @@ class Person(models.Base): assert field.parse_value([]) is False -def test_datetime_field(): +def test_datetime_field() -> None: field = fields.DateTimeField() class Event(models.Base): @@ -59,7 +60,7 @@ class Event(models.Base): datetime.datetime(2019, 10, 30, 1, 2, 3, tzinfo=timezone.utc) -def test_custom_field(): +def test_custom_field() -> None: class NameField(fields.StringField): def __init__(self): super(NameField, self).__init__(required=True) @@ -75,7 +76,7 @@ class Person(models.Base): assert person.to_struct() == expected -def test_custom_field_validation(): +def test_custom_field_validation() -> None: class NameField(fields.StringField): def __init__(self): super(NameField, self).__init__( @@ -102,7 +103,7 @@ class Person(models.Base): person.validate() -def test_map_field(): +def test_map_field() -> None: class Model(models.Base): str_to_int = fields.MapField(fields.StringField(), fields.IntField()) int_to_str = fields.MapField(fields.IntField(), fields.StringField()) @@ -131,13 +132,13 @@ class CircularMapModel(models.Base): ) -def test_map_field_circular(): +def test_map_field_circular() -> None: model = CircularMapModel(mapping={1: {}, 2: CircularMapModel()}) - expected = {'mapping': {1: {}, 2: {}}} + expected: Dict[str, Dict[int, Dict]] = {'mapping': {1: {}, 2: {}}} assert expected == model.to_struct() -def test_map_field_validation(): +def test_map_field_validation() -> None: class Model(models.Base): str_to_int = fields.MapField(fields.StringField(), fields.IntField()) int_to_str = fields.MapField(fields.IntField(), fields.StringField(), @@ -162,7 +163,7 @@ class Model(models.Base): model.validate() -def test_generic_field(): +def test_generic_field() -> None: class Model(models.Base): field = fields.GenericField() @@ -178,7 +179,7 @@ class Model(models.Base): assert expected == model_ordered.to_struct() -def test_derived_list_omit_empty(): +def test_derived_list_omit_empty() -> None: class Car(models.Base): wheels = fields.DerivedListField(fields.StringField(), @@ -190,7 +191,7 @@ class Car(models.Base): assert viper.to_struct() == {"doors": []} -def test_automatic_model_detection(): +def test_automatic_model_detection() -> None: class FullName(models.Base): first_name = fields.StringField() @@ -204,7 +205,7 @@ class Car(models.Base): class Person(models.Base): - names = fields.ListField( + names = fields.ListField[Union[str, int, float, bool, FullName, Car]]( [str, int, float, bool, FullName, Car], help_text='A list of names.', ) From 4d109d6b328c7be59164df386af25d9b123d2cfb Mon Sep 17 00:00:00 2001 From: Simon Edwards Date: Tue, 26 May 2020 14:44:52 +0000 Subject: [PATCH 03/12] `FieldValidationError` is fine here in this test --- tests/test_jsonmodels.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_jsonmodels.py b/tests/test_jsonmodels.py index 3e0e162..db07251 100644 --- a/tests/test_jsonmodels.py +++ b/tests/test_jsonmodels.py @@ -54,10 +54,10 @@ class Person(models.Base): alan = Person() - with pytest.raises(ValueError): + with pytest.raises(errors.FieldValidationError): alan.name = 'some name' - with pytest.raises(ValueError): + with pytest.raises(errors.FieldValidationError): alan.name = 2345 From 41603e2881a5baaf50771dd52f54aef05a75f1e0 Mon Sep 17 00:00:00 2001 From: Simon Edwards Date: Tue, 26 May 2020 14:45:11 +0000 Subject: [PATCH 04/12] Use Python 3.8; `typing` is included in 3.8 --- Dockerfile | 2 +- requirements.txt | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/Dockerfile b/Dockerfile index 81e7019..0bddeff 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,5 +1,5 @@ # Install pytest python library as well as add all files in current directory -FROM python:3.7 AS base +FROM python:3.8 AS base WORKDIR /usr/src/app RUN apt-get update \ && apt-get install -y enchant \ diff --git a/requirements.txt b/requirements.txt index d455999..9925baf 100644 --- a/requirements.txt +++ b/requirements.txt @@ -15,6 +15,5 @@ pytest pytest-cov sphinxcontrib-spelling tox -typing virtualenv wheel From 2c216024222a48593ebdeb8f25ca64438ed38bce Mon Sep 17 00:00:00 2001 From: Simon Edwards Date: Fri, 29 May 2020 10:18:34 +0000 Subject: [PATCH 05/12] Add missing `python-dateutil` dependency --- requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements.txt b/requirements.txt index 9925baf..c7b171b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,6 +13,7 @@ py pyflakes pytest pytest-cov +python-dateutil sphinxcontrib-spelling tox virtualenv From 26a516e8cae3fba66b09afe066d663c7d7af408b Mon Sep 17 00:00:00 2001 From: Simon Edwards Date: Fri, 26 Apr 2024 17:40:58 +0000 Subject: [PATCH 06/12] Update the Python version in Docker --- Dockerfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Dockerfile b/Dockerfile index 0bddeff..4b81054 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,5 +1,5 @@ # Install pytest python library as well as add all files in current directory -FROM python:3.8 AS base +FROM python:3.11 AS base WORKDIR /usr/src/app RUN apt-get update \ && apt-get install -y enchant \ From d03e793612533c5221b10738c907f5d7abfe1b86 Mon Sep 17 00:00:00 2001 From: Simon Edwards Date: Fri, 26 Apr 2024 17:41:57 +0000 Subject: [PATCH 07/12] Typing for the core JSON models code --- jsonmodels/builders.py | 154 ++++++++++++---------- jsonmodels/collections.py | 14 +- jsonmodels/errors.py | 22 ++-- jsonmodels/fields.py | 265 +++++++++++++++++++------------------- jsonmodels/models.py | 50 +++---- jsonmodels/parsers.py | 46 +++---- jsonmodels/utilities.py | 26 ++-- jsonmodels/validators.py | 47 ++++--- run_mypy.sh | 1 - 9 files changed, 324 insertions(+), 301 deletions(-) diff --git a/jsonmodels/builders.py b/jsonmodels/builders.py index a19efdc..4437d62 100644 --- a/jsonmodels/builders.py +++ b/jsonmodels/builders.py @@ -1,79 +1,95 @@ """Builders to generate in memory representation of model and fields tree.""" -from __future__ import absolute_import from collections import defaultdict - +from typing import Any, Dict, List, Optional, Set import six from . import errors -from .fields import NotSet - +from .fields import NotSet, Value +from .types import Builder, Field, JSONSchemaProperty, JSONSchemaTypeName, Model -class Builder(object): - def __init__(self, parent=None, nullable=False, default=NotSet): +class BaseBuilder: + def __init__( + self, + parent: Optional[Builder] = None, + nullable: bool = False, + default: Any = NotSet, + ) -> None: self.parent = parent - self.types_builders = {} - self.types_count = defaultdict(int) - self.definitions = set() + self.types_builders: Dict[type[Model], Builder] = {} + self.types_count: Dict[type[Model], int] = defaultdict(int) + self.definitions: Set[Builder] = set() self.nullable = nullable self.default = default @property - def has_default(self): + def has_default(self) -> bool: return self.default is not NotSet - def register_type(self, type, builder): + def register_type(self, model_type: type[Model], builder: Builder) -> None: if self.parent: - return self.parent.register_type(type, builder) + self.parent.register_type(model_type, builder) + return - self.types_count[type] += 1 - if type not in self.types_builders: - self.types_builders[type] = builder + self.types_count[model_type] += 1 + if model_type not in self.types_builders: + self.types_builders[model_type] = builder - def get_builder(self, type): + def get_builder(self, model_type: type[Model]) -> Builder: if self.parent: - return self.parent.get_builder(type) + return self.parent.get_builder(model_type) - return self.types_builders[type] + return self.types_builders[model_type] - def count_type(self, type): + def count_type(self, model_type: type[Model]) -> int: if self.parent: - return self.parent.count_type(type) + return self.parent.count_type(model_type) - return self.types_count[type] + return self.types_count[model_type] @staticmethod - def maybe_build(value): + def maybe_build(value: Value) -> JSONSchemaProperty | Value: return value.build() if isinstance(value, Builder) else value - def add_definition(self, builder): + def add_definition(self, builder: Builder) -> None: if self.parent: return self.parent.add_definition(builder) self.definitions.add(builder) + def build_definition(self, add_definitions: bool = True) -> JSONSchemaProperty: + raise NotImplementedError() + + @property + def is_definition(self) -> bool: + raise NotImplementedError() + + @property + def type_name(self) -> str: + raise NotImplementedError() -class ObjectBuilder(Builder): - def __init__(self, model_type, *args, **kwargs): - super(ObjectBuilder, self).__init__(*args, **kwargs) - self.properties = {} - self.required = [] +class ObjectBuilder(BaseBuilder): + def __init__(self, model_type: type[Model], *args: Any, **kwargs: Any): + super().__init__(*args, **kwargs) + self.properties: Dict[str, str | JSONSchemaProperty] = {} + self.required: List[str] = [] self.type = model_type self.register_type(self.type, self) - def add_field(self, name, field, schema): - _apply_validators_modifications(schema, field) + def add_field(self, name: str, field: Field, schema: str | JSONSchemaProperty) -> None: + if not isinstance(schema, str): + _apply_validators_modifications(schema, field) if isinstance(schema, dict) and field.help_text: schema["description"] = field.help_text self.properties[name] = schema if field.required: self.required.append(name) - def build(self): + def build(self) -> str | JSONSchemaProperty: builder = self.get_builder(self.type) if self.is_definition and not self.is_root: self.add_definition(builder) @@ -83,27 +99,27 @@ def build(self): return builder.build_definition() @property - def type_name(self): + def type_name(self) -> str: module_name = '{module}.{name}'.format( module=self.type.__module__, name=self.type.__name__, ) return module_name.replace('.', '_').lower() - def build_definition(self, add_definitions=True): - properties = dict( + def build_definition(self, add_definitions: bool = True) -> JSONSchemaProperty: + properties: Dict[str, str | JSONSchemaProperty] = dict( (name, self.maybe_build(value)) for name, value in self.properties.items() ) - schema = { + schema: JSONSchemaProperty = { 'type': 'object', 'additionalProperties': False, 'properties': properties, } if self.required: - schema['required'] = self.required + schema['required'] = list(self.required) if self.definitions and add_definitions: schema['definitions'] = dict( @@ -114,7 +130,7 @@ def build_definition(self, add_definitions=True): return schema @property - def is_definition(self): + def is_definition(self) -> bool: if self.count_type(self.type) > 1: return True elif self.parent: @@ -123,35 +139,30 @@ def is_definition(self): return False @property - def is_root(self): + def is_root(self) -> bool: return not bool(self.parent) -def _apply_validators_modifications(field_schema, field): +def _apply_validators_modifications(field_schema: JSONSchemaProperty, field: Field) -> None: for validator in field.validators: - try: + if hasattr(validator, "modify_schema"): validator.modify_schema(field_schema) - except AttributeError: - pass # arrays may have separate validators for each item. # we should also add those validators to the schema. if "items" in field_schema: for validator in field.item_validators: - try: + if hasattr(validator, "modify_schema"): validator.modify_schema(field_schema["items"]) - except AttributeError: - pass - - -class PrimitiveBuilder(Builder): - def __init__(self, type, *args, **kwargs): - super(PrimitiveBuilder, self).__init__(*args, **kwargs) - self.type = type +class PrimitiveBuilder(BaseBuilder): + def __init__(self, value_type: type, *args: Any, **kwargs: Any): + super().__init__(*args, **kwargs) + self.type = value_type - def build(self): - schema = {} + def build(self) -> JSONSchemaProperty: + obj_type: JSONSchemaTypeName + schema: JSONSchemaProperty = {} if issubclass(self.type, six.string_types): obj_type = 'string' elif issubclass(self.type, bool): @@ -174,19 +185,21 @@ def build(self): return schema -class ListBuilder(Builder): +class ListBuilder(BaseBuilder): - def __init__(self, *args, **kwargs): - super(ListBuilder, self).__init__(*args, **kwargs) - self.schemas = [] + parent: Builder - def add_type_schema(self, schema): + def __init__(self, *args: Any, **kwargs: Any): + super().__init__(*args, **kwargs) + self.schemas: list[Builder | JSONSchemaProperty] = [] + + def add_type_schema(self, schema: Builder | JSONSchemaProperty) -> None: self.schemas.append(schema) - def build(self): - schema = {'type': 'array'} + def build(self) -> str | JSONSchemaProperty: + schema: JSONSchemaProperty = {'type': 'array'} if self.nullable: - self.add_type_schema({'type': 'null'}) + self.add_type_schema({'type': 'null'}) # <- probably a bug if self.has_default: schema["default"] = [self.to_struct(i) for i in self.default] @@ -201,27 +214,28 @@ def build(self): return schema @property - def is_definition(self): + def is_definition(self) -> bool: return self.parent.is_definition @staticmethod - def to_struct(item): + def to_struct(item: Value) -> Value: from .models import Base if isinstance(item, Base): return item.to_struct() return item -class EmbeddedBuilder(Builder): +class EmbeddedBuilder(BaseBuilder): + parent: Builder - def __init__(self, *args, **kwargs): - super(EmbeddedBuilder, self).__init__(*args, **kwargs) - self.schemas = [] + def __init__(self, *args: Any, **kwargs: Any): + super().__init__(*args, **kwargs) + self.schemas: list[Builder | JSONSchemaProperty] = [] - def add_type_schema(self, schema): + def add_type_schema(self, schema: Builder | JSONSchemaProperty) -> None: self.schemas.append(schema) - def build(self): + def build(self) -> JSONSchemaProperty: if self.nullable: self.add_type_schema({'type': 'null'}) @@ -239,5 +253,5 @@ def build(self): return schema @property - def is_definition(self): + def is_definition(self) -> bool: return self.parent.is_definition diff --git a/jsonmodels/collections.py b/jsonmodels/collections.py index 6c756eb..4dc34ee 100644 --- a/jsonmodels/collections.py +++ b/jsonmodels/collections.py @@ -1,4 +1,6 @@ - +from typing import Any, Iterable +from .types import CollectionField +from typing_extensions import override class ModelCollection(list): @@ -9,14 +11,16 @@ class ModelCollection(list): """ - def __init__(self, field): + def __init__(self, field: CollectionField) -> None: super(ModelCollection, self).__init__() self.field = field - def append(self, value): + @override + def append(self, value: Any) -> None: self.field.validate_single_value(value) super(ModelCollection, self).append(value) - def __setitem__(self, key, value): + @override + def __setitem__(self, index: Any, value: Any, /) -> None: self.field.validate_single_value(value) - super(ModelCollection, self).__setitem__(key, value) + super(ModelCollection, self).__setitem__(index, value) diff --git a/jsonmodels/errors.py b/jsonmodels/errors.py index bcd8d3e..2e64413 100644 --- a/jsonmodels/errors.py +++ b/jsonmodels/errors.py @@ -1,4 +1,6 @@ -from typing import Any, List, Tuple, Type +from typing import Any, List, Sized, Tuple, Type + +from .types import EmbedType, Model class ValidationError(RuntimeError): @@ -56,14 +58,14 @@ def __init__(self, model_name: str, field_name: str, class RequiredFieldError(ValidatorError): """ Error raised when a required field has no value """ - def __init__(self): + def __init__(self) -> None: super(RequiredFieldError, self).__init__('Field is required!') class RegexError(ValidatorError): """ Error raised by the Regex validator """ - def __init__(self, value: str, pattern: str): + def __init__(self, value: str, pattern: str) -> None: tpl = 'Value "{value}" did not match pattern "{pattern}".' super(RegexError, self).__init__(tpl.format( value=value, pattern=pattern @@ -78,7 +80,7 @@ class BadTypeError(ValidatorError): expected one """ - def __init__(self, value: Any, types: Tuple, is_list: bool): + def __init__(self, value: Any, types: Tuple, is_list: bool) -> None: """ :param value: The given value. :param types: The accepted types. @@ -104,7 +106,7 @@ class AmbiguousTypeError(ValidatorError): that supports multiple types """ - def __init__(self, types: Tuple): + def __init__(self, types: tuple[EmbedType, ...]) -> None: """ The types that are allowed """ tpl = 'Cannot decide which type to choose from "{types}".' super(AmbiguousTypeError, self).__init__(tpl.format( @@ -116,7 +118,7 @@ def __init__(self, types: Tuple): class MinLengthError(ValidatorError): """ Error raised by the Length validator when too few items are present """ - def __init__(self, value: list, minimum_length: int): + def __init__(self, value: Sized, minimum_length: int) -> None: """ :param value: The given value. :param minimum_length: The minimum length expected. @@ -132,7 +134,7 @@ def __init__(self, value: list, minimum_length: int): class MaxLengthError(ValidatorError): """ Error raised by the Length validator when receiving too many items """ - def __init__(self, value: list, maximum_length: int): + def __init__(self, value: Sized, maximum_length: int) -> None: """ :param value: The given value. :param maximum_length: The maximum length expected. @@ -148,7 +150,7 @@ def __init__(self, value: list, maximum_length: int): class MinValidationError(ValidatorError): """ Error raised by the Min validator """ - def __init__(self, value, minimum_value, exclusive: bool): + def __init__(self, value: int | float, minimum_value: int | float, exclusive: bool) -> None: """ :param value: The given value. :param minimum_value: The minimum value allowed. @@ -167,7 +169,7 @@ def __init__(self, value, minimum_value, exclusive: bool): class MaxValidationError(ValidatorError): """ Error raised by the Max validator """ - def __init__(self, value, maximum_value, exclusive: bool): + def __init__(self, value: int | float, maximum_value: int | float, exclusive: bool) -> None: """ :param value: The given value. :param maximum_value: The maximum value allowed. @@ -186,7 +188,7 @@ def __init__(self, value, maximum_value, exclusive: bool): class EnumError(ValidatorError): """ Error raised by the Enum validator """ - def __init__(self, value: Any, choices: List[Any]): + def __init__(self, value: Any, choices: List[Any]) -> None: """ :param value: The given value. :param choices: The allowed choices. diff --git a/jsonmodels/fields.py b/jsonmodels/fields.py index 31f327f..3d004f4 100755 --- a/jsonmodels/fields.py +++ b/jsonmodels/fields.py @@ -5,47 +5,38 @@ import re import six from dateutil.parser import parse -from typing import Any, List, Generic, Optional, Dict, Sequence, Set, Tuple, TypeVar, Union, Pattern +from typing import Any, Dict, Generic, List, Optional, Tuple, TypeVar, Union, cast +from typing_extensions import Self + from .collections import ModelCollection -from .errors import RequiredFieldError, BadTypeError, AmbiguousTypeError +from .errors import AmbiguousTypeError, BadTypeError, RequiredFieldError +from .types import BsonEncodable, EmbedType, Field, JSONValue, Model, Validator, ValidatorFunction, ValidatorObject, Value -MYPY = False -if MYPY: - from .models import Base +T = TypeVar("T") # unique marker for "no default value specified". None is not good enough since # it is a completely valid default value. NotSet = object() -# BSON compatible types, which can be returned by toBsonEncodable. -BsonEncodable = Union[ - float, str, object, Dict, List, bytes, bool, datetime.datetime, None, - Pattern, int, bytes -] - - -T = TypeVar("T") - -class BaseField(Generic[T]): +class BaseField: """Base class for all fields.""" - types: Tuple[Any, ...] = () - - validators: List[Any] = [] - memory: WeakKeyDictionary + types: Tuple[Any, ...] = tuple() + validators: List[Validator] = [] def __init__( - self, - required=False, - nullable=False, - help_text=None, - validators: Optional[List[Any]]=None, - default=NotSet, - name=None): - self.memory = WeakKeyDictionary() + self, + required: bool = False, + nullable: bool = False, + help_text: Optional[str] = None, + validators: Optional[List[Validator]] = None, + default: Value = NotSet, + name: Optional[str] = None, + ) -> None: + self.memory: WeakKeyDictionary = WeakKeyDictionary() self.required = required self.help_text = help_text self.nullable = nullable @@ -60,49 +51,52 @@ def __init__( def has_default(self) -> bool: return self._default is not NotSet - def _assign_validators(self, validators) -> None: - if validators and not isinstance(validators, list): - validators = [validators] - self.validators = validators or [] + def _assign_validators(self, validators: Validator | List[Validator] | None) -> None: + if isinstance(validators, list): + self.validators = validators + elif validators is not None: + self.validators = [validators] + else: + self.validators = [] - def __set__(self, instance: "Base", value: Optional[T]) -> None: + def __set__(self, instance: Model, value: Any) -> None: self._finish_initialization(type(instance)) value = self.parse_value(value) self.validate(value) self.memory[instance._cache_key] = value - def __get__(self, instance: "Base", owner=None) -> T: + def __get__(self, instance: Model, owner: Model | None = None) -> Any: if instance is None: self._finish_initialization(owner) - return self # type: ignore + return self self._finish_initialization(type(instance)) self._check_value(instance) return self.memory[instance._cache_key] - def _finish_initialization(self, owner) -> None: + def _finish_initialization(self, owner: type[Model]) -> None: pass - def _check_value(self, obj): + def _check_value(self, obj: Model) -> None: if obj._cache_key not in self.memory: self.__set__(obj, self.get_default_value()) - def validate_for_object(self, obj): + def validate_for_object(self, obj: Model) -> None: value = self.__get__(obj) self.validate(value) - def validate(self, value: Optional[T]) -> None: + def validate(self, value: Any) -> None: self._check_types() self._validate_against_types(value) self._check_against_required(value) self._validate_with_custom_validators(value) - def _check_against_required(self, value) -> None: + def _check_against_required(self, value: Any) -> None: if value is None and self.required: raise RequiredFieldError() - def _validate_against_types(self, value) -> None: + def _validate_against_types(self, value: Any) -> None: if value is not None and not isinstance(value, self.types): raise BadTypeError(value, self.types, is_list=False) @@ -112,7 +106,7 @@ def _check_types(self) -> None: raise ValueError(tpl.format(type=type(self).__name__)) @staticmethod - def _get_embed_type(value, models): + def _get_embed_type(value: Value, models: tuple[EmbedType, ...]) -> EmbedType: """ Tries to guess which of the given models is applicable to the dict. :param value: The dict to check. @@ -130,7 +124,7 @@ def _get_embed_type(value, models): in model.iterate_with_name() } for model in models if hasattr(model, "iterate_with_name") - } # type: Dict[type, Set[str]] + } matching_models = [model for model, fields in model_fields.items() if fields.issuperset(value)] @@ -141,7 +135,7 @@ def _get_embed_type(value, models): return matching_models[0] return models[0] - def toBsonEncodable(self, value) -> BsonEncodable: + def toBsonEncodable(self, value: Any) -> BsonEncodable: """Optionally return a bson encodable python object. Returned object should be BSON compatible. By default uses the @@ -159,30 +153,30 @@ def toBsonEncodable(self, value) -> BsonEncodable: """ return self.to_struct(value=value) - def to_struct(self, value): - """Cast value to Python dict.""" - return value + def to_struct(self, value: Any) -> JSONValue: + """Cast value to Python structure.""" + return cast(JSONValue, value) - def parse_value(self, value: Optional[Any]) -> Optional[T]: + def parse_value(self, value: Any) -> T | None: """Parse value from primitive to desired format. Each field can parse value to form it wants it to be (like string or int). """ - return value + return cast( T | None, value) - def _validate_with_custom_validators(self, value): + def _validate_with_custom_validators(self, value: Any) -> None: if value is None and self.nullable: return for validator in self.validators: try: - validator.validate(value) + cast(ValidatorObject, validator).validate(value) except AttributeError: - validator(value) + cast(ValidatorFunction, validator)(value) - def get_default_value(self): + def get_default_value(self) -> Any: """Get default value for field. Each field can specify its default. @@ -190,35 +184,35 @@ def get_default_value(self): """ return self._default if self.has_default else None - def _validate_name(self): + def _validate_name(self) -> None: if self.name is None: return if not re.match(r'^[A-Za-z_](([\w\-]*)?\w+)?$', self.name): raise ValueError('Wrong name', self.name) - def structure_name(self, default): + def structure_name(self, default: str) -> str: return self.name if self.name is not None else default - def structue_name(self, default): + def structue_name(self, default: str) -> str: warnings.warn("`structue_name` is deprecated, please use " "`structure_name`") return self.structure_name(default) -class StringField(BaseField[str]): +class StringField(BaseField): """String field.""" types: Tuple[Any, ...] = six.string_types -class IntField(BaseField[int]): +class IntField(BaseField): """Integer field.""" types: Tuple[Any, ...] = (int,) - def parse_value(self, value): + def parse_value(self, value: Any) -> Any: """Cast value to `int`, e.g. from string or long""" parsed = super(IntField, self).parse_value(value) if parsed is None: @@ -229,20 +223,20 @@ def parse_value(self, value): raise BadTypeError(value, types=(int,), is_list=False) -class FloatField(BaseField[float]): +class FloatField(BaseField): """Float field.""" types: Tuple[Any, ...] = (float, int) -class BoolField(BaseField[bool]): +class BoolField(BaseField): """Bool field.""" types: Tuple[Any, ...] = (bool,) - def parse_value(self, value): + def parse_value(self, value: Value) -> Any: """Cast value to `bool`.""" parsed = super(BoolField, self).parse_value(value) return bool(parsed) if parsed is not None else None @@ -251,16 +245,16 @@ def parse_value(self, value): I = TypeVar("I") -class ListField(BaseField[List[I]]): +class ListField(BaseField): """List field.""" types: Tuple[Any, ...] = (list, tuple) - items_types: Tuple[Any, ...] + items_types: tuple[EmbedType, ...] item_validators: List[Any] - def __init__(self, items_types: Optional[List[Any]]=None, item_validators: Union[Any, List[Any]]=[], omit_empty=False, - *args, **kwargs): + def __init__(self, items_types: Optional[tuple[EmbedType, ...]]=None, item_validators: Union[Any, List[Any]]=[], + omit_empty: bool=False, *args: Any, **kwargs: Any): """Init. `ListField` is **always not required**. If you want to control number @@ -277,20 +271,17 @@ def __init__(self, items_types: Optional[List[Any]]=None, item_validators: Union self.required = False self._omit_empty = omit_empty - def get_default_value(self): + def get_default_value(self) -> Any: default = super(ListField, self).get_default_value() if default is None: return ModelCollection(self) return default - def _assign_types(self, items_types): + def _assign_types(self, items_types: tuple[EmbedType, ...] | None) -> None: if items_types: - try: - self.items_types = tuple(items_types) - except TypeError: - self.items_types = items_types, + self.items_types = tuple(items_types) else: - self.items_types = tuple() + self.items_types = () types = [] for type_ in self.items_types: @@ -300,13 +291,13 @@ def _assign_types(self, items_types): types.append(type_) self.items_types = tuple(types) - def validate(self, value): + def validate(self, value: Any) -> None: super(ListField, self).validate(value) for item in value: self.validate_single_value(item) - def validate_single_value(self, value): + def validate_single_value(self, value: Any) -> None: for validator in self.item_validators: try: validator.validate(value) @@ -317,9 +308,9 @@ def validate_single_value(self, value): return if not isinstance(value, self.items_types): - raise BadTypeError(value, self.items_types, is_list=True) + raise BadTypeError(value, tuple(self.items_types), is_list=True) - def parse_value(self, values): + def parse_value(self, values: Any) -> Any: """Cast value to proper collection.""" result = self.get_default_value() @@ -331,16 +322,16 @@ def parse_value(self, values): return [self._cast_value(value) for value in values] - def _cast_value(self, value): + def _cast_value(self, value: Any) -> Any: if isinstance(value, self.items_types): return value elif isinstance(value, dict): model_type = self._get_embed_type(value, self.items_types) return model_type(**value) else: - raise BadTypeError(value, self.items_types, is_list=True) + raise BadTypeError(value, tuple(self.items_types), is_list=True) - def _finish_initialization(self, owner): + def _finish_initialization(self, owner: type[Model]) -> None: super(ListField, self)._finish_initialization(owner) types = [] @@ -351,14 +342,14 @@ def _finish_initialization(self, owner): types.append(item_type) self.items_types = tuple(types) - def _elem_to_struct(self, value): + def _elem_to_struct(self, value: Value) -> Value | dict[str, Value]: try: return value.to_struct() except AttributeError: return value - def to_struct(self, values): - return [self._elem_to_struct(v) for v in values] \ + def to_struct(self, values: Any) -> JSONValue: + return [self._elem_to_struct(v) for v in cast(List, values)] \ if values or not self._omit_empty else None @@ -367,7 +358,7 @@ class DerivedListField(ListField): A list field that has another field for its items. """ - def __init__(self, field: BaseField, *args, **kwargs): + def __init__(self, field: BaseField, *args: Any, **kwargs: Any): """ :param field: The field that will be in each of the items of the list. :param help_text: The help text of the list field. @@ -382,16 +373,16 @@ def __init__(self, field: BaseField, *args, **kwargs): *args, **fixed_kwargs, ) - def to_struct(self, values: List[Any]) -> Optional[List[Any]]: + def to_struct(self, values: Any) -> JSONValue: """ Converts the list to its output format. :param values: The values in the list. :return: The converted values. """ - return [self._field.to_struct(value) for value in values] \ + return [self._field.to_struct(value) for value in cast(List, values)] \ if values or not self._omit_empty else None - def parse_value(self, values: Optional[Any]) -> Optional[List[Any]]: + def parse_value(self, values: Any) -> Any: """ Converts the list to its internal format. :param values: The values in the list. @@ -418,23 +409,23 @@ class EmbeddedField(BaseField): """Field for embedded models.""" - def __init__(self, model_types, *args, **kwargs): + def __init__(self, model_types: tuple[EmbedType | str, ...], *args: Any, **kwargs: Any) -> None: self._assign_model_types(model_types) super(EmbeddedField, self).__init__(*args, **kwargs) - def _assign_model_types(self, model_types): + def _assign_model_types(self, model_types: tuple[EmbedType | str, ...]) -> None: if not isinstance(model_types, (list, tuple)): model_types = (model_types,) - types = [] + types: List[EmbedType | _LazyType] = [] for type_ in model_types: if isinstance(type_, six.string_types): types.append(_LazyType(type_)) else: - types.append(type_) + types.append(cast(EmbedType, type_)) self.types = tuple(types) - def _finish_initialization(self, owner): + def _finish_initialization(self, owner: type[Model]) -> None: super(EmbeddedField, self)._finish_initialization(owner) types = [] for model_type in self.types: @@ -445,26 +436,26 @@ def _finish_initialization(self, owner): self.types = tuple(types) - def validate(self, value): + def validate(self, value: Any) -> None: super(EmbeddedField, self).validate(value) try: value.validate() except AttributeError: pass - def parse_value(self, value): + def parse_value(self, value: Any) -> Any: """Parse value to proper model type.""" if not isinstance(value, dict): - return value + return cast(EmbedType, value) embed_type = self._get_embed_type(value, self.types) return embed_type(**value) - def to_struct(self, value): - return value.to_struct() + def to_struct(self, value: Any) -> JSONValue: + return cast(Model, value).to_struct() -class MapField(BaseField[Dict[Any, Any]]): +class MapField(BaseField): """ Model field that keeps a mapping between two other fields. It is basically a dictionary with key and values being separate fields. @@ -476,8 +467,8 @@ class MapField(BaseField[Dict[Any, Any]]): """ types: Tuple[Any, ...] = (dict,) - def __init__(self, key_field: BaseField, value_field: BaseField, - **kwargs): + def __init__(self, key_field: Field, value_field: Field, + **kwargs: Any): """ :param key_field: The field that is responsible for converting and validating the keys in this mapping. @@ -489,7 +480,7 @@ def __init__(self, key_field: BaseField, value_field: BaseField, self._key_field = key_field self._value_field = value_field - def _finish_initialization(self, owner): + def _finish_initialization(self, owner: type[Model]) -> None: """ Completes the initialization of the fields, allowing for lazy refs. """ @@ -504,7 +495,7 @@ def get_default_value(self) -> Any: return dict() return default - def parse_value(self, values: Optional[Any]) -> Optional[Dict[Any, Any]]: + def parse_value(self, values: Any) -> Any: """ Parses the given values into a new dict. """ values = super().parse_value(values) if values is None: @@ -516,16 +507,16 @@ def parse_value(self, values: Optional[Any]) -> Optional[Dict[Any, Any]]: ] return type(values)(items) # Preserves OrderedDict - def to_struct(self, values: Dict[Any, Any]) -> Dict[Any, Any]: + def to_struct(self, values: Any) -> JSONValue: """ Casts the field values into a dict. """ items = [ (self._key_field.to_struct(key), self._value_field.to_struct(value)) - for key, value in values.items() + for key, value in cast(Dict, values).items() ] - return type(values)(items) # Preserves OrderedDict + return cast(JSONValue, type(values)(items)) # Preserves OrderedDict - def validate(self, values: Optional[Dict[Any, Any]]) -> None: + def validate(self, values: Any) -> None: """ Validates all keys and values in the map field. :param values: The values in the mapping. @@ -538,17 +529,16 @@ def validate(self, values: Optional[Dict[Any, Any]]) -> None: self._value_field.validate(value) -class _LazyType(object): - - def __init__(self, path): +class _LazyType: + def __init__(self, path: str) -> None: self.path = path - def evaluate(self, base_cls): + def evaluate(self, base_cls: type[Model]) -> Any: module, type_name = _evaluate_path(self.path, base_cls) return _import(module, type_name) -def _evaluate_path(relative_path, base_cls): +def _evaluate_path(relative_path: str, base_cls: type[Model]) -> tuple[Any, str]: base_module = base_cls.__module__ modules = _get_modules(relative_path, base_module) @@ -560,7 +550,7 @@ def _evaluate_path(relative_path, base_cls): return module, type_name -def _get_modules(relative_path, base_module): +def _get_modules(relative_path: str, base_module: str) -> Any: canonical_path = relative_path.lstrip('.') canonical_modules = canonical_path.split('.') @@ -575,7 +565,7 @@ def _get_modules(relative_path, base_module): return parent_modules[:parents_amount * -1] + canonical_modules -def _import(module_name, type_name): +def _import(module_name: str, type_name: str) -> Any: module = __import__(module_name, fromlist=[type_name]) try: return getattr(module, type_name) @@ -590,7 +580,9 @@ class TimeField(StringField): types: Tuple[Any, ...] = (datetime.time,) - def __init__(self, str_format=None, *args, **kwargs): + def __init__( + self, str_format: Optional[str] = None, *args: Any, **kwargs: Any + ) -> None: """Init. :param str str_format: Format to cast time to (if `None` - casting to @@ -600,13 +592,14 @@ def __init__(self, str_format=None, *args, **kwargs): self.str_format = str_format super(TimeField, self).__init__(*args, **kwargs) - def to_struct(self, value): + def to_struct(self, value: Any) -> JSONValue: """Cast `time` object to string.""" + datetime_value = cast(datetime.time, value) if self.str_format: - return value.strftime(self.str_format) - return value.isoformat() + return datetime_value.strftime(self.str_format) + return datetime_value.isoformat() - def parse_value(self, value): + def parse_value(self, value: Any) -> Any: """Parse string into instance of `time`.""" if value is None: return value @@ -622,7 +615,9 @@ class DateField(StringField): types: Tuple[Any, ...] = (datetime.date,) default_format = '%Y-%m-%d' - def __init__(self, str_format=None, *args, **kwargs): + def __init__( + self, str_format: Optional[str] = None, *args: Any, **kwargs: Any + ) -> None: """Init. :param str str_format: Format to cast date to (if `None` - casting to @@ -632,13 +627,14 @@ def __init__(self, str_format=None, *args, **kwargs): self.str_format = str_format super(DateField, self).__init__(*args, **kwargs) - def to_struct(self, value): + def to_struct(self, value: Any) -> JSONValue: """Cast `date` object to string.""" + date_value = cast(datetime.date, value) if self.str_format: - return value.strftime(self.str_format) - return value.strftime(self.default_format) + return date_value.strftime(self.str_format) + return date_value.strftime(self.default_format) - def parse_value(self, value): + def parse_value(self, value: Any) -> Any: """Parse string into instance of `date`.""" if value is None: return value @@ -653,7 +649,9 @@ class DateTimeField(StringField): types: Tuple[Any, ...] = (datetime.datetime,) - def __init__(self, str_format=None, *args, **kwargs): + def __init__( + self, str_format: Optional[str] = None, *args: Any, **kwargs: Any + ) -> None: """Init. :param str str_format: Format to cast datetime to (if `None` - casting @@ -663,21 +661,22 @@ def __init__(self, str_format=None, *args, **kwargs): self.str_format = str_format super(DateTimeField, self).__init__(*args, **kwargs) - def to_struct(self, value): + def to_struct(self, value: Any) -> JSONValue: """Cast `datetime` object to string.""" + datetime_value = cast(datetime.datetime, value) if self.str_format: - return value.strftime(self.str_format) - return value.isoformat() + return datetime_value.strftime(self.str_format) + return datetime_value.isoformat() - def toBsonEncodable(self, value: datetime.datetime) -> datetime.datetime: + def toBsonEncodable(self, value: Any) -> BsonEncodable: """ Keep datetime object a datetime object, since pymongo supports that. """ if not isinstance(value, self.types): raise BadTypeError(value, self.types, is_list=False) - return value + return cast(BsonEncodable, value) - def parse_value(self, value): + def parse_value(self, value: Any) -> Any: """Parse string into instance of `datetime`.""" if isinstance(value, datetime.datetime): return value @@ -687,17 +686,17 @@ def parse_value(self, value): return None -class GenericField(BaseField[Any]): +class GenericField(BaseField): """ Field that supports any kind of value, converting models to their correct struct, keeping ordered dictionaries in their original order. """ types: Tuple[Any, ...] = (any,) - def _validate_against_types(self, value) -> None: + def _validate_against_types(self, value: Value) -> None: pass - def to_struct(self, values: Any) -> Any: + def to_struct(self, values: Any) -> JSONValue: """ Casts value to Python structure. """ from .models import Base if isinstance(values, Base): @@ -711,4 +710,4 @@ def to_struct(self, values: Any) -> Any: for key, value in values.items()] return type(values)(items) # preserves OrderedDict - return values + return cast(JSONValue, values) diff --git a/jsonmodels/models.py b/jsonmodels/models.py index 350393e..3d1ebe4 100644 --- a/jsonmodels/models.py +++ b/jsonmodels/models.py @@ -1,18 +1,24 @@ -import six +from typing import Any, Dict, Generator, Tuple, Type, cast + +from jsonmodels.types import JSONSchemaProperty from . import parsers, errors from .fields import BaseField from .errors import FieldValidationError, ValidatorError, ValidationError +from .types import Field, JSONSchemaProperty, JSONValue +Values = Dict[str, Any] +Fields = Tuple[str, Field] +FieldsWithNames = Tuple[str, str, Field] -class JsonmodelMeta(type): - def __new__(cls, name, bases, attributes): +class JsonmodelMeta(type): + def __new__(cls: Type[JsonmodelMeta], name: str, bases: tuple, attributes: dict) -> type: cls.validate_fields(attributes) return super(cls, cls).__new__(cls, name, bases, attributes) @staticmethod - def validate_fields(attributes): + def validate_fields(attributes: dict[str, Any]) -> None: fields = { key: value for key, value in attributes.items() if isinstance(value, BaseField) @@ -25,15 +31,15 @@ def validate_fields(attributes): taken_names.add(structure_name) -class Base(six.with_metaclass(JsonmodelMeta, object)): +class Base(metaclass=JsonmodelMeta): """Base class for all models.""" - def __init__(self, **kwargs): + def __init__(self, **kwargs: Values) -> None: self._cache_key = _CacheKey() self.populate(**kwargs) - def populate(self, **values): + def populate(self, **values: Values) -> None: """Populate values to fields. Skip non-existing.""" values = values.copy() fields = list(self.iterate_with_name()) @@ -45,7 +51,7 @@ def populate(self, **values): if name in values: self.set_field(field, name, values.pop(name)) - def get_field(self, field_name): + def get_field(self, field_name: str) -> Field: """Get field associated with given attribute.""" for attr_name, field in self: if field_name == attr_name: @@ -53,7 +59,7 @@ def get_field(self, field_name): raise errors.FieldNotFound(field_name) - def set_field(self, field, field_name, value): + def set_field(self, field: Field, field_name: str, value: Any) -> None: """ Sets the value of a field. """ try: field.__set__(self, value) @@ -61,12 +67,12 @@ def set_field(self, field, field_name, value): raise FieldValidationError(type(self).__name__, field_name, value, error) - def __iter__(self): + def __iter__(self) -> Generator[Fields, None, None]: """Iterate through fields and values.""" for name, field in self.iterate_over_fields(): yield name, field - def validate(self): + def validate(self) -> None: """Explicitly validate all the fields.""" for name, field in self: try: @@ -77,15 +83,15 @@ def validate(self): value, error) @classmethod - def iterate_over_fields(cls): + def iterate_over_fields(cls) -> Generator[Fields, None, None]: """Iterate through fields as `(attribute_name, field_instance)`.""" for attr in dir(cls): class_attribute = getattr(cls, attr) if isinstance(class_attribute, BaseField): - yield attr, class_attribute + yield attr, cast(Field, class_attribute) @classmethod - def iterate_with_name(cls): + def iterate_with_name(cls) -> Generator[FieldsWithNames, None, None]: """Iterate over fields, but also give `structure_name`. Format is `(attribute_name, structure_name, field_instance)`. @@ -96,16 +102,16 @@ def iterate_with_name(cls): structure_name = field.structure_name(attr_name) yield attr_name, structure_name, field - def to_struct(self): + def to_struct(self) -> JSONValue: """Cast model to Python structure.""" return parsers.to_struct(self) @classmethod - def to_json_schema(cls): + def to_json_schema(cls) -> JSONSchemaProperty: """Generate JSON schema for model.""" return parsers.to_json_schema(cls) - def __repr__(self): + def __repr__(self) -> str: attrs = {} for name, _ in self: try: @@ -122,17 +128,17 @@ def __repr__(self): ), ) - def __str__(self): + def __str__(self) -> str: return '{name} object'.format(name=self.__class__.__name__) - def __setattr__(self, name, value): + def __setattr__(self, name: str, value: Any) -> None: try: return super(Base, self).__setattr__(name, value) except ValidatorError as error: raise FieldValidationError(type(self).__name__, name, value, error) - def __eq__(self, other): + def __eq__(self, other: object) -> bool: if type(other) is not type(self): return False @@ -152,9 +158,9 @@ def __eq__(self, other): return True - def __ne__(self, other): + def __ne__(self, other: object) -> bool: return not (self == other) -class _CacheKey(object): +class _CacheKey: """Object to identify model in memory.""" diff --git a/jsonmodels/parsers.py b/jsonmodels/parsers.py index 3f0f60d..250441a 100644 --- a/jsonmodels/parsers.py +++ b/jsonmodels/parsers.py @@ -1,16 +1,13 @@ """Parsers to change model structure into different ones.""" import inspect +from typing import Any, cast -from . import fields, builders, errors +from . import builders, errors, fields +from .types import Builder, CollectionField, Field, JSONSchemaProperty, JSONSchemaTypeName, JSONValue, Model -def to_struct(model): - """ - Cast instance of model to python structure. - :param model: Model to be casted. - :rtype: ``dict`` - - """ +def to_struct(model: Model) -> JSONValue: + """Cast instance of model to python structure.""" model.validate() resp = {} @@ -25,18 +22,13 @@ def to_struct(model): return resp -def to_json_schema(cls): - """Generate JSON schema for given class. - - :param cls: Class to be casted. - :rtype: ``dict`` - - """ +def to_json_schema(cls: Any) -> JSONSchemaProperty: + """Generate JSON schema for given class.""" builder = build_json_schema(cls) - return builder.build() + return cast(JSONSchemaProperty, builder.build()) -def build_json_schema(value, parent_builder=None): +def build_json_schema(value: Any, parent_builder: Builder | None = None) -> Builder: from .models import Base cls = value if inspect.isclass(value) else value.__class__ @@ -46,7 +38,7 @@ def build_json_schema(value, parent_builder=None): return build_json_schema_primitive(cls, parent_builder) -def build_json_schema_object(cls, parent_builder=None): +def build_json_schema_object(cls: type[Model], parent_builder: Builder | None = None) -> builders.ObjectBuilder: builder = builders.ObjectBuilder(cls, parent_builder) if builder.count_type(builder.type) > 1: return builder @@ -56,12 +48,11 @@ def build_json_schema_object(cls, parent_builder=None): elif isinstance(field, fields.ListField): builder.add_field(name, field, _parse_list(field, builder)) else: - builder.add_field( - name, field, _create_primitive_field_schema(field)) + builder.add_field(name, field, _create_primitive_field_schema(field)) return builder -def _parse_list(field, parent_builder): +def _parse_list(field: fields.ListField, parent_builder: Builder | None) -> str | JSONSchemaProperty: builder = builders.ListBuilder( parent_builder, field.nullable, default=field._default) for type in field.items_types: @@ -69,7 +60,7 @@ def _parse_list(field, parent_builder): return builder.build() -def _parse_embedded(field, parent_builder): +def _parse_embedded(field: fields.EmbeddedField, parent_builder: Builder | None) -> str | JSONSchemaProperty: builder = builders.EmbeddedBuilder( parent_builder, field.nullable, default=field._default) for type in field.types: @@ -77,13 +68,13 @@ def _parse_embedded(field, parent_builder): return builder.build() -def build_json_schema_primitive(cls, parent_builder): +def build_json_schema_primitive(cls: type, parent_builder: Builder | None) -> Builder: builder = builders.PrimitiveBuilder(cls, parent_builder) return builder -def _create_primitive_field_schema(field): - schema = {'type': _get_schema_type(field)} +def _create_primitive_field_schema(field: Field) -> JSONSchemaProperty: + schema: JSONSchemaProperty = {'type': _get_schema_type(field)} if isinstance(field, fields.FloatField): schema['format'] = 'float' @@ -98,7 +89,8 @@ def _create_primitive_field_schema(field): return schema -def _get_schema_type(field): +def _get_schema_type(field: Field) -> JSONSchemaTypeName: + obj_type: JSONSchemaTypeName if isinstance(field, fields.StringField): obj_type = 'string' elif isinstance(field, fields.IntField): @@ -112,5 +104,5 @@ def _get_schema_type(field): else: raise errors.FieldNotSupported(type(field)) if field.nullable: - obj_type = [obj_type, 'null'] + return [obj_type, 'null'] return obj_type diff --git a/jsonmodels/utilities.py b/jsonmodels/utilities.py index 6c9102b..51c576e 100644 --- a/jsonmodels/utilities.py +++ b/jsonmodels/utilities.py @@ -5,7 +5,9 @@ import six import re from collections import namedtuple -from typing import cast, Any, List, Tuple +from typing import Dict, cast, Any, List, Tuple + +from jsonmodels.types import JSONSchemaProperty six_string_types: List[Any] = list(six.string_types) SCALAR_TYPES = cast(Tuple[Any], tuple(six_string_types + [int, float, bool])) @@ -22,14 +24,14 @@ PythonRegex = namedtuple('PythonRegex', ['regex', 'flags']) -def _normalize_string_type(value): +def _normalize_string_type(value: Any) -> Any: if isinstance(value, six.string_types): return six.text_type(value) else: return value -def _compare_dicts(one, two): +def _compare_dicts(one: Dict, two: Dict) -> bool: if len(one) != len(two): return False @@ -42,7 +44,7 @@ def _compare_dicts(one, two): return True -def _compare_lists(one, two): +def _compare_lists(one: List, two: List) -> bool: if len(one) != len(two): return False @@ -56,13 +58,13 @@ def _compare_lists(one, two): return they_match -def _assert_same_types(one, two): +def _assert_same_types(one: Any, two: Any) -> None: if not isinstance(one, type(two)) or not isinstance(two, type(one)): raise RuntimeError('Types mismatch! "{type1}" and "{type2}".'.format( type1=type(one).__name__, type2=type(two).__name__)) -def compare_schemas(one, two): +def compare_schemas(one: JSONSchemaProperty, two: JSONSchemaProperty) -> bool: """Compare two structures that represents JSON schemas. For comparison you can't use normal comparison, because in JSON schema @@ -85,7 +87,7 @@ def compare_schemas(one, two): if isinstance(one, list): return _compare_lists(one, two) elif isinstance(one, dict): - return _compare_dicts(one, two) + return _compare_dicts(cast(dict, one), cast(dict, two)) elif isinstance(one, SCALAR_TYPES): return one == two elif one is None: @@ -95,7 +97,7 @@ def compare_schemas(one, two): type=type(one).__name__)) -def is_ecma_regex(regex): +def is_ecma_regex(regex: str) -> bool: """Check if given regex is of type ECMA 262 or not. :rtype: bool @@ -112,7 +114,7 @@ def is_ecma_regex(regex): return False -def convert_ecma_regex_to_python(value): +def convert_ecma_regex_to_python(value: str) -> PythonRegex: """Convert ECMA 262 regex to Python tuple with regex and flags. If given value is already Python regex it will be returned unchanged. @@ -136,7 +138,7 @@ def convert_ecma_regex_to_python(value): return PythonRegex('/'.join(parts[1:]), result_flags) -def convert_python_regex_to_ecma(value, flags=()): +def convert_python_regex_to_ecma(value: str, flags: List[re.RegexFlag]=[]) -> str: """Convert Python regex to ECMA 262 regex. If given value is already ECMA regex it will be returned unchanged. @@ -151,6 +153,6 @@ def convert_python_regex_to_ecma(value, flags=()): return value result_flags = [PYTHON_TO_ECMA_FLAGS[f] for f in flags] - result_flags = ''.join(result_flags) + result_flags_str = ''.join(result_flags) - return '/{value}/{flags}'.format(value=value, flags=result_flags) + return '/{value}/{flags}'.format(value=value, flags=result_flags_str) diff --git a/jsonmodels/validators.py b/jsonmodels/validators.py index 480141a..fa9935f 100644 --- a/jsonmodels/validators.py +++ b/jsonmodels/validators.py @@ -1,18 +1,19 @@ """Predefined validators.""" import re +from typing import Sized, cast from six.moves import reduce from .errors import MinValidationError, MaxValidationError, BadTypeError, \ RegexError, MinLengthError, MaxLengthError, EnumError from . import utilities - +from .types import JSONSchemaProperty class Min(object): """Validator for minimum value.""" - def __init__(self, minimum_value, exclusive=False): + def __init__(self, minimum_value: int | float, exclusive: bool = False) -> None: """Init. :param minimum_value: Minimum value for validator. @@ -23,13 +24,13 @@ def __init__(self, minimum_value, exclusive=False): self.minimum_value = minimum_value self.exclusive = exclusive - def validate(self, value): + def validate(self, value: int | float) -> None: """Validate value.""" if value < self.minimum_value \ or (self.exclusive and value == self.minimum_value): raise MinValidationError(value, self.minimum_value, self.exclusive) - def modify_schema(self, field_schema): + def modify_schema(self, field_schema: JSONSchemaProperty) -> None: """Modify field schema.""" field_schema['minimum'] = self.minimum_value if self.exclusive: @@ -40,7 +41,7 @@ class Max(object): """Validator for maximum value.""" - def __init__(self, maximum_value, exclusive=False): + def __init__(self, maximum_value: int | float, exclusive: bool = False) -> None: """Init. :param maximum_value: Maximum value for validator. @@ -51,13 +52,13 @@ def __init__(self, maximum_value, exclusive=False): self.maximum_value = maximum_value self.exclusive = exclusive - def validate(self, value): + def validate(self, value: int | float) -> None: """Validate value.""" if value > self.maximum_value \ or (self.exclusive and value == self.maximum_value): raise MaxValidationError(value, self.maximum_value, self.exclusive) - def modify_schema(self, field_schema): + def modify_schema(self, field_schema: JSONSchemaProperty) -> None: """Modify field schema.""" field_schema['maximum'] = self.maximum_value if self.exclusive: @@ -73,7 +74,7 @@ class Regex(object): 'multiline': re.M, } - def __init__(self, pattern, custom_error=None, **flags): + def __init__(self, pattern: str, custom_error: Exception | None=None, **flags: re._FlagsType) -> None: """Init. Note, that if given pattern is ECMA regex, given flags will be @@ -98,7 +99,7 @@ def __init__(self, pattern, custom_error=None, **flags): self.flags = [self.FLAGS[key] for key, value in flags.items() if key in self.FLAGS and value] - def validate(self, value): + def validate(self, value: str) -> None: """Validate value.""" flags = self._calculate_flags() @@ -112,10 +113,10 @@ def validate(self, value): raise self.custom_error raise RegexError(value, self.pattern) - def _calculate_flags(self): + def _calculate_flags(self) -> re._FlagsType: return reduce(lambda x, y: x | y, self.flags, 0) - def modify_schema(self, field_schema): + def modify_schema(self, field_schema: JSONSchemaProperty) -> None: """Modify field schema.""" field_schema['pattern'] = utilities.convert_python_regex_to_ecma( self.pattern, self.flags @@ -126,7 +127,7 @@ class Length(object): """Validator for length.""" - def __init__(self, minimum_value=None, maximum_value=None): + def __init__(self, minimum_value: int | None = None, maximum_value: int | None = None) -> None: """Init. Note that if no `minimum_value` neither `maximum_value` will be @@ -144,7 +145,7 @@ def __init__(self, minimum_value=None, maximum_value=None): self.minimum_value = minimum_value self.maximum_value = maximum_value - def validate(self, value): + def validate(self, value: Sized) -> None: """Validate value.""" len_ = len(value) @@ -154,24 +155,28 @@ def validate(self, value): if self.maximum_value is not None and len_ > self.maximum_value: raise MaxLengthError(value, self.maximum_value) - def modify_schema(self, field_schema): + def modify_schema(self, field_schema: JSONSchemaProperty) -> None: """Modify field schema.""" is_array = field_schema.get('type') == 'array' if self.minimum_value: - key = 'minItems' if is_array else 'minLength' - field_schema[key] = self.minimum_value + if is_array: + field_schema['minItems'] = self.minimum_value + else: + field_schema['minLength'] = self.minimum_value if self.maximum_value: - key = 'maxItems' if is_array else 'maxLength' - field_schema[key] = self.maximum_value + if is_array: + field_schema['maxItems'] = self.maximum_value + else: + field_schema['maxLength'] = self.maximum_value class Enum(object): """Validator for enums.""" - def __init__(self, *choices): + def __init__(self, *choices: str) -> None: """Init. :param [] choices: Valid choices for the field. @@ -179,9 +184,9 @@ def __init__(self, *choices): self.choices = list(choices) - def validate(self, value): + def validate(self, value: str) -> None: if value not in self.choices: raise EnumError(value, self.choices) - def modify_schema(self, field_schema): + def modify_schema(self, field_schema: JSONSchemaProperty) -> None: field_schema['enum'] = self.choices diff --git a/run_mypy.sh b/run_mypy.sh index da90413..61b9d3c 100755 --- a/run_mypy.sh +++ b/run_mypy.sh @@ -1,4 +1,3 @@ #!/bin/bash mypy -p jsonmodels -mypy tests/test_fields.py From ff01b3041f165ef5b97f53bcd75f46aece0d458c Mon Sep 17 00:00:00 2001 From: Simon Edwards Date: Sat, 27 Apr 2024 13:02:44 +0000 Subject: [PATCH 08/12] Add missing file --- jsonmodels/types.py | 157 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 157 insertions(+) create mode 100644 jsonmodels/types.py diff --git a/jsonmodels/types.py b/jsonmodels/types.py new file mode 100644 index 0000000..acce75f --- /dev/null +++ b/jsonmodels/types.py @@ -0,0 +1,157 @@ +import datetime +from re import Pattern +from typing import Any, Callable, Dict, Generator, List, Literal, Protocol, Tuple, TypedDict, Union, runtime_checkable +from weakref import WeakKeyDictionary + +from jsonmodels.models import _CacheKey + +Value = Any + +JSONObject = Dict[str, "JSONValue"] +JSONValue = Union[None, bool, str, float, int, List["JSONValue"], JSONObject] + +# JSONSchema = JSONValue +# JSONSchemaDict = JSONObject + +JSONSchemaBasicTypeName = Literal["string"] | Literal["number"] | Literal["boolean"] | Literal["object"] | Literal["array"] | Literal["null"] +JSONSchemaTypeName = JSONSchemaBasicTypeName | List[JSONSchemaBasicTypeName | Literal["null"]] + +class JSONSchemaProperty(TypedDict, total=False): + type: JSONSchemaTypeName + format: str + default: Any + required: List[str] + items: JSONSchemaProperty + description: str + properties: Dict[str, str | JSONSchemaProperty] + additionalProperties: bool + definitions: Dict[str, JSONSchemaProperty] + + minItems: int + minLength: int + maxItems: int + maxLength: int + + minimum: int | float + exclusiveMinimum: bool + + maximum: int | float + exclusiveMaximum: bool + + pattern: str + enum: List[str] + +# JSONSchema = JSONSchemaProperty + +# BSON compatible types, which can be returned by toBsonEncodable. +BsonEncodable = Union[ + float, str, object, Dict, List, bytes, bool, datetime.datetime, None, + Pattern, int, bytes +] + +@runtime_checkable +class Builder(Protocol): + def register_type(self, model_type: type[Model], builder: "Builder") -> None: + ... + + def get_builder(self, model_type: type[Model]) -> "Builder": + ... + + def count_type(self, model_type: type[Model]) -> int: + ... + + def build(self) -> str | JSONSchemaProperty: + ... + + def add_definition(self, builder: "Builder") -> None: + ... + + def build_definition(self, add_definitions: bool = True) -> JSONSchemaProperty: + ... + + @property + def is_definition(self) -> bool: + ... + + @property + def type_name(self) -> str: + ... + + +@runtime_checkable +class ValidatorObject(Protocol): + + def validate(self, value: Any) -> None: + ... + + def modify_schema(self, field_schema: JSONSchemaProperty) -> None: + ... + +ValidatorFunction = Callable[[Any], None] +Validator = ValidatorFunction | ValidatorObject + +class Field(Protocol): + types: Tuple[Any, ...] + memory: WeakKeyDictionary + required: bool + validators: List[Validator] + item_validators: List[Validator] + help_text: str | None + nullable: bool + _default: Any + + @property + def has_default(self) -> bool: + ... + + def __set__(self, instance: Model, value: Any) -> None: + ... + + def __get__(self, instance: Model) -> Any: + ... + + def _finish_initialization(self, owner: type[Model]) -> None: + ... + + def to_struct(self, value: Any) -> JSONValue: + ... + + def structure_name(self, default: str) -> str: + ... + + def toBsonEncodable(self, value: Any) -> BsonEncodable: + ... + + def validate_for_object(self, obj: Model) -> None: + ... + + def parse_value(self, value: Value) -> Any: + ... + + def validate(self, value: Value) -> None: + ... + +class CollectionField(Field, Protocol): + def validate_single_value(self, value: Value) -> None: + ... + + +Fields = Tuple[str, Field] +FieldsWithNames = Tuple[str, str, Field] + +class Model(Protocol): + + # __name__: str + _cache_key: _CacheKey + + def validate(self) -> None: + ... + + @classmethod + def iterate_with_name(cls) -> Generator[FieldsWithNames, None, None]: + ... + + def to_struct(self) -> JSONValue: + ... + +EmbedType = Union[type[str], type[int], type[float], type[bool], type[list], type[dict], type[Model]] From 4eb11225710d2aac6a5d959f1fb0225ef371b322 Mon Sep 17 00:00:00 2001 From: Simon Edwards Date: Sat, 27 Apr 2024 13:12:58 +0000 Subject: [PATCH 09/12] Cleans and minor improvements --- jsonmodels/errors.py | 2 +- jsonmodels/fields.py | 22 +++++++++++----------- jsonmodels/parsers.py | 2 +- jsonmodels/validators.py | 2 +- tests/test_fields.py | 2 +- 5 files changed, 15 insertions(+), 15 deletions(-) diff --git a/jsonmodels/errors.py b/jsonmodels/errors.py index 2e64413..411079e 100644 --- a/jsonmodels/errors.py +++ b/jsonmodels/errors.py @@ -1,6 +1,6 @@ from typing import Any, List, Sized, Tuple, Type -from .types import EmbedType, Model +from .types import EmbedType class ValidationError(RuntimeError): diff --git a/jsonmodels/fields.py b/jsonmodels/fields.py index 3d004f4..97a7e56 100755 --- a/jsonmodels/fields.py +++ b/jsonmodels/fields.py @@ -5,15 +5,12 @@ import re import six from dateutil.parser import parse -from typing import Any, Dict, Generic, List, Optional, Tuple, TypeVar, Union, cast -from typing_extensions import Self - +from typing import Any, Dict, List, Optional, Tuple, TypeVar, Union, cast from .collections import ModelCollection from .errors import AmbiguousTypeError, BadTypeError, RequiredFieldError from .types import BsonEncodable, EmbedType, Field, JSONValue, Model, Validator, ValidatorFunction, ValidatorObject, Value -T = TypeVar("T") # unique marker for "no default value specified". None is not good enough since # it is a completely valid default value. @@ -157,14 +154,14 @@ def to_struct(self, value: Any) -> JSONValue: """Cast value to Python structure.""" return cast(JSONValue, value) - def parse_value(self, value: Any) -> T | None: + def parse_value(self, value: Any) -> Any: """Parse value from primitive to desired format. Each field can parse value to form it wants it to be (like string or int). """ - return cast( T | None, value) + return value def _validate_with_custom_validators(self, value: Any) -> None: if value is None and self.nullable: @@ -253,7 +250,7 @@ class ListField(BaseField): items_types: tuple[EmbedType, ...] item_validators: List[Any] - def __init__(self, items_types: Optional[tuple[EmbedType, ...]]=None, item_validators: Union[Any, List[Any]]=[], + def __init__(self, items_types: EmbedType | tuple[EmbedType, ...] | List[EmbedType] | None=None, item_validators: Union[Any, List[Any]]=[], omit_empty: bool=False, *args: Any, **kwargs: Any): """Init. @@ -277,9 +274,12 @@ def get_default_value(self) -> Any: return ModelCollection(self) return default - def _assign_types(self, items_types: tuple[EmbedType, ...] | None) -> None: + def _assign_types(self, items_types: EmbedType | tuple[EmbedType, ...] | List[EmbedType] | None) -> None: if items_types: - self.items_types = tuple(items_types) + if isinstance(items_types, (tuple, list)): + self.items_types = tuple(items_types) + else: + self.items_types = (items_types, ) else: self.items_types = () @@ -409,11 +409,11 @@ class EmbeddedField(BaseField): """Field for embedded models.""" - def __init__(self, model_types: tuple[EmbedType | str, ...], *args: Any, **kwargs: Any) -> None: + def __init__(self, model_types: EmbedType | str | tuple[EmbedType | str, ...], *args: Any, **kwargs: Any) -> None: self._assign_model_types(model_types) super(EmbeddedField, self).__init__(*args, **kwargs) - def _assign_model_types(self, model_types: tuple[EmbedType | str, ...]) -> None: + def _assign_model_types(self, model_types: EmbedType | str | tuple[EmbedType | str, ...]) -> None: if not isinstance(model_types, (list, tuple)): model_types = (model_types,) diff --git a/jsonmodels/parsers.py b/jsonmodels/parsers.py index 250441a..62f8050 100644 --- a/jsonmodels/parsers.py +++ b/jsonmodels/parsers.py @@ -3,7 +3,7 @@ from typing import Any, cast from . import builders, errors, fields -from .types import Builder, CollectionField, Field, JSONSchemaProperty, JSONSchemaTypeName, JSONValue, Model +from .types import Builder, Field, JSONSchemaProperty, JSONSchemaTypeName, JSONValue, Model def to_struct(model: Model) -> JSONValue: diff --git a/jsonmodels/validators.py b/jsonmodels/validators.py index fa9935f..3909a6c 100644 --- a/jsonmodels/validators.py +++ b/jsonmodels/validators.py @@ -1,6 +1,6 @@ """Predefined validators.""" import re -from typing import Sized, cast +from typing import Sized from six.moves import reduce diff --git a/tests/test_fields.py b/tests/test_fields.py index c5bdc02..438b31d 100755 --- a/tests/test_fields.py +++ b/tests/test_fields.py @@ -205,7 +205,7 @@ class Car(models.Base): class Person(models.Base): - names = fields.ListField[Union[str, int, float, bool, FullName, Car]]( + names = fields.ListField( [str, int, float, bool, FullName, Car], help_text='A list of names.', ) From 72bdf3726dc21cd12f1aab9b02f369cf95a11e37 Mon Sep 17 00:00:00 2001 From: Simon Edwards Date: Mon, 29 Apr 2024 18:51:43 +0000 Subject: [PATCH 10/12] Mypy plugin --- jsonmodels/mypy_plugin.py | 101 +++++++++++++++++++++++++++++++++ mypy_plugin.ini | 9 +++ tests_mypy/__init__.py | 0 tests_mypy/case_date.py | 4 ++ tests_mypy/case_datetime.py | 4 ++ tests_mypy/case_derivedlist.py | 4 ++ tests_mypy/case_embedded.py | 7 +++ tests_mypy/case_int.py | 4 ++ tests_mypy/case_list.py | 4 ++ tests_mypy/case_nullable.py | 7 +++ tests_mypy/case_str.py | 5 ++ tests_mypy/models.py | 27 +++++++++ tests_mypy/test_mypy_plugin.py | 49 ++++++++++++++++ 13 files changed, 225 insertions(+) create mode 100644 jsonmodels/mypy_plugin.py create mode 100644 mypy_plugin.ini create mode 100644 tests_mypy/__init__.py create mode 100644 tests_mypy/case_date.py create mode 100644 tests_mypy/case_datetime.py create mode 100644 tests_mypy/case_derivedlist.py create mode 100644 tests_mypy/case_embedded.py create mode 100644 tests_mypy/case_int.py create mode 100644 tests_mypy/case_list.py create mode 100644 tests_mypy/case_nullable.py create mode 100644 tests_mypy/case_str.py create mode 100644 tests_mypy/models.py create mode 100644 tests_mypy/test_mypy_plugin.py diff --git a/jsonmodels/mypy_plugin.py b/jsonmodels/mypy_plugin.py new file mode 100644 index 0000000..10793bb --- /dev/null +++ b/jsonmodels/mypy_plugin.py @@ -0,0 +1,101 @@ +from typing import Callable, List, Type +import mypy +from mypy.plugin import Plugin, AttributeContext, FunctionContext +from mypy.types import Type as MypyType + +class JSONModelsPlugin(Plugin): + def get_function_hook(self, fullname: str) -> Callable[[AttributeContext], Type] | None: + if fullname == "jsonmodels.fields.StringField": + return self._string_field_callback + if fullname == "jsonmodels.fields.IntField": + return self._int_field_callback + if fullname == "jsonmodels.fields.FloatField": + return self._float_field_callback + if fullname == "jsonmodels.fields.BoolField": + return self._bool_field_callback + if fullname == "jsonmodels.fields.TimeField": + return self._time_field_callback + if fullname == "jsonmodels.fields.DateField": + return self._date_field_callback + if fullname == "jsonmodels.fields.DateTimeField": + return self._datetime_field_callback + if fullname == "jsonmodels.fields.EmbeddedField": + return self._embedded_field_callback + if fullname == "jsonmodels.fields.ListField": + return self._list_field_callback + if fullname == "jsonmodels.fields.DerivedListField": + return self._list_field_callback + + return None + + def _wrap_nullable(self, ctx: FunctionContext, core_type: MypyType) -> MypyType: + try: + nullable_index = ctx.callee_arg_names.index("nullable") + except ValueError: + return core_type + + arg_value = ctx.args[nullable_index] + if len(arg_value) == 0: + return core_type + + nullable_value = arg_value[0] + if isinstance(nullable_value, mypy.nodes.NameExpr) and nullable_value.fullname == "builtins.True": + return mypy.types.UnionType([core_type, mypy.types.NoneType()]) + + return core_type + + def _string_field_callback(self, ctx: FunctionContext) -> MypyType: + return self._wrap_nullable(ctx, ctx.api.named_type("builtins.str")) + + def _int_field_callback(self, ctx: FunctionContext) -> MypyType: + return self._wrap_nullable(ctx, ctx.api.named_type("builtins.int")) + + def _float_field_callback(self, ctx: FunctionContext) -> MypyType: + return self._wrap_nullable(ctx, ctx.api.named_type("builtins.float")) + + def _bool_field_callback(self, ctx: FunctionContext) -> MypyType: + return self._wrap_nullable(ctx, ctx.api.named_type("builtins.bool")) + + def _time_field_callback(self, ctx: FunctionContext) -> MypyType: + return self._wrap_nullable(ctx, ctx.api.named_type("datetime.time")) + + def _date_field_callback(self, ctx: FunctionContext) -> MypyType: + return self._wrap_nullable(ctx, ctx.api.named_type("datetime.date")) + + def _datetime_field_callback(self, ctx: FunctionContext) -> MypyType: + return self._wrap_nullable(ctx, ctx.api.named_type("datetime.datetime")) + + def _get_type_from_arg(self, ctx: FunctionContext, arg_name: str) -> MypyType: + try: + model_types_index = ctx.callee_arg_names.index(arg_name) + except ValueError: + return mypy.types.NoneType() + + arg_value = ctx.args[model_types_index] + if len(arg_value) == 0: + return mypy.types.NoneType() + + model_types_value = arg_value[0] + + if isinstance(model_types_value, mypy.nodes.NameExpr): + return ctx.api.named_type(model_types_value.fullname) + + if isinstance(model_types_value, mypy.nodes.TupleExpr): + accepted_types: List[MypyType] = [] + for item in model_types_value.items: + if isinstance(item, mypy.nodes.NameExpr): + accepted_types.append(ctx.api.named_type(item.fullname)) + return mypy.types.UnionType(accepted_types) + + return mypy.types.NoneType() + + def _embedded_field_callback(self, ctx: FunctionContext) -> MypyType: + return self._wrap_nullable(ctx, self._get_type_from_arg(ctx, "model_types")) + + def _list_field_callback(self, ctx: FunctionContext) -> MypyType: + item_type = self._get_type_from_arg(ctx, "items_types") + list_type = ctx.api.named_generic_type("list", [item_type]) + return self._wrap_nullable(ctx, list_type) + +def plugin(version: str): + return JSONModelsPlugin diff --git a/mypy_plugin.ini b/mypy_plugin.ini new file mode 100644 index 0000000..3c0a5de --- /dev/null +++ b/mypy_plugin.ini @@ -0,0 +1,9 @@ +[mypy] +disallow_untyped_defs = True +disallow_any_unimported = True +no_implicit_optional = True +warn_return_any = True +warn_unused_configs = True +warn_unused_ignores = True +show_error_codes = True +plugins = jsonmodels/mypy_plugin.py diff --git a/tests_mypy/__init__.py b/tests_mypy/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests_mypy/case_date.py b/tests_mypy/case_date.py new file mode 100644 index 0000000..aab2422 --- /dev/null +++ b/tests_mypy/case_date.py @@ -0,0 +1,4 @@ +from models import person + +reveal_type(person.dob) +# expect: datetime.date diff --git a/tests_mypy/case_datetime.py b/tests_mypy/case_datetime.py new file mode 100644 index 0000000..27edaa2 --- /dev/null +++ b/tests_mypy/case_datetime.py @@ -0,0 +1,4 @@ +from models import person + +reveal_type(person.last_update) +# expect: datetime.datetime diff --git a/tests_mypy/case_derivedlist.py b/tests_mypy/case_derivedlist.py new file mode 100644 index 0000000..9ce2cb5 --- /dev/null +++ b/tests_mypy/case_derivedlist.py @@ -0,0 +1,4 @@ +from models import person + +reveal_type(person.nicknames) +# expect: builtins.list[builtins.str] diff --git a/tests_mypy/case_embedded.py b/tests_mypy/case_embedded.py new file mode 100644 index 0000000..26dc711 --- /dev/null +++ b/tests_mypy/case_embedded.py @@ -0,0 +1,7 @@ +from models import person + +reveal_type(person.address) +# expect: models.Address + +reveal_type(person.transport) +# expect: Union[models.Car, models.Boat] diff --git a/tests_mypy/case_int.py b/tests_mypy/case_int.py new file mode 100644 index 0000000..4e8b370 --- /dev/null +++ b/tests_mypy/case_int.py @@ -0,0 +1,4 @@ +from models import person + +reveal_type(person.age) +# expect: builtins.int diff --git a/tests_mypy/case_list.py b/tests_mypy/case_list.py new file mode 100644 index 0000000..87bf2d4 --- /dev/null +++ b/tests_mypy/case_list.py @@ -0,0 +1,4 @@ +from models import person + +reveal_type(person.pet_names) +# expect: builtins.list[builtins.str] diff --git a/tests_mypy/case_nullable.py b/tests_mypy/case_nullable.py new file mode 100644 index 0000000..0af3d9e --- /dev/null +++ b/tests_mypy/case_nullable.py @@ -0,0 +1,7 @@ +from models import address + +reveal_type(address.line_1) +# expect: builtins.str + +reveal_type(address.line_2) +# expect: Union[builtins.str, None] diff --git a/tests_mypy/case_str.py b/tests_mypy/case_str.py new file mode 100644 index 0000000..dc5e077 --- /dev/null +++ b/tests_mypy/case_str.py @@ -0,0 +1,5 @@ + +from models import person + +reveal_type(person.name) +# expect: builtins.str diff --git a/tests_mypy/models.py b/tests_mypy/models.py new file mode 100644 index 0000000..b83f55b --- /dev/null +++ b/tests_mypy/models.py @@ -0,0 +1,27 @@ +from jsonmodels import models, fields + +class Address(models.Base): + line_1 = fields.StringField() + line_2 = fields.StringField(nullable=True) + city = fields.StringField() + +class Car(models.Base): + registration = fields.StringField() + +class Boat(models.Base): + name = fields.StringField() + +class Person(models.Base): + name = fields.StringField() + surname = fields.StringField() + age = fields.IntField() + dob = fields.DateField() + alive = fields.BoolField() + last_update = fields.DateTimeField() + address = fields.EmbeddedField(model_types=Address) + transport = fields.EmbeddedField(model_types=(Car, Boat)) + pet_names = fields.ListField(items_types=str) + nicknames = fields.DerivedListField(fields.StringField()) + +person = Person() +address = Address() diff --git a/tests_mypy/test_mypy_plugin.py b/tests_mypy/test_mypy_plugin.py new file mode 100644 index 0000000..88a9dff --- /dev/null +++ b/tests_mypy/test_mypy_plugin.py @@ -0,0 +1,49 @@ +from mypy import api +import os + + +EXPECT_LINE = "# expect: " +EXPECT_LINE_OUTPUT = "Revealed type is " + + +def test_file(directory: str, file_name: str) -> bool: + expected: list[str] = [] + file_path = os.path.join(directory, file_name) + with open(file_path, 'r') as f: + lines = f.readlines() + for line in lines: + if line.startswith(EXPECT_LINE): + expected.append(line[len(EXPECT_LINE):].strip()) + + result = api.run([ + "--config-file=../mypy_plugin.ini", + "--show-traceback", + file_path]) + + output_expected: list[str] = [] + for output_line in result[0].splitlines(): + index = output_line.find(EXPECT_LINE_OUTPUT) + if index > 0: + output_expected.append(output_line[index + len(EXPECT_LINE_OUTPUT):].strip().strip('"')) + + if expected == output_expected: + print(f"PASS {file_name}") + return True + else: + print(f"FAIL {file_name}\n") + print(f"Expected: {repr(expected)}") + print(f"Received: {repr(output_expected)}") + print("STDOUT----------------") + print(result[0]) + print(result[1]) + print("----------------------") + return False + +def main() -> None: + directory = '.' + files = [f for f in os.listdir(directory) if os.path.isfile(os.path.join(directory, f)) and f.startswith("case_")] + + for file_name in files: + test_file(directory, file_name) + +main() From 06bee12ef0f5c5b30b5c1b06a7198e551b080205 Mon Sep 17 00:00:00 2001 From: Simon Edwards Date: Mon, 29 Apr 2024 18:51:58 +0000 Subject: [PATCH 11/12] Add Mypy config for type checking main code --- mypy.ini | 8 ++++++++ 1 file changed, 8 insertions(+) create mode 100644 mypy.ini diff --git a/mypy.ini b/mypy.ini new file mode 100644 index 0000000..0fa0de0 --- /dev/null +++ b/mypy.ini @@ -0,0 +1,8 @@ +[mypy] +disallow_untyped_defs = True +disallow_any_unimported = True +no_implicit_optional = True +warn_return_any = True +warn_unused_configs = True +warn_unused_ignores = True +show_error_codes = True From 596f4884772d52971691d491f0b40842ab33e723 Mon Sep 17 00:00:00 2001 From: Simon Edwards Date: Thu, 2 May 2024 11:47:19 +0000 Subject: [PATCH 12/12] Add support for the rest of the field types --- jsonmodels/fields.py | 23 ++++++++-- jsonmodels/mypy_plugin.py | 83 +++++++++++++++++++++++++++++----- jsonmodels/types.py | 5 +- tests_mypy/case_alias.py | 11 +++++ tests_mypy/case_derivedlist.py | 3 ++ tests_mypy/case_generic.py | 4 ++ tests_mypy/case_map.py | 4 ++ tests_mypy/case_str.py | 1 - tests_mypy/case_subfield.py | 13 ++++++ tests_mypy/models.py | 9 +++- tests_mypy/test_mypy_plugin.py | 1 + 11 files changed, 137 insertions(+), 20 deletions(-) create mode 100644 tests_mypy/case_alias.py create mode 100644 tests_mypy/case_generic.py create mode 100644 tests_mypy/case_map.py create mode 100644 tests_mypy/case_subfield.py diff --git a/jsonmodels/fields.py b/jsonmodels/fields.py index 97a7e56..e3602c2 100755 --- a/jsonmodels/fields.py +++ b/jsonmodels/fields.py @@ -9,7 +9,7 @@ from .collections import ModelCollection from .errors import AmbiguousTypeError, BadTypeError, RequiredFieldError -from .types import BsonEncodable, EmbedType, Field, JSONValue, Model, Validator, ValidatorFunction, ValidatorObject, Value +from .types import BsonEncodable, EmbedType, Field, JSONValue, Model, PrimitiveTypeInstance, Validator, ValidatorFunction, ValidatorObject, Value # unique marker for "no default value specified". None is not good enough since @@ -358,12 +358,18 @@ class DerivedListField(ListField): A list field that has another field for its items. """ - def __init__(self, field: BaseField, *args: Any, **kwargs: Any): + def __init__(self, field: BaseField | PrimitiveTypeInstance, *args: Any, **kwargs: Any): """ - :param field: The field that will be in each of the items of the list. + :param field: The field instance that will be in each of the items of the list. :param help_text: The help text of the list field. :param validators: The validators for the list field. """ + # Note: It is a bit of a hack but the signature allows many primitive + # types even though in reality we only accept BaseField instances. + # The extra types are for the type checker and our Mypy plugin. + if not isinstance(field, BaseField): + raise BadTypeError(field, (BaseField,), is_list=False) + self._field = field fixed_kwargs = kwargs.copy() @@ -467,7 +473,7 @@ class MapField(BaseField): """ types: Tuple[Any, ...] = (dict,) - def __init__(self, key_field: Field, value_field: Field, + def __init__(self, key_field: BaseField | PrimitiveTypeInstance, value_field: BaseField | PrimitiveTypeInstance, **kwargs: Any): """ :param key_field: The field that is responsible for converting and @@ -477,7 +483,16 @@ def __init__(self, key_field: Field, value_field: Field, :param kwargs: Other keyword arguments to the base class. """ super(MapField, self).__init__(**kwargs) + + # Note: It is a bit of a hack but the signature allows many primitive + # types even though in reality we only accept BaseField instances. + # The extra types are for the type checker and our Mypy plugin. + if not isinstance(key_field, BaseField): + raise BadTypeError(key_field, (BaseField,), is_list=False) self._key_field = key_field + + if not isinstance(value_field, BaseField): + raise BadTypeError(value_field, (BaseField,), is_list=False) self._value_field = value_field def _finish_initialization(self, owner: type[Model]) -> None: diff --git a/jsonmodels/mypy_plugin.py b/jsonmodels/mypy_plugin.py index 10793bb..14ad722 100644 --- a/jsonmodels/mypy_plugin.py +++ b/jsonmodels/mypy_plugin.py @@ -2,29 +2,68 @@ import mypy from mypy.plugin import Plugin, AttributeContext, FunctionContext from mypy.types import Type as MypyType +from mypy.nodes import TypeInfo + + +JSONMODEL_TYPE = [ + "jsonmodels.fields.StringField", + "jsonmodels.fields.IntField", + "jsonmodels.fields.FloatField", + "jsonmodels.fields.BoolField", + "jsonmodels.fields.TimeField", + "jsonmodels.fields.DateField", + "jsonmodels.fields.DateTimeField", + "jsonmodels.fields.EmbeddedField", + "jsonmodels.fields.ListField", + "jsonmodels.fields.DerivedListField", + "jsonmodels.fields.MapField", + "jsonmodels.fields.GenericField" +] + class JSONModelsPlugin(Plugin): def get_function_hook(self, fullname: str) -> Callable[[AttributeContext], Type] | None: - if fullname == "jsonmodels.fields.StringField": + jsonmodel_fullname: str + + sym = self.lookup_fully_qualified(fullname) + if sym is None: + return None + node = sym.node + if not isinstance(node, TypeInfo): + return None + + # Find a known jsonmodel field type in this type's class hierarchy. + for node in node.mro: + if node.fullname in JSONMODEL_TYPE: + jsonmodel_fullname = node.fullname + break + else: + return None + + if jsonmodel_fullname == "jsonmodels.fields.StringField": return self._string_field_callback - if fullname == "jsonmodels.fields.IntField": + if jsonmodel_fullname == "jsonmodels.fields.IntField": return self._int_field_callback - if fullname == "jsonmodels.fields.FloatField": + if jsonmodel_fullname == "jsonmodels.fields.FloatField": return self._float_field_callback - if fullname == "jsonmodels.fields.BoolField": + if jsonmodel_fullname == "jsonmodels.fields.BoolField": return self._bool_field_callback - if fullname == "jsonmodels.fields.TimeField": + if jsonmodel_fullname == "jsonmodels.fields.TimeField": return self._time_field_callback - if fullname == "jsonmodels.fields.DateField": + if jsonmodel_fullname == "jsonmodels.fields.DateField": return self._date_field_callback - if fullname == "jsonmodels.fields.DateTimeField": + if jsonmodel_fullname == "jsonmodels.fields.DateTimeField": return self._datetime_field_callback - if fullname == "jsonmodels.fields.EmbeddedField": + if jsonmodel_fullname == "jsonmodels.fields.EmbeddedField": return self._embedded_field_callback - if fullname == "jsonmodels.fields.ListField": - return self._list_field_callback - if fullname == "jsonmodels.fields.DerivedListField": + if jsonmodel_fullname == "jsonmodels.fields.ListField": return self._list_field_callback + if jsonmodel_fullname == "jsonmodels.fields.DerivedListField": + return self._derived_list_field_callback + if jsonmodel_fullname == "jsonmodels.fields.MapField": + return self._map_field_callback + if jsonmodel_fullname == "jsonmodels.fields.GenericField": + return self._generic_field_callback return None @@ -97,5 +136,27 @@ def _list_field_callback(self, ctx: FunctionContext) -> MypyType: list_type = ctx.api.named_generic_type("list", [item_type]) return self._wrap_nullable(ctx, list_type) + def _get_type_from_arg_type(self, ctx: FunctionContext, arg_name: str) -> MypyType: + try: + model_types_index = ctx.callee_arg_names.index(arg_name) + except ValueError: + return mypy.types.NoneType() + + return ctx.arg_types[model_types_index][0] + + def _derived_list_field_callback(self, ctx: FunctionContext) -> MypyType: + item_type = self._get_type_from_arg_type(ctx, "field") + list_type = ctx.api.named_generic_type("list", [item_type]) + return self._wrap_nullable(ctx, list_type) + + def _map_field_callback(self, ctx: FunctionContext) -> MypyType: + key_type = self._get_type_from_arg_type(ctx, "key_field") + value_type = self._get_type_from_arg_type(ctx, "value_field") + list_type = ctx.api.named_generic_type("dict", [key_type, value_type]) + return self._wrap_nullable(ctx, list_type) + + def _generic_field_callback(self, ctx: FunctionContext) -> MypyType: + return mypy.types.AnyType(mypy.types.TypeOfAny.special_form) + def plugin(version: str): return JSONModelsPlugin diff --git a/jsonmodels/types.py b/jsonmodels/types.py index acce75f..a225878 100644 --- a/jsonmodels/types.py +++ b/jsonmodels/types.py @@ -10,9 +10,6 @@ JSONObject = Dict[str, "JSONValue"] JSONValue = Union[None, bool, str, float, int, List["JSONValue"], JSONObject] -# JSONSchema = JSONValue -# JSONSchemaDict = JSONObject - JSONSchemaBasicTypeName = Literal["string"] | Literal["number"] | Literal["boolean"] | Literal["object"] | Literal["array"] | Literal["null"] JSONSchemaTypeName = JSONSchemaBasicTypeName | List[JSONSchemaBasicTypeName | Literal["null"]] @@ -155,3 +152,5 @@ def to_struct(self) -> JSONValue: ... EmbedType = Union[type[str], type[int], type[float], type[bool], type[list], type[dict], type[Model]] + +PrimitiveTypeInstance = str | int | float | bool diff --git a/tests_mypy/case_alias.py b/tests_mypy/case_alias.py new file mode 100644 index 0000000..f6a6cf6 --- /dev/null +++ b/tests_mypy/case_alias.py @@ -0,0 +1,11 @@ +from jsonmodels import models, fields + +first_name_field = fields.StringField + +class AliasModel(models.Base): + name = first_name_field() + +alias = AliasModel() + +reveal_type(alias.name) +# expect: builtins.str diff --git a/tests_mypy/case_derivedlist.py b/tests_mypy/case_derivedlist.py index 9ce2cb5..fefda8d 100644 --- a/tests_mypy/case_derivedlist.py +++ b/tests_mypy/case_derivedlist.py @@ -2,3 +2,6 @@ reveal_type(person.nicknames) # expect: builtins.list[builtins.str] + +reveal_type(person.alias_names) +# expect: builtins.list[builtins.str] diff --git a/tests_mypy/case_generic.py b/tests_mypy/case_generic.py new file mode 100644 index 0000000..46c1996 --- /dev/null +++ b/tests_mypy/case_generic.py @@ -0,0 +1,4 @@ +from models import car_registry + +reveal_type(car_registry.any_random) +# expect: Any diff --git a/tests_mypy/case_map.py b/tests_mypy/case_map.py new file mode 100644 index 0000000..dead5dc --- /dev/null +++ b/tests_mypy/case_map.py @@ -0,0 +1,4 @@ +from models import car_registry + +reveal_type(car_registry.registry) +# expect: builtins.dict[builtins.str, builtins.str] diff --git a/tests_mypy/case_str.py b/tests_mypy/case_str.py index dc5e077..fbc146f 100644 --- a/tests_mypy/case_str.py +++ b/tests_mypy/case_str.py @@ -1,4 +1,3 @@ - from models import person reveal_type(person.name) diff --git a/tests_mypy/case_subfield.py b/tests_mypy/case_subfield.py new file mode 100644 index 0000000..694d95f --- /dev/null +++ b/tests_mypy/case_subfield.py @@ -0,0 +1,13 @@ +from jsonmodels import models, fields + + +class SubStringField(fields.StringField): + pass + +class TestModel(models.Base): + name = SubStringField() + +test_instance = TestModel() + +reveal_type(test_instance.name) +# expect: builtins.str diff --git a/tests_mypy/models.py b/tests_mypy/models.py index b83f55b..8bc2e9b 100644 --- a/tests_mypy/models.py +++ b/tests_mypy/models.py @@ -21,7 +21,14 @@ class Person(models.Base): address = fields.EmbeddedField(model_types=Address) transport = fields.EmbeddedField(model_types=(Car, Boat)) pet_names = fields.ListField(items_types=str) - nicknames = fields.DerivedListField(fields.StringField()) + nicknames = fields.DerivedListField(field=fields.StringField()) + alias_names = fields.DerivedListField(fields.StringField()) + +class CarRegistry(models.Base): + registry = fields.MapField(key_field=fields.StringField(), value_field=fields.StringField()) + any_random = fields.GenericField() + person = Person() address = Address() +car_registry = CarRegistry() diff --git a/tests_mypy/test_mypy_plugin.py b/tests_mypy/test_mypy_plugin.py index 88a9dff..6238b43 100644 --- a/tests_mypy/test_mypy_plugin.py +++ b/tests_mypy/test_mypy_plugin.py @@ -42,6 +42,7 @@ def test_file(directory: str, file_name: str) -> bool: def main() -> None: directory = '.' files = [f for f in os.listdir(directory) if os.path.isfile(os.path.join(directory, f)) and f.startswith("case_")] + files.sort() for file_name in files: test_file(directory, file_name)