diff --git a/Makefile b/Makefile index 8105fe2..30b0e64 100644 --- a/Makefile +++ b/Makefile @@ -8,11 +8,19 @@ install: ## Install the virtual environment and install the pre-commit hooks check: ## Run code quality tools. @echo "🚀 Checking lock file consistency with 'pyproject.toml'" @uv lock --locked + @echo "🚀 Running type check" + @uv run pyright @echo "🚀 Linting code: Running pre-commit" @uv run pre-commit run -a @echo "🚀 Checking for obsolete dependencies: Running deptry" @uv run deptry . +.PHONY: fmt +fmt: + @echo "Format and fix" + @uv run ruff check --fix + @uv run ruff format + .PHONY: test test: ## Test the code with pytest @echo "🚀 Testing code: Running pytest" diff --git a/punq/__init__.py b/punq/__init__.py index f3fd9ce..ae914e9 100644 --- a/punq/__init__.py +++ b/punq/__init__.py @@ -18,16 +18,17 @@ import contextlib import inspect -from collections import defaultdict, ChainMap +from collections import defaultdict from enum import Enum from importlib.metadata import PackageNotFoundError, version -from typing import Any, Callable, NamedTuple, get_type_hints +from typing import TYPE_CHECKING, Any, Callable, Generic, NamedTuple, Self, TypeVar, get_type_hints, overload, Protocol -from ._compat import ensure_forward_ref, is_generic_list +from ._compat import ServiceKey, ensure_forward_ref, is_generic_list with contextlib.suppress(PackageNotFoundError): __version__ = version(__name__) +TService = TypeVar("TService", covariant=True) class MissingDependencyException(Exception): """Deprecated alias for MissingDependencyError.""" @@ -128,6 +129,7 @@ class InvalidForwardReferenceError(InvalidForwardReferenceException): pass + class RegistrationScope: """ Simple chained dictionary[service, list[implementation]]. @@ -154,8 +156,6 @@ def get(self, key): return self.__get(key, []) - - class Scope(Enum): """Controls the lifetime of resolved objects. @@ -168,12 +168,20 @@ class Scope(Enum): singleton = 1 -class _Registration(NamedTuple): +class _Registration(NamedTuple, Generic[TService]): + service: ServiceKey[TService] + scope: Scope + builder: Callable[..., TService] + needs: Any + args: dict + + +class _UntypedRegistration(NamedTuple): service: str scope: Scope - builder: Callable[[], Any] + builder: Callable[..., Any] needs: Any - args: list[Any] + args: dict class _Empty: @@ -248,8 +256,7 @@ def register_service_and_impl(self, service, scope, impl, resolve_args): Sending message via smtp: Hello """ self.__registrations.append( - service, - _Registration(service, scope, impl, self._get_needs_for_ctor(impl), resolve_args) + service, _Registration(service, scope, impl, self._get_needs_for_ctor(impl), resolve_args) ) def register_service_and_instance(self, service, instance): @@ -300,8 +307,7 @@ def register_concrete_service(self, service, scope, resolve_args=None): if not inspect.isclass(service): raise InvalidSelfRegistrationError(service) self.__registrations.append( - service, - _Registration(service, scope, service, self._get_needs_for_ctor(service), resolve_args or {}) + service, _Registration(service, scope, service, self._get_needs_for_ctor(service), resolve_args or {}) ) def build_context(self, key, existing=None): @@ -386,6 +392,26 @@ def __init__(self, registrations=None): self.register(Container, instance=self) self._singletons = {} + @overload + def register(self, service: str, factory: Callable[..., Any]) -> Self: + ... + + @overload + def register(self, service: str, *, instance: Any) -> Self: + ... + + @overload + def register(self, service: ServiceKey[TService], factory: Callable[..., TService], *, scope=Scope.transient, **kwargs) -> Self: + ... + + @overload + def register(self, service: ServiceKey[TService], *, instance: TService, scope=Scope.transient, **kwargs) -> Self: + ... + + @overload + def register(self, service: ServiceKey[TService], *, scope=Scope.transient, **kwargs) -> Self: + ... + def register(self, service, factory=empty, instance=empty, scope=Scope.transient, **kwargs): """Register a dependency into the container. @@ -524,6 +550,8 @@ def _resolve_impl(self, service_key, kwargs, context, default=None): target = context.target(service_key) + if TYPE_CHECKING: + assert target is not None if target.is_generic_list(): return self.resolve_all(target.generic_parameter) @@ -537,6 +565,14 @@ def _resolve_impl(self, service_key, kwargs, context, default=None): return self._build_impl(registration, kwargs, context) + @overload + def resolve(self, service_key: str, **kwargs) -> Any: + ... + + @overload + def resolve(self, service_key: ServiceKey[TService], **kwargs) -> TService: + ... + def resolve(self, service_key, **kwargs): """Build and return an instance of a registered service.""" context = self.registrations.build_context(service_key) diff --git a/punq/_compat.py b/punq/_compat.py index 5192eed..cd83605 100644 --- a/punq/_compat.py +++ b/punq/_compat.py @@ -1,6 +1,9 @@ -from typing import ForwardRef +import sys +import typing +import typing_extensions GenericListClass = list +ServiceKey = typing_extensions.TypeForm def is_generic_list(service): @@ -12,4 +15,4 @@ def is_generic_list(service): def ensure_forward_ref(self, service, factory, instance, **kwargs): if isinstance(service, str): - self.register(ForwardRef(service), factory, instance, **kwargs) + self.register(typing.ForwardRef(service), factory, instance, **kwargs) diff --git a/pyproject.toml b/pyproject.toml index 95511d0..46f4f5a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,6 +39,8 @@ dev-dependencies = [ "attrs>=24.2.0", "mkdocs-include-markdown-plugin>=6.2.2", "xdoctest>=1.2.0", + "pyright>=1.1.391", + "typing-extensions>=4.12.2", ] [build-system] @@ -55,6 +57,9 @@ warn_return_any = true warn_unused_ignores = true show_error_codes = true +[tool.pyright] +enableExperimentalFeatures = true + [tool.pytest.ini_options] testpaths = ["tests"] diff --git a/tests/test_instance_creation.py b/tests/test_instance_creation.py index 489c23d..605245a 100644 --- a/tests/test_instance_creation.py +++ b/tests/test_instance_creation.py @@ -1,3 +1,4 @@ +import typing as t from tempfile import NamedTemporaryFile import pytest @@ -27,8 +28,11 @@ def test_dependencies_are_injected(): container.register(MessageSpeaker, HelloWorldSpeaker) speaker = container.resolve(MessageSpeaker) + t.assert_type(speaker, MessageSpeaker) expect(speaker).to(be_a(HelloWorldSpeaker)) + + speaker = t.cast(HelloWorldSpeaker, speaker) expect(speaker.writer).to(be_a(StdoutMessageWriter)) @@ -54,6 +58,9 @@ def test_can_register_with_a_custom_factory(): speaker = container.resolve(MessageSpeaker) + t.assert_type(speaker, MessageSpeaker) + speaker = t.cast(HelloWorldSpeaker, speaker) + expect(speaker).to(be_a(HelloWorldSpeaker)) expect(speaker.writer).to(equal("win")) @@ -104,7 +111,7 @@ def test_registering_an_instance_as_concrete_is_exception(): writer = MessageWriter() with pytest.raises(InvalidRegistrationError): - container.register(writer) + container.register(writer) # type: ignore def test_registering_an_instance_as_factory_is_exception(): @@ -116,7 +123,7 @@ def test_registering_an_instance_as_factory_is_exception(): writer = MessageWriter() with pytest.raises(InvalidRegistrationError): - container.register(MessageWriter, writer) + container.register(MessageWriter, writer) # type: ignore def test_registering_a_callable_as_concrete_is_exception(): @@ -129,7 +136,7 @@ def test_registering_a_callable_as_concrete_is_exception(): container = Container() with pytest.raises(InvalidRegistrationError): - container.register(lambda: "oops") + container.register(lambda: "oops") # type: ignore def test_can_provide_arguments_to_registrations(): @@ -137,8 +144,11 @@ def test_can_provide_arguments_to_registrations(): container.register(MessageWriter, FancyDbMessageWriter, cstr=lambda: "Hello world") writer = container.resolve(MessageWriter) + t.assert_type(writer, MessageWriter) expect(writer).to(be_a(FancyDbMessageWriter)) + + writer = t.cast(FancyDbMessageWriter, writer) expect(writer.connection_string).to(equal("Hello world")) @@ -147,6 +157,7 @@ def test_can_provide_arguments_to_resolve(): container.register(MessageWriter, TmpFileMessageWriter) instance = container.resolve(MessageWriter, path="foo") + instance = t.cast(TmpFileMessageWriter, instance) expect(instance.path).to(equal("foo")) @@ -156,6 +167,8 @@ def test_can_provide_arguments_to_resolve_having_dependencies(): container.register(MessageWriter, WrappingMessageWriter) instance = container.resolve(MessageWriter, context="bar") + + instance = t.cast(WrappingMessageWriter, instance) expect(instance.context).to(equal("bar")) @@ -165,15 +178,15 @@ def test_can_provide_typed_arguments_to_resolve(): container.register(TmpFileMessageWriter) container.register(HelloWorldSpeaker) - tmpfile = NamedTemporaryFile() + with NamedTemporaryFile() as tmpfile: - writer = container.resolve(MessageWriter, path=tmpfile.name) - speaker = container.resolve(HelloWorldSpeaker, writer=writer) + writer = container.resolve(MessageWriter, path=tmpfile.name) + speaker = container.resolve(HelloWorldSpeaker, writer=writer) - speaker.speak() + speaker.speak() - tmpfile.seek(0) - expect(tmpfile.read().decode()).to(equal("Hello World")) + tmpfile.seek(0) + expect(tmpfile.read().decode()).to(equal("Hello World")) def test_resolve_returns_the_latest_registration_for_a_service(): diff --git a/tests/test_kwonly_deps.py b/tests/test_kwonly_deps.py index 3e78664..0b72316 100644 --- a/tests/test_kwonly_deps.py +++ b/tests/test_kwonly_deps.py @@ -1,6 +1,8 @@ -from typing import Protocol +from typing import Protocol, Callable, reveal_type +from typing_extensions import TypeForm import punq +from punq._compat import ServiceKey class Parser(Protocol): @@ -42,6 +44,10 @@ def parse(self, val) -> str: def test_can_resolve_with_kwonlyargs(): container = punq.Container() result = [] + f: Callable[..., Parser] = ReverseParser + k: TypeForm[Parser] = Parser + + container.register(Parser, ReverseParser) container.register(Writer, instance=ListWriter(result)) container.register(Doer) diff --git a/tests/test_list_resolution.py b/tests/test_list_resolution.py index 135ed7d..6160a78 100644 --- a/tests/test_list_resolution.py +++ b/tests/test_list_resolution.py @@ -1,4 +1,5 @@ -from expects import expect, have_len +import typing +from expects import be_a, expect, have_len from punq import Container from tests.test_dependencies import MessageSpeaker, MessageWriter, StdoutMessageWriter, TmpFileMessageWriter @@ -24,4 +25,6 @@ def __init__(self, writers: list[MessageWriter]) -> None: instance = container.resolve(MessageSpeaker) + expect(instance).to(be_a(BroadcastSpeaker)) + instance = typing.cast(BroadcastSpeaker, instance) expect(instance.writers).to(have_len(2)) diff --git a/tests/test_resolution_scope.py b/tests/test_resolution_scope.py index e144279..f24eb89 100644 --- a/tests/test_resolution_scope.py +++ b/tests/test_resolution_scope.py @@ -7,7 +7,7 @@ In order to handle scoping, we need our own datastructure. """ -from collections import defaultdict + from punq import RegistrationScope @@ -19,6 +19,7 @@ def test_a_root_scope_returns_the_empty_list_when_nothing_is_registered(): scope = RegistrationScope() assert scope.get("some_key") == [] + def test_a_scope_contains_items(): """ We can add items into a scope and get them back. @@ -30,6 +31,7 @@ def test_a_scope_contains_items(): assert scope.get("some-key") == ["hello", "world"] + def test_a_child_scope_extends_its_parent(): """ When a child scope adds an item, it should be added to the list of @@ -39,11 +41,12 @@ def test_a_child_scope_extends_its_parent(): child = parent.child() parent.append("some-key", "hello") - child.append("some-key","world") + child.append("some-key", "world") assert child.get("some-key") == ["hello", "world"] assert parent.get("some-key") == ["hello"] + def test_resolution_can_skip_a_level(): """ If someone goes nuts, the registrations should inherit across multiple @@ -64,13 +67,13 @@ def test_resolution_can_skip_a_level(): assert parent.get("a") == [1] assert child.get("a") == [1] - assert grandparent.get("b") == [ 2 ] - assert parent.get("b") == [ 2 ] + assert grandparent.get("b") == [2] + assert parent.get("b") == [2] assert child.get("b") == [2, "x"] assert grandparent.get("c") == [] - assert parent.get("c") == [ 3 ] - assert child.get("c") == [ 3 ] + assert parent.get("c") == [3] + assert child.get("c") == [3] assert grandparent.get("d") == [] assert parent.get("d") == [] diff --git a/tests/test_scoped_resolution.py b/tests/test_scoped_resolution.py index 088e43d..2eaa211 100644 --- a/tests/test_scoped_resolution.py +++ b/tests/test_scoped_resolution.py @@ -1,16 +1,15 @@ -from expects import be_a, equal, expect +import typing + import pytest +from expects import be_a, expect -from punq import Container, InvalidRegistrationError, MissingDependencyError, Scope +from punq import Container, MissingDependencyError from tests.test_dependencies import ( ConnectionStringFactory, FancyDbMessageWriter, - HelloWorldSpeaker, - MessageSpeaker, MessageWriter, StdoutMessageWriter, TmpFileMessageWriter, - WrappingMessageWriter, ) @@ -21,6 +20,7 @@ def test_scoped_service_with_no_dependencies(): child = container.child() expect(child.resolve(MessageWriter)).to(be_a(StdoutMessageWriter)) + def test_when_overriding_a_service(): """ In this test, we replace the parent registration completely. @@ -89,8 +89,8 @@ def test_when_overriding_a_singleton_instance(): parent.register(ConnectionStringFactory, instance=lambda: "hello") child.register(ConnectionStringFactory, instance=lambda: "world") - assert parent.resolve(MessageWriter).connection_string == "hello" - assert child.resolve(MessageWriter).connection_string == "world" + assert typing.cast(FancyDbMessageWriter, parent.resolve(MessageWriter)).connection_string == "hello" + assert typing.cast(FancyDbMessageWriter, child.resolve(MessageWriter)).connection_string == "world" def test_when_inheriting_a_singleton_instance(): @@ -105,19 +105,19 @@ def test_when_inheriting_a_singleton_instance(): parent.register(MessageWriter, FancyDbMessageWriter) parent.register(ConnectionStringFactory, instance=lambda: "hello") - assert parent.resolve(MessageWriter).connection_string == "hello" - assert child.resolve(MessageWriter).connection_string == "hello" + assert typing.cast(FancyDbMessageWriter, parent.resolve(MessageWriter)).connection_string == "hello" + assert typing.cast(FancyDbMessageWriter, child.resolve(MessageWriter)).connection_string == "hello" ContextBag = dict -class ThingDoer: +class ThingDoer: def __init__(self, context: ContextBag): self.context = context -def test_when_registering_a_state_bag(): +def test_when_registering_a_state_bag(): parent = Container() parent.register(ThingDoer) diff --git a/tests/test_typing.py b/tests/test_typing.py new file mode 100644 index 0000000..b51f7a5 --- /dev/null +++ b/tests/test_typing.py @@ -0,0 +1,32 @@ +from typing import Callable, Generic, Type, TypeVar, Protocol, Union + +# A flexible TypeVar for TService +TService = TypeVar("TService", bound=object) # Must be an object type +TProtocol = TypeVar("TProtocol", bound=Protocol) + +class _Registration(Generic[TService]): + def __init__( + self, + service: Union[Type[TService], Protocol], # Supports a class or a protocol + builder: Callable[..., TService], # Builder must return TService or compatible type + ): + self.service = service + self.builder = builder + + # Runtime validation for classes (if service is a class) + if isinstance(service, type): + instance = builder() + if not isinstance(instance, service): + raise TypeError(f"Builder does not produce an instance of {service}") + + +class Parent(Protocol): + def do_something(self) -> None: + pass + +class Child: + def do_something(self) -> None: + print("Child doing something") + + +f: Callable[..., Parent] = Child diff --git a/uv.lock b/uv.lock index 75cfb8d..0b2521a 100644 --- a/uv.lock +++ b/uv.lock @@ -748,11 +748,13 @@ dev = [ { name = "mkdocs-material" }, { name = "mkdocstrings", extra = ["python"] }, { name = "pre-commit" }, + { name = "pyright" }, { name = "pytest" }, { name = "pytest-cov" }, { name = "ruff" }, { name = "sqlalchemy" }, { name = "tox-uv" }, + { name = "typing-extensions" }, { name = "xdoctest" }, ] @@ -768,11 +770,13 @@ dev = [ { name = "mkdocs-material", specifier = ">=8.5.10" }, { name = "mkdocstrings", extras = ["python"], specifier = ">=0.26.1" }, { name = "pre-commit", specifier = ">=2.20.0" }, + { name = "pyright", specifier = ">=1.1.391" }, { name = "pytest", specifier = ">=7.2.0" }, { name = "pytest-cov", specifier = ">=4.0.0" }, { name = "ruff", specifier = ">=0.6.9" }, { name = "sqlalchemy", specifier = ">=2.0.36" }, { name = "tox-uv", specifier = ">=1.11.3" }, + { name = "typing-extensions", specifier = ">=4.12.2" }, { name = "xdoctest", specifier = ">=1.2.0" }, ] @@ -811,6 +815,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ba/f4/3c4ddfcc0c19c217c6de513842d286de8021af2f2ab79bbb86c00342d778/pyproject_api-1.8.0-py3-none-any.whl", hash = "sha256:3d7d347a047afe796fd5d1885b1e391ba29be7169bd2f102fcd378f04273d228", size = 13100 }, ] +[[package]] +name = "pyright" +version = "1.1.391" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "nodeenv" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/11/05/4ea52a8a45cc28897edb485b4102d37cbfd5fce8445d679cdeb62bfad221/pyright-1.1.391.tar.gz", hash = "sha256:66b2d42cdf5c3cbab05f2f4b76e8bec8aa78e679bfa0b6ad7b923d9e027cadb2", size = 21965 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ad/89/66f49552fbeb21944c8077d11834b2201514a56fd1b7747ffff9630f1bd9/pyright-1.1.391-py3-none-any.whl", hash = "sha256:54fa186f8b3e8a55a44ebfa842636635688670c6896dcf6cf4a7fc75062f4d15", size = 18579 }, +] + [[package]] name = "pytest" version = "8.3.3"