diff --git a/.gitignore b/.gitignore index dab7335..918c131 100644 --- a/.gitignore +++ b/.gitignore @@ -6,6 +6,12 @@ vv .tox/ venv/ +# Local scratch / repro scripts — not part of the package, kept out of the +# tracked tree so `ruff check .` (which respects .gitignore) stays green. +scripts/ +.coverage +.vscode/ + node_modules npm-debug.log .DS_Store diff --git a/CHANGELOG.md b/CHANGELOG.md index ef3f3d2..d22d605 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,21 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/) and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.html). +## [1.3.0] - 2026-06-19 + +### Added +- FastAPI (async) backend support for `BasicAuth` and `OIDCAuth` — run on Starlette/ASGI with `Dash(__name__, backend="fastapi")`. Auth is enforced by pure-ASGI middleware, sessions use Starlette `SessionMiddleware`, and public routes/callbacks are stored on the app's `server.state` +- Authenticated WebSocket callbacks on the FastAPI backend: each `callback_request` is authorized via `Backend.ws_identity` and fails closed, while public callbacks still stream unauthenticated +- WebSocket authentication now reconnects on login. When a browser authenticates over HTTP, its stale pre-login WebSocket is retired so the renderer reconnects with the authenticated session — eliminating the "first click is dropped / only works on the second click" behaviour for callbacks invoked over a socket opened before login. Applies to both the FastAPI and Quart backends + +### Changed +- The WebSocket `callback_map` is migrated lazily on the first `callback_request` rather than at `Auth(...)` construction, so a global `@callback` registered after `Auth(...)` is still picked up and a WebSocket-first client no longer hits an empty map + +### Fixed +- The public-route helpers resolve the backend from `app.server` instead of a process-global fallback, keeping routing correct when several apps share a process +- `secure_session` is honoured through `setup_session`, and the FastAPI session lookup is hardened so it raises a clear error (rather than `KeyError`) under `python -O` +- FastAPI OIDC views annotate their request parameter so Starlette injects the request, and the ASGI body-replay emits `http.disconnect` once the cached body is consumed + ## [1.2.1] - 2026-06-17 ### Fixed diff --git a/README.md b/README.md index ffe360e..2308d98 100644 --- a/README.md +++ b/README.md @@ -21,17 +21,20 @@ How this fork compares to upstream [`dash-auth`](https://github.com/plotly/dash- | --- | :---: | :---: | | Flask backend | ✅ | ✅ | | Quart backend | ❌ | ✅ | -| FastAPI backend | ❌ | 🚧 1 | -| Custom backends | ❌ | ✅ 2 | +| FastAPI backend | ❌ | ✅ | +| Custom backends | ❌ | ✅ 1 | | Protected / public callbacks | ✅ | ✅ | -| Async callbacks | ❌ | ✅ 3 | +| Async callbacks | ❌ | ✅ 2 | | Authenticated WebSocket callbacks | ❌ | ✅ 3 | -✅ supported · 🚧 on the roadmap · ❌ not supported +✅ supported · ❌ not supported + +1 `detect_backend` resolves Flask/Quart/FastAPI automatically; any other server is supported by supplying your own `Backend` instance. + +2 Provided by the Quart and FastAPI backends. + +3 Provided by the Quart and FastAPI backends. WebSocket auth is a no-op on Flask, which has no WebSocket callback transport. -1 A `dash-auth-async[fastapi]` extra is declared and a native FastAPI backend is on the roadmap. In the meantime you can support it by implementing the `Backend` ABC and passing `Auth(..., backend=MyBackend())`. -2 `detect_backend` resolves Flask/Quart automatically; any other server is supported by supplying your own `Backend` instance. -3 Provided by the Quart backend. WebSocket auth is a no-op on Flask, which has no WebSocket callback transport. For local testing, install [uv](https://docs.astral.sh/uv/getting-started/installation/), then install the dev dependencies and run individual tests: @@ -276,7 +279,22 @@ if __name__ == "__main__": app.run(debug=True) ``` -### Quart (async) Backend +### Async backends + +#### Known Limitations + +> **⚠️ WebSocket callbacks & auth:** +> Do **not** enable `websocket_callbacks=True` +> globally on an authenticated `use_pages` app. The global flag routes *every* +> callback — including Dash's built-in page-routing callback — over the WebSocket, +> which bypasses the HTTP `before_request` auth guard where the login challenge is +> issued. Navigating to a protected page then hangs (the socket closes with `4401` +> and reconnect-loops) instead of prompting for login — the prompt appears only after +> a full page reload. Opt **individual** streaming callbacks into `websocket=True` +> instead, so routing and login stay on HTTP. + + +#### Quart (async) Backend `dash-auth-async` supports [Dash's Quart backend](https://dash.plotly.com/) for fully async request handling. Install the `quart` extra to pull in the required dependencies: @@ -288,7 +306,7 @@ pip install dash-auth-async[quart] Then pass `backend="quart"` when creating your Dash app. The auth setup is identical to the Flask examples above — no code changes required beyond the backend flag. -#### BasicAuth with Quart +##### BasicAuth with Quart ```python from dash import Dash @@ -302,7 +320,7 @@ if __name__ == "__main__": app.run(host="127.0.0.1", port=8050, debug=True) ``` -#### OIDCAuth with Quart +##### OIDCAuth with Quart ```python import os @@ -332,6 +350,74 @@ if __name__ == "__main__": > **Note:** The Quart backend requires Dash >= 4.2.0 and Python >= 3.10. +#### FastAPI (async) Backend + +`dash-auth-async` supports [Dash's FastAPI backend](https://dash.plotly.com/) too. +Install the `fastapi` extra to pull in the required dependencies: + +``` +pip install dash-auth-async[fastapi] +``` + +Then pass `backend="fastapi"` when creating your Dash app. `BasicAuth` and +`OIDCAuth` work exactly as on Flask/Quart — no code changes beyond the backend flag. + +##### BasicAuth with FastAPI + +```python +from dash import Dash +from dash_auth_async import BasicAuth + +app = Dash(__name__, backend="fastapi") + +BasicAuth( + app, + {"admin": "admin", "viewer": "viewer123"}, + secret_key="aStaticSecretKey!", # enables sessions (SessionMiddleware) +) + +if __name__ == "__main__": + app.run(host="127.0.0.1", port=8050, debug=True) +``` + +##### OIDCAuth with FastAPI + +```python +import os +from dash import Dash, html +from dash_auth_async import OIDCAuth + +app = Dash(__name__, backend="fastapi") + +app.layout = html.Div([ + html.H2("OIDCAuth + FastAPI"), + html.A("Logout", href="/oidc/logout"), +]) + +auth = OIDCAuth(app, secret_key="aStaticSecretKey!") +auth.register_provider( + "myidp", + client_id=os.environ["OIDC_CLIENT_ID"], + client_secret=os.environ["OIDC_CLIENT_SECRET"], + server_metadata_url=os.environ["OIDC_METADATA_URL"], + token_endpoint_auth_method="client_secret_post", + client_kwargs={"scope": "openid email profile"}, +) + +if __name__ == "__main__": + app.run(host="127.0.0.1", port=8050, debug=True) +``` + +Notes: + +- A `secret_key` installs Starlette's `SessionMiddleware` automatically. If you + add your own `SessionMiddleware`, `dash-auth-async` defers to it. +- `Auth`/`OIDCAuth` must be constructed before the server starts serving + (Starlette forbids adding middleware after startup) — the normal usage pattern. +- OIDC uses authlib's official `starlette_client`; no extra client module required. + +> The FastAPI backend requires Dash >= 4.2.0 and Python >= 3.10. + ### User-group-based permissions `dash_auth_async` provides a convenient way to secure parts of your app based on user groups. diff --git a/dash_auth_async/backends.py b/dash_auth_async/backends.py index 5278af7..6b2e50a 100644 --- a/dash_auth_async/backends.py +++ b/dash_auth_async/backends.py @@ -3,8 +3,10 @@ from __future__ import annotations import inspect +import re from abc import ABC, abstractmethod from collections.abc import Callable, MutableMapping +from contextvars import ContextVar from typing import Any import flask @@ -17,6 +19,29 @@ quart: Any = None HAS_QUART = False +try: + import fastapi + from starlette.middleware.sessions import SessionMiddleware + from starlette.requests import Request as StarletteRequest + from starlette.responses import ( + HTMLResponse, + PlainTextResponse, + RedirectResponse, + Response as StarletteResponse, + ) + + HAS_FASTAPI = True +except ImportError: + fastapi: Any = None + StarletteRequest: Any = None + SessionMiddleware: Any = None + HAS_FASTAPI = False + + +# dash-auth-async owns its own request ContextVar, set in its own ASGI +# middleware — independent of Dash's private get_current_request. +_current_request_var: ContextVar = ContextVar("dash_auth_request", default=None) + class Backend(ABC): """Framework adapter isolating everything Flask/Quart-specific. @@ -29,6 +54,11 @@ class Backend(ABC): request: Any session: MutableMapping + # Whether OIDC views must be async (awaited) on this backend. Lets + # OIDCAuth pick the sync vs async view set polymorphically instead of + # branching on isinstance(backend, ...). + is_async: bool = False + @abstractmethod def has_request_context(self) -> bool: """Whether a request context is currently active.""" @@ -58,6 +88,149 @@ def url_for(self, endpoint: str, **values) -> str: def redirect(self, location: str) -> Any: """Return a redirect response to the given location.""" + # --- Operations that diverge on FastAPI; defaults reproduce the + # --- Flask/Quart behavior so those backends inherit unchanged. These + # --- are concrete interface defaults, not all of which use `self`, so + # --- PLR6301 (no-self-use) is suppressed where that applies. + + def coerce_response(self, result: Any) -> Any: # noqa: PLR6301 + """Convert an ``_authorize`` return into a real response. + + Flask and Quart accept ``(body, status, headers)`` tuples, bare + strings, and framework responses natively, so the default is a + pass-through. FastAPI overrides this to build a Starlette response. + + Returns: + The response value unchanged. + """ + return result + + def setup_session( # noqa: PLR6301 + self, server, secret_key: str | None, secure_session: bool = False + ) -> None: + """Install session support on the server. + + Flask/Quart store a ``secret_key`` attribute and harden the session + cookie via ``SESSION_COOKIE_*`` config. FastAPI overrides this to add + ``SessionMiddleware`` (and wires ``secure_session`` to ``https_only``). + """ + if secret_key is not None: + server.secret_key = secret_key + if secure_session: + server.config["SESSION_COOKIE_SECURE"] = True + server.config["SESSION_COOKIE_HTTPONLY"] = True + + def session_configured(self, server) -> bool: # noqa: PLR6301 + """Whether the server can store a session. + + Returns: + True if a session secret is configured. + """ + return getattr(server, "secret_key", None) is not None + + def current_host(self) -> str: + """Host (netloc) of the active request, for proxy host rewrites. + + Returns: + The request host string. + """ + return self.request.host + + def current_path(self) -> str: + """Path of the active request. + + Returns: + The request path string. + """ + return self.request.path + + def add_route( # noqa: PLR6301 + self, server, rule: str, view_func, endpoint: str, methods + ) -> None: + """Register an OIDC route on the server.""" + server.add_url_rule( + rule, endpoint=endpoint, view_func=view_func, methods=methods + ) + + def make_oauth(self, server) -> Any: # noqa: PLR6301 + """Build the authlib OAuth registry for this backend. + + Returns: + A flask_client ``OAuth`` registry bound to ``server``. + """ + from authlib.integrations.flask_client import OAuth # noqa: PLC0415 + + return OAuth(server) + + def get_oauth(self, server) -> Any: # noqa: PLR6301 + """Retrieve the authlib OAuth registry stored by :meth:`make_oauth`. + + Symmetric with ``make_oauth``: each backend knows where it put its own + registry, so the module-level ``get_oauth`` doesn't have to guess. + + Returns: + The registry, or None if ``OIDCAuth`` has not run yet. + """ + return getattr(server, "extensions", {}).get( + "authlib.integrations.flask_client" + ) + + def oauth_authorize_redirect( # noqa: PLR6301 + self, client, redirect_uri: str, **kwargs + ) -> Any: + """Start the OAuth authorize-redirect on the authlib ``client``. + + Encapsulates how each backend's authlib client is invoked: the Flask + client is synchronous and reads the request from a context global, so + the default just calls through. Quart/FastAPI override to await, and + FastAPI additionally passes the active request explicitly. + + Returns: + The authorize-redirect response (awaitable on async backends). + """ + return client.authorize_redirect(redirect_uri, **kwargs) + + def oauth_authorize_access_token(self, client, **kwargs) -> Any: # noqa: PLR6301 + """Exchange the OAuth callback for a token on the authlib ``client``. + + Mirror of :meth:`oauth_authorize_redirect` for the callback leg. + + Returns: + The token (awaitable on async backends). + """ + return client.authorize_access_token(**kwargs) + + def store_config(self, server, key: str, value: Any) -> None: # noqa: PLR6301 + """Stash an app-scoped config value (public routes/callbacks). + + Flask/Quart expose a dict-like ``server.config``. FastAPI has no + such attribute and overrides this to use ``server.state``. + """ + server.config[key] = value + + def read_config( # noqa: PLR6301 + self, server, key: str, default: Any = None + ) -> Any: + """Read an app-scoped config value set by :meth:`store_config`. + + Returns: + The stored value, or ``default`` when unset. + """ + return server.config.get(key, default) + + def ws_identity(self, ws) -> tuple[Any, dict | None]: + """Resolve ``(owning_server, session_user)`` for a WS callback_request. + + WS-capable backends override this. The owning server is the key into + ``_AUTH_BY_SERVER``; the user is ``session["user"]`` or ``None``. + Flask has no WebSocket transport, so the default raises -- the WS hook's + fail-closed boundary turns that into a rejection if it is ever reached. + + Raises: + NotImplementedError: on backends without WebSocket support. + """ + raise NotImplementedError + class FlaskBackend(Backend): """Backend adapter for a Flask server.""" @@ -126,6 +299,8 @@ def redirect(self, location: str) -> Any: # noqa: PLR6301 class QuartBackend(Backend): """Backend adapter for a Quart (async) server.""" + is_async = True + def __init__(self) -> None: """Create the Quart backend, requiring the optional ``quart`` extra. @@ -171,8 +346,16 @@ def register_auth_hook(self, server, needs_body, decide) -> None: # noqa: PLR63 """Register the before-request auth hook on a Quart server. Awaits both the request body and the (possibly coroutine) decision so - async auth logic is preserved. + async auth logic is preserved. Also mints the browser-stable + ``dac_client`` cookie on first contact and, when a browser authenticates, + retires its stale pre-login WebSocket so the renderer reconnects + authenticated (see websocket_auth) -- the Quart half of Design A. """ + from .websocket_auth import ( # noqa: PLC0415 — avoid import cycle + WS_CLIENT_COOKIE, + close_anonymous_ws, + mint_ws_client_id, + ) @server.before_request async def before_request_auth(): @@ -183,8 +366,30 @@ async def before_request_auth(): ) result = decide(quart.request.path, body) if inspect.isawaitable(result): - return await result - return result + result = await result + if result is not None: + return result + + # Authorized. If this browser is now logged in, retire any stale + # anonymous socket it opened before login so the renderer reconnects + # authenticated before the first click. + if quart.session.get("user") is not None: + close_anonymous_ws(quart.request.cookies.get(WS_CLIENT_COOKIE)) + return None + + @server.after_request + def set_client_cookie(response): + # Minted on every response that lacks it -- including the BasicAuth + # 401 challenge -- so it predates a later pre-login WS handshake. + if quart.request.cookies.get(WS_CLIENT_COOKIE) is None: + response.set_cookie( + WS_CLIENT_COOKIE, + mint_ws_client_id(), + httponly=True, + samesite="Lax", + path="/", + ) + return response def url_for(self, endpoint: str, **values) -> str: # noqa: PLR6301 """Build a URL for a Quart endpoint. @@ -202,12 +407,428 @@ def redirect(self, location: str) -> Any: # noqa: PLR6301 """ return quart.redirect(location) + def make_oauth(self, server) -> Any: # noqa: PLR6301 + """Build the authlib OAuth registry for the Quart backend. + + Returns: + A custom Quart ``OAuth`` registry bound to ``server``. + """ + # Imported lazily so flask-only installs never import quart/httpx. + from dash_auth_async import quart_client # noqa: PLC0415 + + return quart_client.OAuth(server) + + def get_oauth(self, server) -> Any: # noqa: PLR6301 + """Retrieve the Quart OAuth registry stored by :meth:`make_oauth`. + + Returns: + The registry, or None if ``OIDCAuth`` has not run yet. + """ + return getattr(server, "extensions", {}).get( + "authlib.integrations.quart_client" + ) + + async def oauth_authorize_redirect( # noqa: PLR6301 + self, client, redirect_uri: str, **kwargs + ) -> Any: + """Await the Quart authlib client's authorize-redirect. + + Returns: + The authorize-redirect response. + """ + return await client.authorize_redirect(redirect_uri, **kwargs) + + async def oauth_authorize_access_token(self, client, **kwargs) -> Any: # noqa: PLR6301 + """Await the Quart authlib client's token exchange. + + Returns: + The token. + """ + return await client.authorize_access_token(**kwargs) + + def ws_identity(self, ws) -> tuple[Any, dict | None]: # noqa: PLR6301 + """Resolve the Quart app and session user for a WS callback_request. + + Uses Quart's context globals (correct when several apps share a + process); ``ws`` is unused on this backend. + + Returns: + ``(quart_app, session["user"] or None)``. + """ + import quart # noqa: PLC0415 — quart is an optional dependency + + # ``quart.current_app`` is a proxy; ``_get_current_object`` unwraps it to + # the real Quart app (the key in ``_AUTH_BY_SERVER``). Go through + # ``getattr`` because the attribute is absent from the proxy's type stub. + app = getattr(quart.current_app, "_get_current_object")() + return app, quart.session.get("user") + + +class FastAPIBackend(Backend): + """Adapter for Dash's FastAPI backend (Dash 4.2+). + + Unlike Flask/Quart there is no global request/session proxy: Starlette + passes the request explicitly. This backend resolves the active request + from a ContextVar set by its own ASGI auth middleware, and reads the + session off ``request.session`` (populated by SessionMiddleware). + """ + + is_async = True + + def __init__(self) -> None: + """Create the FastAPI backend, requiring the optional ``fastapi`` extra. + + Raises: + ImportError: if FastAPI is not installed. + """ + if not HAS_FASTAPI: + raise ImportError( + "FastAPI is not installed. Please install it with " + "`pip install dash-auth-async[fastapi]` to use the FastAPI backend." + ) + + @property + def request(self) -> Any: + """The active Starlette request, resolved from the ContextVar. + + Returns: + The current request, or None outside a request context. + """ + return _current_request_var.get() + + @property + def session(self) -> MutableMapping: + """The session mapping off the active request. + + Returns: + The Starlette session mapping. + + Raises: + RuntimeError: if SessionMiddleware is not installed. + """ + # Starlette signals "no SessionMiddleware" via a bare `assert + # "session" in scope`, which `python -O` strips — the next line then + # raises KeyError instead. Check the scope directly so the + # RuntimeError translation holds under -O too, keeping the existing + # `except RuntimeError` guards working identically to the Flask path. + request = self.request + if "session" not in getattr(request, "scope", {}): + raise RuntimeError("Session is not available. Have you set a secret key?") + return request.session + + def has_request_context(self) -> bool: # noqa: PLR6301 + """Whether a request is currently bound to the ContextVar. + + Returns: + True if a request context is active. + """ + return _current_request_var.get() is not None + + def url_for(self, endpoint: str, **values) -> str: + """Build an absolute URL for a Starlette endpoint. + + Maps the Flask-style ``_external``/``_scheme`` kwargs used by + ``OIDCAuth._create_redirect_uri`` onto Starlette's ``url_for``. + + Returns: + The URL string for ``endpoint``. + """ + values.pop("_external", None) # Starlette url_for is always absolute + scheme = values.pop("_scheme", None) + url = self.request.url_for(endpoint, **values) + if scheme: + url = url.replace(scheme=scheme) + return str(url) + + def redirect(self, location: str) -> Any: # noqa: PLR6301 + """Build a Starlette redirect response to ``location``. + + Returns: + A 302 ``RedirectResponse``. + """ + return RedirectResponse(location, status_code=302) + + def current_host(self) -> str: + """Host (netloc) of the active request. + + Returns: + The request netloc string. + """ + return self.request.url.netloc + + def current_path(self) -> str: + """Path of the active Starlette request. + + Returns: + The request path string. + """ + return self.request.url.path + + def coerce_response(self, result: Any) -> Any: # noqa: PLR6301 + """Build a Starlette response from an ``_authorize``/view return value. + + This is the single coercion boundary for the FastAPI path: every OIDC + view and the auth middleware funnel their return value through here, so + the framework-response knowledge lives in exactly one place. + + A bare ``str`` becomes an ``HTMLResponse`` (matching Flask/Quart, which + render returned strings as ``text/html`` — e.g. the OIDC logout page), + while ``(body, status[, headers])`` tuples become a ``PlainTextResponse`` + carrying the status/headers (e.g. the Basic-auth 401 challenge). + + Returns: + A Starlette response (passthrough if already one). + """ + if isinstance(result, StarletteResponse): + return result + if isinstance(result, tuple): + body, *rest = result + status = rest[0] if rest else 200 + headers = rest[1] if len(rest) > 1 else None + return PlainTextResponse(body, status_code=status, headers=headers) + if isinstance(result, str): + return HTMLResponse(result) + return PlainTextResponse(str(result)) + + @staticmethod + def _has_session_middleware(server) -> bool: + return any( + getattr(m, "cls", None) is SessionMiddleware + for m in getattr(server, "user_middleware", []) + ) + + def setup_session( + self, server, secret_key: str | None, secure_session: bool = False + ) -> None: + """Install Starlette ``SessionMiddleware`` from ``secret_key``. + + ``secure_session`` is wired to ``https_only`` so the parity with + Flask/Quart's ``SESSION_COOKIE_SECURE`` is honored rather than + silently dropped. Starlette always sets ``HttpOnly``. Defers to a + user-installed ``SessionMiddleware`` (opt-out/override) and is + idempotent — never adds a second instance. + """ + if secret_key is None: + return + if self._has_session_middleware(server): + return + server.add_middleware( + SessionMiddleware, secret_key=secret_key, https_only=secure_session + ) + + def session_configured(self, server) -> bool: + """Whether a ``SessionMiddleware`` is installed on the server. + + Returns: + True if session storage is available. + """ + return self._has_session_middleware(server) + + def store_config(self, server, key: str, value: Any) -> None: # noqa: PLR6301 + """Stash an app-scoped config value on the FastAPI ``server.state``.""" + setattr(server.state, key, value) + + def read_config( # noqa: PLR6301 + self, server, key: str, default: Any = None + ) -> Any: + """Read an app-scoped config value off the FastAPI ``server.state``. + + Returns: + The stored value, or ``default`` when unset. + """ + return getattr(server.state, key, default) + + def add_route( # noqa: PLR6301 + self, server, rule: str, view_func, endpoint: str, methods + ) -> None: + """Register an OIDC route, translating Flask ```` to ``{idp}``.""" + fastapi_rule = re.sub(r"<([^>]+)>", r"{\1}", rule) + server.add_api_route( + fastapi_rule, + view_func, + methods=methods, + name=endpoint, + include_in_schema=False, + ) + + def make_oauth(self, server) -> Any: # noqa: PLR6301 + """Build authlib's Starlette OAuth registry, stashed on ``server.state``. + + Returns: + A ``starlette_client`` ``OAuth`` registry. + """ + from authlib.integrations.starlette_client import ( # noqa: PLC0415 + OAuth as StarletteOAuth, + ) + + oauth = StarletteOAuth() + # The Starlette registry doesn't attach to app.extensions, so stash + # it where get_oauth can find it. + server.state.dash_auth_oauth = oauth + return oauth + + def get_oauth(self, server) -> Any: # noqa: PLR6301 + """Retrieve the Starlette OAuth registry stashed on ``server.state``. + + Returns: + The registry, or None if ``OIDCAuth`` has not run yet. + """ + return getattr(getattr(server, "state", None), "dash_auth_oauth", None) + + async def oauth_authorize_redirect( + self, client, redirect_uri: str, **kwargs + ) -> Any: + """Await the Starlette authlib client's authorize-redirect. + + The Starlette OAuth client takes the request explicitly (no context + global), so the active request is resolved from the ContextVar and + passed through. + + Returns: + The authorize-redirect response. + """ + return await client.authorize_redirect(self.request, redirect_uri, **kwargs) + + async def oauth_authorize_access_token(self, client, **kwargs) -> Any: + """Await the Starlette authlib client's token exchange. + + Passes the ContextVar-resolved request, as the Starlette client + requires. + + Returns: + The token. + """ + return await client.authorize_access_token(self.request, **kwargs) + + def ws_identity(self, ws) -> tuple[Any, dict | None]: # noqa: PLR6301 + """Resolve the FastAPI app and session user for a WS callback_request. + + Reads from the Starlette ``WebSocket``: ``ws.app`` is the owning server + and ``ws.session`` carries ``session["user"]`` (SessionMiddleware runs on + websocket scopes, so the handshake cookie is available). The + ``"session" in ws.scope`` guard mirrors the ``session`` property: with no + SessionMiddleware installed the user is ``None`` (fail-closed for + protected callbacks, still allows public) rather than raising. + + Returns: + ``(fastapi_app, session["user"] or None)``. + """ + user = ws.session.get("user") if "session" in ws.scope else None + return ws.app, user + + def register_auth_hook(self, server, needs_body, decide) -> None: + """Register the before-request auth hook as pure-ASGI middleware. + + Pure ASGI (not ``BaseHTTPMiddleware``) so the request ContextVar set + here is visible inside the Dash callback, which runs in the same + task/context via the inner DashMiddleware's ``copy_context()``. + """ + backend = self + + class _AuthMiddleware: + def __init__(self, app) -> None: + self.app = app + + async def __call__(self, scope, receive, send) -> None: + if scope["type"] != "http": + await self.app(scope, receive, send) + return + + request = StarletteRequest(scope, receive) + token = _current_request_var.set(request) + + # Browser-stable client id, minted on first contact so it is + # already present at a later pre-login WS handshake. This ties an + # anonymous socket back to the browser that authenticates here, + # letting login retire the stale socket (see websocket_auth). + from .websocket_auth import ( # noqa: PLC0415 — avoid import cycle + WS_CLIENT_COOKIE, + close_anonymous_ws, + mint_ws_client_id, + ) + + client_id = request.cookies.get(WS_CLIENT_COOKIE) + set_client_cookie = client_id is None + if set_client_cookie: + client_id = mint_ws_client_id() + + async def send(message, _send=send): + if set_client_cookie and message["type"] == "http.response.start": + message.setdefault("headers", []).append( + ( + b"set-cookie", + f"{WS_CLIENT_COOKIE}={client_id}; Path=/; " + f"HttpOnly; SameSite=Lax".encode("latin-1"), + ) + ) + await _send(message) + + try: + body = None + downstream_receive = receive + if needs_body(request.url.path): + # Consuming the body drains the receive stream; cache + # the bytes and replay them so DashMiddleware (inner) + # can re-parse the callback JSON. + raw = await request.body() + body_replayed = False + + # Must be a coroutine to satisfy the ASGI `receive` + # interface, even though this replay never awaits. + async def downstream_receive(): # noqa: RUF029 + nonlocal body_replayed + if body_replayed: + # Body already delivered; further reads see a + # disconnect, per the ASGI contract — not the + # same body event replayed forever (which would + # spin an app that polls receive() to detect + # client disconnect). + return {"type": "http.disconnect"} + body_replayed = True + return { + "type": "http.request", + "body": raw, + "more_body": False, + } + + try: + body = await request.json() + except Exception: # unparseable == no body + body = None + + result = decide(request.url.path, body) + if inspect.isawaitable(result): + result = await result + + if result is not None: + response = backend.coerce_response(result) + await response(scope, receive, send) + return + + # Authorized. If this browser is now logged in, retire any + # stale anonymous sockets it opened before login so the + # renderer reconnects authenticated before the first click. + try: + logged_in = backend.session.get("user") is not None + except RuntimeError: + logged_in = False # no SessionMiddleware -> nothing to retire + if logged_in: + close_anonymous_ws(client_id) + + await self.app(scope, downstream_receive, send) + finally: + _current_request_var.reset(token) + + server.add_middleware(_AuthMiddleware) + def detect_backend(server: Any) -> Backend: - """Return the matching backend for a Flask or Quart server.""" + """Return the matching backend for a Flask, Quart, or FastAPI server.""" if quart is not None: if isinstance(server, quart.Quart): return QuartBackend() + if HAS_FASTAPI and isinstance(server, fastapi.FastAPI): + return FastAPIBackend() if isinstance(server, flask.Flask): return FlaskBackend() @@ -220,7 +841,6 @@ def detect_backend(server: Any) -> Backend: # One backend per process, matching how Dash apps are deployed. _active_backend: Backend | None = None -_DEFAULT_BACKEND = FlaskBackend() def set_active_backend(backend: Backend) -> None: @@ -234,5 +854,16 @@ def set_active_backend(backend: Backend) -> None: def get_active_backend() -> Backend: - """Return the active backend, defaulting to Flask when none is set.""" - return _active_backend if _active_backend is not None else _DEFAULT_BACKEND + """Return the active backend registered by ``Auth.__init__``. + + Falls back to a Flask backend for the legacy Flask-only path where no + ``Auth`` has registered one yet. The fallback is constructed lazily on + first use rather than at import, so Flask isn't cemented as the default + at module load. Note this still assumes a single backend per process: in + a non-Flask process the active backend must be set (which ``Auth.__init__`` + does) before any request-context helper runs. + """ + global _active_backend # noqa: PLW0603 — one backend per process, by design + if _active_backend is None: + _active_backend = FlaskBackend() + return _active_backend diff --git a/dash_auth_async/basic_auth.py b/dash_auth_async/basic_auth.py index 6bbf275..2a0077a 100644 --- a/dash_auth_async/basic_auth.py +++ b/dash_auth_async/basic_auth.py @@ -63,8 +63,6 @@ def __init__( # noqa: PLR0913, PLR0917 — configuration constructor else: self._user_groups_dict = None self._user_groups_func = user_groups # Callable or None after dict excluded - if secret_key is not None: - app.server.secret_key = secret_key if self._auth_func is not None: if username_password_list is not None: @@ -86,6 +84,12 @@ def __init__( # noqa: PLR0913, PLR0917 — configuration constructor ) super().__init__(app, public_routes=public_routes) + # After super().__init__: self.backend now exists, and the auth + # middleware is registered, so SessionMiddleware (added here) lands + # outermost on FastAPI — making request.session available to both + # the auth layer and the Dash callback. + self.backend.setup_session(app.server, secret_key) + def is_authorized(self): """Return whether the request carries valid Basic credentials. diff --git a/dash_auth_async/group_protection.py b/dash_auth_async/group_protection.py index 73f9b11..c56ddea 100644 --- a/dash_auth_async/group_protection.py +++ b/dash_auth_async/group_protection.py @@ -39,7 +39,12 @@ def _current_user() -> dict | None: backend = get_active_backend() if backend.has_request_context(): # Normal HTTP path: read the user from the framework session. - return backend.session.get("user") + try: + return backend.session.get("user") + except RuntimeError: + # Session unavailable (e.g. FastAPI without SessionMiddleware): + # treat as not authenticated rather than crashing. + return None # WebSocket worker path: no framework context here, so read the user the # websocket_message hook stashed for this dispatch. return _WS_AUTH_USER.get() diff --git a/dash_auth_async/oidc_auth.py b/dash_auth_async/oidc_auth.py index ee0c420..959c273 100644 --- a/dash_auth_async/oidc_auth.py +++ b/dash_auth_async/oidc_auth.py @@ -14,7 +14,7 @@ from dash_auth_async.auth import Auth from dash_auth_async.public_routes import get_url_base -from .backends import QuartBackend +from .backends import detect_backend if TYPE_CHECKING: from authlib.integrations.flask_client.apps import ( @@ -22,8 +22,7 @@ FlaskOAuth2App, ) - from dash_auth_async.quart_client import OAuth as QuartOAuth - from dash_auth_async.quart_client import QuartOAuth2App + from dash_auth_async.quart_client import OAuth as QuartOAuth, QuartOAuth2App class OIDCAuth(Auth): @@ -92,9 +91,10 @@ def __init__( # noqa: PLR0913, PLR0917 — configuration constructor Page seen by the user after logging out, by default None which will default to a simple logged out message secure_session: bool, optional - Whether to ensure the session is secure, setting the flasck config - SESSION_COOKIE_SECURE and SESSION_COOKIE_HTTPONLY to True, - by default False + Whether to restrict the session cookie to HTTPS, by default False. + On Flask/Quart this sets SESSION_COOKIE_SECURE and + SESSION_COOKIE_HTTPONLY; on FastAPI it sets the Starlette + SessionMiddleware ``https_only`` flag (HttpOnly is always on). Raises: RuntimeError: if ``app.server.secret_key`` is not defined. @@ -117,10 +117,9 @@ def __init__( # noqa: PLR0913, PLR0917 — configuration constructor self.idp_selection_route = idp_selection_route self.logout_page = logout_page - if secret_key is not None: - app.server.secret_key = secret_key + self.backend.setup_session(app.server, secret_key, secure_session) - if app.server.secret_key is None: + if not self.backend.session_configured(app.server): raise RuntimeError(""" app.server.secret_key is missing. Generate a secret key in your Python session @@ -136,18 +135,7 @@ def __init__( # noqa: PLR0913, PLR0917 — configuration constructor that key in your code/via a secret. """) - if secure_session: - app.server.config["SESSION_COOKIE_SECURE"] = True - app.server.config["SESSION_COOKIE_HTTPONLY"] = True - - if isinstance(self.backend, QuartBackend): - # Imported lazily so flask-only installs never import - # quart/httpx (quart_client raises ImportError without them). - from dash_auth_async import quart_client # noqa: PLC0415 - - self.oauth: OAuth | quart_client.OAuth = quart_client.OAuth(app.server) - else: - self.oauth = OAuth(app.server) + self.oauth = self.backend.make_oauth(app.server) # Check that the login and callback rules have an placeholder if not re.findall(r"/(?=/|$)", login_route): @@ -155,7 +143,7 @@ def __init__( # noqa: PLR0913, PLR0917 — configuration constructor if not re.findall(r"/(?=/|$)", callback_route): raise Exception("The callback route must contain a placeholder.") - if isinstance(self.backend, QuartBackend): + if self.backend.is_async: login_view = self._login_request_async logout_view = self._logout_async callback_view = self._callback_async @@ -164,23 +152,14 @@ def __init__( # noqa: PLR0913, PLR0917 — configuration constructor logout_view = self.logout callback_view = self.callback - app.server.add_url_rule( - login_route, - endpoint="oidc_login", - view_func=login_view, - methods=["GET"], + self.backend.add_route( + app.server, login_route, login_view, "oidc_login", ["GET"] ) - app.server.add_url_rule( - logout_route, - endpoint="oidc_logout", - view_func=logout_view, - methods=["GET"], + self.backend.add_route( + app.server, logout_route, logout_view, "oidc_logout", ["GET"] ) - app.server.add_url_rule( - callback_route, - endpoint="oidc_callback", - view_func=callback_view, - methods=["GET"], + self.backend.add_route( + app.server, callback_route, callback_view, "oidc_callback", ["GET"] ) def register_provider(self, idp_name: str, **kwargs): @@ -283,7 +262,7 @@ def _create_redirect_uri(self, idp: str): ) host = self.request.headers.get("X-Forwarded-Host") if host: - redirect_uri = redirect_uri.replace(self.request.host, host, 1) + redirect_uri = redirect_uri.replace(self.backend.current_host(), host, 1) return redirect_uri def login_request(self, idp: str | None = None): @@ -298,7 +277,8 @@ def login_request(self, idp: str | None = None): """ # `idp` can be none here as login_request is called # without arguments in the before_request hook - if isinstance(self.backend, QuartBackend): + if self.backend.is_async: + # Returns a coroutine; the async before-request hook / route awaits it. return self._login_request_async(idp) idp, response = self._resolve_idp(idp) @@ -308,28 +288,34 @@ def login_request(self, idp: str | None = None): redirect_uri = self._create_redirect_uri(idp) oauth_client = self.get_oauth_client(idp) oauth_kwargs = self.get_oauth_kwargs(idp) - return oauth_client.authorize_redirect( + return self.backend.oauth_authorize_redirect( + oauth_client, redirect_uri, **oauth_kwargs.get("authorize_redirect_kwargs", {}), ) async def _login_request_async(self, idp: str | None = None): - """Async login view for the Quart path. + """Async login view shared by the Quart and FastAPI paths. + + Backend-agnostic: the backend supplies the (possibly request-injecting) + authlib call and coerces the result to a framework response. Returns: The authorize-redirect response. """ idp, response = self._resolve_idp(idp) if response is not None: - return response + return self.backend.coerce_response(response) redirect_uri = self._create_redirect_uri(idp) oauth_client = self.get_oauth_client(idp) oauth_kwargs = self.get_oauth_kwargs(idp) - return await oauth_client.authorize_redirect( + result = await self.backend.oauth_authorize_redirect( + oauth_client, redirect_uri, **oauth_kwargs.get("authorize_redirect_kwargs", {}), ) + return self.backend.coerce_response(result) def logout(self): # pylint: disable=C0116 """Logout the user. @@ -352,12 +338,15 @@ def logout(self): # pylint: disable=C0116 return page async def _logout_async(self): - """Async logout view for the Quart path; the body is sync. + """Async logout view shared by the Quart and FastAPI paths. + + The body is sync; the backend coerces the HTML page to a framework + response (passthrough on Quart, ``HTMLResponse`` on FastAPI). Returns: - The logged-out page content. + The logged-out page response. """ - return self.logout() + return self.backend.coerce_response(self.logout()) def callback(self, idp: str): # pylint: disable=C0116 """Handle the OIDC dance and post-login actions. @@ -371,7 +360,8 @@ def callback(self, idp: str): # pylint: disable=C0116 oauth_client = self.get_oauth_client(idp) oauth_kwargs = self.get_oauth_kwargs(idp) try: - token = oauth_client.authorize_access_token( + token = self.backend.oauth_authorize_access_token( + oauth_client, **oauth_kwargs.get("authorize_token_kwargs", {}), ) except OAuthError as err: @@ -381,25 +371,31 @@ def callback(self, idp: str): # pylint: disable=C0116 return self.after_logged_in(user, idp, token) async def _callback_async(self, idp: str): - """Async OIDC callback view for the Quart path. + """Async OIDC callback view shared by the Quart and FastAPI paths. + + Backend-agnostic: the backend supplies the (possibly request-injecting) + token exchange and coerces every return through the single boundary. Returns: - The post-login redirect, or an error tuple on failure. + The post-login redirect, or an error response on failure. """ if idp not in self.oauth._registry: - return f"'{idp}' is not a valid registered idp", 400 + return self.backend.coerce_response( + (f"'{idp}' is not a valid registered idp", 400) + ) oauth_client = self.get_oauth_client(idp) oauth_kwargs = self.get_oauth_kwargs(idp) try: - token = await oauth_client.authorize_access_token( + token = await self.backend.oauth_authorize_access_token( + oauth_client, **oauth_kwargs.get("authorize_token_kwargs", {}), ) except OAuthError as err: - return str(err), 401 + return self.backend.coerce_response((str(err), 401)) user = token.get("userinfo") - return self.after_logged_in(user, idp, token) + return self.backend.coerce_response(self.after_logged_in(user, idp, token)) def after_logged_in(self, user: dict | None, idp: str, token: dict): """Run post-login actions after successful OIDC authentication. @@ -444,7 +440,7 @@ def is_authorized(self): # pylint: disable=C0116 if x ] ).bind("") - return map_adapter.test(self.request.path) or "user" in self.session + return map_adapter.test(self.backend.current_path()) or "user" in self.session def get_oauth(app: dash.Dash | None = None) -> "OAuth | QuartOAuth": @@ -463,14 +459,12 @@ def get_oauth(app: dash.Dash | None = None) -> "OAuth | QuartOAuth": if app is None: app = dash.get_app() - extensions = getattr(app.server, "extensions", {}) - for extension_key in ( - "authlib.integrations.flask_client", - "authlib.integrations.quart_client", - ): - oauth = extensions.get(extension_key) - if oauth is not None: - return oauth + # Retrieval is symmetric with storage: each backend knows where its own + # make_oauth stashed the registry (extensions vs server.state), so there's + # no need to probe every framework's location here. + oauth = detect_backend(app.server).get_oauth(app.server) + if oauth is not None: + return oauth raise RuntimeError( "OAuth object is not yet defined. `OIDCAuth(app, **kwargs)` needs " diff --git a/dash_auth_async/public_routes.py b/dash_auth_async/public_routes.py index 29363da..cf6a33a 100644 --- a/dash_auth_async/public_routes.py +++ b/dash_auth_async/public_routes.py @@ -9,6 +9,8 @@ from dash._callback import GLOBAL_CALLBACK_MAP # noqa: PLC2701 from werkzeug.routing import Map, MapAdapter, Rule +from .backends import detect_backend + DASH_PUBLIC_ASSETS_EXTENSIONS = "js,css" BASE_PUBLIC_ROUTES = [ f"/assets/.{ext}" @@ -80,7 +82,7 @@ def add_public_routes(app: Dash, routes: list): full_route = url_base.rstrip("/") + full_route public_routes.map.add(Rule(full_route)) - app.server.config[PUBLIC_ROUTES] = public_routes + detect_backend(app.server).store_config(app.server, PUBLIC_ROUTES, public_routes) def public_callback(*callback_args, **callback_kwargs): @@ -107,10 +109,12 @@ def decorator(func): ) try: app = get_app() - app.server.config[PUBLIC_CALLBACKS] = [ - *get_public_callbacks(app), - callback_id, - ] + backend = detect_backend(app.server) + backend.store_config( + app.server, + PUBLIC_CALLBACKS, + [*backend.read_config(app.server, PUBLIC_CALLBACKS, []), callback_id], + ) except Exception: print( "Could not set up the public callback as the Dash object " @@ -131,7 +135,9 @@ def get_public_routes(app: Dash) -> MapAdapter: Returns: The MapAdapter holding the app's registered public routes. """ - return app.server.config.get(PUBLIC_ROUTES, Map([]).bind("")) + return detect_backend(app.server).read_config( + app.server, PUBLIC_ROUTES, Map([]).bind("") + ) def get_public_callbacks(app: Dash) -> list: @@ -140,4 +146,4 @@ def get_public_callbacks(app: Dash) -> list: Returns: The list of whitelisted public callback ids. """ - return app.server.config.get(PUBLIC_CALLBACKS, []) + return detect_backend(app.server).read_config(app.server, PUBLIC_CALLBACKS, []) diff --git a/dash_auth_async/version.py b/dash_auth_async/version.py index e461518..13e7faf 100644 --- a/dash_auth_async/version.py +++ b/dash_auth_async/version.py @@ -1,3 +1,3 @@ """Single source of truth for the package version.""" -__version__ = "1.2.1" +__version__ = "1.3.0" diff --git a/dash_auth_async/websocket_auth.py b/dash_auth_async/websocket_auth.py index de8fdfa..6124ed4 100644 --- a/dash_auth_async/websocket_auth.py +++ b/dash_auth_async/websocket_auth.py @@ -8,13 +8,23 @@ from __future__ import annotations +import asyncio import contextvars +import secrets import threading from concurrent.futures import ThreadPoolExecutor from contextvars import ContextVar from typing import Any from weakref import WeakKeyDictionary +from .backends import get_active_backend + +# Name of the pre-login, browser-stable cookie that ties an anonymous WebSocket +# back to the browser that later authenticates over HTTP. Minted on first +# contact by the auth hook (see the backends' before-request middleware) so it +# is already present at a pre-login WS handshake. +WS_CLIENT_COOKIE = "dac_client" + # The authenticated user (session["user"] dict) for the callback currently being # dispatched over a WebSocket. Set by the websocket_message hook in the WS # context and propagated into Dash's callback worker by the context-copying @@ -47,6 +57,83 @@ def submit(self, fn, /, *args, **kwargs): _hook_registered = False +class _WSEntry: + """A registered anonymous WebSocket plus the loop it runs on.""" + + __slots__ = ("loop", "ws") + + def __init__(self, ws: Any, loop: Any) -> None: + self.ws = ws + self.loop = loop + + +# client-id cookie -> the browser's current anonymous (pre-login) socket. +# +# Only *anonymous* sockets are tracked: a login retires those and only those, so +# an already-authenticated socket would only ever bloat the map. And only one +# per browser -- a browser drives a single shared SharedWorker socket, so a +# reconnect replaces the prior entry rather than stacking. Together these bound +# the map to (at most) one entry per distinct pre-login browser. +_WS_BY_CLIENT: dict[str, _WSEntry] = {} +_ws_registry_lock = threading.Lock() + + +def mint_ws_client_id() -> str: + """Return a fresh, opaque browser-stable client id for the cookie.""" + return secrets.token_urlsafe(16) + + +def register_anonymous_ws(client_id: str | None, ws: Any, loop: Any) -> None: + """Track a browser's current pre-login socket, replacing any stale prior. + + No-op without a client id (e.g. a backend that never minted the cookie): + such a socket can't be correlated to a later login and falls back to the + reactive 4401-then-reconnect path. + """ + if not client_id: + return + with _ws_registry_lock: + _WS_BY_CLIENT[client_id] = _WSEntry(ws, loop) + + +def close_anonymous_ws(client_id: str | None) -> None: + """Close the browser's pre-login socket so the renderer reconnects authed. + + Called when ``client_id`` authenticates over HTTP. The close code (``4408``) + is outside the renderer's no-reconnect set (``1000``/``4001``), so the + SharedWorker dials a fresh handshake that carries the now-present session + cookie -- before the user's first click rather than after a sacrificed one. + Popping the entry also deregisters it, keeping the map from leaking. + """ + if not client_id: + return + with _ws_registry_lock: + entry = _WS_BY_CLIENT.pop(client_id, None) + if entry is None: + return + _schedule_ws_close(entry.loop, entry.ws) + + +def _schedule_ws_close(loop: Any, ws: Any) -> None: + """Schedule ``ws.close`` onto the socket's own event loop, thread-safely. + + The login runs in a separate HTTP task (and possibly thread); a socket can + only be closed from its own loop, so we hop onto it via + ``call_soon_threadsafe`` rather than awaiting the close inline. + """ + + async def _safe_close() -> None: + try: + await ws.close(code=4408) + except Exception: # pylint: disable=broad-exception-caught + pass # already closing/closed -- nothing to recover + + try: + loop.call_soon_threadsafe(lambda: asyncio.ensure_future(_safe_close())) + except Exception: # pylint: disable=broad-exception-caught + pass # loop gone (socket already torn down) -- nothing to close + + def _ws_message_hook(ws: Any, message: Any): """Global Dash websocket_message hook: authorize each callback_request. @@ -61,38 +148,43 @@ def _ws_message_hook(ws: Any, message: Any): if not isinstance(message, dict) or message.get("type") != "callback_request": return True try: - return _authorize_ws_message(message) + return _authorize_ws_message(ws, message) except Exception: # pylint: disable=broad-exception-caught # Fail closed on any unexpected error. return (4401, "Unauthorized") -def _authorize_ws_message(message: dict) -> bool | tuple[int, str]: - """Authorize one WebSocket ``callback_request`` for the current Quart app. +def _authorize_ws_message(ws: Any, message: dict) -> bool | tuple[int, str]: + """Authorize one WebSocket ``callback_request`` for the owning app. - Resolves the owning app via ``quart.current_app`` so it is correct when - several apps share the process; inert for apps that do not use - dash-auth-async. + The owning server and the session user are resolved through the active + backend's :meth:`Backend.ws_identity` (Quart uses its context globals, + FastAPI reads ``ws.app``/``ws.session``), keeping this module + framework-agnostic. Inert for apps that do not use dash-auth-async. Returns: ``True`` to allow, or a ``(code, reason)`` tuple to reject the socket. """ - import quart # noqa: PLC0415 — quart is an optional dependency - - # ``quart.current_app`` is a proxy; ``_get_current_object`` unwraps it to - # the real Quart app (the key in ``_AUTH_BY_SERVER``). The attribute is - # present at runtime but absent from the proxy's type stub, so go through - # ``getattr`` to keep the static type checker happy. - current_app: Any = quart.current_app - app = getattr(current_app, "_get_current_object")() - auth = _AUTH_BY_SERVER.get(app) + backend = get_active_backend() + server, user = backend.ws_identity(ws) + auth = _AUTH_BY_SERVER.get(server) if auth is None: # Not a dash-auth-async app: nothing to enforce. Safe because the - # registry entry is created by the developer's ``Auth(app, ...)`` - # call, not by the client -- an attacker cannot evict their own app. + # registry entry is created by the developer's ``Auth(app, ...)`` call, + # not by the client -- an attacker cannot evict their own app. return True + # Migrate Dash's GLOBAL_CALLBACK_MAP into app.callback_map before Dash + # validates this request against it -- validation runs *after* the + # websocket_message hooks return (see _fastapi.py / _quart.py). On FastAPI + # our auth middleware can short-circuit the page/readiness GET with a 401 + # before Dash's own _setup_server before-hook ever runs, so a WS-first + # client would otherwise hit an empty map ("Callback function not found"). + # Doing it here is lazy (fires on the first real WS message, after every + # module-level @callback is registered -- so unlike a construction-time + # call it never freezes the map early) and idempotent (a self-guarded flag + # makes every call after the first a single boolean check). + auth.app._setup_server() payload = message.get("payload", {}) or {} - user = quart.session.get("user") if auth.authorize_ws(payload, user): # Load-bearing invariant: this hook runs before every callback_request # is submitted to the executor, so the context-copying executor always @@ -103,15 +195,51 @@ def _authorize_ws_message(message: dict) -> bool | tuple[int, str]: return (4401, "Unauthorized") +def _ws_connect_hook(ws: Any): + """Global Dash websocket_connect hook: register the socket by client id. + + Runs at the handshake (before accept) for every socket. Tracks only sockets + that handshake *anonymously*, keyed by browser (the ``dac_client`` cookie), + so a later login can retire them; an already-authenticated socket needs no + retiring and is left untracked. Always allows the connection -- public pages + legitimately stream over an unauthenticated socket, so this hook must never + reject. + + Returns: + ``True`` to allow the connection (bookkeeping never blocks a socket). + """ + try: + _track_anonymous_ws(ws) + except Exception: # pylint: disable=broad-exception-caught + pass # never let bookkeeping block a connection + return True + + +def _track_anonymous_ws(ws: Any) -> None: + """Register ``ws`` if its handshake was anonymous (see ``_ws_connect_hook``).""" + backend = get_active_backend() + _, user = backend.ws_identity(ws) + if user is not None: + return # already authenticated -- nothing to retire later + client_id = getattr(ws, "cookies", {}).get(WS_CLIENT_COOKIE) + # Resolve the concrete socket: Quart hands us a context-local proxy that + # can't be closed from the later (different-task) login, so we must keep the + # underlying object. Starlette's WebSocket has no such proxy and is stored + # as-is. + real_ws = ws._get_current_object() if hasattr(ws, "_get_current_object") else ws + register_anonymous_ws(client_id, real_ws, asyncio.get_running_loop()) + + def _ensure_hook_registered() -> None: - """Register the global websocket_message hook exactly once per process.""" - global _hook_registered # noqa: PLW0603 — register the hook once per process + """Register the global websocket hooks exactly once per process.""" + global _hook_registered # noqa: PLW0603 — register the hooks once per process with _hook_lock: if _hook_registered: return from dash import hooks # noqa: PLC0415 — lazy import to avoid an import cycle hooks.websocket_message()(_ws_message_hook) + hooks.websocket_connect()(_ws_connect_hook) _hook_registered = True @@ -119,8 +247,13 @@ def enable_ws_auth(auth: Any, app: Any) -> None: """Wire WebSocket auth for a dash-auth-async app. No-op on backends without WebSocket support (e.g. Flask). For WS-capable - backends it records the app->Auth mapping, installs the context-copying + backends it records the server->Auth mapping, installs the context-copying executor (before any dispatch), and registers the global hook once. + + Note Dash's ``callback_map`` is *not* populated here. It is migrated lazily + on the first WS ``callback_request`` by the message hook (see + ``_authorize_ws_message``), so a global ``@callback`` registered after + ``Auth(...)`` is still picked up and the server is never set up twice. """ backend = getattr(app, "backend", None) if backend is None or not getattr(backend, "websocket_capability", False): diff --git a/examples/websocket_auth_quart/README.md b/examples/websocket_auth_quart/README.md new file mode 100644 index 0000000..1ec910c --- /dev/null +++ b/examples/websocket_auth_quart/README.md @@ -0,0 +1,42 @@ +# Quart WebSockets + public/private auth example + +A minimal multi-page Dash app on the **Quart** backend that combines: + +- **WebSocket streaming callbacks** (`Dash(backend="quart", websocket_callbacks=True)`) +- **Public vs authenticated pages** via `dash-auth-async` `BasicAuth` + +## Pages + +| Route | Access | What it shows | +|------------|---------------|----------------------------------------| +| `/` | public | Landing page + navigation | +| `/live` | public | Live counter/clock streamed over WS | +| `/private` | authenticated | Simulated 0→100% progress task over WS | + +## Run + +```bash +pip install "dash-auth-async[quart]" # or: uv sync, from the repo root +python examples/websocket_auth_quart/app.py +``` + +Open . The `/private` page prompts for login: + +- `admin` / `admin` +- `viewer` / `viewer123` + +(Use `127.0.0.1` rather than `localhost`.) + +## Note on WebSocket-layer authentication + +Authentication is enforced at the **HTTP page level**: `/private` is a normal +HTTP GET that is auth-checked, so an anonymous user cannot load the page or start +its stream. This is what the public/private split demonstrates. + +However, `dash-auth-async`'s auth hook is registered as `@server.before_request`, +which does **not** fire for WebSocket connections (Quart uses separate +`before_websocket` / `websocket_connect` hooks). The WebSocket callback route +(`/_dash-ws-callback`) is therefore **not independently auth-gated** — a +hand-crafted raw WS connection would bypass the check. Closing this gap properly +would mean adding auth via Dash's `websocket_connect` hook in the library, which +is out of scope for this example. diff --git a/examples/websocket_auth_quart/app.py b/examples/websocket_auth_quart/app.py new file mode 100644 index 0000000..983e5f5 --- /dev/null +++ b/examples/websocket_auth_quart/app.py @@ -0,0 +1,53 @@ +"""Example: Dash + Quart backend with WebSocket streaming and public/private auth. + +Run: + python examples/websocket_auth_quart/app.py + +Then open http://127.0.0.1:8050/ in a browser. + +Credentials for the private page: + admin / admin + viewer / viewer123 +""" + +from dash import Dash, dcc, html, page_container + +from dash_auth_async import BasicAuth + +app = Dash( + __name__, + backend="quart", + use_pages=True, + websocket_callbacks=True, + suppress_callback_exceptions=True, +) + +app.layout = html.Div( + [ + html.Div( + [ + dcc.Link("Home", href="/"), + dcc.Link("Live (public)", href="/live"), + dcc.Link("Private", href="/private"), + ], + style={ + "display": "flex", + "gap": "1rem", + "background": "#eee", + "padding": "0.5rem 1rem", + }, + ), + page_container, + ], + style={"display": "flex", "flexDirection": "column", "fontFamily": "sans-serif"}, +) + +BasicAuth( + app, + {"admin": "admin", "viewer": "viewer123"}, + secret_key="example-secret-not-for-production", + public_routes=["/", "/live"], +) + +if __name__ == "__main__": + app.run(host="127.0.0.1", port=8050, debug=True) diff --git a/examples/websocket_auth_quart/pages/__init__.py b/examples/websocket_auth_quart/pages/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/examples/websocket_auth_quart/pages/home.py b/examples/websocket_auth_quart/pages/home.py new file mode 100644 index 0000000..239d32f --- /dev/null +++ b/examples/websocket_auth_quart/pages/home.py @@ -0,0 +1,23 @@ +"""Public landing page for the WebSocket + auth example.""" + +from dash import html, register_page + +register_page(__name__, path="/", name="Home") + +layout = html.Div( + [ + html.H1("Quart + WebSockets + dash-auth-async"), + html.P( + "This example runs on the Dash Quart backend with WebSocket " + "streaming callbacks, and demonstrates public vs authenticated " + "pages using dash-auth-async BasicAuth." + ), + html.Ul( + [ + html.Li("'/live' is public — anyone can watch the live counter."), + html.Li("'/private' requires login (try admin / admin)."), + ] + ), + ], + style={"padding": "1rem"}, +) diff --git a/examples/websocket_auth_quart/pages/live.py b/examples/websocket_auth_quart/pages/live.py new file mode 100644 index 0000000..5365541 --- /dev/null +++ b/examples/websocket_auth_quart/pages/live.py @@ -0,0 +1,49 @@ +"""Public page: a live counter/clock streamed over the WebSocket. + +Uses public_callback so the callback is whitelisted by dash-auth-async even if a +leg is served over HTTP. When routed over the WebSocket the whitelist is simply +unused. The streaming pattern (async + set_props + is_shutdown) follows +https://dash.plotly.com/websocket-callbacks. +""" + +import asyncio +from datetime import datetime + +import dash +from dash import Input, Output, ctx, html, register_page, set_props + +from dash_auth_async import public_callback + +register_page(__name__, path="/live", name="Live") + +layout = html.Div( + [ + html.H1("Live counter / clock (public)"), + html.P("Anyone can view this page and start the stream — no login required."), + html.Button("Start stream", id="counter-start"), + html.Div( + "Press start.", + id="counter-out", + style={"marginTop": "1rem", "fontSize": "1.5rem"}, + ), + ], + style={"padding": "1rem"}, +) + + +@public_callback( + Output("counter-out", "children"), + Input("counter-start", "n_clicks"), + prevent_initial_call=True, +) +async def stream_counter(_n_clicks): + ws = getattr(ctx, "websocket", None) + for i in range(1, 11): + if ws is not None and ws.is_shutdown: + return dash.no_update + set_props( + "counter-out", + {"children": f"Tick {i}/10 — {datetime.now():%H:%M:%S}"}, + ) + await asyncio.sleep(1) + return "Done." diff --git a/examples/websocket_auth_quart/pages/private.py b/examples/websocket_auth_quart/pages/private.py new file mode 100644 index 0000000..4d6e679 --- /dev/null +++ b/examples/websocket_auth_quart/pages/private.py @@ -0,0 +1,44 @@ +"""Private page: a simulated long-running task streamed over the WebSocket. + +This page is NOT in public_routes, so loading it requires BasicAuth login. The +callback is a plain @callback: the page is only reachable when authenticated and +no user-group gating is required. Streaming follows the same async + set_props + +is_shutdown pattern as the public page. +""" + +import asyncio + +import dash +from dash import Input, Output, callback, ctx, html, register_page, set_props + +register_page(__name__, path="/private", name="Private") + +layout = html.Div( + [ + html.H1("Simulated task (private)"), + html.P("You are logged in — this page sits behind BasicAuth."), + html.Button("Run task", id="task-start"), + html.Div( + html.Progress(id="task-bar", value="0", max="100"), + style={"marginTop": "1rem"}, + ), + html.Div("Idle.", id="task-status", style={"marginTop": "1rem"}), + ], + style={"padding": "1rem"}, +) + + +@callback( + Output("task-status", "children"), + Input("task-start", "n_clicks"), + prevent_initial_call=True, +) +async def run_task(_n_clicks): + ws = getattr(ctx, "websocket", None) + for pct in range(0, 101, 10): + if ws is not None and ws.is_shutdown: + return dash.no_update + set_props("task-bar", {"value": str(pct)}) + set_props("task-status", {"children": f"Working… {pct}%"}) + await asyncio.sleep(0.5) + return "Complete!" diff --git a/pyproject.toml b/pyproject.toml index 9ab916b..6e6266e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "dash-auth-async" -version = "1.2.1" +version = "1.3.0" description = "Dash Authorization Package." readme = "README.md" requires-python = ">=3.10" @@ -66,7 +66,13 @@ dev = [ "ruff>=0.15.16", "pytest-cov>=7.1.0", "websocket-client>=1.9.0", - "requests[security]>=2.34.2" + "requests[security]>=2.34.2", + "fastapi", + "starlette", + "uvicorn", + "httpx>=0.23.0", + "itsdangerous", + "python-multipart" ] [tool.uv] @@ -96,6 +102,13 @@ preview = true select = ["FAST","I","D","DOC","PL","UP","PERF","RUF"] [tool.ruff.lint.per-file-ignores] +# Examples are runnable demos, not shipped library code (setuptools only +# packages dash_auth_async*). Like tests, they read better without forced +# docstring coverage on every demo page/callback. +"examples/**" = [ + "D", # demo scripts don't need module/function docstrings + "DOC", # ...nor docstring sections (Returns, etc.) +] # Tests are self-describing and have different idioms than library code: "tests/**" = [ "D", # don't require docstrings on every test/fixture/module @@ -114,5 +127,10 @@ select = ["FAST","I","D","DOC","PL","UP","PERF","RUF"] quote-style = "double" indent-style = "space" +[tool.ruff.lint.isort] +# Keep aliased imports on the same line as their siblings so a single +# optional-dependency import guard stays one statement (avoids PLW0717). +combine-as-imports = true + [tool.ruff.lint.pydocstyle] convention = "google" \ No newline at end of file diff --git a/tests/conftest.py b/tests/conftest.py index 9440ac8..9a4338b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -38,6 +38,26 @@ def _stop_quart_gracefully(runner) -> bool: return not runner.thread.is_alive() +def _stop_fastapi_gracefully(runner) -> bool: + """Shut down a fastapi-backend dash test server via uvicorn's should_exit. + + Dash's FastAPI backend stores the uvicorn Server on the Dash app as + ``_uvicorn_server`` when run threaded (_fastapi.py:384). Setting + ``should_exit`` lets uvicorn's serve loop return so the thread exits + cleanly instead of being killed mid-flight. + + Returns True if the server thread exited; False means fall back to the + original kill-based stop. + """ + dash_app = getattr(runner, "_app", None) + server = getattr(dash_app, "_uvicorn_server", None) + if server is None: + return False + server.should_exit = True + runner.thread.join(timeout=runner.stop_timeout) + return not runner.thread.is_alive() + + _original_init = _runners.BaseDashRunner.__init__ @@ -61,15 +81,15 @@ def _init_with_ipv4_host( _original_stop = _runners.ThreadedRunner.stop -def _stop_with_graceful_quart(self: Any) -> Any: - if _stop_quart_gracefully(self): +def _stop_with_graceful_async(self: Any) -> Any: + if _stop_quart_gracefully(self) or _stop_fastapi_gracefully(self): self._app = None self.started = False return return _original_stop(self) -_runners.ThreadedRunner.stop = _stop_with_graceful_quart # type: ignore +_runners.ThreadedRunner.stop = _stop_with_graceful_async # type: ignore @pytest.fixture(autouse=True) diff --git a/tests/integration/test_basic_auth_integration_auth_func.py b/tests/integration/test_basic_auth_integration_auth_func.py index 4631ece..7e8722b 100644 --- a/tests/integration/test_basic_auth_integration_auth_func.py +++ b/tests/integration/test_basic_auth_integration_auth_func.py @@ -30,7 +30,7 @@ def auth_function(username, password): }, ], ) -def test_ba002_basic_auth_login_flow(dash_br, dash_thread_server, kwargs): +def test_ba004_basic_auth_login_flow(dash_br, dash_thread_server, kwargs): app = Dash(__name__, **kwargs) app.layout = html.Div( [dcc.Input(id="input", value="initial value"), html.Div(id="output")] @@ -104,7 +104,7 @@ def both_no_auth_func_or_dict(dash_br, dash_thread_server, **kwargs): }, ], ) -def test_ba003_basic_auth_login_flow(dash_br, dash_thread_server, kwargs): +def test_ba005_basic_auth_login_flow(dash_br, dash_thread_server, kwargs): with pytest.raises(ValueError): both_dict_and_func(dash_br, dash_thread_server, **kwargs) with pytest.raises(ValueError): diff --git a/tests/integration/test_basic_auth_integration_auth_func_fastapi.py b/tests/integration/test_basic_auth_integration_auth_func_fastapi.py new file mode 100644 index 0000000..5f34036 --- /dev/null +++ b/tests/integration/test_basic_auth_integration_auth_func_fastapi.py @@ -0,0 +1,100 @@ +import pytest +import requests +from dash import Dash, Input, Output, dcc, html + +from dash_auth_async import basic_auth + +pytest.importorskip("fastapi", reason="FastAPI extra dependencies are not installed") + +TEST_USERS = { + "valid": [["hello", "world"], ["hello2", "wo:rld"]], + "invalid": [["hello", "password"]], +} + + +def auth_function(username, password): + return [username, password] in TEST_USERS["valid"] + + +@pytest.mark.parametrize( + "kwargs", + [ + {}, + {"url_base_pathname": "/app/"}, + {"url_base_pathname": "/sub/app/"}, + { + "routes_pathname_prefix": "/app/", + "requests_pathname_prefix": "/app/", + }, + ], +) +def test_ba004_basic_auth_login_flow(dash_br, dash_thread_server, kwargs): + app = Dash(__name__, backend="fastapi", **kwargs) + app.layout = html.Div( + [dcc.Input(id="input", value="initial value"), html.Div(id="output")] + ) + + @app.callback(Output("output", "children"), Input("input", "value")) + def update_output(new_value): + return new_value + + basic_auth.BasicAuth(app, auth_func=auth_function) + + dash_thread_server(app) + path_prefix = ( + app.config.get("url_base_pathname", "") + or app.config.get("requests_pathname_prefix", "") + or app.config.get("routes_pathname_prefix", "") + ) + base_url = dash_thread_server.url + path_prefix + + def test_failed_views(url): + assert requests.get(url).status_code == 401 + assert requests.get(url.strip("/") + "/_dash-layout").status_code == 401 + + test_failed_views(base_url) + + for user, password in TEST_USERS["invalid"]: + test_failed_views(base_url.replace("//", f"//{user}:{password}@")) + + for user, password in TEST_USERS["valid"]: + dash_br.driver.get(base_url.replace("//", f"//{user}:{password}@")) + dash_br.driver.get(base_url) + dash_br.wait_for_text_to_equal("#output", "initial value") + + +def both_dict_and_func(dash_br, dash_thread_server, **kwargs): + app = Dash(__name__, backend="fastapi", **kwargs) + app.layout = html.Div( + [dcc.Input(id="input", value="initial value"), html.Div(id="output")] + ) + basic_auth.BasicAuth(app, TEST_USERS["valid"], auth_func=auth_function) + return True + + +def both_no_auth_func_or_dict(dash_br, dash_thread_server, **kwargs): + app = Dash(__name__, backend="fastapi", **kwargs) + app.layout = html.Div( + [dcc.Input(id="input", value="initial value"), html.Div(id="output")] + ) + basic_auth.BasicAuth(app) + return True + + +@pytest.mark.parametrize( + "kwargs", + [ + {}, + {"url_base_pathname": "/app/"}, + {"url_base_pathname": "/sub/app/"}, + { + "routes_pathname_prefix": "/app/", + "requests_pathname_prefix": "/app/", + }, + ], +) +def test_ba005_basic_auth_login_flow(dash_br, dash_thread_server, kwargs): + with pytest.raises(ValueError): + both_dict_and_func(dash_br, dash_thread_server, **kwargs) + with pytest.raises(ValueError): + both_no_auth_func_or_dict(dash_br, dash_thread_server, **kwargs) diff --git a/tests/integration/test_basic_auth_integration_auth_func_quart.py b/tests/integration/test_basic_auth_integration_auth_func_quart.py index ac87026..d2c493a 100644 --- a/tests/integration/test_basic_auth_integration_auth_func_quart.py +++ b/tests/integration/test_basic_auth_integration_auth_func_quart.py @@ -32,7 +32,7 @@ def auth_function(username, password): }, ], ) -def test_ba002_basic_auth_login_flow(dash_br, dash_thread_server, kwargs): +def test_ba004_basic_auth_login_flow(dash_br, dash_thread_server, kwargs): app = Dash(__name__, backend="quart", **kwargs) app.layout = html.Div( [dcc.Input(id="input", value="initial value"), html.Div(id="output")] @@ -106,7 +106,7 @@ def both_no_auth_func_or_dict(dash_br, dash_thread_server, **kwargs): }, ], ) -def test_ba003_basic_auth_login_flow(dash_br, dash_thread_server, kwargs): +def test_ba005_basic_auth_login_flow(dash_br, dash_thread_server, kwargs): with pytest.raises(ValueError): both_dict_and_func(dash_br, dash_thread_server, **kwargs) with pytest.raises(ValueError): diff --git a/tests/integration/test_basic_auth_integration_fastapi.py b/tests/integration/test_basic_auth_integration_fastapi.py new file mode 100644 index 0000000..b8c419b --- /dev/null +++ b/tests/integration/test_basic_auth_integration_fastapi.py @@ -0,0 +1,120 @@ +import pytest +import requests +from dash import Dash, Input, Output, dcc, html + +from dash_auth_async import BasicAuth, add_public_routes, protected + +pytest.importorskip("fastapi", reason="FastAPI extra dependencies are not installed") + +TEST_USERS = { + "valid": [["hello", "world"], ["hello2", "wo:rld"]], + "invalid": [["hello", "password"]], +} + + +@pytest.mark.parametrize( + "kwargs", + [ + {}, + {"url_base_pathname": "/app/"}, + {"url_base_pathname": "/sub/app/"}, + { + "routes_pathname_prefix": "/app/", + "requests_pathname_prefix": "/app/", + }, + ], +) +def test_ba001_basic_auth_login_flow(dash_br, dash_thread_server, kwargs): + app = Dash(__name__, backend="fastapi", **kwargs) + app.layout = html.Div( + [dcc.Input(id="input", value="initial value"), html.Div(id="output")] + ) + + @app.callback(Output("output", "children"), Input("input", "value")) + def update_output(new_value): + return new_value + + BasicAuth(app, TEST_USERS["valid"], public_routes=["/home"]) + add_public_routes(app, ["/user//public"]) + + dash_thread_server(app) + path_prefix = ( + app.config.get("url_base_pathname", "") + or app.config.get("requests_pathname_prefix", "") + or app.config.get("routes_pathname_prefix", "") + ) + base_url = dash_thread_server.url + path_prefix + + def test_failed_views(url): + assert requests.get(url).status_code == 401 + + def test_successful_views(url): + assert requests.get(url.rstrip("/") + "/_dash-layout").status_code == 200 + assert requests.get(url.rstrip("/") + "/home").status_code == 200 + assert requests.get(url.rstrip("/") + "/user/john123/public").status_code == 200 + + test_failed_views(base_url) + test_successful_views(base_url) + + for user, password in TEST_USERS["invalid"]: + test_failed_views(base_url.replace("//", f"//{user}:{password}@")) + test_successful_views(base_url.replace("//", f"//{user}:{password}@")) + + for user, password in TEST_USERS["valid"]: + dash_br.driver.get(base_url.replace("//", f"//{user}:{password}@")) + dash_br.driver.get(base_url) + dash_br.wait_for_text_to_equal("#output", "initial value") + + +@pytest.mark.parametrize( + "kwargs", + [ + {}, + {"url_base_pathname": "/app/"}, + {"url_base_pathname": "/sub/app/"}, + { + "routes_pathname_prefix": "/app/", + "requests_pathname_prefix": "/app/", + }, + ], +) +def test_ba002_basic_auth_groups(dash_br, dash_thread_server, kwargs): + app = Dash(__name__, backend="fastapi", **kwargs) + app.layout = html.Div( + [dcc.Input(id="input", value="initial value"), html.Div(id="output")] + ) + + @app.callback( + Output("output", "children"), + Input("input", "value"), + groups=["admin"], + ) + @protected( + unauthenticated_output="unauthenticated", + missing_permissions_output="forbidden", + groups=["admin"], + ) + def update_output(new_value): + return new_value + + BasicAuth( + app, + TEST_USERS["valid"], + public_routes=["/home"], + user_groups={"hello": ["admin"]}, + secret_key="Test!", + ) + + dash_thread_server(app) + path_prefix = ( + app.config.get("url_base_pathname", "") + or app.config.get("requests_pathname_prefix", "") + or app.config.get("routes_pathname_prefix", "") + ) + base_url = dash_thread_server.url + path_prefix + + for user, password in TEST_USERS["valid"]: + dash_br.driver.get(base_url.replace("//", f"//{user}:{password}@")) + dash_br.driver.get(base_url) + expected = "initial value" if user == "hello" else "forbidden" + dash_br.wait_for_text_to_equal("#output", expected) diff --git a/tests/integration/test_oidc_auth_fastapi.py b/tests/integration/test_oidc_auth_fastapi.py new file mode 100644 index 0000000..0bd02a6 --- /dev/null +++ b/tests/integration/test_oidc_auth_fastapi.py @@ -0,0 +1,240 @@ +from unittest.mock import patch + +import pytest +import requests +from dash import Dash, Input, Output, dcc, html + +from dash_auth_async import OIDCAuth, protected_callback + +pytest.importorskip("fastapi", reason="FastAPI extra dependencies are not installed") + +from starlette.responses import RedirectResponse + +_OAUTH_APP = "authlib.integrations.starlette_client.apps.StarletteOAuth2App" +_METADATA_URL = "https://idp2.com/oidc/2/.well-known/openid-configuration" + + +async def valid_authorize_redirect(self, request, redirect_uri, *args, **kwargs): + return RedirectResponse("/" + redirect_uri.split("/", maxsplit=3)[-1]) + + +async def invalid_authorize_redirect(self, request, redirect_uri, *args, **kwargs): + base_url = "/" + redirect_uri.split("/", maxsplit=3)[-1] + return RedirectResponse( + f"{base_url}?error=Unauthorized&error_description=something went wrong" + ) + + +async def valid_authorize_access_token(self, request, *args, **kwargs): + return { + "userinfo": {"email": "a.b@mail.com", "groups": ["viewer", "editor"]}, + "refresh_token": "ABCDEF", + } + + +@pytest.mark.parametrize( + "kwargs", + [ + {}, + {"url_base_pathname": "/app/"}, + {"url_base_pathname": "/sub/app/"}, + { + "routes_pathname_prefix": "/app/", + "requests_pathname_prefix": "/app/", + }, + ], +) +@patch(f"{_OAUTH_APP}.authorize_redirect", valid_authorize_redirect) +@patch(f"{_OAUTH_APP}.authorize_access_token", valid_authorize_access_token) +def test_oaf001_oidc_auth_login_flow_success(dash_br, dash_thread_server, kwargs): + app = Dash(__name__, backend="fastapi", **kwargs) + app.layout = html.Div( + [ + dcc.Input(id="input", value="initial value"), + html.Div(id="output1"), + html.Div(id="output2"), + html.Div("static", id="output3"), + html.Div("static", id="output4"), + html.Div("not static", id="output5"), + ] + ) + + @app.callback(Output("output1", "children"), Input("input", "value")) + def update_output1(new_value): + return new_value + + @protected_callback( + Output("output2", "children"), + Input("input", "value"), + groups=["editor"], + check_type="one_of", + ) + def update_output2(new_value): + return new_value + + @protected_callback( + Output("output3", "children"), + Input("input", "value"), + groups=["admin"], + check_type="one_of", + ) + def update_output3(new_value): + return new_value + + @protected_callback( + Output("output4", "children"), + Input("input", "value"), + groups=["viewer"], + check_type="none_of", + ) + def update_output4(new_value): + return new_value + + @protected_callback( + Output("output5", "children"), + Input("input", "value"), + groups=["viewer", "editor"], + check_type="all_of", + ) + def update_output5(new_value): + return new_value + + oidc = OIDCAuth(app, secret_key="Test") + oidc.register_provider( + "oidc", + token_endpoint_auth_method="client_secret_post", + client_id="", + client_secret="", + server_metadata_url=_METADATA_URL, + ) + dash_thread_server(app) + path_prefix = ( + app.config.get("url_base_pathname", "") + or app.config.get("requests_pathname_prefix", "") + or app.config.get("routes_pathname_prefix", "") + ) + base_url = dash_thread_server.url + path_prefix + + assert requests.get(base_url).status_code == 200 + + dash_br.driver.get(base_url) + dash_br.wait_for_text_to_equal("#output1", "initial value") + dash_br.wait_for_text_to_equal("#output2", "initial value") + dash_br.wait_for_text_to_equal("#output3", "static") + dash_br.wait_for_text_to_equal("#output4", "static") + dash_br.wait_for_text_to_equal("#output5", "initial value") + + +@pytest.mark.parametrize( + "kwargs", + [ + {}, + {"url_base_pathname": "/app/"}, + {"url_base_pathname": "/sub/app/"}, + { + "routes_pathname_prefix": "/app/", + "requests_pathname_prefix": "/app/", + }, + ], +) +@patch(f"{_OAUTH_APP}.authorize_redirect", invalid_authorize_redirect) +def test_oaf002_oidc_auth_login_fail(dash_thread_server, kwargs): + app = Dash(__name__, backend="fastapi", **kwargs) + app.layout = html.Div( + [dcc.Input(id="input", value="initial value"), html.Div(id="output")] + ) + + @app.callback(Output("output", "children"), Input("input", "value")) + def update_output(new_value): + return new_value + + oidc = OIDCAuth(app, public_routes=["/public"], secret_key="Test") + oidc.register_provider( + "oidc", + token_endpoint_auth_method="client_secret_post", + client_id="", + client_secret="", + server_metadata_url=_METADATA_URL, + ) + dash_thread_server(app) + path_prefix = ( + app.config.get("url_base_pathname", "") + or app.config.get("requests_pathname_prefix", "") + or app.config.get("routes_pathname_prefix", "") + ) + base_url = dash_thread_server.url + path_prefix + + def test_unauthorized(url): + r = requests.get(url) + assert r.status_code == 401 + assert r.text == "Unauthorized: something went wrong" + + def test_authorized(url): + assert requests.get(url).status_code == 200 + + test_unauthorized(base_url) + test_authorized(base_url.rstrip("/") + "/public") + + +@pytest.mark.parametrize( + "kwargs", + [ + {}, + {"url_base_pathname": "/app/"}, + {"url_base_pathname": "/sub/app/"}, + { + "routes_pathname_prefix": "/app/", + "requests_pathname_prefix": "/app/", + }, + ], +) +@patch(f"{_OAUTH_APP}.authorize_redirect", valid_authorize_redirect) +@patch(f"{_OAUTH_APP}.authorize_access_token", valid_authorize_access_token) +def test_oaf003_oidc_auth_login_several_idp(dash_br, dash_thread_server, kwargs): + app = Dash(__name__, backend="fastapi", **kwargs) + app.layout = html.Div( + [ + dcc.Input(id="input", value="initial value"), + html.Div(id="output1"), + ] + ) + + @app.callback(Output("output1", "children"), Input("input", "value")) + def update_output1(new_value): + return new_value + + oidc = OIDCAuth(app, secret_key="Test") + oidc.register_provider( + "idp1", + token_endpoint_auth_method="client_secret_post", + client_id="", + client_secret="", + server_metadata_url=_METADATA_URL, + ) + oidc.register_provider( + "idp2", + token_endpoint_auth_method="client_secret_post", + client_id="", + client_secret="", + server_metadata_url=_METADATA_URL, + ) + + dash_thread_server(app) + path_prefix = ( + app.config.get("url_base_pathname", "") + or app.config.get("requests_pathname_prefix", "") + or app.config.get("routes_pathname_prefix", "") + ) + base_url = dash_thread_server.url + base_url_prefix = (base_url + path_prefix).rstrip("/") + assert requests.get(base_url).status_code == 400 + assert requests.get(base_url_prefix).status_code == 400 + + assert requests.get(base_url + "/oidc/idp1/login").status_code == 200 + assert requests.get(base_url + "/oidc/logout").status_code == 200 + assert requests.get(base_url).status_code == 400 + assert requests.get(base_url + "/oidc/idp2/login").status_code == 200 + + dash_br.driver.get(base_url + "/oidc/idp2/login") + dash_br.driver.get(base_url_prefix) + dash_br.wait_for_text_to_equal("#output1", "initial value") diff --git a/tests/integration/test_oidc_auth_fastapi_wiring.py b/tests/integration/test_oidc_auth_fastapi_wiring.py new file mode 100644 index 0000000..280dc0d --- /dev/null +++ b/tests/integration/test_oidc_auth_fastapi_wiring.py @@ -0,0 +1,85 @@ +"""OIDCAuth construction/wiring on a FastAPI-backed Dash app (no browser).""" + +import asyncio + +import pytest +from dash import Dash, dcc, html + +from dash_auth_async import OIDCAuth +from dash_auth_async.oidc_auth import get_oauth + +pytest.importorskip("fastapi", reason="FastAPI extra dependencies are not installed") + +_METADATA_URL = "https://idp2.com/oidc/2/.well-known/openid-configuration" + + +def _make_oidc_app(): + app = Dash(__name__, backend="fastapi") + app.layout = html.Div( + [dcc.Input(id="input", value="initial value"), html.Div(id="output")] + ) + oidc = OIDCAuth(app, secret_key="Test") + oidc.register_provider( + "idp", + token_endpoint_auth_method="client_secret_post", + client_id="", + client_secret="", + server_metadata_url=_METADATA_URL, + ) + return app, oidc + + +def test_fastapi_backend_uses_starlette_oauth_registry(): + from authlib.integrations.starlette_client import ( + OAuth as StarletteOAuth, + StarletteOAuth2App, + ) + + app, oidc = _make_oidc_app() + assert isinstance(oidc.oauth, StarletteOAuth) + assert app.server.state.dash_auth_oauth is oidc.oauth + assert isinstance(oidc.get_oauth_client("idp"), StarletteOAuth2App) + + +def test_oidc_routes_registered_with_translated_idp_placeholder(): + app, _ = _make_oidc_app() + paths = {route.path for route in app.server.routes if hasattr(route, "path")} + assert "/oidc/{idp}/login" in paths + assert "/oidc/{idp}/callback" in paths + assert "/oidc/logout" in paths + names = {route.name for route in app.server.routes if hasattr(route, "name")} + assert {"oidc_login", "oidc_logout", "oidc_callback"} <= names + + +def test_get_oauth_finds_state_registry(): + app, oidc = _make_oidc_app() + assert get_oauth(app) is oidc.oauth + + +def test_callback_unknown_idp_returns_400(): + from starlette.requests import Request + + from dash_auth_async.backends import _current_request_var + + _, oidc = _make_oidc_app() + scope = { + "type": "http", + "method": "GET", + "path": "/oidc/nope/callback", + "headers": [], + "query_string": b"", + } + request = Request(scope) + + async def run(): + # The merged async callback resolves the request from the ContextVar + # (set by the auth middleware in a live request) rather than a param. + token = _current_request_var.set(request) + try: + response = await oidc._callback_async("nope") + finally: + _current_request_var.reset(token) + assert response.status_code == 400 + assert b"not a valid registered idp" in response.body + + asyncio.run(run()) diff --git a/tests/integration/test_oidc__auth_quart_wiring.py b/tests/integration/test_oidc_auth_quart_wiring.py similarity index 94% rename from tests/integration/test_oidc__auth_quart_wiring.py rename to tests/integration/test_oidc_auth_quart_wiring.py index bc050ad..43fb509 100644 --- a/tests/integration/test_oidc__auth_quart_wiring.py +++ b/tests/integration/test_oidc_auth_quart_wiring.py @@ -3,7 +3,7 @@ import asyncio import pytest -from dash import Dash +from dash import Dash, dcc, html from dash_auth_async import OIDCAuth from dash_auth_async.oidc_auth import get_oauth @@ -16,6 +16,9 @@ def _make_oidc_app(): app = Dash(__name__, backend="quart") + app.layout = html.Div( + [dcc.Input(id="input", value="initial value"), html.Div(id="output")] + ) oidc = OIDCAuth(app, secret_key="Test") oidc.register_provider( "idp", diff --git a/tests/integration/test_oidc_state_csrf.py b/tests/integration/test_oidc_state_csrf.py new file mode 100644 index 0000000..30bed0a --- /dev/null +++ b/tests/integration/test_oidc_state_csrf.py @@ -0,0 +1,63 @@ +"""OIDC OAuth state/CSRF validation, driven through real authlib (no mock IDP). + +The browser OIDC tests patch ``authorize_redirect``/``authorize_access_token``, +so authlib's anti-CSRF state check never actually runs there. These tests drive +the *real* authlib state path on a Flask backend via its test client to lock the +invariant against future authlib changes: ``/login`` stores the generated state +in the session, and ``/callback`` presented with a tampered or missing state must +be rejected with 401 (the ``except OAuthError`` branch in ``OIDCAuth.callback``), +never silently accepted. + +authlib validates state before any token-endpoint call, so registering the +provider with explicit endpoints keeps the whole flow offline. +""" + +from urllib.parse import parse_qs, urlparse + +from dash import Dash, html + +from dash_auth_async import OIDCAuth + +_AUTHORIZE_URL = "https://idp.example/authorize" +_TOKEN_URL = "https://idp.example/token" + + +def _make_client(): + app = Dash(__name__) + app.layout = html.Div("state-csrf") # Dash validates layout on first request + oidc = OIDCAuth(app, secret_key="state-csrf-secret") + oidc.register_provider( + "idp", + client_id="client-id", + client_secret="client-secret", + authorize_url=_AUTHORIZE_URL, + access_token_url=_TOKEN_URL, + ) + # The Flask test client persists the session cookie across requests, so the + # state stored at /login is presented back at /callback automatically. + return app.server.test_client() + + +def _login_and_capture_state(client) -> str: + resp = client.get("/oidc/idp/login") + assert resp.status_code == 302 + query = parse_qs(urlparse(resp.headers["Location"]).query) + # authlib generated a state and stored it in the session before redirecting. + assert "state" in query + return query["state"][0] + + +def test_oidc_callback_rejects_tampered_state(): + client = _make_client() + real_state = _login_and_capture_state(client) + + resp = client.get(f"/oidc/idp/callback?code=fake-code&state={real_state}-tampered") + assert resp.status_code == 401 + + +def test_oidc_callback_rejects_missing_state(): + client = _make_client() + _login_and_capture_state(client) + + resp = client.get("/oidc/idp/callback?code=fake-code") + assert resp.status_code == 401 diff --git a/tests/integration/test_protected_callback_websocket_integration_fastapi.py b/tests/integration/test_protected_callback_websocket_integration_fastapi.py new file mode 100644 index 0000000..f860e2b --- /dev/null +++ b/tests/integration/test_protected_callback_websocket_integration_fastapi.py @@ -0,0 +1,92 @@ +"""End-to-end integration tests for a *protected WebSocket* callback on FastAPI. + +FastAPI sibling of test_protected_callback_websocket_integration_quart.py. The +callback opts into the WebSocket transport (``websocket=True``) and streams +updates with ``set_props``; the final visible text is streamed, not returned, so +it can only arrive over the socket. The auth gate must wrap it correctly -- an +under-privileged user gets the fallback and the stream body never runs. +""" + +import asyncio + +import dash +import pytest +from dash import Dash, Input, Output, html, set_props + +from dash_auth_async import BasicAuth, protected_callback + +pytest.importorskip("fastapi", reason="FastAPI extra dependencies are not installed") + +TEST_USERS = {"hello": "world", "hello2": "wo:rld"} +USER_GROUPS = {"hello": ["admin"]} + + +def _build_app() -> Dash: + # websocket_callbacks=True makes the *client* open the socket; per-callback + # websocket=True then routes this callback over it. + app = Dash(__name__, backend="fastapi", websocket_callbacks=True) + app.layout = html.Div( + [ + html.Button("Start stream", id="ws-start"), + html.Div("idle", id="ws-out"), + ] + ) + + @protected_callback( + Output("ws-out", "children"), + Input("ws-start", "n_clicks"), + groups=["admin"], + missing_permissions_output="forbidden", + prevent_initial_call=True, + websocket=True, + ) + async def stream_ws(_n_clicks): + ws = getattr(dash.ctx, "websocket", None) + for i in range(1, 4): + if ws is not None and ws.is_shutdown: + return dash.no_update + set_props("ws-out", {"children": f"tick {i}/3"}) + await asyncio.sleep(0) + # Final visible text is streamed, not returned: proves the WebSocket push. + set_props("ws-out", {"children": "streamed"}) + return dash.no_update + + return app + + +def _login(dash_br, base_url, username, password): + dash_br.driver.get(base_url.replace("//", f"//{username}:{password}@")) + dash_br.driver.get(base_url) + + +def test_pcwf001_authorized_protected_websocket_callback_streams_set_props( + dash_br, dash_thread_server +): + """An authorised user sees the value pushed via ``set_props`` over the socket.""" + app = _build_app() + BasicAuth(app, TEST_USERS, user_groups=USER_GROUPS, secret_key="Test!") + + dash_thread_server(app) + base_url = dash_thread_server.url + _login(dash_br, base_url, "hello", "world") + + dash_br.wait_for_text_to_equal("#ws-out", "idle") + dash_br.find_element("#ws-start").click() + dash_br.wait_for_text_to_equal("#ws-out", "streamed") + + +def test_pcwf002_missing_permissions_protected_websocket_callback_emits_fallback( + dash_br, dash_thread_server +): + """An authenticated user without the group gets the fallback, never the stream.""" + app = _build_app() + BasicAuth(app, TEST_USERS, user_groups=USER_GROUPS, secret_key="Test!") + + dash_thread_server(app) + base_url = dash_thread_server.url + # "hello2" authenticates but has no groups -> admin gate rejects it. + _login(dash_br, base_url, "hello2", "wo:rld") + + dash_br.wait_for_text_to_equal("#ws-out", "idle") + dash_br.find_element("#ws-start").click() + dash_br.wait_for_text_to_equal("#ws-out", "forbidden") diff --git a/tests/integration/test_websocket_reconnect_on_login_fastapi.py b/tests/integration/test_websocket_reconnect_on_login_fastapi.py new file mode 100644 index 0000000..d2d06f2 --- /dev/null +++ b/tests/integration/test_websocket_reconnect_on_login_fastapi.py @@ -0,0 +1,103 @@ +"""Regression: logging in must close the same browser's stale pre-login socket. + +A WebSocket's session is frozen at the handshake cookie. A socket opened before +login is therefore permanently unauthenticated, and Dash's renderer only +reconnects on a socket *close* -- so without intervention the first protected +callback over that socket is rejected (4401) and silently dropped, and the user +must click a second time (the rejection is what finally triggers the reconnect). + +Design A closes that gap from the server: a pre-login ``dac_client`` cookie ties +an anonymous socket back to the browser, and the moment that browser +authenticates over HTTP the server closes its stale anonymous socket. The +renderer's existing auto-reconnect then dials a fresh, authenticated handshake +*before* the first click. + +This test drives the raw socket (no browser) and asserts that proactive close. +FastAPI sibling of the Quart variant. +""" + +import pytest +from dash import Dash, Input, Output, callback, html + +from dash_auth_async import BasicAuth + +pytest.importorskip("fastapi", reason="FastAPI extra dependencies are not installed") +requests = pytest.importorskip("requests", reason="requests is not installed") +websocket = pytest.importorskip( + "websocket", reason="websocket-client (the 'websocket' module) is not installed" +) + + +def _build_app() -> Dash: + app = Dash(__name__, backend="fastapi", websocket_callbacks=True) + app.layout = html.Div( + [html.Button("p", id="priv-in"), html.Div("idle", id="priv-out")] + ) + + @callback( + Output("priv-out", "children"), + Input("priv-in", "n_clicks"), + prevent_initial_call=True, + websocket=True, + ) + def private(_n): + return "TOP-SECRET-USER-DATA" + + BasicAuth(app, {"hello": "world"}, secret_key="Test!") + return app + + +def _socket_was_closed(conn, timeout=6) -> bool: + """Return True if the server closed ``conn`` within ``timeout`` seconds. + + websocket-client surfaces a server close as an empty ``recv()`` or a + ``WebSocketConnectionClosedException``; a still-open idle socket raises a + timeout instead. + """ + conn.settimeout(timeout) + try: + return conn.recv() in {"", b"", None} + except websocket.WebSocketConnectionClosedException: + return True + except Exception: + return False + + +def test_login_closes_stale_anonymous_socket(dash_thread_server): + app = _build_app() + dash_thread_server(app) + base = dash_thread_server.url + ws_url = base.replace("http://", "ws://") + "/_dash-ws-callback" + + session = requests.Session() + + # First anonymous contact must mint a pre-login client-id cookie -- even on + # the BasicAuth 401 challenge -- so the socket can be tied to this browser. + session.get(base, timeout=8) + client_id = session.cookies.get("dac_client") + assert client_id, "expected a pre-login 'dac_client' cookie on first contact" + + # Open an anonymous socket carrying only dac_client (no session): this models + # the SharedWorker socket opened on a public page before login. + conn = websocket.create_connection( + ws_url, + header=[f"Origin: {base}", f"Cookie: dac_client={client_id}"], + timeout=8, + suppress_origin=True, + ) + try: + # The same browser now authenticates over HTTP (same dac_client cookie). + resp = session.get(base, auth=("hello", "world"), timeout=8) + assert resp.status_code == 200, resp.status_code + + # The server must proactively close the stale anonymous socket so the + # renderer reconnects authenticated before the user's first click. + assert _socket_was_closed(conn), ( + "stale pre-login socket was not closed after the same browser " + "logged in -- the first protected click would be dropped" + ) + finally: + try: + conn.close() + except Exception: + pass diff --git a/tests/integration/test_websocket_reconnect_on_login_quart.py b/tests/integration/test_websocket_reconnect_on_login_quart.py new file mode 100644 index 0000000..7b58380 --- /dev/null +++ b/tests/integration/test_websocket_reconnect_on_login_quart.py @@ -0,0 +1,87 @@ +"""Regression: logging in must close the same browser's stale pre-login socket. + +Quart sibling of test_websocket_reconnect_on_login_fastapi.py -- see that file +for the full rationale. A socket's session is frozen at its handshake cookie, so +a pre-login socket stays unauthenticated; Design A closes it the moment the same +browser (tied by the pre-login ``dac_client`` cookie) authenticates over HTTP, +so the renderer reconnects authenticated before the first click. +""" + +import pytest +from dash import Dash, Input, Output, callback, html + +from dash_auth_async import BasicAuth + +pytest.importorskip("quart", reason="Quart extra dependencies are not installed") +requests = pytest.importorskip("requests", reason="requests is not installed") +websocket = pytest.importorskip( + "websocket", reason="websocket-client (the 'websocket' module) is not installed" +) + + +def _build_app() -> Dash: + app = Dash(__name__, backend="quart", websocket_callbacks=True) + app.layout = html.Div( + [html.Button("p", id="priv-in"), html.Div("idle", id="priv-out")] + ) + + @callback( + Output("priv-out", "children"), + Input("priv-in", "n_clicks"), + prevent_initial_call=True, + websocket=True, + ) + def private(_n): + return "TOP-SECRET-USER-DATA" + + BasicAuth(app, {"hello": "world"}, secret_key="Test!") + return app + + +def _socket_was_closed(conn, timeout=6) -> bool: + """Return True if the server closed ``conn`` within ``timeout`` seconds.""" + conn.settimeout(timeout) + try: + return conn.recv() in {"", b"", None} + except websocket.WebSocketConnectionClosedException: + return True + except Exception: + return False + + +def test_login_closes_stale_anonymous_socket(dash_thread_server): + app = _build_app() + dash_thread_server(app) + base = dash_thread_server.url + ws_url = base.replace("http://", "ws://") + "/_dash-ws-callback" + + session = requests.Session() + + # First anonymous contact must mint a pre-login client-id cookie -- even on + # the BasicAuth 401 challenge -- so the socket can be tied to this browser. + session.get(base, timeout=8) + client_id = session.cookies.get("dac_client") + assert client_id, "expected a pre-login 'dac_client' cookie on first contact" + + # Open an anonymous socket carrying only dac_client (no session). + conn = websocket.create_connection( + ws_url, + header=[f"Origin: {base}", f"Cookie: dac_client={client_id}"], + timeout=8, + suppress_origin=True, + ) + try: + # The same browser now authenticates over HTTP (same dac_client cookie). + resp = session.get(base, auth=("hello", "world"), timeout=8) + assert resp.status_code == 200, resp.status_code + + # The server must proactively close the stale anonymous socket. + assert _socket_was_closed(conn), ( + "stale pre-login socket was not closed after the same browser " + "logged in -- the first protected click would be dropped" + ) + finally: + try: + conn.close() + except Exception: + pass diff --git a/tests/integration/test_websocket_security_fastapi.py b/tests/integration/test_websocket_security_fastapi.py new file mode 100644 index 0000000..f50aecc --- /dev/null +++ b/tests/integration/test_websocket_security_fastapi.py @@ -0,0 +1,184 @@ +"""Security: the FastAPI websocket callback endpoint must enforce auth. + +Drives the raw socket with websocket-client (no browser) and asserts an +unauthenticated client cannot invoke a private callback, a public_callback still +streams unauthenticated, and an authenticated-but-under-privileged user gets the +fallback rather than the protected payload. FastAPI sibling of +test_websocket_security_quart.py. +""" + +import json + +import pytest +from dash import Dash, Input, Output, callback, html + +from dash_auth_async import BasicAuth, protected_callback, public_callback + +# Guard the optional, non-extra test deps *before* importing them. fastapi is +# checked first so the non-async matrix jobs skip cleanly without needing the +# socket client at all. +pytest.importorskip("fastapi", reason="FastAPI extra dependencies are not installed") +requests = pytest.importorskip("requests", reason="requests is not installed") +websocket = pytest.importorskip( # websocket-client (synchronous) + "websocket", + reason="websocket-client (the 'websocket' module) is not installed", +) + + +def _build_app() -> Dash: + app = Dash(__name__, backend="fastapi", websocket_callbacks=True) + app.layout = html.Div( + [ + html.Button("p", id="priv-in"), + html.Div("idle", id="priv-out"), + html.Button("u", id="pub-in"), + html.Div("idle", id="pub-out"), + ] + ) + + @callback( + Output("priv-out", "children"), + Input("priv-in", "n_clicks"), + prevent_initial_call=True, + websocket=True, + ) + def private(_n): + return "TOP-SECRET-USER-DATA" + + @public_callback( + Output("pub-out", "children"), + Input("pub-in", "n_clicks"), + prevent_initial_call=True, + websocket=True, + ) + async def public(_n): + return "PUBLIC-OK" + + BasicAuth(app, {"hello": "world"}, secret_key="Test!") + return app + + +def _build_app_with_protected_admin_callback() -> Dash: + """App whose private callback is group-gated to ``admin`` over the socket.""" + app = Dash(__name__, backend="fastapi", websocket_callbacks=True) + app.layout = html.Div( + [ + html.Button("p", id="priv-in"), + html.Div("idle", id="priv-out"), + ] + ) + + @protected_callback( + Output("priv-out", "children"), + Input("priv-in", "n_clicks"), + groups=["admin"], + missing_permissions_output="forbidden", + prevent_initial_call=True, + websocket=True, + ) + async def private(_n): + return "TOP-SECRET-ADMIN-DATA" + + BasicAuth( + app, + {"admin": "pw", "viewer": "pw"}, + user_groups={"admin": ["admin"]}, # "viewer" authenticates with no groups + secret_key="Test!", + ) + return app + + +def _login_cookie_header(base_url, username, password) -> str: + """Authenticate over HTTP and return the session ``Cookie`` header value. + + The auth hook runs ``is_authorized`` on this request, which stashes + ``session["user"]`` (with groups) and sets the session cookie -- the same + cookie the browser would send at the WS handshake. + """ + resp = requests.get(base_url, auth=(username, password), timeout=8) + assert resp.status_code == 200, resp.status_code + return "; ".join(f"{c.name}={c.value}" for c in resp.cookies) + + +def _send_callback_request(ws_url, origin, output, comp_id, in_id, cookie=None): + header = [f"Origin: {origin}"] + if cookie: + header.append(f"Cookie: {cookie}") + # suppress_origin: FastAPI/Uvicorn rejects a duplicate ``Origin`` header with + # HTTP 400 (Quart is lenient), and websocket-client injects its own by + # default -- suppress it so only our explicit Origin reaches the server. + conn = websocket.create_connection( + ws_url, header=header, timeout=8, suppress_origin=True + ) + try: + conn.send( + json.dumps( + { + "type": "callback_request", + "requestId": "1", + "rendererId": "r1", + "payload": { + "output": output, + "outputs": {"id": comp_id, "property": "children"}, + "inputs": [{"id": in_id, "property": "n_clicks", "value": 1}], + "changedPropIds": [f"{in_id}.n_clicks"], + "state": [], + }, + } + ) + ) + frames = [] + for _ in range(5): + try: + frames.append(str(conn.recv())) + except Exception: + break + return "".join(frames) + finally: + try: + conn.close() + except Exception: + pass + + +def test_unauthenticated_ws_cannot_invoke_private_callback(dash_thread_server): + app = _build_app() + dash_thread_server(app) + base = dash_thread_server.url + ws_url = base.replace("http://", "ws://") + "/_dash-ws-callback" + + received = _send_callback_request( + ws_url, base, "priv-out.children", "priv-out", "priv-in" + ) + assert "TOP-SECRET-USER-DATA" not in received + + +def test_unauthenticated_ws_can_invoke_public_callback(dash_thread_server): + app = _build_app() + dash_thread_server(app) + base = dash_thread_server.url + ws_url = base.replace("http://", "ws://") + "/_dash-ws-callback" + + received = _send_callback_request( + ws_url, base, "pub-out.children", "pub-out", "pub-in" + ) + assert "PUBLIC-OK" in received + + +def test_authenticated_wrong_group_ws_gets_fallback_not_secret(dash_thread_server): + """Authenticated-but-under-privileged over the raw socket: the group gate + renders ``missing_permissions_output`` and never leaks the admin payload. + """ + app = _build_app_with_protected_admin_callback() + dash_thread_server(app) + base = dash_thread_server.url + ws_url = base.replace("http://", "ws://") + "/_dash-ws-callback" + + # "viewer" authenticates (cookie set) but lacks the "admin" group. + cookie = _login_cookie_header(base, "viewer", "pw") + + received = _send_callback_request( + ws_url, base, "priv-out.children", "priv-out", "priv-in", cookie=cookie + ) + assert "TOP-SECRET-ADMIN-DATA" not in received + assert "forbidden" in received diff --git a/tests/unit/test_backends.py b/tests/unit/test_backends.py index 84466b6..3d66d7e 100644 --- a/tests/unit/test_backends.py +++ b/tests/unit/test_backends.py @@ -1,9 +1,11 @@ from dash import Dash +from dash_auth_async import backends from dash_auth_async.backends import ( FlaskBackend, detect_backend, get_active_backend, + set_active_backend, ) @@ -16,6 +18,21 @@ def test_active_backend_defaults_to_flask(): assert isinstance(get_active_backend(), FlaskBackend) +def test_set_active_backend_overrides_default(): + sentinel = FlaskBackend() + set_active_backend(sentinel) + # The process-global helper returns exactly what Auth.__init__ registered, + # not a freshly detected backend — this is the cache public_routes reads. + assert get_active_backend() is sentinel + + +def test_default_backend_is_not_constructed_eagerly_at_import(): + # B2: the Flask fallback is built lazily inside get_active_backend(), not + # as a module-level _DEFAULT_BACKEND at import, so Flask isn't cemented as + # the default at module load in a non-Flask process. + assert not hasattr(backends, "_DEFAULT_BACKEND") + + def test_flask_backend_url_for_and_redirect(): app = Dash(__name__) server = app.server @@ -30,3 +47,28 @@ def target(): response = backend.redirect("/target") assert response.status_code == 302 assert response.headers["Location"] == "/target" + + +def test_flask_backend_new_method_defaults(): + import flask + + from dash_auth_async.backends import FlaskBackend + + backend = FlaskBackend() + + # coerce_response is pass-through on Flask + sentinel = ("body", 401, {"X": "y"}) + assert backend.coerce_response(sentinel) is sentinel + + # setup_session sets secret_key; session_configured reflects it + server = flask.Flask(__name__) + assert backend.session_configured(server) is False + backend.setup_session(server, "Test!") + assert server.secret_key == "Test!" + assert backend.session_configured(server) is True + + # make_oauth returns a flask_client OAuth bound to the server + from authlib.integrations.flask_client import OAuth as FlaskOAuth + + oauth = backend.make_oauth(server) + assert isinstance(oauth, FlaskOAuth) diff --git a/tests/unit/test_backends_fastapi.py b/tests/unit/test_backends_fastapi.py new file mode 100644 index 0000000..716458a --- /dev/null +++ b/tests/unit/test_backends_fastapi.py @@ -0,0 +1,369 @@ +import pytest +from dash import Dash + +pytest.importorskip("fastapi", reason="FastAPI extra dependencies are not installed") + +from starlette.requests import Request + +from dash_auth_async.backends import ( + FastAPIBackend, + _current_request_var, + detect_backend, + get_active_backend, + set_active_backend, +) + + +def _bare_request(path="/", session=None): + scope = { + "type": "http", + "method": "GET", + "path": path, + "headers": [], + "query_string": b"", + } + if session is not None: + scope["session"] = session + return Request(scope) + + +def test_detect_backend_fastapi(): + app = Dash(__name__, backend="fastapi") + assert isinstance(detect_backend(app.server), FastAPIBackend) + + +def test_active_backend_roundtrip(): + backend = FastAPIBackend() + set_active_backend(backend) + assert get_active_backend() is backend + + +def test_contextvar_set_reset_and_request_context(): + backend = FastAPIBackend() + assert backend.has_request_context() is False + + req = _bare_request() + token = _current_request_var.set(req) + try: + assert backend.has_request_context() is True + assert backend.request is req + finally: + _current_request_var.reset(token) + assert backend.has_request_context() is False + + +def test_session_without_middleware_raises_runtimeerror(): + backend = FastAPIBackend() + req = _bare_request() # no "session" in scope + token = _current_request_var.set(req) + try: + with pytest.raises(RuntimeError): + _ = backend.session + finally: + _current_request_var.reset(token) + + +def test_session_off_request_raises_runtimeerror(): + # request is None outside any request context. The scope-based check + # must surface RuntimeError (a caught, "not authenticated" signal), not + # an AttributeError on None — and never relies on a -O-stripped assert. + backend = FastAPIBackend() + assert backend.request is None + with pytest.raises(RuntimeError): + _ = backend.session + + +def test_session_present_returns_mapping(): + backend = FastAPIBackend() + req = _bare_request(session={"user": {"email": "a.b@mail.com"}}) + token = _current_request_var.set(req) + try: + assert backend.session["user"]["email"] == "a.b@mail.com" + finally: + _current_request_var.reset(token) + + +def test_coerce_response_tuple_str_and_response(): + from starlette.responses import Response as StarletteResponse + + backend = FastAPIBackend() + + # Tuples carry status/headers and stay plain text (e.g. the 401 challenge). + resp = backend.coerce_response( + ("Login Required", 401, {"WWW-Authenticate": 'Basic realm="x"'}) + ) + assert resp.status_code == 401 + assert resp.headers["WWW-Authenticate"] == 'Basic realm="x"' + assert resp.body == b"Login Required" + assert resp.media_type == "text/plain" + + # A bare string is treated as HTML, matching Flask/Quart str returns (e.g. + # the OIDC logout page) so the browser renders rather than shows the markup. + resp2 = backend.coerce_response("hello") + assert resp2.status_code == 200 + assert resp2.body == b"hello" + assert resp2.media_type == "text/html" + + passthrough = StarletteResponse(content="x", status_code=204) + assert backend.coerce_response(passthrough) is passthrough + + +def test_fastapi_backend_url_for_and_redirect(): + from fastapi import FastAPI + from fastapi.testclient import TestClient + + app = FastAPI() + backend = FastAPIBackend() + + @app.get("/target", name="target") + async def target(): + return {"ok": True} + + @app.get("/probe") + async def probe(request: Request): + token = _current_request_var.set(request) + try: + return { + "url": backend.url_for("target"), + "https_url": backend.url_for("target", _external=True, _scheme="https"), + "redirect_loc": backend.redirect("/target").headers["location"], + "host": backend.current_host(), + } + finally: + _current_request_var.reset(token) + + client = TestClient(app) + data = client.get("/probe").json() + assert data["url"].endswith("/target") + assert data["https_url"].startswith("https://") + assert data["redirect_loc"] == "/target" + assert data["host"] # non-empty netloc + + +def _build_app_with_auth(decide, needs_body): + """A FastAPI app whose only middleware is the auth hook, plus an echo + route that proves the body survives middleware body-consumption.""" + from fastapi import FastAPI + + app = FastAPI() + backend = FastAPIBackend() + + @app.post("/_dash-update-component") + async def echo(request: Request): + body = await request.json() + return {"seen": body, "had_context": backend.has_request_context()} + + @app.get("/open") + async def open_route(): + return {"ok": True} + + backend.register_auth_hook(app, needs_body, decide) + return app + + +def test_auth_hook_allows_when_decide_returns_none(): + from fastapi.testclient import TestClient + + calls = [] + + def decide(path, body): + calls.append((path, body)) + + app = _build_app_with_auth( + decide, needs_body=lambda p: p == "/_dash-update-component" + ) + client = TestClient(app) + + r = client.post("/_dash-update-component", json={"output": "x", "inputs": []}) + assert r.status_code == 200 + # Body was replayed: the downstream route still parsed it. + assert r.json()["seen"] == {"output": "x", "inputs": []} + assert r.json()["had_context"] is True + # decide saw the parsed body for the callback route. + assert calls == [("/_dash-update-component", {"output": "x", "inputs": []})] + + +def test_auth_hook_short_circuits_with_tuple(): + from fastapi.testclient import TestClient + + def decide(path, body): + return ("Login Required", 401, {"WWW-Authenticate": 'Basic realm="x"'}) + + app = _build_app_with_auth(decide, needs_body=lambda p: False) + client = TestClient(app) + + r = client.get("/open") + assert r.status_code == 401 + assert r.headers["WWW-Authenticate"] == 'Basic realm="x"' + assert r.text == "Login Required" + + +def test_auth_hook_awaits_coroutine_results(): + from fastapi.testclient import TestClient + from starlette.responses import PlainTextResponse + + async def decide(path, body): + return PlainTextResponse("async-block", status_code=403) + + app = _build_app_with_auth(decide, needs_body=lambda p: False) + client = TestClient(app) + + r = client.get("/open") + assert r.status_code == 403 + assert r.text == "async-block" + + +def test_auth_hook_unparseable_body_is_treated_as_none(): + # Fail-closed path: malformed JSON on a needs_body route must reach decide + # as None (the `except Exception: body = None` branch), not raise a 500. + from fastapi.testclient import TestClient + + seen = [] + + def decide(path, body): + seen.append(body) + return ("Unauthorized", 401) if body is None else None + + app = _build_app_with_auth( + decide, needs_body=lambda p: p == "/_dash-update-component" + ) + client = TestClient(app) + + r = client.post( + "/_dash-update-component", + content="{ not valid json", + headers={"content-type": "application/json"}, + ) + assert seen == [None] + assert r.status_code == 401 + assert r.text == "Unauthorized" + + +def test_auth_hook_short_circuits_after_parsing_body(): + # The other short-circuit branch: decide returns non-None *when needs_body + # is True*, so the body is parsed first and then the request is blocked. + from fastapi.testclient import TestClient + + seen = [] + + def decide(path, body): + seen.append(body) + return ("Login Required", 401) + + app = _build_app_with_auth(decide, needs_body=lambda p: True) + client = TestClient(app) + + r = client.post("/_dash-update-component", json={"output": "x", "inputs": []}) + assert r.status_code == 401 + assert r.text == "Login Required" + # decide saw the parsed body — it ran after body parsing, not before. + assert seen == [{"output": "x", "inputs": []}] + + +def test_downstream_receive_emits_disconnect_after_body(): + # The replayed receive must deliver the cached body once, then signal + # http.disconnect — an app that polls receive() after the body (to detect + # disconnect) must not get the same body event forever. + import asyncio + + from fastapi import FastAPI + + app = FastAPI() + backend = FastAPIBackend() + backend.register_auth_hook( + app, needs_body=lambda p: True, decide=lambda path, body: None + ) + # add_middleware prepends; our auth middleware is the only/outermost one. + auth_middleware_cls = app.user_middleware[0].cls + + received = [] + + async def inner_app(scope, receive, send): + received.append(await receive()) # cached body + received.append(await receive()) # past the body → disconnect + + middleware = auth_middleware_cls(inner_app) + + scope = { + "type": "http", + "method": "POST", + "path": "/_dash-update-component", + "headers": [(b"content-type", b"application/json")], + "query_string": b"", + } + + async def receive(): + return { + "type": "http.request", + "body": b'{"output": "x", "inputs": []}', + "more_body": False, + } + + async def send(_message): + pass + + async def drive(): + await middleware(scope, receive, send) + + asyncio.run(drive()) + + assert received[0]["type"] == "http.request" + assert received[0]["body"] == b'{"output": "x", "inputs": []}' + assert received[1] == {"type": "http.disconnect"} + + +def test_setup_session_adds_session_middleware_once(): + from fastapi import FastAPI + from starlette.middleware.sessions import SessionMiddleware + + backend = FastAPIBackend() + app = FastAPI() + + assert backend.session_configured(app) is False + + backend.setup_session(app, "Test!") + assert backend.session_configured(app) is True + count = sum(1 for m in app.user_middleware if m.cls is SessionMiddleware) + assert count == 1 + + # Calling again must not add a second SessionMiddleware. + backend.setup_session(app, "Test!") + count = sum(1 for m in app.user_middleware if m.cls is SessionMiddleware) + assert count == 1 + + +def test_setup_session_wires_secure_session_to_https_only(): + from fastapi import FastAPI + from starlette.middleware.sessions import SessionMiddleware + + backend = FastAPIBackend() + + insecure = FastAPI() + backend.setup_session(insecure, "Test!") + sm = next(m for m in insecure.user_middleware if m.cls is SessionMiddleware) + assert sm.kwargs.get("https_only") is False + + secure = FastAPI() + backend.setup_session(secure, "Test!", secure_session=True) + sm = next(m for m in secure.user_middleware if m.cls is SessionMiddleware) + assert sm.kwargs.get("https_only") is True + + +def test_setup_session_noop_without_secret_key(): + from fastapi import FastAPI + + backend = FastAPIBackend() + app = FastAPI() + backend.setup_session(app, None) + assert backend.session_configured(app) is False + + +def test_config_store_read_roundtrip_via_state(): + from fastapi import FastAPI + + backend = FastAPIBackend() + app = FastAPI() # FastAPI has no .config, only .state + + assert backend.read_config(app, "PUBLIC_ROUTES", "fallback") == "fallback" + backend.store_config(app, "PUBLIC_ROUTES", ["/home"]) + assert backend.read_config(app, "PUBLIC_ROUTES") == ["/home"] diff --git a/tests/unit/test_group_protection_fastapi.py b/tests/unit/test_group_protection_fastapi.py new file mode 100644 index 0000000..b376716 --- /dev/null +++ b/tests/unit/test_group_protection_fastapi.py @@ -0,0 +1,78 @@ +import asyncio +import inspect + +import pytest + +pytest.importorskip("fastapi", reason="FastAPI extra dependencies are not installed") + +from starlette.requests import Request + +from dash_auth_async import check_groups, list_groups, protected +from dash_auth_async.backends import ( + FastAPIBackend, + _current_request_var, + set_active_backend, +) + + +def _request_with_session(session): + scope = { + "type": "http", + "method": "GET", + "path": "/", + "headers": [], + "query_string": b"", + } + if session is not None: + scope["session"] = session + return Request(scope) + + +def test_gp_list_groups_fastapi(): + set_active_backend(FastAPIBackend()) + req = _request_with_session( + {"user": {"email": "a.b@mail.com", "groups": ["default"]}} + ) + token = _current_request_var.set(req) + try: + assert list_groups() == ["default"] + assert check_groups(["default"]) is True + assert check_groups(["other"]) is False + finally: + _current_request_var.reset(token) + + +def test_gp_no_session_returns_none(): + set_active_backend(FastAPIBackend()) + req = _request_with_session(None) # no SessionMiddleware -> session raises + token = _current_request_var.set(req) + try: + assert list_groups() is None + assert check_groups(["default"]) is None + finally: + _current_request_var.reset(token) + + +def test_gp_async_protected_unauthenticated_without_session_user(): + """An async ``protected`` wrapper with no logged-in user (empty session) + short-circuits to ``unauthenticated_output`` over the FastAPI backend, while + staying a coroutine so Dash keeps it on the async dispatch path. + """ + set_active_backend(FastAPIBackend()) + req = _request_with_session({}) # session available, but no "user" + token = _current_request_var.set(req) + try: + + async def func(): + return "success" + + wrapped = protected( + unauthenticated_output="unauthenticated", + missing_permissions_output="forbidden", + groups=["admin"], + )(func) + + assert inspect.iscoroutinefunction(wrapped) + assert asyncio.run(wrapped()) == "unauthenticated" + finally: + _current_request_var.reset(token) diff --git a/tests/unit/test_group_protection_quart.py b/tests/unit/test_group_protection_quart.py index f910346..a83039a 100644 --- a/tests/unit/test_group_protection_quart.py +++ b/tests/unit/test_group_protection_quart.py @@ -1,15 +1,15 @@ import asyncio +import inspect import pytest -from dash_auth_async import check_groups, list_groups +from dash_auth_async import check_groups, list_groups, protected pytest.importorskip("quart", reason="Quart extra dependencies are not installed") def test_gp004_list_groups_quart(): - from quart import Quart - from quart import session as quart_session + from quart import Quart, session as quart_session from dash_auth_async.backends import QuartBackend, set_active_backend @@ -27,3 +27,34 @@ async def run(): assert check_groups(["other"]) is False asyncio.run(run()) + + +def test_gp_async_protected_unauthenticated_without_session_user(): + """An async ``protected`` wrapper with no logged-in user short-circuits to + ``unauthenticated_output`` over the Quart backend, while staying a coroutine + so Dash keeps it on the async dispatch path. + """ + from quart import Quart + + from dash_auth_async.backends import QuartBackend, set_active_backend + + app = Quart(__name__) + app.secret_key = "Test!" + + async def func(): + return "success" + + async def run(): + async with app.test_request_context("/", method="GET"): # ty: ignore[invalid-context-manager] + # No session["user"] -> unauthenticated. + set_active_backend(QuartBackend()) + wrapped = protected( + unauthenticated_output="unauthenticated", + missing_permissions_output="forbidden", + groups=["admin"], + )(func) + + assert inspect.iscoroutinefunction(wrapped) + assert await wrapped() == "unauthenticated" + + asyncio.run(run()) diff --git a/tests/unit/test_protected_callback_async.py b/tests/unit/test_protected_callback_async.py index 5f9d0bb..f624490 100644 --- a/tests/unit/test_protected_callback_async.py +++ b/tests/unit/test_protected_callback_async.py @@ -80,6 +80,32 @@ async def func(): assert asyncio.run(wrapped()) == "forbidden" +def test_gp008b_async_protected_emits_unauthenticated_output_without_session_user(): + """An async target with no session user gets the unauthenticated output. + + The third gate branch (alongside authorized and missing-permissions): when + ``_current_user`` resolves to ``None`` the wrapper must short-circuit to + ``unauthenticated_output`` and never await the inner coroutine -- while still + staying a coroutine function so Dash keeps it on the async dispatch path. + """ + + async def func(): + return "success" + + app = Flask(__name__) + app.secret_key = "Test!" + with app.test_request_context("/", method="GET"): + # No session["user"] -> unauthenticated. + wrapped = protected( + unauthenticated_output="unauthenticated", + missing_permissions_output="forbidden", + groups=["admin"], + )(func) + + assert inspect.iscoroutinefunction(wrapped) + assert asyncio.run(wrapped()) == "unauthenticated" + + # --------------------------------------------------------------------------- # # Level 3: protected_callback registers on the correct Dash dispatch path # --------------------------------------------------------------------------- # diff --git a/tests/unit/test_public_routes.py b/tests/unit/test_public_routes.py new file mode 100644 index 0000000..0f14256 --- /dev/null +++ b/tests/unit/test_public_routes.py @@ -0,0 +1,45 @@ +"""Unit tests for the public-route/callback registration helpers.""" + +import pytest +from dash import Dash + +from dash_auth_async import backends +from dash_auth_async.public_routes import ( + add_public_routes, + get_public_callbacks, + get_public_routes, +) + + +def test_public_helpers_resolve_backend_from_app_server_not_global_fallback(): + """The public-route helpers must resolve the backend from ``app.server``, + not the process-global ``get_active_backend()`` (which falls back to + ``FlaskBackend``). On a FastAPI app that fallback's ``store_config`` would + write to the nonexistent ``server.config``; ``detect_backend(app.server)`` + routes through ``server.state`` instead. + + Regression guard for the FastAPI public-route registration path: with no + ``Auth`` having set the active backend, the global is the Flask fallback, + so a helper that trusted it would ``AttributeError`` here. + """ + pytest.importorskip( + "fastapi", reason="FastAPI extra dependencies are not installed" + ) + + # Leave the active backend unset -> get_active_backend() is the Flask + # fallback, the exact condition that exposed the bug. + backends._active_backend = None + + app = Dash(__name__, backend="fastapi") + + # Would raise AttributeError (no server.config) under the Flask fallback. + add_public_routes(app, ["/login"]) + + # Stored on (and read back from) the FastAPI server.state, not server.config. + assert get_public_routes(app) is app.server.state.PUBLIC_ROUTES + assert any( + rule.rule == "/login" for rule in get_public_routes(app).map.iter_rules() + ) + + # The callbacks reader returns its default through the same backend path. + assert get_public_callbacks(app) == [] diff --git a/tests/unit/test_websocket_auth.py b/tests/unit/test_websocket_auth.py index 6817c2e..213fb59 100644 --- a/tests/unit/test_websocket_auth.py +++ b/tests/unit/test_websocket_auth.py @@ -1,5 +1,6 @@ """Unit tests for the WebSocket auth primitives.""" +import asyncio import contextvars from concurrent.futures import ThreadPoolExecutor @@ -130,28 +131,130 @@ class _Server: """Weak-referenceable stand-in for an app.server (the registry key).""" +class _DashAppStub: + """Minimal Dash-app double exposing the idempotent ``_setup_server`` the hook + calls to migrate ``callback_map`` lazily on the first WS ``callback_request``. + """ + + def __init__(self) -> None: + self.setup_calls = 0 + + def _setup_server(self) -> None: + self.setup_calls += 1 + + class _RecordingAuth: """Auth double that records the calls the hook routes to it.""" def __init__(self) -> None: self.calls: list = [] + self.app = _DashAppStub() def authorize_ws(self, payload, user) -> bool: self.calls.append((payload, user)) return True +# --------------------------------------------------------------------------- # +# Connection registry: the websocket_connect hook tracks sockets so a later +# login can retire the browser's stale anonymous one. It must not accumulate -- +# authenticated sockets are never retired (so never tracked), and a browser's +# reconnects (one SharedWorker socket per browser) must not pile up. +# --------------------------------------------------------------------------- # +class _IdentityBackend: + """Backend double whose ``ws_identity`` returns a fixed (server, user).""" + + def __init__(self, user) -> None: + self._user = user + + def ws_identity(self, _ws): + return object(), self._user + + +class _FakeWS: + """Minimal socket double exposing the cookies and an awaitable close.""" + + def __init__(self, client_id) -> None: + self.cookies = {"dac_client": client_id} if client_id else {} + self.close_code = None + + async def close(self, code=None) -> None: + self.close_code = code + + +def _tracked_count(client_id) -> int: + """Sockets tracked for ``client_id``, agnostic to the registry's shape.""" + from dash_auth_async.websocket_auth import _WS_BY_CLIENT + + entry = _WS_BY_CLIENT.get(client_id) + if entry is None: + return 0 + return len(entry) if isinstance(entry, (set, list, dict)) else 1 + + +def _run_connect_hook(ws) -> None: + from dash_auth_async.websocket_auth import _ws_connect_hook + + async def _run(): + _ws_connect_hook(ws) + + asyncio.run(_run()) + + +def _use_backend(monkeypatch, user) -> None: + """Point the connect hook at a backend double whose identity returns ``user``.""" + monkeypatch.setattr( + "dash_auth_async.websocket_auth.get_active_backend", + lambda: _IdentityBackend(user), + ) + + +def test_authenticated_handshake_is_not_tracked(monkeypatch): + """A socket that handshakes already authenticated is never retired, so the + registry must not hold it (else authenticated sockets leak unboundedly). + """ + from dash_auth_async.websocket_auth import _WS_BY_CLIENT + + _WS_BY_CLIENT.clear() + _use_backend(monkeypatch, {"email": "a@b.c", "groups": []}) + + _run_connect_hook(_FakeWS("client-1")) + + assert _tracked_count("client-1") == 0 + + +def test_reconnect_does_not_accumulate_entries(monkeypatch): + """A browser has one SharedWorker socket; a reconnect replaces the prior + anonymous entry rather than stacking, so tracking stays at one per browser. + """ + from dash_auth_async.websocket_auth import _WS_BY_CLIENT + + _WS_BY_CLIENT.clear() + _use_backend(monkeypatch, None) # anonymous handshakes + + _run_connect_hook(_FakeWS("client-2")) + _run_connect_hook(_FakeWS("client-2")) + + assert _tracked_count("client-2") == 1 + + def test_ws_hook_resolves_auth_for_the_current_app(monkeypatch): """With two dash-auth-async apps in the process, the hook consults only the Auth registered for ``quart.current_app`` -- not some other app's Auth. """ quart = pytest.importorskip("quart") + from dash_auth_async.backends import QuartBackend, set_active_backend from dash_auth_async.websocket_auth import ( _AUTH_BY_SERVER, _WS_AUTH_USER, _ws_message_hook, ) + # The hook resolves identity via the active backend; this is the Quart path + # (ws_identity reads the quart.current_app/quart.session monkeypatched below). + # The autouse reset_active_backend fixture clears it after the test. + set_active_backend(QuartBackend()) + server_a, server_b = _Server(), _Server() auth_a, auth_b = _RecordingAuth(), _RecordingAuth() _AUTH_BY_SERVER[server_a] = auth_a @@ -174,6 +277,10 @@ def _get_current_object(self): # Only app B's Auth was consulted, with app B's session user. assert auth_b.calls == [({"output": "x.children"}, user)] assert auth_a.calls == [] + # The hook migrated only app B's callback_map (lazy, idempotent), and + # never touched app A's. + assert auth_b.app.setup_calls == 1 + assert auth_a.app.setup_calls == 0 # The resolved user was stashed for the worker. assert _WS_AUTH_USER.get() == user finally: diff --git a/uv.lock b/uv.lock index 9e56cdf..d90e9a7 100644 --- a/uv.lock +++ b/uv.lock @@ -596,7 +596,7 @@ testing = [ [[package]] name = "dash-auth-async" -version = "1.2.1" +version = "1.3.0" source = { editable = "." } dependencies = [ { name = "authlib" }, @@ -622,13 +622,19 @@ dev = [ { name = "asgiref" }, { name = "creosote" }, { name = "dash", extra = ["testing"] }, + { name = "fastapi" }, + { name = "httpx" }, + { name = "itsdangerous" }, { name = "pre-commit" }, { name = "pytest" }, { name = "pytest-cov" }, + { name = "python-multipart" }, { name = "requests" }, { name = "ruff" }, { name = "setuptools" }, + { name = "starlette" }, { name = "ty" }, + { name = "uvicorn" }, { name = "websocket-client" }, ] @@ -650,13 +656,19 @@ dev = [ { name = "asgiref", specifier = ">=3.11.1" }, { name = "creosote", specifier = ">=5.2.0" }, { name = "dash", extras = ["testing"], specifier = ">=4.2.0" }, + { name = "fastapi" }, + { name = "httpx", specifier = ">=0.23.0" }, + { name = "itsdangerous" }, { name = "pre-commit", specifier = ">=3.5.0" }, { name = "pytest", specifier = ">=8.3.5" }, { name = "pytest-cov", specifier = ">=7.1.0" }, + { name = "python-multipart" }, { name = "requests", extras = ["security"], specifier = ">=2.34.2" }, { name = "ruff", specifier = ">=0.15.16" }, { name = "setuptools", specifier = ">=79.0.1" }, + { name = "starlette" }, { name = "ty", specifier = ">=0.0.46" }, + { name = "uvicorn" }, { name = "websocket-client", specifier = ">=1.9.0" }, ] @@ -1767,6 +1779,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/0b/d7/1959b9648791274998a9c3526f6d0ec8fd2233e4d4acce81bbae76b44b2a/python_dotenv-1.2.2-py3-none-any.whl", hash = "sha256:1d8214789a24de455a8b8bd8ae6fe3c6b69a5e3d64aa8a8e5d68e694bbcb285a", size = 22101, upload-time = "2026-03-01T16:00:25.09Z" }, ] +[[package]] +name = "python-multipart" +version = "0.0.32" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/5b/42/55c32bb9b12693c092ad250a0e82edb5b31ddeda6eb772de5f308b3804ad/python_multipart-0.0.32.tar.gz", hash = "sha256:be54b7f3fa167bb83e4fcd936b887b708f4e57fe75911c02aebf53efaf8d938e", size = 46881, upload-time = "2026-06-04T16:18:58.647Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e1/04/e8135ebd1ad02c56ec633277529b2602ff99ff634be76cdba5744cf554fd/python_multipart-0.0.32-py3-none-any.whl", hash = "sha256:ff6d3f776f16878c894e52e107296ffc890e913c611b1a4ec6c44e2821fe2e23", size = 30042, upload-time = "2026-06-04T16:18:57.319Z" }, +] + [[package]] name = "pyyaml" version = "6.0.3"