From 68697806d537987709a0f58cb1563684258d8c72 Mon Sep 17 00:00:00 2001 From: Konstantin Alekseev Date: Mon, 25 May 2026 09:39:46 +0300 Subject: [PATCH] sync core/exceptions and test/runner --- django-stubs/core/exceptions.pyi | 34 ++--- django-stubs/test/runner.pyi | 230 ++++++++++++++++++++----------- django-stubs/test/testcases.pyi | 2 +- django-stubs/test/utils.pyi | 189 ++++++++++++++++--------- s/fix-sync.yml | 68 ++++++++- s/sync | 1 + s/sync-files.txt | 3 + 7 files changed, 353 insertions(+), 174 deletions(-) diff --git a/django-stubs/core/exceptions.pyi b/django-stubs/core/exceptions.pyi index 9f0c5b26e..ceaada382 100644 --- a/django-stubs/core/exceptions.pyi +++ b/django-stubs/core/exceptions.pyi @@ -1,5 +1,5 @@ -from collections.abc import Iterator, Mapping -from typing import Any, TypeAlias +from collections.abc import Iterator +from typing import Any, Literal from django.utils.functional import _StrOrPromise @@ -7,8 +7,9 @@ class FieldDoesNotExist(Exception): ... class AppRegistryNotReady(Exception): ... class ObjectDoesNotExist(Exception): - silent_variable_failure: bool = ... + silent_variable_failure: bool +class ObjectNotUpdated(Exception): ... class MultipleObjectsReturned(Exception): ... class SuspiciousOperation(Exception): ... class SuspiciousMultipartForm(SuspiciousOperation): ... @@ -26,32 +27,27 @@ class MiddlewareNotUsed(Exception): ... class ImproperlyConfigured(Exception): ... class FieldError(Exception): ... -NON_FIELD_ERRORS: str - -ValidationErrorMessageArg: TypeAlias = ( - _StrOrPromise | ValidationError | dict[str, ValidationErrorMessageArg] | list[ValidationErrorMessageArg] -) +NON_FIELD_ERRORS: Literal["__all__"] class ValidationError(Exception): - error_dict: dict[str, list[ValidationError]] | None - error_list: list[ValidationError] | None - message: _StrOrPromise | None + error_dict: dict[str, list[ValidationError]] + error_list: list[ValidationError] + message: _StrOrPromise code: str | None - params: Mapping[str, Any] | None + params: dict[str, Any] | None def __init__( self, - message: ValidationErrorMessageArg, - code: str | None = ..., - params: Mapping[str, Any] | None = ..., + # Accepts arbitrarily nested data structure, mypy doesn't allow describing it accurately. + message: _StrOrPromise | ValidationError | dict[str, Any] | list[Any], + code: str | None = None, + params: dict[str, Any] | None = None, ) -> None: ... @property def message_dict(self) -> dict[str, list[str]]: ... @property def messages(self) -> list[str]: ... - def update_error_dict( - self, error_dict: Mapping[str, list[ValidationError]] - ) -> Mapping[str, list[ValidationError]]: ... - def __iter__(self) -> Iterator[tuple[str, list[ValidationError]] | str]: ... + def update_error_dict(self, error_dict: dict[str, list[ValidationError]]) -> dict[str, list[ValidationError]]: ... + def __iter__(self) -> Iterator[tuple[str, list[str]] | str]: ... class EmptyResultSet(Exception): ... class FullResultSet(Exception): ... diff --git a/django-stubs/test/runner.pyi b/django-stubs/test/runner.pyi index 34ae1b961..2d33d38ef 100644 --- a/django-stubs/test/runner.pyi +++ b/django-stubs/test/runner.pyi @@ -1,101 +1,152 @@ import logging +import sys from argparse import ArgumentParser -from collections.abc import Sequence -from io import StringIO -from typing import Any -from unittest import TestCase as _TestCase -from unittest import TestSuite, TextTestResult +from collections.abc import Callable, Iterable, Iterator, Sequence +from contextlib import AbstractContextManager +from typing import Any, Literal +from unittest import TestCase, TestLoader, TestResult, TestSuite, TextTestResult, TextTestRunner from django.db.backends.base.base import BaseDatabaseWrapper -from django.test.testcases import SimpleTestCase, TestCase -from django.utils.datastructures import OrderedSet +from django.test.testcases import SimpleTestCase +from django.test.testcases import TestCase as DjangoTestCase +from django.test.utils import TimeKeeperProtocol +from typing_extensions import override + +class QueryFormatter(logging.Formatter): ... class DebugSQLTextTestResult(TextTestResult): - buffer: bool - descriptions: bool - dots: bool - expectedFailures: list[Any] - failfast: bool - shouldStop: bool - showAll: bool - skipped: list[Any] - tb_locals: bool - testsRun: int - unexpectedSuccesses: list[Any] - logger: logging.Logger = ... + logger: logging.Logger def __init__(self, stream: Any, descriptions: bool, verbosity: int) -> None: ... - debug_sql_stream: StringIO = ... - handler: logging.FileHandler = ... - def startTest(self, test: _TestCase) -> None: ... - def stopTest(self, test: _TestCase) -> None: ... + handler: logging.StreamHandler[Any] + @override + def startTest(self, test: TestCase) -> None: ... + @override + def stopTest(self, test: TestCase) -> None: ... + @override def addError(self, test: Any, err: Any) -> None: ... + @override def addFailure(self, test: Any, err: Any) -> None: ... -class RemoteTestResult: - events: list[Any] = ... - failfast: bool = ... - shouldStop: bool = ... - testsRun: int = ... - def __init__(self) -> None: ... +class PDBDebugResult(TextTestResult): + def debug(self, error: tuple[type[BaseException], BaseException, Any]) -> None: ... + +class DummyList: + __slots__ = () + def append(self, item: Any) -> None: ... + +class RemoteTestResult(TestResult): + events: list[Any] + def __init__(self, *args: Any, **kwargs: Any) -> None: ... @property - def test_index(self) -> Any: ... - def check_picklable(self, test: Any, err: Any) -> None: ... + def test_index(self) -> int: ... def _confirm_picklable(self, obj: Any) -> None: ... + def check_picklable(self, test: Any, err: Any) -> None: ... def check_subtest_picklable(self, test: Any, subtest: Any) -> None: ... - def stop_if_failfast(self) -> None: ... - def stop(self) -> None: ... + @override def startTestRun(self) -> None: ... + @override def stopTestRun(self) -> None: ... + @override def startTest(self, test: Any) -> None: ... + @override def stopTest(self, test: Any) -> None: ... + if sys.version_info >= (3, 12): + @override + def addDuration(self, test: Any, elapsed: Any) -> None: ... + else: + def addDuration(self, test: Any, elapsed: Any) -> None: ... + @override def addError(self, test: Any, err: Any) -> None: ... + @override def addFailure(self, test: Any, err: Any) -> None: ... + @override def addSubTest(self, test: Any, subtest: Any, err: Any) -> None: ... + @override def addSuccess(self, test: Any) -> None: ... + @override def addSkip(self, test: Any, reason: Any) -> None: ... + @override def addExpectedFailure(self, test: Any, err: Any) -> None: ... + @override def addUnexpectedSuccess(self, test: Any) -> None: ... + @override + def wasSuccessful(self) -> bool: ... class RemoteTestRunner: - resultclass: Any = ... - failfast: Any = ... - def __init__(self, failfast: bool = ..., resultclass: Any | None = ...) -> None: ... + resultclass: Any + failfast: bool + buffer: bool + def __init__(self, failfast: bool = ..., resultclass: Any | None = ..., buffer: bool = ...) -> None: ... def run(self, test: Any) -> Any: ... -def default_test_processes() -> int: ... +def get_max_test_processes() -> int: ... +def parallel_type(value: str) -> int | Literal["auto"]: ... class ParallelTestSuite(TestSuite): - init_worker: Any = ... - run_subsuite: Any = ... - runner_class: Any = ... - subsuites: Any = ... - processes: Any = ... - failfast: Any = ... - def __init__(self, suite: Any, processes: Any, failfast: bool = ...) -> None: ... + init_worker: Callable[..., Any] + process_setup: Callable[..., Any] + process_setup_args: tuple[Any, ...] + run_subsuite: Callable[..., Any] + runner_class: type[RemoteTestRunner] + subsuites: list[TestSuite] + processes: int + failfast: bool + debug_mode: bool + buffer: bool + initial_settings: dict[str, dict[str, Any]] | None + serialized_contents: dict[str, str] | None + used_aliases: set[str] | None + def __init__( + self, + subsuites: list[TestSuite], + processes: int, + failfast: bool = ..., + debug_mode: bool = ..., + buffer: bool = ..., + ) -> None: ... + @override def run(self, result: Any) -> Any: ... # type: ignore[override] + def handle_event(self, result: Any, tests: list[TestSuite], event: Sequence[Any]) -> None: ... + def initialize_suite(self) -> None: ... + +class Shuffler: + hash_algorithm: str + seed: int + seed_source: str + def __init__(self, seed: int | None = ...) -> None: ... + @property + def seed_display(self) -> str: ... + def shuffle(self, items: Iterable[Any], key: Callable[[Any], str]) -> list[Any]: ... class DiscoverRunner: - test_suite: Any = ... - parallel_test_suite: Any = ... - test_runner: Any = ... - test_loader: Any = ... - reorder_by: Any = ... - pattern: str | None = ... - top_level: None = ... - verbosity: int = ... - interactive: bool = ... - failfast: bool = ... - keepdb: bool = ... - reverse: bool = ... - debug_mode: bool = ... - debug_sql: bool = ... - parallel: int = ... - tags: set[str] = ... - exclude_tags: set[str] = ... + test_suite: type[TestSuite] + parallel_test_suite: type[ParallelTestSuite] + test_runner: type[TextTestRunner] + test_loader: TestLoader + reorder_by: tuple[type[DjangoTestCase], type[SimpleTestCase]] + pattern: str | None + top_level: str | None + verbosity: int + interactive: bool + failfast: bool + keepdb: bool + reverse: bool + debug_mode: bool + debug_sql: bool + parallel: int + tags: set[str] + exclude_tags: set[str] + pdb: bool + buffer: bool + test_name_patterns: set[str] | None + time_keeper: TimeKeeperProtocol + shuffle: int | Literal[False] + logger: logging.Logger | None + durations: int | None def __init__( self, pattern: str | None = ..., - top_level: None = ..., + top_level: str | None = ..., verbosity: int = ..., interactive: bool = ..., failfast: bool = ..., @@ -106,35 +157,50 @@ class DiscoverRunner: parallel: int = ..., tags: list[str] | None = ..., exclude_tags: list[str] | None = ..., + test_name_patterns: list[str] | None = ..., + pdb: bool = ..., + buffer: bool = ..., + enable_faulthandler: bool = ..., + timing: bool = ..., + shuffle: int | Literal[False] = ..., + logger: logging.Logger | None = ..., + durations: int | None = ..., **kwargs: Any, ) -> None: ... @classmethod def add_arguments(cls, parser: ArgumentParser) -> None: ... + @property + def shuffle_seed(self) -> int | None: ... + def log(self, msg: str, level: int | None = ...) -> None: ... def setup_test_environment(self, **kwargs: Any) -> None: ... - def build_suite( - self, test_labels: Sequence[str] = ..., extra_tests: list[Any] | None = ..., **kwargs: Any - ) -> TestSuite: ... + def setup_shuffler(self) -> None: ... + def load_with_patterns(self) -> AbstractContextManager[None]: ... + def load_tests_for_label(self, label: str, discover_kwargs: dict[str, str]) -> TestSuite: ... + def build_suite(self, test_labels: Sequence[str] | None = ..., **kwargs: Any) -> TestSuite: ... def setup_databases(self, **kwargs: Any) -> list[tuple[BaseDatabaseWrapper, str, bool]]: ... - def get_resultclass(self) -> type[DebugSQLTextTestResult] | None: ... - def get_test_runner_kwargs(self) -> dict[str, int | None]: ... - def run_checks(self) -> None: ... + def get_resultclass(self) -> type[TextTestResult] | None: ... + def get_test_runner_kwargs(self) -> dict[str, Any]: ... + def run_checks(self, databases: set[str]) -> None: ... def run_suite(self, suite: TestSuite, **kwargs: Any) -> TextTestResult: ... def teardown_databases(self, old_config: list[tuple[BaseDatabaseWrapper, str, bool]], **kwargs: Any) -> None: ... def teardown_test_environment(self, **kwargs: Any) -> None: ... def suite_result(self, suite: TestSuite, result: TextTestResult, **kwargs: Any) -> int: ... - def run_tests(self, test_labels: list[str], extra_tests: list[Any] = ..., **kwargs: Any) -> int: ... + def _get_databases(self, suite: TestSuite) -> set[str]: ... + def get_databases(self, suite: TestSuite) -> set[str]: ... + def run_tests(self, test_labels: list[str], **kwargs: Any) -> int: ... -def is_discoverable(label: str) -> bool: ... -def reorder_suite( - suite: TestSuite, - classes: tuple[type[TestCase], type[SimpleTestCase]], - reverse: bool = ..., -) -> TestSuite: ... -def partition_suite_by_type( - suite: TestSuite, - classes: tuple[type[TestCase], type[SimpleTestCase]], - bins: list[OrderedSet[Any]], +def try_importing(label: str) -> tuple[bool, bool]: ... +def find_top_level(top_level: str) -> str: ... +def shuffle_tests(tests: Iterable[TestCase], shuffler: Shuffler) -> Iterator[TestCase]: ... +def reorder_test_bin( + tests: Sequence[TestCase], shuffler: Shuffler | None = ..., reverse: bool = ... +) -> Iterator[TestCase]: ... +def reorder_tests( + tests: Iterable[TestCase], + classes: Sequence[type[TestCase]], reverse: bool = ..., -) -> None: ... -def partition_suite_by_case(suite: Any) -> Any: ... -def filter_tests_by_tags(suite: TestSuite, tags: set[str], exclude_tags: set[str]) -> TestSuite: ... + shuffler: Shuffler | None = ..., +) -> Iterator[TestCase]: ... +def partition_suite_by_case(suite: TestSuite) -> list[TestSuite]: ... +def test_match_tags(test: TestCase, tags: set[str], exclude_tags: set[str]) -> bool: ... +def filter_tests_by_tags(tests: Iterable[TestCase], tags: set[str], exclude_tags: set[str]) -> Iterator[TestCase]: ... diff --git a/django-stubs/test/testcases.pyi b/django-stubs/test/testcases.pyi index d209ae3b7..b6c67aa89 100644 --- a/django-stubs/test/testcases.pyi +++ b/django-stubs/test/testcases.pyi @@ -29,7 +29,7 @@ class _AssertTemplateUsedContext: template_name: str = ... rendered_templates: list[Template] = ... rendered_template_names: list[str] = ... - context: ContextList[Any] = ... + context: ContextList = ... def __init__(self, test_case: Any, template_name: Any) -> None: ... def on_template_render(self, sender: Any, signal: Any, template: Any, context: Any, **kwargs: Any) -> None: ... def test(self) -> Any: ... diff --git a/django-stubs/test/utils.pyi b/django-stubs/test/utils.pyi index 19f569e16..647a7ffd2 100644 --- a/django-stubs/test/utils.pyi +++ b/django-stubs/test/utils.pyi @@ -2,105 +2,100 @@ import decimal from collections.abc import Callable, Iterable, Iterator, Mapping from contextlib import AbstractContextManager from decimal import Decimal +from io import StringIO +from logging import Logger from types import TracebackType -from typing import Any, TypeAlias, TypeVar, overload +from typing import Any, Protocol, SupportsIndex, TypeAlias, type_check_only from django.apps.registry import Apps from django.conf import LazySettings, Settings from django.core.checks.registry import CheckRegistry from django.db.backends.base.base import BaseDatabaseWrapper +from django.db.models.lookups import Lookup, Transform +from django.db.models.query import _SupportsContains +from django.db.models.query_utils import RegisterLookupMixin from django.test.runner import DiscoverRunner from django.test.testcases import SimpleTestCase -from typing_extensions import Self +from typing_extensions import Self, TypeVar, override _TestClass: TypeAlias = type[SimpleTestCase] + _DecoratedTest: TypeAlias = Callable[..., Any] | _TestClass -_C = TypeVar("_C", bound=Callable[..., Any]) -_T = TypeVar("_T") -_U = TypeVar("_U") -_TestClassGeneric = TypeVar("_TestClassGeneric", bound=_TestClass) +_DT = TypeVar("_DT", bound=_DecoratedTest) +_C = TypeVar("_C", bound=Callable[..., Any]) # Any callable -TZ_SUPPORT: bool = ... +TZ_SUPPORT: bool class Approximate: - val: decimal.Decimal | float = ... - places: int = ... + val: decimal.Decimal | float + places: int def __init__(self, val: Decimal | float, places: int = ...) -> None: ... -class ContextList(list[Mapping[str, _T]]): - def get(self, key: str, default: _U | None = ...) -> _T | _U | None: ... +class ContextList(list[dict[str, Any]]): + @override + def __getitem__(self, key: str | SupportsIndex | slice) -> Any: ... + def get(self, key: str, default: Any | None = ...) -> Any: ... + @override + def __contains__(self, key: object) -> bool: ... def keys(self) -> set[str]: ... class _TestState: ... def setup_test_environment(debug: bool | None = ...) -> None: ... def teardown_test_environment() -> None: ... -def setup_databases( - verbosity: int, - interactive: bool, - *, - time_keeper: Any | None = ..., - keepdb: bool = ..., - debug_sql: bool = ..., - parallel: int = ..., - aliases: Iterable[str] | None = ..., - **kwargs: Any, -) -> list[tuple[BaseDatabaseWrapper, str, bool]]: ... def get_runner(settings: LazySettings, test_runner_class: str | None = ...) -> type[DiscoverRunner]: ... class TestContextDecorator: - attr_name: str | None = ... - kwarg_name: str | None = ... + attr_name: str | None + kwarg_name: str | None def __init__(self, attr_name: str | None = ..., kwarg_name: str | None = ...) -> None: ... - def enable(self) -> Any | None: ... + def enable(self) -> Any: ... def disable(self) -> None: ... - def __enter__(self) -> Any | None: ... + def __enter__(self) -> Apps | None: ... def __exit__( self, exc_type: type[BaseException] | None, exc_value: BaseException | None, - traceback: TracebackType | None, + exc_tb: TracebackType | None, ) -> None: ... - def decorate_class(self, cls: _TestClassGeneric) -> _TestClassGeneric: ... + def decorate_class(self, cls: _TestClass) -> _TestClass: ... def decorate_callable(self, func: _C) -> _C: ... - @overload - def __call__(self, decorated: _TestClassGeneric) -> _TestClassGeneric: ... - @overload - def __call__(self, decorated: _C) -> _C: ... + def __call__(self, decorated: _DT) -> _DT: ... class override_settings(TestContextDecorator): - enable_exception: bool | None = ... - wrapped: Settings = ... - options: dict[str, Any] = ... + enable_exception: Exception | None + options: dict[str, Any] def __init__(self, **kwargs: Any) -> None: ... + wrapped: Settings def save_options(self, test_func: _DecoratedTest) -> None: ... + @override + def decorate_class(self, cls: type) -> type: ... class modify_settings(override_settings): wrapped: Settings - operations: list[tuple[str, dict[str, list[str] | str]]] = ... - options: dict[str, list[tuple[str, str] | str]] = ... + operations: list[tuple[str, dict[str, list[str] | str]]] def __init__(self, *args: Any, **kwargs: Any) -> None: ... + @override def save_options(self, test_func: _DecoratedTest) -> None: ... + options: dict[str, list[tuple[str, str] | str]] class override_system_checks(TestContextDecorator): - registry: CheckRegistry = ... - new_checks: list[Callable[..., Any]] = ... - deployment_checks: list[Callable[..., Any]] | None = ... - old_checks: set[Callable[..., Any]] = ... - old_deployment_checks: set[Callable[..., Any]] = ... + registry: CheckRegistry + new_checks: list[Callable[..., Any]] + deployment_checks: list[Callable[..., Any]] | None def __init__( - self, - new_checks: list[Callable[..., Any]], - deployment_checks: list[Callable[..., Any]] | None = ..., + self, new_checks: list[Callable[..., Any]], deployment_checks: list[Callable[..., Any]] | None = ... ) -> None: ... - -class CaptureQueriesContext: - connection: Any = ... - force_debug_cursor: bool = ... - initial_queries: int = ... - final_queries: int | None = ... - def __init__(self, connection: Any) -> None: ... - def __iter__(self) -> Any: ... + old_checks: set[Callable[..., Any]] + old_deployment_checks: set[Callable[..., Any]] + +class CaptureQueriesContext(_SupportsContains[dict[str, str]]): + connection: BaseDatabaseWrapper + force_debug_cursor: bool + initial_queries: int + final_queries: int | None + def __init__(self, connection: BaseDatabaseWrapper) -> None: ... + def __iter__(self) -> Iterator[dict[str, str]]: ... def __getitem__(self, index: int) -> dict[str, str]: ... def __len__(self) -> int: ... @property @@ -110,34 +105,94 @@ class CaptureQueriesContext: self, exc_type: type[BaseException] | None, exc_value: BaseException | None, - traceback: TracebackType | None, + exc_tb: TracebackType | None, ) -> None: ... class ignore_warnings(TestContextDecorator): - ignore_kwargs: dict[str, Any] = ... - filter_func: Callable[..., Any] = ... + ignore_kwargs: dict[str, Any] + filter_func: Callable[..., Any] def __init__(self, **kwargs: Any) -> None: ... - catch_warnings: AbstractContextManager[list[Any] | None] = ... + catch_warnings: AbstractContextManager[list[Any] | None] requires_tz_support: Any -def isolate_lru_cache(lru_cache_object: Callable[..., Any]) -> Iterator[None]: ... +def isolate_lru_cache(lru_cache_object: Callable[..., Any]) -> AbstractContextManager[None]: ... class override_script_prefix(TestContextDecorator): - prefix: str = ... + prefix: str def __init__(self, prefix: str) -> None: ... - old_prefix: str = ... + old_prefix: str class LoggingCaptureMixin: - logger: Any = ... - old_stream: Any = ... - logger_output: Any = ... + logger: Logger + old_stream: Any + logger_output: Any def setUp(self) -> None: ... def tearDown(self) -> None: ... class isolate_apps(TestContextDecorator): - installed_apps: tuple[str] = ... + installed_apps: tuple[str, ...] def __init__(self, *installed_apps: Any, **kwargs: Any) -> None: ... - old_apps: Apps = ... - -def tag(*tags: str) -> Callable[[_T], _T]: ... + old_apps: Apps + +def extend_sys_path(*paths: str) -> AbstractContextManager[None]: ... +def captured_output(stream_name: str) -> AbstractContextManager[StringIO]: ... +def captured_stdin() -> AbstractContextManager[StringIO]: ... +def captured_stdout() -> AbstractContextManager[StringIO]: ... +def captured_stderr() -> AbstractContextManager[StringIO]: ... +def freeze_time(t: float) -> AbstractContextManager[None]: ... +def tag(*tags: str) -> Callable[[_C], _C]: ... + +_Signature: TypeAlias = str +_TestDatabase: TypeAlias = tuple[str, list[str]] + +@type_check_only +class TimeKeeperProtocol(Protocol): + def timed(self, name: Any) -> AbstractContextManager[None]: ... + def print_results(self) -> None: ... + +def dependency_ordered( + test_databases: Iterable[tuple[_Signature, _TestDatabase]], dependencies: Mapping[str, list[str]] +) -> list[tuple[_Signature, _TestDatabase]]: ... +def get_unique_databases_and_mirrors( + aliases: set[str] | None = ..., +) -> tuple[dict[_Signature, _TestDatabase], dict[str, Any]]: ... +def setup_databases( + verbosity: int, + interactive: bool, + *, + time_keeper: TimeKeeperProtocol | None = ..., + keepdb: bool = ..., + debug_sql: bool = ..., + parallel: int = ..., + aliases: Mapping[str, Any] | None = ..., + serialized_aliases: Iterable[str] | None = ..., + **kwargs: Any, +) -> list[tuple[BaseDatabaseWrapper, str, bool]]: ... +def teardown_databases( + old_config: Iterable[tuple[Any, str, bool]], verbosity: int, parallel: int = ..., keepdb: bool = ... +) -> None: ... +def require_jinja2(test_func: _C) -> _C: ... +def register_lookup( + field: type[RegisterLookupMixin], *lookups: type[Lookup[Any] | Transform], lookup_name: str | None = ... +) -> AbstractContextManager[None]: ... +def garbage_collect() -> None: ... + +__all__ = ( + "Approximate", + "CaptureQueriesContext", + "ContextList", + "garbage_collect", + "get_runner", + "ignore_warnings", + "isolate_apps", + "isolate_lru_cache", + "modify_settings", + "override_settings", + "override_system_checks", + "requires_tz_support", + "setup_databases", + "setup_test_environment", + "tag", + "teardown_test_environment", +) diff --git a/s/fix-sync.yml b/s/fix-sync.yml index 302f04339..920ba0f10 100644 --- a/s/fix-sync.yml +++ b/s/fix-sync.yml @@ -1,10 +1,28 @@ -# ast-grep rules to fix bare Callable -> Callable[..., Any] -# and bare tuple -> tuple[Any, ...] -id: fix-callable-simple +# ast-grep rules to fix common bare generics in synced stubs. +id: fix-callable-type language: python rule: - pattern: "$NAME: Callable" -fix: "$NAME: Callable[..., Any]" + pattern: Callable + inside: + kind: type +fix: Callable[..., Any] +--- +id: fix-callable-typevar-bound +language: python +rule: + all: + - pattern: Callable + - inside: + stopBy: end + pattern: TypeVar($NAME, bound=$$$) + - not: + inside: + kind: generic_type + - not: + inside: + stopBy: end + pattern: Callable[$$$] +fix: Callable[..., Any] --- id: fix-callable-union-before language: python @@ -22,6 +40,46 @@ rule: pattern: "|" fix: Callable[..., Any] --- +id: fix-list-union-before +language: python +rule: + pattern: list + precedes: + pattern: "|" +fix: list[Any] +--- +id: fix-list-union-after +language: python +rule: + pattern: list + follows: + pattern: "|" +fix: list[Any] +--- +id: fix-lookup-type +language: python +rule: + pattern: Lookup + inside: + kind: type +fix: Lookup[Any] +--- +id: fix-lookup-union-before +language: python +rule: + pattern: Lookup + precedes: + pattern: "|" +fix: Lookup[Any] +--- +id: fix-lookup-union-after +language: python +rule: + pattern: Lookup + follows: + pattern: "|" +fix: Lookup[Any] +--- id: fix-tuple-type language: python rule: diff --git a/s/sync b/s/sync index 40f06448f..d8d0abe4c 100755 --- a/s/sync +++ b/s/sync @@ -46,6 +46,7 @@ while IFS= read -r file || [ -n "$file" ]; do ast-grep scan --rule "$SCRIPT_DIR/fix-sync.yml" --update-all "$file" done < "$SYNC_FILES" +uv run ruff format tests django-stubs uv run ruff check --fix echo "Sync complete!" diff --git a/s/sync-files.txt b/s/sync-files.txt index 539a80886..5c636d07a 100644 --- a/s/sync-files.txt +++ b/s/sync-files.txt @@ -42,3 +42,6 @@ django-stubs/db/models/functions/math.pyi #django-stubs/db/models/functions/text.pyi django-stubs/db/__init__.pyi django-stubs/db/migrations/executor.pyi +django-stubs/core/exceptions.pyi +django-stubs/test/runner.pyi +django-stubs/test/utils.pyi