From 98c93d29dce05f06413285ee1d56bd2854f0ebd5 Mon Sep 17 00:00:00 2001 From: Bob Gregory Date: Sun, 12 Jan 2025 13:29:53 +0000 Subject: [PATCH 1/8] Ruffen --- Makefile | 6 ++++++ punq/__init__.py | 11 ++++------- tests/test_instance_creation.py | 12 ++++++------ tests/test_resolution_scope.py | 15 +++++++++------ tests/test_scoped_resolution.py | 12 +++++------- 5 files changed, 30 insertions(+), 26 deletions(-) diff --git a/Makefile b/Makefile index 8105fe2..0d36307 100644 --- a/Makefile +++ b/Makefile @@ -13,6 +13,12 @@ check: ## Run code quality tools. @echo "🚀 Checking for obsolete dependencies: Running deptry" @uv run deptry . +.PHONY: fmt +fmt: + @echo "Format and fix" + @uv run ruff check --fix + @uv run ruff format + .PHONY: test test: ## Test the code with pytest @echo "🚀 Testing code: Running pytest" diff --git a/punq/__init__.py b/punq/__init__.py index f3fd9ce..3fc019e 100644 --- a/punq/__init__.py +++ b/punq/__init__.py @@ -18,7 +18,7 @@ import contextlib import inspect -from collections import defaultdict, ChainMap +from collections import defaultdict from enum import Enum from importlib.metadata import PackageNotFoundError, version from typing import Any, Callable, NamedTuple, get_type_hints @@ -128,6 +128,7 @@ class InvalidForwardReferenceError(InvalidForwardReferenceException): pass + class RegistrationScope: """ Simple chained dictionary[service, list[implementation]]. @@ -154,8 +155,6 @@ def get(self, key): return self.__get(key, []) - - class Scope(Enum): """Controls the lifetime of resolved objects. @@ -248,8 +247,7 @@ def register_service_and_impl(self, service, scope, impl, resolve_args): Sending message via smtp: Hello """ self.__registrations.append( - service, - _Registration(service, scope, impl, self._get_needs_for_ctor(impl), resolve_args) + service, _Registration(service, scope, impl, self._get_needs_for_ctor(impl), resolve_args) ) def register_service_and_instance(self, service, instance): @@ -300,8 +298,7 @@ def register_concrete_service(self, service, scope, resolve_args=None): if not inspect.isclass(service): raise InvalidSelfRegistrationError(service) self.__registrations.append( - service, - _Registration(service, scope, service, self._get_needs_for_ctor(service), resolve_args or {}) + service, _Registration(service, scope, service, self._get_needs_for_ctor(service), resolve_args or {}) ) def build_context(self, key, existing=None): diff --git a/tests/test_instance_creation.py b/tests/test_instance_creation.py index 489c23d..01b7b28 100644 --- a/tests/test_instance_creation.py +++ b/tests/test_instance_creation.py @@ -165,15 +165,15 @@ def test_can_provide_typed_arguments_to_resolve(): container.register(TmpFileMessageWriter) container.register(HelloWorldSpeaker) - tmpfile = NamedTemporaryFile() + with NamedTemporaryFile() as tmpfile: - writer = container.resolve(MessageWriter, path=tmpfile.name) - speaker = container.resolve(HelloWorldSpeaker, writer=writer) + writer = container.resolve(MessageWriter, path=tmpfile.name) + speaker = container.resolve(HelloWorldSpeaker, writer=writer) - speaker.speak() + speaker.speak() - tmpfile.seek(0) - expect(tmpfile.read().decode()).to(equal("Hello World")) + tmpfile.seek(0) + expect(tmpfile.read().decode()).to(equal("Hello World")) def test_resolve_returns_the_latest_registration_for_a_service(): diff --git a/tests/test_resolution_scope.py b/tests/test_resolution_scope.py index e144279..f24eb89 100644 --- a/tests/test_resolution_scope.py +++ b/tests/test_resolution_scope.py @@ -7,7 +7,7 @@ In order to handle scoping, we need our own datastructure. """ -from collections import defaultdict + from punq import RegistrationScope @@ -19,6 +19,7 @@ def test_a_root_scope_returns_the_empty_list_when_nothing_is_registered(): scope = RegistrationScope() assert scope.get("some_key") == [] + def test_a_scope_contains_items(): """ We can add items into a scope and get them back. @@ -30,6 +31,7 @@ def test_a_scope_contains_items(): assert scope.get("some-key") == ["hello", "world"] + def test_a_child_scope_extends_its_parent(): """ When a child scope adds an item, it should be added to the list of @@ -39,11 +41,12 @@ def test_a_child_scope_extends_its_parent(): child = parent.child() parent.append("some-key", "hello") - child.append("some-key","world") + child.append("some-key", "world") assert child.get("some-key") == ["hello", "world"] assert parent.get("some-key") == ["hello"] + def test_resolution_can_skip_a_level(): """ If someone goes nuts, the registrations should inherit across multiple @@ -64,13 +67,13 @@ def test_resolution_can_skip_a_level(): assert parent.get("a") == [1] assert child.get("a") == [1] - assert grandparent.get("b") == [ 2 ] - assert parent.get("b") == [ 2 ] + assert grandparent.get("b") == [2] + assert parent.get("b") == [2] assert child.get("b") == [2, "x"] assert grandparent.get("c") == [] - assert parent.get("c") == [ 3 ] - assert child.get("c") == [ 3 ] + assert parent.get("c") == [3] + assert child.get("c") == [3] assert grandparent.get("d") == [] assert parent.get("d") == [] diff --git a/tests/test_scoped_resolution.py b/tests/test_scoped_resolution.py index 088e43d..cc6684e 100644 --- a/tests/test_scoped_resolution.py +++ b/tests/test_scoped_resolution.py @@ -1,16 +1,13 @@ -from expects import be_a, equal, expect import pytest +from expects import be_a, expect -from punq import Container, InvalidRegistrationError, MissingDependencyError, Scope +from punq import Container, MissingDependencyError from tests.test_dependencies import ( ConnectionStringFactory, FancyDbMessageWriter, - HelloWorldSpeaker, - MessageSpeaker, MessageWriter, StdoutMessageWriter, TmpFileMessageWriter, - WrappingMessageWriter, ) @@ -21,6 +18,7 @@ def test_scoped_service_with_no_dependencies(): child = container.child() expect(child.resolve(MessageWriter)).to(be_a(StdoutMessageWriter)) + def test_when_overriding_a_service(): """ In this test, we replace the parent registration completely. @@ -111,13 +109,13 @@ def test_when_inheriting_a_singleton_instance(): ContextBag = dict -class ThingDoer: +class ThingDoer: def __init__(self, context: ContextBag): self.context = context -def test_when_registering_a_state_bag(): +def test_when_registering_a_state_bag(): parent = Container() parent.register(ThingDoer) From b0ec91fab423c0f22e02ea42befed43c0b492378 Mon Sep 17 00:00:00 2001 From: Bob Gregory Date: Sun, 12 Jan 2025 13:36:20 +0000 Subject: [PATCH 2/8] 1 down, 56 to go! --- Makefile | 2 ++ punq/__init__.py | 2 +- pyproject.toml | 1 + uv.lock | 15 +++++++++++++++ 4 files changed, 19 insertions(+), 1 deletion(-) diff --git a/Makefile b/Makefile index 0d36307..30b0e64 100644 --- a/Makefile +++ b/Makefile @@ -8,6 +8,8 @@ install: ## Install the virtual environment and install the pre-commit hooks check: ## Run code quality tools. @echo "🚀 Checking lock file consistency with 'pyproject.toml'" @uv lock --locked + @echo "🚀 Running type check" + @uv run pyright @echo "🚀 Linting code: Running pre-commit" @uv run pre-commit run -a @echo "🚀 Checking for obsolete dependencies: Running deptry" diff --git a/punq/__init__.py b/punq/__init__.py index 3fc019e..52aafbc 100644 --- a/punq/__init__.py +++ b/punq/__init__.py @@ -276,7 +276,7 @@ def register_service_and_instance(self, service, instance): ... ) """ - self.__registrations.append(service, _Registration(service, Scope.singleton, lambda: instance, {}, {})) + self.__registrations.append(service, _Registration(service, Scope.singleton, lambda: instance, {}, [])) def register_concrete_service(self, service, scope, resolve_args=None): """Register a service as its own implementation. diff --git a/pyproject.toml b/pyproject.toml index 95511d0..a6ee98c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,6 +39,7 @@ dev-dependencies = [ "attrs>=24.2.0", "mkdocs-include-markdown-plugin>=6.2.2", "xdoctest>=1.2.0", + "pyright>=1.1.391", ] [build-system] diff --git a/uv.lock b/uv.lock index 75cfb8d..4c7574c 100644 --- a/uv.lock +++ b/uv.lock @@ -748,6 +748,7 @@ dev = [ { name = "mkdocs-material" }, { name = "mkdocstrings", extra = ["python"] }, { name = "pre-commit" }, + { name = "pyright" }, { name = "pytest" }, { name = "pytest-cov" }, { name = "ruff" }, @@ -768,6 +769,7 @@ dev = [ { name = "mkdocs-material", specifier = ">=8.5.10" }, { name = "mkdocstrings", extras = ["python"], specifier = ">=0.26.1" }, { name = "pre-commit", specifier = ">=2.20.0" }, + { name = "pyright", specifier = ">=1.1.391" }, { name = "pytest", specifier = ">=7.2.0" }, { name = "pytest-cov", specifier = ">=4.0.0" }, { name = "ruff", specifier = ">=0.6.9" }, @@ -811,6 +813,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ba/f4/3c4ddfcc0c19c217c6de513842d286de8021af2f2ab79bbb86c00342d778/pyproject_api-1.8.0-py3-none-any.whl", hash = "sha256:3d7d347a047afe796fd5d1885b1e391ba29be7169bd2f102fcd378f04273d228", size = 13100 }, ] +[[package]] +name = "pyright" +version = "1.1.391" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "nodeenv" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/11/05/4ea52a8a45cc28897edb485b4102d37cbfd5fce8445d679cdeb62bfad221/pyright-1.1.391.tar.gz", hash = "sha256:66b2d42cdf5c3cbab05f2f4b76e8bec8aa78e679bfa0b6ad7b923d9e027cadb2", size = 21965 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ad/89/66f49552fbeb21944c8077d11834b2201514a56fd1b7747ffff9630f1bd9/pyright-1.1.391-py3-none-any.whl", hash = "sha256:54fa186f8b3e8a55a44ebfa842636635688670c6896dcf6cf4a7fc75062f4d15", size = 18579 }, +] + [[package]] name = "pytest" version = "8.3.3" From d8efa4c22cde3e360d68f754e81a845b94b7945c Mon Sep 17 00:00:00 2001 From: Bob Gregory Date: Sun, 12 Jan 2025 14:03:45 +0000 Subject: [PATCH 3/8] temp --- punq/__init__.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/punq/__init__.py b/punq/__init__.py index 52aafbc..692957d 100644 --- a/punq/__init__.py +++ b/punq/__init__.py @@ -29,6 +29,8 @@ __version__ = version(__name__) +ServiceKey = type + class MissingDependencyException(Exception): """Deprecated alias for MissingDependencyError.""" From 5b5dc667b8f773934ca10f229ddc688d87607495 Mon Sep 17 00:00:00 2001 From: Bob Gregory Date: Sun, 12 Jan 2025 14:46:42 +0000 Subject: [PATCH 4/8] Add first test for registration --- punq/__init__.py | 12 ++++++------ punq/_compat.py | 9 +++++++-- tests/test_typing.py | 8 ++++++++ 3 files changed, 21 insertions(+), 8 deletions(-) create mode 100644 tests/test_typing.py diff --git a/punq/__init__.py b/punq/__init__.py index 692957d..cdacb19 100644 --- a/punq/__init__.py +++ b/punq/__init__.py @@ -21,15 +21,15 @@ from collections import defaultdict from enum import Enum from importlib.metadata import PackageNotFoundError, version -from typing import Any, Callable, NamedTuple, get_type_hints +from typing import Any, Callable, NamedTuple, get_type_hints, TypeVar, Generic -from ._compat import ensure_forward_ref, is_generic_list +from ._compat import ensure_forward_ref, is_generic_list, ServiceKey with contextlib.suppress(PackageNotFoundError): __version__ = version(__name__) +TService = TypeVar("TService") -ServiceKey = type class MissingDependencyException(Exception): """Deprecated alias for MissingDependencyError.""" @@ -169,10 +169,10 @@ class Scope(Enum): singleton = 1 -class _Registration(NamedTuple): - service: str +class _Registration(NamedTuple, Generic[TService]): + service: ServiceKey[TService] scope: Scope - builder: Callable[[], Any] + builder: Callable[..., TService] needs: Any args: list[Any] diff --git a/punq/_compat.py b/punq/_compat.py index 5192eed..1c62119 100644 --- a/punq/_compat.py +++ b/punq/_compat.py @@ -1,6 +1,11 @@ -from typing import ForwardRef +import sys +import typing GenericListClass = list +if sys.version_info >= (3, 11): + ServiceKey = type +else: + ServiceKey = typing.Type def is_generic_list(service): @@ -12,4 +17,4 @@ def is_generic_list(service): def ensure_forward_ref(self, service, factory, instance, **kwargs): if isinstance(service, str): - self.register(ForwardRef(service), factory, instance, **kwargs) + self.register(typing.ForwardRef(service), factory, instance, **kwargs) diff --git a/tests/test_typing.py b/tests/test_typing.py new file mode 100644 index 0000000..bc83f58 --- /dev/null +++ b/tests/test_typing.py @@ -0,0 +1,8 @@ +from typing import assert_type + +import punq as pq +from . import test_dependencies as d + +container = pq.Container() + +registration = pq._Registration(type[d.MessageWriter], pq.Scope.transient, d.TmpFileMessageWriter, pq.empty, []) From 4d95ef610f61939b525b1761f53854babf777cc6 Mon Sep 17 00:00:00 2001 From: Bob Gregory Date: Sun, 12 Jan 2025 18:04:54 +0000 Subject: [PATCH 5/8] Making headway --- punq/__init__.py | 46 ++++++++++++++++++++++++++++++--- tests/test_instance_creation.py | 19 +++++++++++--- tests/test_typing.py | 8 +++++- 3 files changed, 65 insertions(+), 8 deletions(-) diff --git a/punq/__init__.py b/punq/__init__.py index cdacb19..a7d7a23 100644 --- a/punq/__init__.py +++ b/punq/__init__.py @@ -21,14 +21,14 @@ from collections import defaultdict from enum import Enum from importlib.metadata import PackageNotFoundError, version -from typing import Any, Callable, NamedTuple, get_type_hints, TypeVar, Generic +from typing import TYPE_CHECKING, Any, Callable, Generic, NamedTuple, Self, TypeVar, get_type_hints, overload -from ._compat import ensure_forward_ref, is_generic_list, ServiceKey +from ._compat import ServiceKey, ensure_forward_ref, is_generic_list with contextlib.suppress(PackageNotFoundError): __version__ = version(__name__) -TService = TypeVar("TService") +TService = TypeVar("TService", covariant=True) class MissingDependencyException(Exception): @@ -174,7 +174,15 @@ class _Registration(NamedTuple, Generic[TService]): scope: Scope builder: Callable[..., TService] needs: Any - args: list[Any] + args: dict + + +class _UntypedRegistration(NamedTuple): + service: str + scope: Scope + builder: Callable[..., Any] + needs: Any + args: dict class _Empty: @@ -385,6 +393,26 @@ def __init__(self, registrations=None): self.register(Container, instance=self) self._singletons = {} + @overload + def register(self, service: str, factory: Callable[..., Any]) -> Self: + ... + + @overload + def register(self, service: str, *, instance: Any) -> Self: + ... + + @overload + def register(self, service: ServiceKey[TService], factory: Callable[..., TService], *, scope=Scope.transient, **kwargs) -> Self: + ... + + @overload + def register(self, service: ServiceKey[TService], *, instance: TService, scope=Scope.transient, **kwargs) -> Self: + ... + + @overload + def register(self, service: ServiceKey[TService], *, scope=Scope.transient, **kwargs) -> Self: + ... + def register(self, service, factory=empty, instance=empty, scope=Scope.transient, **kwargs): """Register a dependency into the container. @@ -523,6 +551,8 @@ def _resolve_impl(self, service_key, kwargs, context, default=None): target = context.target(service_key) + if TYPE_CHECKING: + assert target is not None if target.is_generic_list(): return self.resolve_all(target.generic_parameter) @@ -536,6 +566,14 @@ def _resolve_impl(self, service_key, kwargs, context, default=None): return self._build_impl(registration, kwargs, context) + @overload + def resolve(self, service_key: str, **kwargs) -> Any: + ... + + @overload + def resolve(self, service_key: ServiceKey[TService], **kwargs) -> TService: + ... + def resolve(self, service_key, **kwargs): """Build and return an instance of a registered service.""" context = self.registrations.build_context(service_key) diff --git a/tests/test_instance_creation.py b/tests/test_instance_creation.py index 01b7b28..605245a 100644 --- a/tests/test_instance_creation.py +++ b/tests/test_instance_creation.py @@ -1,3 +1,4 @@ +import typing as t from tempfile import NamedTemporaryFile import pytest @@ -27,8 +28,11 @@ def test_dependencies_are_injected(): container.register(MessageSpeaker, HelloWorldSpeaker) speaker = container.resolve(MessageSpeaker) + t.assert_type(speaker, MessageSpeaker) expect(speaker).to(be_a(HelloWorldSpeaker)) + + speaker = t.cast(HelloWorldSpeaker, speaker) expect(speaker.writer).to(be_a(StdoutMessageWriter)) @@ -54,6 +58,9 @@ def test_can_register_with_a_custom_factory(): speaker = container.resolve(MessageSpeaker) + t.assert_type(speaker, MessageSpeaker) + speaker = t.cast(HelloWorldSpeaker, speaker) + expect(speaker).to(be_a(HelloWorldSpeaker)) expect(speaker.writer).to(equal("win")) @@ -104,7 +111,7 @@ def test_registering_an_instance_as_concrete_is_exception(): writer = MessageWriter() with pytest.raises(InvalidRegistrationError): - container.register(writer) + container.register(writer) # type: ignore def test_registering_an_instance_as_factory_is_exception(): @@ -116,7 +123,7 @@ def test_registering_an_instance_as_factory_is_exception(): writer = MessageWriter() with pytest.raises(InvalidRegistrationError): - container.register(MessageWriter, writer) + container.register(MessageWriter, writer) # type: ignore def test_registering_a_callable_as_concrete_is_exception(): @@ -129,7 +136,7 @@ def test_registering_a_callable_as_concrete_is_exception(): container = Container() with pytest.raises(InvalidRegistrationError): - container.register(lambda: "oops") + container.register(lambda: "oops") # type: ignore def test_can_provide_arguments_to_registrations(): @@ -137,8 +144,11 @@ def test_can_provide_arguments_to_registrations(): container.register(MessageWriter, FancyDbMessageWriter, cstr=lambda: "Hello world") writer = container.resolve(MessageWriter) + t.assert_type(writer, MessageWriter) expect(writer).to(be_a(FancyDbMessageWriter)) + + writer = t.cast(FancyDbMessageWriter, writer) expect(writer.connection_string).to(equal("Hello world")) @@ -147,6 +157,7 @@ def test_can_provide_arguments_to_resolve(): container.register(MessageWriter, TmpFileMessageWriter) instance = container.resolve(MessageWriter, path="foo") + instance = t.cast(TmpFileMessageWriter, instance) expect(instance.path).to(equal("foo")) @@ -156,6 +167,8 @@ def test_can_provide_arguments_to_resolve_having_dependencies(): container.register(MessageWriter, WrappingMessageWriter) instance = container.resolve(MessageWriter, context="bar") + + instance = t.cast(WrappingMessageWriter, instance) expect(instance.context).to(equal("bar")) diff --git a/tests/test_typing.py b/tests/test_typing.py index bc83f58..63afb64 100644 --- a/tests/test_typing.py +++ b/tests/test_typing.py @@ -1,8 +1,14 @@ -from typing import assert_type +import typing as t import punq as pq + from . import test_dependencies as d container = pq.Container() registration = pq._Registration(type[d.MessageWriter], pq.Scope.transient, d.TmpFileMessageWriter, pq.empty, []) + + +registration = pq._UntypedRegistration("My type", pq.Scope.transient, d.TmpFileMessageWriter, pq.empty, []) + +t.assert_type(registration.builder, t.Callable[..., t.Any]) From 0e5a6a3f2a0eff3a12cbd27f6f5bde7efae16cb8 Mon Sep 17 00:00:00 2001 From: Bob Gregory Date: Sun, 12 Jan 2025 20:55:57 +0000 Subject: [PATCH 6/8] go bakc to dict --- punq/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/punq/__init__.py b/punq/__init__.py index a7d7a23..a94e714 100644 --- a/punq/__init__.py +++ b/punq/__init__.py @@ -286,7 +286,7 @@ def register_service_and_instance(self, service, instance): ... ) """ - self.__registrations.append(service, _Registration(service, Scope.singleton, lambda: instance, {}, [])) + self.__registrations.append(service, _Registration(service, Scope.singleton, lambda: instance, {}, {})) def register_concrete_service(self, service, scope, resolve_args=None): """Register a service as its own implementation. From e300892d24eefa45a4eb749720a819687ebd389c Mon Sep 17 00:00:00 2001 From: Bob Gregory Date: Mon, 13 Jan 2025 02:05:28 +0000 Subject: [PATCH 7/8] Use typing_extensions --- punq/__init__.py | 3 +-- punq/_compat.py | 5 +++-- pyproject.toml | 4 ++++ tests/test_kwonly_deps.py | 8 +++++++- tests/test_typing.py | 32 +++++++++++++++++++++++++------- uv.lock | 2 ++ 6 files changed, 42 insertions(+), 12 deletions(-) diff --git a/punq/__init__.py b/punq/__init__.py index a94e714..ae914e9 100644 --- a/punq/__init__.py +++ b/punq/__init__.py @@ -21,7 +21,7 @@ from collections import defaultdict from enum import Enum from importlib.metadata import PackageNotFoundError, version -from typing import TYPE_CHECKING, Any, Callable, Generic, NamedTuple, Self, TypeVar, get_type_hints, overload +from typing import TYPE_CHECKING, Any, Callable, Generic, NamedTuple, Self, TypeVar, get_type_hints, overload, Protocol from ._compat import ServiceKey, ensure_forward_ref, is_generic_list @@ -30,7 +30,6 @@ TService = TypeVar("TService", covariant=True) - class MissingDependencyException(Exception): """Deprecated alias for MissingDependencyError.""" diff --git a/punq/_compat.py b/punq/_compat.py index 1c62119..31b07f3 100644 --- a/punq/_compat.py +++ b/punq/_compat.py @@ -1,11 +1,12 @@ import sys import typing +import typing_extensions GenericListClass = list if sys.version_info >= (3, 11): - ServiceKey = type + ServiceKey = typing_extensions.TypeForm else: - ServiceKey = typing.Type + ServiceKey = typing_extensions.TypeForm def is_generic_list(service): diff --git a/pyproject.toml b/pyproject.toml index a6ee98c..46f4f5a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,6 +40,7 @@ dev-dependencies = [ "mkdocs-include-markdown-plugin>=6.2.2", "xdoctest>=1.2.0", "pyright>=1.1.391", + "typing-extensions>=4.12.2", ] [build-system] @@ -56,6 +57,9 @@ warn_return_any = true warn_unused_ignores = true show_error_codes = true +[tool.pyright] +enableExperimentalFeatures = true + [tool.pytest.ini_options] testpaths = ["tests"] diff --git a/tests/test_kwonly_deps.py b/tests/test_kwonly_deps.py index 3e78664..0b72316 100644 --- a/tests/test_kwonly_deps.py +++ b/tests/test_kwonly_deps.py @@ -1,6 +1,8 @@ -from typing import Protocol +from typing import Protocol, Callable, reveal_type +from typing_extensions import TypeForm import punq +from punq._compat import ServiceKey class Parser(Protocol): @@ -42,6 +44,10 @@ def parse(self, val) -> str: def test_can_resolve_with_kwonlyargs(): container = punq.Container() result = [] + f: Callable[..., Parser] = ReverseParser + k: TypeForm[Parser] = Parser + + container.register(Parser, ReverseParser) container.register(Writer, instance=ListWriter(result)) container.register(Doer) diff --git a/tests/test_typing.py b/tests/test_typing.py index 63afb64..b51f7a5 100644 --- a/tests/test_typing.py +++ b/tests/test_typing.py @@ -1,14 +1,32 @@ -import typing as t +from typing import Callable, Generic, Type, TypeVar, Protocol, Union -import punq as pq +# A flexible TypeVar for TService +TService = TypeVar("TService", bound=object) # Must be an object type +TProtocol = TypeVar("TProtocol", bound=Protocol) -from . import test_dependencies as d +class _Registration(Generic[TService]): + def __init__( + self, + service: Union[Type[TService], Protocol], # Supports a class or a protocol + builder: Callable[..., TService], # Builder must return TService or compatible type + ): + self.service = service + self.builder = builder -container = pq.Container() + # Runtime validation for classes (if service is a class) + if isinstance(service, type): + instance = builder() + if not isinstance(instance, service): + raise TypeError(f"Builder does not produce an instance of {service}") -registration = pq._Registration(type[d.MessageWriter], pq.Scope.transient, d.TmpFileMessageWriter, pq.empty, []) +class Parent(Protocol): + def do_something(self) -> None: + pass -registration = pq._UntypedRegistration("My type", pq.Scope.transient, d.TmpFileMessageWriter, pq.empty, []) +class Child: + def do_something(self) -> None: + print("Child doing something") -t.assert_type(registration.builder, t.Callable[..., t.Any]) + +f: Callable[..., Parent] = Child diff --git a/uv.lock b/uv.lock index 4c7574c..0b2521a 100644 --- a/uv.lock +++ b/uv.lock @@ -754,6 +754,7 @@ dev = [ { name = "ruff" }, { name = "sqlalchemy" }, { name = "tox-uv" }, + { name = "typing-extensions" }, { name = "xdoctest" }, ] @@ -775,6 +776,7 @@ dev = [ { name = "ruff", specifier = ">=0.6.9" }, { name = "sqlalchemy", specifier = ">=2.0.36" }, { name = "tox-uv", specifier = ">=1.11.3" }, + { name = "typing-extensions", specifier = ">=4.12.2" }, { name = "xdoctest", specifier = ">=1.2.0" }, ] From 2a285015e1b0e64e378b6cba2cf8748d935674e5 Mon Sep 17 00:00:00 2001 From: Bob Gregory Date: Mon, 13 Jan 2025 02:11:03 +0000 Subject: [PATCH 8/8] Progress --- punq/_compat.py | 5 +---- tests/test_list_resolution.py | 5 ++++- tests/test_scoped_resolution.py | 10 ++++++---- 3 files changed, 11 insertions(+), 9 deletions(-) diff --git a/punq/_compat.py b/punq/_compat.py index 31b07f3..cd83605 100644 --- a/punq/_compat.py +++ b/punq/_compat.py @@ -3,10 +3,7 @@ import typing_extensions GenericListClass = list -if sys.version_info >= (3, 11): - ServiceKey = typing_extensions.TypeForm -else: - ServiceKey = typing_extensions.TypeForm +ServiceKey = typing_extensions.TypeForm def is_generic_list(service): diff --git a/tests/test_list_resolution.py b/tests/test_list_resolution.py index 135ed7d..6160a78 100644 --- a/tests/test_list_resolution.py +++ b/tests/test_list_resolution.py @@ -1,4 +1,5 @@ -from expects import expect, have_len +import typing +from expects import be_a, expect, have_len from punq import Container from tests.test_dependencies import MessageSpeaker, MessageWriter, StdoutMessageWriter, TmpFileMessageWriter @@ -24,4 +25,6 @@ def __init__(self, writers: list[MessageWriter]) -> None: instance = container.resolve(MessageSpeaker) + expect(instance).to(be_a(BroadcastSpeaker)) + instance = typing.cast(BroadcastSpeaker, instance) expect(instance.writers).to(have_len(2)) diff --git a/tests/test_scoped_resolution.py b/tests/test_scoped_resolution.py index cc6684e..2eaa211 100644 --- a/tests/test_scoped_resolution.py +++ b/tests/test_scoped_resolution.py @@ -1,3 +1,5 @@ +import typing + import pytest from expects import be_a, expect @@ -87,8 +89,8 @@ def test_when_overriding_a_singleton_instance(): parent.register(ConnectionStringFactory, instance=lambda: "hello") child.register(ConnectionStringFactory, instance=lambda: "world") - assert parent.resolve(MessageWriter).connection_string == "hello" - assert child.resolve(MessageWriter).connection_string == "world" + assert typing.cast(FancyDbMessageWriter, parent.resolve(MessageWriter)).connection_string == "hello" + assert typing.cast(FancyDbMessageWriter, child.resolve(MessageWriter)).connection_string == "world" def test_when_inheriting_a_singleton_instance(): @@ -103,8 +105,8 @@ def test_when_inheriting_a_singleton_instance(): parent.register(MessageWriter, FancyDbMessageWriter) parent.register(ConnectionStringFactory, instance=lambda: "hello") - assert parent.resolve(MessageWriter).connection_string == "hello" - assert child.resolve(MessageWriter).connection_string == "hello" + assert typing.cast(FancyDbMessageWriter, parent.resolve(MessageWriter)).connection_string == "hello" + assert typing.cast(FancyDbMessageWriter, child.resolve(MessageWriter)).connection_string == "hello" ContextBag = dict