Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,19 @@ install: ## Install the virtual environment and install the pre-commit hooks
check: ## Run code quality tools.
@echo "🚀 Checking lock file consistency with 'pyproject.toml'"
@uv lock --locked
@echo "🚀 Running type check"
@uv run pyright
@echo "🚀 Linting code: Running pre-commit"
@uv run pre-commit run -a
@echo "🚀 Checking for obsolete dependencies: Running deptry"
@uv run deptry .

.PHONY: fmt
fmt:
@echo "Format and fix"
@uv run ruff check --fix
@uv run ruff format

.PHONY: test
test: ## Test the code with pytest
@echo "🚀 Testing code: Running pytest"
Expand Down
60 changes: 48 additions & 12 deletions punq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,17 @@

import contextlib
import inspect
from collections import defaultdict, ChainMap
from collections import defaultdict
from enum import Enum
from importlib.metadata import PackageNotFoundError, version
from typing import Any, Callable, NamedTuple, get_type_hints
from typing import TYPE_CHECKING, Any, Callable, Generic, NamedTuple, Self, TypeVar, get_type_hints, overload, Protocol

from ._compat import ensure_forward_ref, is_generic_list
from ._compat import ServiceKey, ensure_forward_ref, is_generic_list

with contextlib.suppress(PackageNotFoundError):
__version__ = version(__name__)

TService = TypeVar("TService", covariant=True)

class MissingDependencyException(Exception):
"""Deprecated alias for MissingDependencyError."""
Expand Down Expand Up @@ -128,6 +129,7 @@ class InvalidForwardReferenceError(InvalidForwardReferenceException):

pass


class RegistrationScope:
"""
Simple chained dictionary[service, list[implementation]].
Expand All @@ -154,8 +156,6 @@ def get(self, key):
return self.__get(key, [])




class Scope(Enum):
"""Controls the lifetime of resolved objects.

Expand All @@ -168,12 +168,20 @@ class Scope(Enum):
singleton = 1


class _Registration(NamedTuple):
class _Registration(NamedTuple, Generic[TService]):
service: ServiceKey[TService]
scope: Scope
builder: Callable[..., TService]
needs: Any
args: dict


class _UntypedRegistration(NamedTuple):
service: str
scope: Scope
builder: Callable[[], Any]
builder: Callable[..., Any]
needs: Any
args: list[Any]
args: dict


class _Empty:
Expand Down Expand Up @@ -248,8 +256,7 @@ def register_service_and_impl(self, service, scope, impl, resolve_args):
Sending message via smtp: Hello
"""
self.__registrations.append(
service,
_Registration(service, scope, impl, self._get_needs_for_ctor(impl), resolve_args)
service, _Registration(service, scope, impl, self._get_needs_for_ctor(impl), resolve_args)
)

def register_service_and_instance(self, service, instance):
Expand Down Expand Up @@ -300,8 +307,7 @@ def register_concrete_service(self, service, scope, resolve_args=None):
if not inspect.isclass(service):
raise InvalidSelfRegistrationError(service)
self.__registrations.append(
service,
_Registration(service, scope, service, self._get_needs_for_ctor(service), resolve_args or {})
service, _Registration(service, scope, service, self._get_needs_for_ctor(service), resolve_args or {})
)

def build_context(self, key, existing=None):
Expand Down Expand Up @@ -386,6 +392,26 @@ def __init__(self, registrations=None):
self.register(Container, instance=self)
self._singletons = {}

@overload
def register(self, service: str, factory: Callable[..., Any]) -> Self:
...

@overload
def register(self, service: str, *, instance: Any) -> Self:
...

@overload
def register(self, service: ServiceKey[TService], factory: Callable[..., TService], *, scope=Scope.transient, **kwargs) -> Self:
...

@overload
def register(self, service: ServiceKey[TService], *, instance: TService, scope=Scope.transient, **kwargs) -> Self:
...

@overload
def register(self, service: ServiceKey[TService], *, scope=Scope.transient, **kwargs) -> Self:
...

def register(self, service, factory=empty, instance=empty, scope=Scope.transient, **kwargs):
"""Register a dependency into the container.

Expand Down Expand Up @@ -524,6 +550,8 @@ def _resolve_impl(self, service_key, kwargs, context, default=None):

target = context.target(service_key)

if TYPE_CHECKING:
assert target is not None
if target.is_generic_list():
return self.resolve_all(target.generic_parameter)

Expand All @@ -537,6 +565,14 @@ def _resolve_impl(self, service_key, kwargs, context, default=None):

return self._build_impl(registration, kwargs, context)

@overload
def resolve(self, service_key: str, **kwargs) -> Any:
...

@overload
def resolve(self, service_key: ServiceKey[TService], **kwargs) -> TService:
...

def resolve(self, service_key, **kwargs):
"""Build and return an instance of a registered service."""
context = self.registrations.build_context(service_key)
Expand Down
7 changes: 5 additions & 2 deletions punq/_compat.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from typing import ForwardRef
import sys
import typing
import typing_extensions

GenericListClass = list
ServiceKey = typing_extensions.TypeForm


def is_generic_list(service):
Expand All @@ -12,4 +15,4 @@ def is_generic_list(service):

def ensure_forward_ref(self, service, factory, instance, **kwargs):
if isinstance(service, str):
self.register(ForwardRef(service), factory, instance, **kwargs)
self.register(typing.ForwardRef(service), factory, instance, **kwargs)
5 changes: 5 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ dev-dependencies = [
"attrs>=24.2.0",
"mkdocs-include-markdown-plugin>=6.2.2",
"xdoctest>=1.2.0",
"pyright>=1.1.391",
"typing-extensions>=4.12.2",
]

[build-system]
Expand All @@ -55,6 +57,9 @@ warn_return_any = true
warn_unused_ignores = true
show_error_codes = true

[tool.pyright]
enableExperimentalFeatures = true

[tool.pytest.ini_options]
testpaths = ["tests"]

Expand Down
31 changes: 22 additions & 9 deletions tests/test_instance_creation.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import typing as t
from tempfile import NamedTemporaryFile

import pytest
Expand Down Expand Up @@ -27,8 +28,11 @@ def test_dependencies_are_injected():
container.register(MessageSpeaker, HelloWorldSpeaker)

speaker = container.resolve(MessageSpeaker)
t.assert_type(speaker, MessageSpeaker)

expect(speaker).to(be_a(HelloWorldSpeaker))

speaker = t.cast(HelloWorldSpeaker, speaker)
expect(speaker.writer).to(be_a(StdoutMessageWriter))


Expand All @@ -54,6 +58,9 @@ def test_can_register_with_a_custom_factory():

speaker = container.resolve(MessageSpeaker)

t.assert_type(speaker, MessageSpeaker)
speaker = t.cast(HelloWorldSpeaker, speaker)

expect(speaker).to(be_a(HelloWorldSpeaker))
expect(speaker.writer).to(equal("win"))

Expand Down Expand Up @@ -104,7 +111,7 @@ def test_registering_an_instance_as_concrete_is_exception():
writer = MessageWriter()

with pytest.raises(InvalidRegistrationError):
container.register(writer)
container.register(writer) # type: ignore


def test_registering_an_instance_as_factory_is_exception():
Expand All @@ -116,7 +123,7 @@ def test_registering_an_instance_as_factory_is_exception():
writer = MessageWriter()

with pytest.raises(InvalidRegistrationError):
container.register(MessageWriter, writer)
container.register(MessageWriter, writer) # type: ignore


def test_registering_a_callable_as_concrete_is_exception():
Expand All @@ -129,16 +136,19 @@ def test_registering_a_callable_as_concrete_is_exception():
container = Container()

with pytest.raises(InvalidRegistrationError):
container.register(lambda: "oops")
container.register(lambda: "oops") # type: ignore


def test_can_provide_arguments_to_registrations():
container = Container()
container.register(MessageWriter, FancyDbMessageWriter, cstr=lambda: "Hello world")

writer = container.resolve(MessageWriter)
t.assert_type(writer, MessageWriter)

expect(writer).to(be_a(FancyDbMessageWriter))

writer = t.cast(FancyDbMessageWriter, writer)
expect(writer.connection_string).to(equal("Hello world"))


Expand All @@ -147,6 +157,7 @@ def test_can_provide_arguments_to_resolve():
container.register(MessageWriter, TmpFileMessageWriter)

instance = container.resolve(MessageWriter, path="foo")
instance = t.cast(TmpFileMessageWriter, instance)
expect(instance.path).to(equal("foo"))


Expand All @@ -156,6 +167,8 @@ def test_can_provide_arguments_to_resolve_having_dependencies():
container.register(MessageWriter, WrappingMessageWriter)

instance = container.resolve(MessageWriter, context="bar")

instance = t.cast(WrappingMessageWriter, instance)
expect(instance.context).to(equal("bar"))


Expand All @@ -165,15 +178,15 @@ def test_can_provide_typed_arguments_to_resolve():
container.register(TmpFileMessageWriter)
container.register(HelloWorldSpeaker)

tmpfile = NamedTemporaryFile()
with NamedTemporaryFile() as tmpfile:

writer = container.resolve(MessageWriter, path=tmpfile.name)
speaker = container.resolve(HelloWorldSpeaker, writer=writer)
writer = container.resolve(MessageWriter, path=tmpfile.name)
speaker = container.resolve(HelloWorldSpeaker, writer=writer)

speaker.speak()
speaker.speak()

tmpfile.seek(0)
expect(tmpfile.read().decode()).to(equal("Hello World"))
tmpfile.seek(0)
expect(tmpfile.read().decode()).to(equal("Hello World"))


def test_resolve_returns_the_latest_registration_for_a_service():
Expand Down
8 changes: 7 additions & 1 deletion tests/test_kwonly_deps.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from typing import Protocol
from typing import Protocol, Callable, reveal_type
from typing_extensions import TypeForm

import punq
from punq._compat import ServiceKey


class Parser(Protocol):
Expand Down Expand Up @@ -42,6 +44,10 @@ def parse(self, val) -> str:
def test_can_resolve_with_kwonlyargs():
container = punq.Container()
result = []
f: Callable[..., Parser] = ReverseParser
k: TypeForm[Parser] = Parser


container.register(Parser, ReverseParser)
container.register(Writer, instance=ListWriter(result))
container.register(Doer)
Expand Down
5 changes: 4 additions & 1 deletion tests/test_list_resolution.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from expects import expect, have_len
import typing
from expects import be_a, expect, have_len

from punq import Container
from tests.test_dependencies import MessageSpeaker, MessageWriter, StdoutMessageWriter, TmpFileMessageWriter
Expand All @@ -24,4 +25,6 @@ def __init__(self, writers: list[MessageWriter]) -> None:

instance = container.resolve(MessageSpeaker)

expect(instance).to(be_a(BroadcastSpeaker))
instance = typing.cast(BroadcastSpeaker, instance)
expect(instance.writers).to(have_len(2))
15 changes: 9 additions & 6 deletions tests/test_resolution_scope.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

In order to handle scoping, we need our own datastructure.
"""
from collections import defaultdict

from punq import RegistrationScope


Expand All @@ -19,6 +19,7 @@ def test_a_root_scope_returns_the_empty_list_when_nothing_is_registered():
scope = RegistrationScope()
assert scope.get("some_key") == []


def test_a_scope_contains_items():
"""
We can add items into a scope and get them back.
Expand All @@ -30,6 +31,7 @@ def test_a_scope_contains_items():

assert scope.get("some-key") == ["hello", "world"]


def test_a_child_scope_extends_its_parent():
"""
When a child scope adds an item, it should be added to the list of
Expand All @@ -39,11 +41,12 @@ def test_a_child_scope_extends_its_parent():
child = parent.child()

parent.append("some-key", "hello")
child.append("some-key","world")
child.append("some-key", "world")

assert child.get("some-key") == ["hello", "world"]
assert parent.get("some-key") == ["hello"]


def test_resolution_can_skip_a_level():
"""
If someone goes nuts, the registrations should inherit across multiple
Expand All @@ -64,13 +67,13 @@ def test_resolution_can_skip_a_level():
assert parent.get("a") == [1]
assert child.get("a") == [1]

assert grandparent.get("b") == [ 2 ]
assert parent.get("b") == [ 2 ]
assert grandparent.get("b") == [2]
assert parent.get("b") == [2]
assert child.get("b") == [2, "x"]

assert grandparent.get("c") == []
assert parent.get("c") == [ 3 ]
assert child.get("c") == [ 3 ]
assert parent.get("c") == [3]
assert child.get("c") == [3]

assert grandparent.get("d") == []
assert parent.get("d") == []
Expand Down
Loading