diff --git a/.gitignore b/.gitignore
index dab7335..918c131 100644
--- a/.gitignore
+++ b/.gitignore
@@ -6,6 +6,12 @@ vv
.tox/
venv/
+# Local scratch / repro scripts — not part of the package, kept out of the
+# tracked tree so `ruff check .` (which respects .gitignore) stays green.
+scripts/
+.coverage
+.vscode/
+
node_modules
npm-debug.log
.DS_Store
diff --git a/CHANGELOG.md b/CHANGELOG.md
index ef3f3d2..d22d605 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -4,6 +4,21 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/)
and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.html).
+## [1.3.0] - 2026-06-19
+
+### Added
+- FastAPI (async) backend support for `BasicAuth` and `OIDCAuth` — run on Starlette/ASGI with `Dash(__name__, backend="fastapi")`. Auth is enforced by pure-ASGI middleware, sessions use Starlette `SessionMiddleware`, and public routes/callbacks are stored on the app's `server.state`
+- Authenticated WebSocket callbacks on the FastAPI backend: each `callback_request` is authorized via `Backend.ws_identity` and fails closed, while public callbacks still stream unauthenticated
+- WebSocket authentication now reconnects on login. When a browser authenticates over HTTP, its stale pre-login WebSocket is retired so the renderer reconnects with the authenticated session — eliminating the "first click is dropped / only works on the second click" behaviour for callbacks invoked over a socket opened before login. Applies to both the FastAPI and Quart backends
+
+### Changed
+- The WebSocket `callback_map` is migrated lazily on the first `callback_request` rather than at `Auth(...)` construction, so a global `@callback` registered after `Auth(...)` is still picked up and a WebSocket-first client no longer hits an empty map
+
+### Fixed
+- The public-route helpers resolve the backend from `app.server` instead of a process-global fallback, keeping routing correct when several apps share a process
+- `secure_session` is honoured through `setup_session`, and the FastAPI session lookup is hardened so it raises a clear error (rather than `KeyError`) under `python -O`
+- FastAPI OIDC views annotate their request parameter so Starlette injects the request, and the ASGI body-replay emits `http.disconnect` once the cached body is consumed
+
## [1.2.1] - 2026-06-17
### Fixed
diff --git a/README.md b/README.md
index ffe360e..2308d98 100644
--- a/README.md
+++ b/README.md
@@ -21,17 +21,20 @@ How this fork compares to upstream [`dash-auth`](https://github.com/plotly/dash-
| --- | :---: | :---: |
| Flask backend | ✅ | ✅ |
| Quart backend | ❌ | ✅ |
-| FastAPI backend | ❌ | 🚧 1 |
-| Custom backends | ❌ | ✅ 2 |
+| FastAPI backend | ❌ | ✅ |
+| Custom backends | ❌ | ✅ 1 |
| Protected / public callbacks | ✅ | ✅ |
-| Async callbacks | ❌ | ✅ 3 |
+| Async callbacks | ❌ | ✅ 2 |
| Authenticated WebSocket callbacks | ❌ | ✅ 3 |
-✅ supported · 🚧 on the roadmap · ❌ not supported
+✅ supported · ❌ not supported
+
+1 `detect_backend` resolves Flask/Quart/FastAPI automatically; any other server is supported by supplying your own `Backend` instance.
+
+2 Provided by the Quart and FastAPI backends.
+
+3 Provided by the Quart and FastAPI backends. WebSocket auth is a no-op on Flask, which has no WebSocket callback transport.
-1 A `dash-auth-async[fastapi]` extra is declared and a native FastAPI backend is on the roadmap. In the meantime you can support it by implementing the `Backend` ABC and passing `Auth(..., backend=MyBackend())`.
-2 `detect_backend` resolves Flask/Quart automatically; any other server is supported by supplying your own `Backend` instance.
-3 Provided by the Quart backend. WebSocket auth is a no-op on Flask, which has no WebSocket callback transport.
For local testing, install [uv](https://docs.astral.sh/uv/getting-started/installation/), then install the dev dependencies and run individual tests:
@@ -276,7 +279,22 @@ if __name__ == "__main__":
app.run(debug=True)
```
-### Quart (async) Backend
+### Async backends
+
+#### Known Limitations
+
+> **⚠️ WebSocket callbacks & auth:**
+> Do **not** enable `websocket_callbacks=True`
+> globally on an authenticated `use_pages` app. The global flag routes *every*
+> callback — including Dash's built-in page-routing callback — over the WebSocket,
+> which bypasses the HTTP `before_request` auth guard where the login challenge is
+> issued. Navigating to a protected page then hangs (the socket closes with `4401`
+> and reconnect-loops) instead of prompting for login — the prompt appears only after
+> a full page reload. Opt **individual** streaming callbacks into `websocket=True`
+> instead, so routing and login stay on HTTP.
+
+
+#### Quart (async) Backend
`dash-auth-async` supports [Dash's Quart backend](https://dash.plotly.com/) for fully async request handling.
Install the `quart` extra to pull in the required dependencies:
@@ -288,7 +306,7 @@ pip install dash-auth-async[quart]
Then pass `backend="quart"` when creating your Dash app. The auth setup is identical
to the Flask examples above — no code changes required beyond the backend flag.
-#### BasicAuth with Quart
+##### BasicAuth with Quart
```python
from dash import Dash
@@ -302,7 +320,7 @@ if __name__ == "__main__":
app.run(host="127.0.0.1", port=8050, debug=True)
```
-#### OIDCAuth with Quart
+##### OIDCAuth with Quart
```python
import os
@@ -332,6 +350,74 @@ if __name__ == "__main__":
> **Note:** The Quart backend requires Dash >= 4.2.0 and Python >= 3.10.
+#### FastAPI (async) Backend
+
+`dash-auth-async` supports [Dash's FastAPI backend](https://dash.plotly.com/) too.
+Install the `fastapi` extra to pull in the required dependencies:
+
+```
+pip install dash-auth-async[fastapi]
+```
+
+Then pass `backend="fastapi"` when creating your Dash app. `BasicAuth` and
+`OIDCAuth` work exactly as on Flask/Quart — no code changes beyond the backend flag.
+
+##### BasicAuth with FastAPI
+
+```python
+from dash import Dash
+from dash_auth_async import BasicAuth
+
+app = Dash(__name__, backend="fastapi")
+
+BasicAuth(
+ app,
+ {"admin": "admin", "viewer": "viewer123"},
+ secret_key="aStaticSecretKey!", # enables sessions (SessionMiddleware)
+)
+
+if __name__ == "__main__":
+ app.run(host="127.0.0.1", port=8050, debug=True)
+```
+
+##### OIDCAuth with FastAPI
+
+```python
+import os
+from dash import Dash, html
+from dash_auth_async import OIDCAuth
+
+app = Dash(__name__, backend="fastapi")
+
+app.layout = html.Div([
+ html.H2("OIDCAuth + FastAPI"),
+ html.A("Logout", href="/oidc/logout"),
+])
+
+auth = OIDCAuth(app, secret_key="aStaticSecretKey!")
+auth.register_provider(
+ "myidp",
+ client_id=os.environ["OIDC_CLIENT_ID"],
+ client_secret=os.environ["OIDC_CLIENT_SECRET"],
+ server_metadata_url=os.environ["OIDC_METADATA_URL"],
+ token_endpoint_auth_method="client_secret_post",
+ client_kwargs={"scope": "openid email profile"},
+)
+
+if __name__ == "__main__":
+ app.run(host="127.0.0.1", port=8050, debug=True)
+```
+
+Notes:
+
+- A `secret_key` installs Starlette's `SessionMiddleware` automatically. If you
+ add your own `SessionMiddleware`, `dash-auth-async` defers to it.
+- `Auth`/`OIDCAuth` must be constructed before the server starts serving
+ (Starlette forbids adding middleware after startup) — the normal usage pattern.
+- OIDC uses authlib's official `starlette_client`; no extra client module required.
+
+> The FastAPI backend requires Dash >= 4.2.0 and Python >= 3.10.
+
### User-group-based permissions
`dash_auth_async` provides a convenient way to secure parts of your app based on user groups.
diff --git a/dash_auth_async/backends.py b/dash_auth_async/backends.py
index 5278af7..6b2e50a 100644
--- a/dash_auth_async/backends.py
+++ b/dash_auth_async/backends.py
@@ -3,8 +3,10 @@
from __future__ import annotations
import inspect
+import re
from abc import ABC, abstractmethod
from collections.abc import Callable, MutableMapping
+from contextvars import ContextVar
from typing import Any
import flask
@@ -17,6 +19,29 @@
quart: Any = None
HAS_QUART = False
+try:
+ import fastapi
+ from starlette.middleware.sessions import SessionMiddleware
+ from starlette.requests import Request as StarletteRequest
+ from starlette.responses import (
+ HTMLResponse,
+ PlainTextResponse,
+ RedirectResponse,
+ Response as StarletteResponse,
+ )
+
+ HAS_FASTAPI = True
+except ImportError:
+ fastapi: Any = None
+ StarletteRequest: Any = None
+ SessionMiddleware: Any = None
+ HAS_FASTAPI = False
+
+
+# dash-auth-async owns its own request ContextVar, set in its own ASGI
+# middleware — independent of Dash's private get_current_request.
+_current_request_var: ContextVar = ContextVar("dash_auth_request", default=None)
+
class Backend(ABC):
"""Framework adapter isolating everything Flask/Quart-specific.
@@ -29,6 +54,11 @@ class Backend(ABC):
request: Any
session: MutableMapping
+ # Whether OIDC views must be async (awaited) on this backend. Lets
+ # OIDCAuth pick the sync vs async view set polymorphically instead of
+ # branching on isinstance(backend, ...).
+ is_async: bool = False
+
@abstractmethod
def has_request_context(self) -> bool:
"""Whether a request context is currently active."""
@@ -58,6 +88,149 @@ def url_for(self, endpoint: str, **values) -> str:
def redirect(self, location: str) -> Any:
"""Return a redirect response to the given location."""
+ # --- Operations that diverge on FastAPI; defaults reproduce the
+ # --- Flask/Quart behavior so those backends inherit unchanged. These
+ # --- are concrete interface defaults, not all of which use `self`, so
+ # --- PLR6301 (no-self-use) is suppressed where that applies.
+
+ def coerce_response(self, result: Any) -> Any: # noqa: PLR6301
+ """Convert an ``_authorize`` return into a real response.
+
+ Flask and Quart accept ``(body, status, headers)`` tuples, bare
+ strings, and framework responses natively, so the default is a
+ pass-through. FastAPI overrides this to build a Starlette response.
+
+ Returns:
+ The response value unchanged.
+ """
+ return result
+
+ def setup_session( # noqa: PLR6301
+ self, server, secret_key: str | None, secure_session: bool = False
+ ) -> None:
+ """Install session support on the server.
+
+ Flask/Quart store a ``secret_key`` attribute and harden the session
+ cookie via ``SESSION_COOKIE_*`` config. FastAPI overrides this to add
+ ``SessionMiddleware`` (and wires ``secure_session`` to ``https_only``).
+ """
+ if secret_key is not None:
+ server.secret_key = secret_key
+ if secure_session:
+ server.config["SESSION_COOKIE_SECURE"] = True
+ server.config["SESSION_COOKIE_HTTPONLY"] = True
+
+ def session_configured(self, server) -> bool: # noqa: PLR6301
+ """Whether the server can store a session.
+
+ Returns:
+ True if a session secret is configured.
+ """
+ return getattr(server, "secret_key", None) is not None
+
+ def current_host(self) -> str:
+ """Host (netloc) of the active request, for proxy host rewrites.
+
+ Returns:
+ The request host string.
+ """
+ return self.request.host
+
+ def current_path(self) -> str:
+ """Path of the active request.
+
+ Returns:
+ The request path string.
+ """
+ return self.request.path
+
+ def add_route( # noqa: PLR6301
+ self, server, rule: str, view_func, endpoint: str, methods
+ ) -> None:
+ """Register an OIDC route on the server."""
+ server.add_url_rule(
+ rule, endpoint=endpoint, view_func=view_func, methods=methods
+ )
+
+ def make_oauth(self, server) -> Any: # noqa: PLR6301
+ """Build the authlib OAuth registry for this backend.
+
+ Returns:
+ A flask_client ``OAuth`` registry bound to ``server``.
+ """
+ from authlib.integrations.flask_client import OAuth # noqa: PLC0415
+
+ return OAuth(server)
+
+ def get_oauth(self, server) -> Any: # noqa: PLR6301
+ """Retrieve the authlib OAuth registry stored by :meth:`make_oauth`.
+
+ Symmetric with ``make_oauth``: each backend knows where it put its own
+ registry, so the module-level ``get_oauth`` doesn't have to guess.
+
+ Returns:
+ The registry, or None if ``OIDCAuth`` has not run yet.
+ """
+ return getattr(server, "extensions", {}).get(
+ "authlib.integrations.flask_client"
+ )
+
+ def oauth_authorize_redirect( # noqa: PLR6301
+ self, client, redirect_uri: str, **kwargs
+ ) -> Any:
+ """Start the OAuth authorize-redirect on the authlib ``client``.
+
+ Encapsulates how each backend's authlib client is invoked: the Flask
+ client is synchronous and reads the request from a context global, so
+ the default just calls through. Quart/FastAPI override to await, and
+ FastAPI additionally passes the active request explicitly.
+
+ Returns:
+ The authorize-redirect response (awaitable on async backends).
+ """
+ return client.authorize_redirect(redirect_uri, **kwargs)
+
+ def oauth_authorize_access_token(self, client, **kwargs) -> Any: # noqa: PLR6301
+ """Exchange the OAuth callback for a token on the authlib ``client``.
+
+ Mirror of :meth:`oauth_authorize_redirect` for the callback leg.
+
+ Returns:
+ The token (awaitable on async backends).
+ """
+ return client.authorize_access_token(**kwargs)
+
+ def store_config(self, server, key: str, value: Any) -> None: # noqa: PLR6301
+ """Stash an app-scoped config value (public routes/callbacks).
+
+ Flask/Quart expose a dict-like ``server.config``. FastAPI has no
+ such attribute and overrides this to use ``server.state``.
+ """
+ server.config[key] = value
+
+ def read_config( # noqa: PLR6301
+ self, server, key: str, default: Any = None
+ ) -> Any:
+ """Read an app-scoped config value set by :meth:`store_config`.
+
+ Returns:
+ The stored value, or ``default`` when unset.
+ """
+ return server.config.get(key, default)
+
+ def ws_identity(self, ws) -> tuple[Any, dict | None]:
+ """Resolve ``(owning_server, session_user)`` for a WS callback_request.
+
+ WS-capable backends override this. The owning server is the key into
+ ``_AUTH_BY_SERVER``; the user is ``session["user"]`` or ``None``.
+ Flask has no WebSocket transport, so the default raises -- the WS hook's
+ fail-closed boundary turns that into a rejection if it is ever reached.
+
+ Raises:
+ NotImplementedError: on backends without WebSocket support.
+ """
+ raise NotImplementedError
+
class FlaskBackend(Backend):
"""Backend adapter for a Flask server."""
@@ -126,6 +299,8 @@ def redirect(self, location: str) -> Any: # noqa: PLR6301
class QuartBackend(Backend):
"""Backend adapter for a Quart (async) server."""
+ is_async = True
+
def __init__(self) -> None:
"""Create the Quart backend, requiring the optional ``quart`` extra.
@@ -171,8 +346,16 @@ def register_auth_hook(self, server, needs_body, decide) -> None: # noqa: PLR63
"""Register the before-request auth hook on a Quart server.
Awaits both the request body and the (possibly coroutine) decision so
- async auth logic is preserved.
+ async auth logic is preserved. Also mints the browser-stable
+ ``dac_client`` cookie on first contact and, when a browser authenticates,
+ retires its stale pre-login WebSocket so the renderer reconnects
+ authenticated (see websocket_auth) -- the Quart half of Design A.
"""
+ from .websocket_auth import ( # noqa: PLC0415 — avoid import cycle
+ WS_CLIENT_COOKIE,
+ close_anonymous_ws,
+ mint_ws_client_id,
+ )
@server.before_request
async def before_request_auth():
@@ -183,8 +366,30 @@ async def before_request_auth():
)
result = decide(quart.request.path, body)
if inspect.isawaitable(result):
- return await result
- return result
+ result = await result
+ if result is not None:
+ return result
+
+ # Authorized. If this browser is now logged in, retire any stale
+ # anonymous socket it opened before login so the renderer reconnects
+ # authenticated before the first click.
+ if quart.session.get("user") is not None:
+ close_anonymous_ws(quart.request.cookies.get(WS_CLIENT_COOKIE))
+ return None
+
+ @server.after_request
+ def set_client_cookie(response):
+ # Minted on every response that lacks it -- including the BasicAuth
+ # 401 challenge -- so it predates a later pre-login WS handshake.
+ if quart.request.cookies.get(WS_CLIENT_COOKIE) is None:
+ response.set_cookie(
+ WS_CLIENT_COOKIE,
+ mint_ws_client_id(),
+ httponly=True,
+ samesite="Lax",
+ path="/",
+ )
+ return response
def url_for(self, endpoint: str, **values) -> str: # noqa: PLR6301
"""Build a URL for a Quart endpoint.
@@ -202,12 +407,428 @@ def redirect(self, location: str) -> Any: # noqa: PLR6301
"""
return quart.redirect(location)
+ def make_oauth(self, server) -> Any: # noqa: PLR6301
+ """Build the authlib OAuth registry for the Quart backend.
+
+ Returns:
+ A custom Quart ``OAuth`` registry bound to ``server``.
+ """
+ # Imported lazily so flask-only installs never import quart/httpx.
+ from dash_auth_async import quart_client # noqa: PLC0415
+
+ return quart_client.OAuth(server)
+
+ def get_oauth(self, server) -> Any: # noqa: PLR6301
+ """Retrieve the Quart OAuth registry stored by :meth:`make_oauth`.
+
+ Returns:
+ The registry, or None if ``OIDCAuth`` has not run yet.
+ """
+ return getattr(server, "extensions", {}).get(
+ "authlib.integrations.quart_client"
+ )
+
+ async def oauth_authorize_redirect( # noqa: PLR6301
+ self, client, redirect_uri: str, **kwargs
+ ) -> Any:
+ """Await the Quart authlib client's authorize-redirect.
+
+ Returns:
+ The authorize-redirect response.
+ """
+ return await client.authorize_redirect(redirect_uri, **kwargs)
+
+ async def oauth_authorize_access_token(self, client, **kwargs) -> Any: # noqa: PLR6301
+ """Await the Quart authlib client's token exchange.
+
+ Returns:
+ The token.
+ """
+ return await client.authorize_access_token(**kwargs)
+
+ def ws_identity(self, ws) -> tuple[Any, dict | None]: # noqa: PLR6301
+ """Resolve the Quart app and session user for a WS callback_request.
+
+ Uses Quart's context globals (correct when several apps share a
+ process); ``ws`` is unused on this backend.
+
+ Returns:
+ ``(quart_app, session["user"] or None)``.
+ """
+ import quart # noqa: PLC0415 — quart is an optional dependency
+
+ # ``quart.current_app`` is a proxy; ``_get_current_object`` unwraps it to
+ # the real Quart app (the key in ``_AUTH_BY_SERVER``). Go through
+ # ``getattr`` because the attribute is absent from the proxy's type stub.
+ app = getattr(quart.current_app, "_get_current_object")()
+ return app, quart.session.get("user")
+
+
+class FastAPIBackend(Backend):
+ """Adapter for Dash's FastAPI backend (Dash 4.2+).
+
+ Unlike Flask/Quart there is no global request/session proxy: Starlette
+ passes the request explicitly. This backend resolves the active request
+ from a ContextVar set by its own ASGI auth middleware, and reads the
+ session off ``request.session`` (populated by SessionMiddleware).
+ """
+
+ is_async = True
+
+ def __init__(self) -> None:
+ """Create the FastAPI backend, requiring the optional ``fastapi`` extra.
+
+ Raises:
+ ImportError: if FastAPI is not installed.
+ """
+ if not HAS_FASTAPI:
+ raise ImportError(
+ "FastAPI is not installed. Please install it with "
+ "`pip install dash-auth-async[fastapi]` to use the FastAPI backend."
+ )
+
+ @property
+ def request(self) -> Any:
+ """The active Starlette request, resolved from the ContextVar.
+
+ Returns:
+ The current request, or None outside a request context.
+ """
+ return _current_request_var.get()
+
+ @property
+ def session(self) -> MutableMapping:
+ """The session mapping off the active request.
+
+ Returns:
+ The Starlette session mapping.
+
+ Raises:
+ RuntimeError: if SessionMiddleware is not installed.
+ """
+ # Starlette signals "no SessionMiddleware" via a bare `assert
+ # "session" in scope`, which `python -O` strips — the next line then
+ # raises KeyError instead. Check the scope directly so the
+ # RuntimeError translation holds under -O too, keeping the existing
+ # `except RuntimeError` guards working identically to the Flask path.
+ request = self.request
+ if "session" not in getattr(request, "scope", {}):
+ raise RuntimeError("Session is not available. Have you set a secret key?")
+ return request.session
+
+ def has_request_context(self) -> bool: # noqa: PLR6301
+ """Whether a request is currently bound to the ContextVar.
+
+ Returns:
+ True if a request context is active.
+ """
+ return _current_request_var.get() is not None
+
+ def url_for(self, endpoint: str, **values) -> str:
+ """Build an absolute URL for a Starlette endpoint.
+
+ Maps the Flask-style ``_external``/``_scheme`` kwargs used by
+ ``OIDCAuth._create_redirect_uri`` onto Starlette's ``url_for``.
+
+ Returns:
+ The URL string for ``endpoint``.
+ """
+ values.pop("_external", None) # Starlette url_for is always absolute
+ scheme = values.pop("_scheme", None)
+ url = self.request.url_for(endpoint, **values)
+ if scheme:
+ url = url.replace(scheme=scheme)
+ return str(url)
+
+ def redirect(self, location: str) -> Any: # noqa: PLR6301
+ """Build a Starlette redirect response to ``location``.
+
+ Returns:
+ A 302 ``RedirectResponse``.
+ """
+ return RedirectResponse(location, status_code=302)
+
+ def current_host(self) -> str:
+ """Host (netloc) of the active request.
+
+ Returns:
+ The request netloc string.
+ """
+ return self.request.url.netloc
+
+ def current_path(self) -> str:
+ """Path of the active Starlette request.
+
+ Returns:
+ The request path string.
+ """
+ return self.request.url.path
+
+ def coerce_response(self, result: Any) -> Any: # noqa: PLR6301
+ """Build a Starlette response from an ``_authorize``/view return value.
+
+ This is the single coercion boundary for the FastAPI path: every OIDC
+ view and the auth middleware funnel their return value through here, so
+ the framework-response knowledge lives in exactly one place.
+
+ A bare ``str`` becomes an ``HTMLResponse`` (matching Flask/Quart, which
+ render returned strings as ``text/html`` — e.g. the OIDC logout page),
+ while ``(body, status[, headers])`` tuples become a ``PlainTextResponse``
+ carrying the status/headers (e.g. the Basic-auth 401 challenge).
+
+ Returns:
+ A Starlette response (passthrough if already one).
+ """
+ if isinstance(result, StarletteResponse):
+ return result
+ if isinstance(result, tuple):
+ body, *rest = result
+ status = rest[0] if rest else 200
+ headers = rest[1] if len(rest) > 1 else None
+ return PlainTextResponse(body, status_code=status, headers=headers)
+ if isinstance(result, str):
+ return HTMLResponse(result)
+ return PlainTextResponse(str(result))
+
+ @staticmethod
+ def _has_session_middleware(server) -> bool:
+ return any(
+ getattr(m, "cls", None) is SessionMiddleware
+ for m in getattr(server, "user_middleware", [])
+ )
+
+ def setup_session(
+ self, server, secret_key: str | None, secure_session: bool = False
+ ) -> None:
+ """Install Starlette ``SessionMiddleware`` from ``secret_key``.
+
+ ``secure_session`` is wired to ``https_only`` so the parity with
+ Flask/Quart's ``SESSION_COOKIE_SECURE`` is honored rather than
+ silently dropped. Starlette always sets ``HttpOnly``. Defers to a
+ user-installed ``SessionMiddleware`` (opt-out/override) and is
+ idempotent — never adds a second instance.
+ """
+ if secret_key is None:
+ return
+ if self._has_session_middleware(server):
+ return
+ server.add_middleware(
+ SessionMiddleware, secret_key=secret_key, https_only=secure_session
+ )
+
+ def session_configured(self, server) -> bool:
+ """Whether a ``SessionMiddleware`` is installed on the server.
+
+ Returns:
+ True if session storage is available.
+ """
+ return self._has_session_middleware(server)
+
+ def store_config(self, server, key: str, value: Any) -> None: # noqa: PLR6301
+ """Stash an app-scoped config value on the FastAPI ``server.state``."""
+ setattr(server.state, key, value)
+
+ def read_config( # noqa: PLR6301
+ self, server, key: str, default: Any = None
+ ) -> Any:
+ """Read an app-scoped config value off the FastAPI ``server.state``.
+
+ Returns:
+ The stored value, or ``default`` when unset.
+ """
+ return getattr(server.state, key, default)
+
+ def add_route( # noqa: PLR6301
+ self, server, rule: str, view_func, endpoint: str, methods
+ ) -> None:
+ """Register an OIDC route, translating Flask ```` to ``{idp}``."""
+ fastapi_rule = re.sub(r"<([^>]+)>", r"{\1}", rule)
+ server.add_api_route(
+ fastapi_rule,
+ view_func,
+ methods=methods,
+ name=endpoint,
+ include_in_schema=False,
+ )
+
+ def make_oauth(self, server) -> Any: # noqa: PLR6301
+ """Build authlib's Starlette OAuth registry, stashed on ``server.state``.
+
+ Returns:
+ A ``starlette_client`` ``OAuth`` registry.
+ """
+ from authlib.integrations.starlette_client import ( # noqa: PLC0415
+ OAuth as StarletteOAuth,
+ )
+
+ oauth = StarletteOAuth()
+ # The Starlette registry doesn't attach to app.extensions, so stash
+ # it where get_oauth can find it.
+ server.state.dash_auth_oauth = oauth
+ return oauth
+
+ def get_oauth(self, server) -> Any: # noqa: PLR6301
+ """Retrieve the Starlette OAuth registry stashed on ``server.state``.
+
+ Returns:
+ The registry, or None if ``OIDCAuth`` has not run yet.
+ """
+ return getattr(getattr(server, "state", None), "dash_auth_oauth", None)
+
+ async def oauth_authorize_redirect(
+ self, client, redirect_uri: str, **kwargs
+ ) -> Any:
+ """Await the Starlette authlib client's authorize-redirect.
+
+ The Starlette OAuth client takes the request explicitly (no context
+ global), so the active request is resolved from the ContextVar and
+ passed through.
+
+ Returns:
+ The authorize-redirect response.
+ """
+ return await client.authorize_redirect(self.request, redirect_uri, **kwargs)
+
+ async def oauth_authorize_access_token(self, client, **kwargs) -> Any:
+ """Await the Starlette authlib client's token exchange.
+
+ Passes the ContextVar-resolved request, as the Starlette client
+ requires.
+
+ Returns:
+ The token.
+ """
+ return await client.authorize_access_token(self.request, **kwargs)
+
+ def ws_identity(self, ws) -> tuple[Any, dict | None]: # noqa: PLR6301
+ """Resolve the FastAPI app and session user for a WS callback_request.
+
+ Reads from the Starlette ``WebSocket``: ``ws.app`` is the owning server
+ and ``ws.session`` carries ``session["user"]`` (SessionMiddleware runs on
+ websocket scopes, so the handshake cookie is available). The
+ ``"session" in ws.scope`` guard mirrors the ``session`` property: with no
+ SessionMiddleware installed the user is ``None`` (fail-closed for
+ protected callbacks, still allows public) rather than raising.
+
+ Returns:
+ ``(fastapi_app, session["user"] or None)``.
+ """
+ user = ws.session.get("user") if "session" in ws.scope else None
+ return ws.app, user
+
+ def register_auth_hook(self, server, needs_body, decide) -> None:
+ """Register the before-request auth hook as pure-ASGI middleware.
+
+ Pure ASGI (not ``BaseHTTPMiddleware``) so the request ContextVar set
+ here is visible inside the Dash callback, which runs in the same
+ task/context via the inner DashMiddleware's ``copy_context()``.
+ """
+ backend = self
+
+ class _AuthMiddleware:
+ def __init__(self, app) -> None:
+ self.app = app
+
+ async def __call__(self, scope, receive, send) -> None:
+ if scope["type"] != "http":
+ await self.app(scope, receive, send)
+ return
+
+ request = StarletteRequest(scope, receive)
+ token = _current_request_var.set(request)
+
+ # Browser-stable client id, minted on first contact so it is
+ # already present at a later pre-login WS handshake. This ties an
+ # anonymous socket back to the browser that authenticates here,
+ # letting login retire the stale socket (see websocket_auth).
+ from .websocket_auth import ( # noqa: PLC0415 — avoid import cycle
+ WS_CLIENT_COOKIE,
+ close_anonymous_ws,
+ mint_ws_client_id,
+ )
+
+ client_id = request.cookies.get(WS_CLIENT_COOKIE)
+ set_client_cookie = client_id is None
+ if set_client_cookie:
+ client_id = mint_ws_client_id()
+
+ async def send(message, _send=send):
+ if set_client_cookie and message["type"] == "http.response.start":
+ message.setdefault("headers", []).append(
+ (
+ b"set-cookie",
+ f"{WS_CLIENT_COOKIE}={client_id}; Path=/; "
+ f"HttpOnly; SameSite=Lax".encode("latin-1"),
+ )
+ )
+ await _send(message)
+
+ try:
+ body = None
+ downstream_receive = receive
+ if needs_body(request.url.path):
+ # Consuming the body drains the receive stream; cache
+ # the bytes and replay them so DashMiddleware (inner)
+ # can re-parse the callback JSON.
+ raw = await request.body()
+ body_replayed = False
+
+ # Must be a coroutine to satisfy the ASGI `receive`
+ # interface, even though this replay never awaits.
+ async def downstream_receive(): # noqa: RUF029
+ nonlocal body_replayed
+ if body_replayed:
+ # Body already delivered; further reads see a
+ # disconnect, per the ASGI contract — not the
+ # same body event replayed forever (which would
+ # spin an app that polls receive() to detect
+ # client disconnect).
+ return {"type": "http.disconnect"}
+ body_replayed = True
+ return {
+ "type": "http.request",
+ "body": raw,
+ "more_body": False,
+ }
+
+ try:
+ body = await request.json()
+ except Exception: # unparseable == no body
+ body = None
+
+ result = decide(request.url.path, body)
+ if inspect.isawaitable(result):
+ result = await result
+
+ if result is not None:
+ response = backend.coerce_response(result)
+ await response(scope, receive, send)
+ return
+
+ # Authorized. If this browser is now logged in, retire any
+ # stale anonymous sockets it opened before login so the
+ # renderer reconnects authenticated before the first click.
+ try:
+ logged_in = backend.session.get("user") is not None
+ except RuntimeError:
+ logged_in = False # no SessionMiddleware -> nothing to retire
+ if logged_in:
+ close_anonymous_ws(client_id)
+
+ await self.app(scope, downstream_receive, send)
+ finally:
+ _current_request_var.reset(token)
+
+ server.add_middleware(_AuthMiddleware)
+
def detect_backend(server: Any) -> Backend:
- """Return the matching backend for a Flask or Quart server."""
+ """Return the matching backend for a Flask, Quart, or FastAPI server."""
if quart is not None:
if isinstance(server, quart.Quart):
return QuartBackend()
+ if HAS_FASTAPI and isinstance(server, fastapi.FastAPI):
+ return FastAPIBackend()
if isinstance(server, flask.Flask):
return FlaskBackend()
@@ -220,7 +841,6 @@ def detect_backend(server: Any) -> Backend:
# One backend per process, matching how Dash apps are deployed.
_active_backend: Backend | None = None
-_DEFAULT_BACKEND = FlaskBackend()
def set_active_backend(backend: Backend) -> None:
@@ -234,5 +854,16 @@ def set_active_backend(backend: Backend) -> None:
def get_active_backend() -> Backend:
- """Return the active backend, defaulting to Flask when none is set."""
- return _active_backend if _active_backend is not None else _DEFAULT_BACKEND
+ """Return the active backend registered by ``Auth.__init__``.
+
+ Falls back to a Flask backend for the legacy Flask-only path where no
+ ``Auth`` has registered one yet. The fallback is constructed lazily on
+ first use rather than at import, so Flask isn't cemented as the default
+ at module load. Note this still assumes a single backend per process: in
+ a non-Flask process the active backend must be set (which ``Auth.__init__``
+ does) before any request-context helper runs.
+ """
+ global _active_backend # noqa: PLW0603 — one backend per process, by design
+ if _active_backend is None:
+ _active_backend = FlaskBackend()
+ return _active_backend
diff --git a/dash_auth_async/basic_auth.py b/dash_auth_async/basic_auth.py
index 6bbf275..2a0077a 100644
--- a/dash_auth_async/basic_auth.py
+++ b/dash_auth_async/basic_auth.py
@@ -63,8 +63,6 @@ def __init__( # noqa: PLR0913, PLR0917 — configuration constructor
else:
self._user_groups_dict = None
self._user_groups_func = user_groups # Callable or None after dict excluded
- if secret_key is not None:
- app.server.secret_key = secret_key
if self._auth_func is not None:
if username_password_list is not None:
@@ -86,6 +84,12 @@ def __init__( # noqa: PLR0913, PLR0917 — configuration constructor
)
super().__init__(app, public_routes=public_routes)
+ # After super().__init__: self.backend now exists, and the auth
+ # middleware is registered, so SessionMiddleware (added here) lands
+ # outermost on FastAPI — making request.session available to both
+ # the auth layer and the Dash callback.
+ self.backend.setup_session(app.server, secret_key)
+
def is_authorized(self):
"""Return whether the request carries valid Basic credentials.
diff --git a/dash_auth_async/group_protection.py b/dash_auth_async/group_protection.py
index 73f9b11..c56ddea 100644
--- a/dash_auth_async/group_protection.py
+++ b/dash_auth_async/group_protection.py
@@ -39,7 +39,12 @@ def _current_user() -> dict | None:
backend = get_active_backend()
if backend.has_request_context():
# Normal HTTP path: read the user from the framework session.
- return backend.session.get("user")
+ try:
+ return backend.session.get("user")
+ except RuntimeError:
+ # Session unavailable (e.g. FastAPI without SessionMiddleware):
+ # treat as not authenticated rather than crashing.
+ return None
# WebSocket worker path: no framework context here, so read the user the
# websocket_message hook stashed for this dispatch.
return _WS_AUTH_USER.get()
diff --git a/dash_auth_async/oidc_auth.py b/dash_auth_async/oidc_auth.py
index ee0c420..959c273 100644
--- a/dash_auth_async/oidc_auth.py
+++ b/dash_auth_async/oidc_auth.py
@@ -14,7 +14,7 @@
from dash_auth_async.auth import Auth
from dash_auth_async.public_routes import get_url_base
-from .backends import QuartBackend
+from .backends import detect_backend
if TYPE_CHECKING:
from authlib.integrations.flask_client.apps import (
@@ -22,8 +22,7 @@
FlaskOAuth2App,
)
- from dash_auth_async.quart_client import OAuth as QuartOAuth
- from dash_auth_async.quart_client import QuartOAuth2App
+ from dash_auth_async.quart_client import OAuth as QuartOAuth, QuartOAuth2App
class OIDCAuth(Auth):
@@ -92,9 +91,10 @@ def __init__( # noqa: PLR0913, PLR0917 — configuration constructor
Page seen by the user after logging out,
by default None which will default to a simple logged out message
secure_session: bool, optional
- Whether to ensure the session is secure, setting the flasck config
- SESSION_COOKIE_SECURE and SESSION_COOKIE_HTTPONLY to True,
- by default False
+ Whether to restrict the session cookie to HTTPS, by default False.
+ On Flask/Quart this sets SESSION_COOKIE_SECURE and
+ SESSION_COOKIE_HTTPONLY; on FastAPI it sets the Starlette
+ SessionMiddleware ``https_only`` flag (HttpOnly is always on).
Raises:
RuntimeError: if ``app.server.secret_key`` is not defined.
@@ -117,10 +117,9 @@ def __init__( # noqa: PLR0913, PLR0917 — configuration constructor
self.idp_selection_route = idp_selection_route
self.logout_page = logout_page
- if secret_key is not None:
- app.server.secret_key = secret_key
+ self.backend.setup_session(app.server, secret_key, secure_session)
- if app.server.secret_key is None:
+ if not self.backend.session_configured(app.server):
raise RuntimeError("""
app.server.secret_key is missing.
Generate a secret key in your Python session
@@ -136,18 +135,7 @@ def __init__( # noqa: PLR0913, PLR0917 — configuration constructor
that key in your code/via a secret.
""")
- if secure_session:
- app.server.config["SESSION_COOKIE_SECURE"] = True
- app.server.config["SESSION_COOKIE_HTTPONLY"] = True
-
- if isinstance(self.backend, QuartBackend):
- # Imported lazily so flask-only installs never import
- # quart/httpx (quart_client raises ImportError without them).
- from dash_auth_async import quart_client # noqa: PLC0415
-
- self.oauth: OAuth | quart_client.OAuth = quart_client.OAuth(app.server)
- else:
- self.oauth = OAuth(app.server)
+ self.oauth = self.backend.make_oauth(app.server)
# Check that the login and callback rules have an placeholder
if not re.findall(r"/(?=/|$)", login_route):
@@ -155,7 +143,7 @@ def __init__( # noqa: PLR0913, PLR0917 — configuration constructor
if not re.findall(r"/(?=/|$)", callback_route):
raise Exception("The callback route must contain a placeholder.")
- if isinstance(self.backend, QuartBackend):
+ if self.backend.is_async:
login_view = self._login_request_async
logout_view = self._logout_async
callback_view = self._callback_async
@@ -164,23 +152,14 @@ def __init__( # noqa: PLR0913, PLR0917 — configuration constructor
logout_view = self.logout
callback_view = self.callback
- app.server.add_url_rule(
- login_route,
- endpoint="oidc_login",
- view_func=login_view,
- methods=["GET"],
+ self.backend.add_route(
+ app.server, login_route, login_view, "oidc_login", ["GET"]
)
- app.server.add_url_rule(
- logout_route,
- endpoint="oidc_logout",
- view_func=logout_view,
- methods=["GET"],
+ self.backend.add_route(
+ app.server, logout_route, logout_view, "oidc_logout", ["GET"]
)
- app.server.add_url_rule(
- callback_route,
- endpoint="oidc_callback",
- view_func=callback_view,
- methods=["GET"],
+ self.backend.add_route(
+ app.server, callback_route, callback_view, "oidc_callback", ["GET"]
)
def register_provider(self, idp_name: str, **kwargs):
@@ -283,7 +262,7 @@ def _create_redirect_uri(self, idp: str):
)
host = self.request.headers.get("X-Forwarded-Host")
if host:
- redirect_uri = redirect_uri.replace(self.request.host, host, 1)
+ redirect_uri = redirect_uri.replace(self.backend.current_host(), host, 1)
return redirect_uri
def login_request(self, idp: str | None = None):
@@ -298,7 +277,8 @@ def login_request(self, idp: str | None = None):
"""
# `idp` can be none here as login_request is called
# without arguments in the before_request hook
- if isinstance(self.backend, QuartBackend):
+ if self.backend.is_async:
+ # Returns a coroutine; the async before-request hook / route awaits it.
return self._login_request_async(idp)
idp, response = self._resolve_idp(idp)
@@ -308,28 +288,34 @@ def login_request(self, idp: str | None = None):
redirect_uri = self._create_redirect_uri(idp)
oauth_client = self.get_oauth_client(idp)
oauth_kwargs = self.get_oauth_kwargs(idp)
- return oauth_client.authorize_redirect(
+ return self.backend.oauth_authorize_redirect(
+ oauth_client,
redirect_uri,
**oauth_kwargs.get("authorize_redirect_kwargs", {}),
)
async def _login_request_async(self, idp: str | None = None):
- """Async login view for the Quart path.
+ """Async login view shared by the Quart and FastAPI paths.
+
+ Backend-agnostic: the backend supplies the (possibly request-injecting)
+ authlib call and coerces the result to a framework response.
Returns:
The authorize-redirect response.
"""
idp, response = self._resolve_idp(idp)
if response is not None:
- return response
+ return self.backend.coerce_response(response)
redirect_uri = self._create_redirect_uri(idp)
oauth_client = self.get_oauth_client(idp)
oauth_kwargs = self.get_oauth_kwargs(idp)
- return await oauth_client.authorize_redirect(
+ result = await self.backend.oauth_authorize_redirect(
+ oauth_client,
redirect_uri,
**oauth_kwargs.get("authorize_redirect_kwargs", {}),
)
+ return self.backend.coerce_response(result)
def logout(self): # pylint: disable=C0116
"""Logout the user.
@@ -352,12 +338,15 @@ def logout(self): # pylint: disable=C0116
return page
async def _logout_async(self):
- """Async logout view for the Quart path; the body is sync.
+ """Async logout view shared by the Quart and FastAPI paths.
+
+ The body is sync; the backend coerces the HTML page to a framework
+ response (passthrough on Quart, ``HTMLResponse`` on FastAPI).
Returns:
- The logged-out page content.
+ The logged-out page response.
"""
- return self.logout()
+ return self.backend.coerce_response(self.logout())
def callback(self, idp: str): # pylint: disable=C0116
"""Handle the OIDC dance and post-login actions.
@@ -371,7 +360,8 @@ def callback(self, idp: str): # pylint: disable=C0116
oauth_client = self.get_oauth_client(idp)
oauth_kwargs = self.get_oauth_kwargs(idp)
try:
- token = oauth_client.authorize_access_token(
+ token = self.backend.oauth_authorize_access_token(
+ oauth_client,
**oauth_kwargs.get("authorize_token_kwargs", {}),
)
except OAuthError as err:
@@ -381,25 +371,31 @@ def callback(self, idp: str): # pylint: disable=C0116
return self.after_logged_in(user, idp, token)
async def _callback_async(self, idp: str):
- """Async OIDC callback view for the Quart path.
+ """Async OIDC callback view shared by the Quart and FastAPI paths.
+
+ Backend-agnostic: the backend supplies the (possibly request-injecting)
+ token exchange and coerces every return through the single boundary.
Returns:
- The post-login redirect, or an error tuple on failure.
+ The post-login redirect, or an error response on failure.
"""
if idp not in self.oauth._registry:
- return f"'{idp}' is not a valid registered idp", 400
+ return self.backend.coerce_response(
+ (f"'{idp}' is not a valid registered idp", 400)
+ )
oauth_client = self.get_oauth_client(idp)
oauth_kwargs = self.get_oauth_kwargs(idp)
try:
- token = await oauth_client.authorize_access_token(
+ token = await self.backend.oauth_authorize_access_token(
+ oauth_client,
**oauth_kwargs.get("authorize_token_kwargs", {}),
)
except OAuthError as err:
- return str(err), 401
+ return self.backend.coerce_response((str(err), 401))
user = token.get("userinfo")
- return self.after_logged_in(user, idp, token)
+ return self.backend.coerce_response(self.after_logged_in(user, idp, token))
def after_logged_in(self, user: dict | None, idp: str, token: dict):
"""Run post-login actions after successful OIDC authentication.
@@ -444,7 +440,7 @@ def is_authorized(self): # pylint: disable=C0116
if x
]
).bind("")
- return map_adapter.test(self.request.path) or "user" in self.session
+ return map_adapter.test(self.backend.current_path()) or "user" in self.session
def get_oauth(app: dash.Dash | None = None) -> "OAuth | QuartOAuth":
@@ -463,14 +459,12 @@ def get_oauth(app: dash.Dash | None = None) -> "OAuth | QuartOAuth":
if app is None:
app = dash.get_app()
- extensions = getattr(app.server, "extensions", {})
- for extension_key in (
- "authlib.integrations.flask_client",
- "authlib.integrations.quart_client",
- ):
- oauth = extensions.get(extension_key)
- if oauth is not None:
- return oauth
+ # Retrieval is symmetric with storage: each backend knows where its own
+ # make_oauth stashed the registry (extensions vs server.state), so there's
+ # no need to probe every framework's location here.
+ oauth = detect_backend(app.server).get_oauth(app.server)
+ if oauth is not None:
+ return oauth
raise RuntimeError(
"OAuth object is not yet defined. `OIDCAuth(app, **kwargs)` needs "
diff --git a/dash_auth_async/public_routes.py b/dash_auth_async/public_routes.py
index 29363da..cf6a33a 100644
--- a/dash_auth_async/public_routes.py
+++ b/dash_auth_async/public_routes.py
@@ -9,6 +9,8 @@
from dash._callback import GLOBAL_CALLBACK_MAP # noqa: PLC2701
from werkzeug.routing import Map, MapAdapter, Rule
+from .backends import detect_backend
+
DASH_PUBLIC_ASSETS_EXTENSIONS = "js,css"
BASE_PUBLIC_ROUTES = [
f"/assets/.{ext}"
@@ -80,7 +82,7 @@ def add_public_routes(app: Dash, routes: list):
full_route = url_base.rstrip("/") + full_route
public_routes.map.add(Rule(full_route))
- app.server.config[PUBLIC_ROUTES] = public_routes
+ detect_backend(app.server).store_config(app.server, PUBLIC_ROUTES, public_routes)
def public_callback(*callback_args, **callback_kwargs):
@@ -107,10 +109,12 @@ def decorator(func):
)
try:
app = get_app()
- app.server.config[PUBLIC_CALLBACKS] = [
- *get_public_callbacks(app),
- callback_id,
- ]
+ backend = detect_backend(app.server)
+ backend.store_config(
+ app.server,
+ PUBLIC_CALLBACKS,
+ [*backend.read_config(app.server, PUBLIC_CALLBACKS, []), callback_id],
+ )
except Exception:
print(
"Could not set up the public callback as the Dash object "
@@ -131,7 +135,9 @@ def get_public_routes(app: Dash) -> MapAdapter:
Returns:
The MapAdapter holding the app's registered public routes.
"""
- return app.server.config.get(PUBLIC_ROUTES, Map([]).bind(""))
+ return detect_backend(app.server).read_config(
+ app.server, PUBLIC_ROUTES, Map([]).bind("")
+ )
def get_public_callbacks(app: Dash) -> list:
@@ -140,4 +146,4 @@ def get_public_callbacks(app: Dash) -> list:
Returns:
The list of whitelisted public callback ids.
"""
- return app.server.config.get(PUBLIC_CALLBACKS, [])
+ return detect_backend(app.server).read_config(app.server, PUBLIC_CALLBACKS, [])
diff --git a/dash_auth_async/version.py b/dash_auth_async/version.py
index e461518..13e7faf 100644
--- a/dash_auth_async/version.py
+++ b/dash_auth_async/version.py
@@ -1,3 +1,3 @@
"""Single source of truth for the package version."""
-__version__ = "1.2.1"
+__version__ = "1.3.0"
diff --git a/dash_auth_async/websocket_auth.py b/dash_auth_async/websocket_auth.py
index de8fdfa..6124ed4 100644
--- a/dash_auth_async/websocket_auth.py
+++ b/dash_auth_async/websocket_auth.py
@@ -8,13 +8,23 @@
from __future__ import annotations
+import asyncio
import contextvars
+import secrets
import threading
from concurrent.futures import ThreadPoolExecutor
from contextvars import ContextVar
from typing import Any
from weakref import WeakKeyDictionary
+from .backends import get_active_backend
+
+# Name of the pre-login, browser-stable cookie that ties an anonymous WebSocket
+# back to the browser that later authenticates over HTTP. Minted on first
+# contact by the auth hook (see the backends' before-request middleware) so it
+# is already present at a pre-login WS handshake.
+WS_CLIENT_COOKIE = "dac_client"
+
# The authenticated user (session["user"] dict) for the callback currently being
# dispatched over a WebSocket. Set by the websocket_message hook in the WS
# context and propagated into Dash's callback worker by the context-copying
@@ -47,6 +57,83 @@ def submit(self, fn, /, *args, **kwargs):
_hook_registered = False
+class _WSEntry:
+ """A registered anonymous WebSocket plus the loop it runs on."""
+
+ __slots__ = ("loop", "ws")
+
+ def __init__(self, ws: Any, loop: Any) -> None:
+ self.ws = ws
+ self.loop = loop
+
+
+# client-id cookie -> the browser's current anonymous (pre-login) socket.
+#
+# Only *anonymous* sockets are tracked: a login retires those and only those, so
+# an already-authenticated socket would only ever bloat the map. And only one
+# per browser -- a browser drives a single shared SharedWorker socket, so a
+# reconnect replaces the prior entry rather than stacking. Together these bound
+# the map to (at most) one entry per distinct pre-login browser.
+_WS_BY_CLIENT: dict[str, _WSEntry] = {}
+_ws_registry_lock = threading.Lock()
+
+
+def mint_ws_client_id() -> str:
+ """Return a fresh, opaque browser-stable client id for the cookie."""
+ return secrets.token_urlsafe(16)
+
+
+def register_anonymous_ws(client_id: str | None, ws: Any, loop: Any) -> None:
+ """Track a browser's current pre-login socket, replacing any stale prior.
+
+ No-op without a client id (e.g. a backend that never minted the cookie):
+ such a socket can't be correlated to a later login and falls back to the
+ reactive 4401-then-reconnect path.
+ """
+ if not client_id:
+ return
+ with _ws_registry_lock:
+ _WS_BY_CLIENT[client_id] = _WSEntry(ws, loop)
+
+
+def close_anonymous_ws(client_id: str | None) -> None:
+ """Close the browser's pre-login socket so the renderer reconnects authed.
+
+ Called when ``client_id`` authenticates over HTTP. The close code (``4408``)
+ is outside the renderer's no-reconnect set (``1000``/``4001``), so the
+ SharedWorker dials a fresh handshake that carries the now-present session
+ cookie -- before the user's first click rather than after a sacrificed one.
+ Popping the entry also deregisters it, keeping the map from leaking.
+ """
+ if not client_id:
+ return
+ with _ws_registry_lock:
+ entry = _WS_BY_CLIENT.pop(client_id, None)
+ if entry is None:
+ return
+ _schedule_ws_close(entry.loop, entry.ws)
+
+
+def _schedule_ws_close(loop: Any, ws: Any) -> None:
+ """Schedule ``ws.close`` onto the socket's own event loop, thread-safely.
+
+ The login runs in a separate HTTP task (and possibly thread); a socket can
+ only be closed from its own loop, so we hop onto it via
+ ``call_soon_threadsafe`` rather than awaiting the close inline.
+ """
+
+ async def _safe_close() -> None:
+ try:
+ await ws.close(code=4408)
+ except Exception: # pylint: disable=broad-exception-caught
+ pass # already closing/closed -- nothing to recover
+
+ try:
+ loop.call_soon_threadsafe(lambda: asyncio.ensure_future(_safe_close()))
+ except Exception: # pylint: disable=broad-exception-caught
+ pass # loop gone (socket already torn down) -- nothing to close
+
+
def _ws_message_hook(ws: Any, message: Any):
"""Global Dash websocket_message hook: authorize each callback_request.
@@ -61,38 +148,43 @@ def _ws_message_hook(ws: Any, message: Any):
if not isinstance(message, dict) or message.get("type") != "callback_request":
return True
try:
- return _authorize_ws_message(message)
+ return _authorize_ws_message(ws, message)
except Exception: # pylint: disable=broad-exception-caught
# Fail closed on any unexpected error.
return (4401, "Unauthorized")
-def _authorize_ws_message(message: dict) -> bool | tuple[int, str]:
- """Authorize one WebSocket ``callback_request`` for the current Quart app.
+def _authorize_ws_message(ws: Any, message: dict) -> bool | tuple[int, str]:
+ """Authorize one WebSocket ``callback_request`` for the owning app.
- Resolves the owning app via ``quart.current_app`` so it is correct when
- several apps share the process; inert for apps that do not use
- dash-auth-async.
+ The owning server and the session user are resolved through the active
+ backend's :meth:`Backend.ws_identity` (Quart uses its context globals,
+ FastAPI reads ``ws.app``/``ws.session``), keeping this module
+ framework-agnostic. Inert for apps that do not use dash-auth-async.
Returns:
``True`` to allow, or a ``(code, reason)`` tuple to reject the socket.
"""
- import quart # noqa: PLC0415 — quart is an optional dependency
-
- # ``quart.current_app`` is a proxy; ``_get_current_object`` unwraps it to
- # the real Quart app (the key in ``_AUTH_BY_SERVER``). The attribute is
- # present at runtime but absent from the proxy's type stub, so go through
- # ``getattr`` to keep the static type checker happy.
- current_app: Any = quart.current_app
- app = getattr(current_app, "_get_current_object")()
- auth = _AUTH_BY_SERVER.get(app)
+ backend = get_active_backend()
+ server, user = backend.ws_identity(ws)
+ auth = _AUTH_BY_SERVER.get(server)
if auth is None:
# Not a dash-auth-async app: nothing to enforce. Safe because the
- # registry entry is created by the developer's ``Auth(app, ...)``
- # call, not by the client -- an attacker cannot evict their own app.
+ # registry entry is created by the developer's ``Auth(app, ...)`` call,
+ # not by the client -- an attacker cannot evict their own app.
return True
+ # Migrate Dash's GLOBAL_CALLBACK_MAP into app.callback_map before Dash
+ # validates this request against it -- validation runs *after* the
+ # websocket_message hooks return (see _fastapi.py / _quart.py). On FastAPI
+ # our auth middleware can short-circuit the page/readiness GET with a 401
+ # before Dash's own _setup_server before-hook ever runs, so a WS-first
+ # client would otherwise hit an empty map ("Callback function not found").
+ # Doing it here is lazy (fires on the first real WS message, after every
+ # module-level @callback is registered -- so unlike a construction-time
+ # call it never freezes the map early) and idempotent (a self-guarded flag
+ # makes every call after the first a single boolean check).
+ auth.app._setup_server()
payload = message.get("payload", {}) or {}
- user = quart.session.get("user")
if auth.authorize_ws(payload, user):
# Load-bearing invariant: this hook runs before every callback_request
# is submitted to the executor, so the context-copying executor always
@@ -103,15 +195,51 @@ def _authorize_ws_message(message: dict) -> bool | tuple[int, str]:
return (4401, "Unauthorized")
+def _ws_connect_hook(ws: Any):
+ """Global Dash websocket_connect hook: register the socket by client id.
+
+ Runs at the handshake (before accept) for every socket. Tracks only sockets
+ that handshake *anonymously*, keyed by browser (the ``dac_client`` cookie),
+ so a later login can retire them; an already-authenticated socket needs no
+ retiring and is left untracked. Always allows the connection -- public pages
+ legitimately stream over an unauthenticated socket, so this hook must never
+ reject.
+
+ Returns:
+ ``True`` to allow the connection (bookkeeping never blocks a socket).
+ """
+ try:
+ _track_anonymous_ws(ws)
+ except Exception: # pylint: disable=broad-exception-caught
+ pass # never let bookkeeping block a connection
+ return True
+
+
+def _track_anonymous_ws(ws: Any) -> None:
+ """Register ``ws`` if its handshake was anonymous (see ``_ws_connect_hook``)."""
+ backend = get_active_backend()
+ _, user = backend.ws_identity(ws)
+ if user is not None:
+ return # already authenticated -- nothing to retire later
+ client_id = getattr(ws, "cookies", {}).get(WS_CLIENT_COOKIE)
+ # Resolve the concrete socket: Quart hands us a context-local proxy that
+ # can't be closed from the later (different-task) login, so we must keep the
+ # underlying object. Starlette's WebSocket has no such proxy and is stored
+ # as-is.
+ real_ws = ws._get_current_object() if hasattr(ws, "_get_current_object") else ws
+ register_anonymous_ws(client_id, real_ws, asyncio.get_running_loop())
+
+
def _ensure_hook_registered() -> None:
- """Register the global websocket_message hook exactly once per process."""
- global _hook_registered # noqa: PLW0603 — register the hook once per process
+ """Register the global websocket hooks exactly once per process."""
+ global _hook_registered # noqa: PLW0603 — register the hooks once per process
with _hook_lock:
if _hook_registered:
return
from dash import hooks # noqa: PLC0415 — lazy import to avoid an import cycle
hooks.websocket_message()(_ws_message_hook)
+ hooks.websocket_connect()(_ws_connect_hook)
_hook_registered = True
@@ -119,8 +247,13 @@ def enable_ws_auth(auth: Any, app: Any) -> None:
"""Wire WebSocket auth for a dash-auth-async app.
No-op on backends without WebSocket support (e.g. Flask). For WS-capable
- backends it records the app->Auth mapping, installs the context-copying
+ backends it records the server->Auth mapping, installs the context-copying
executor (before any dispatch), and registers the global hook once.
+
+ Note Dash's ``callback_map`` is *not* populated here. It is migrated lazily
+ on the first WS ``callback_request`` by the message hook (see
+ ``_authorize_ws_message``), so a global ``@callback`` registered after
+ ``Auth(...)`` is still picked up and the server is never set up twice.
"""
backend = getattr(app, "backend", None)
if backend is None or not getattr(backend, "websocket_capability", False):
diff --git a/examples/websocket_auth_quart/README.md b/examples/websocket_auth_quart/README.md
new file mode 100644
index 0000000..1ec910c
--- /dev/null
+++ b/examples/websocket_auth_quart/README.md
@@ -0,0 +1,42 @@
+# Quart WebSockets + public/private auth example
+
+A minimal multi-page Dash app on the **Quart** backend that combines:
+
+- **WebSocket streaming callbacks** (`Dash(backend="quart", websocket_callbacks=True)`)
+- **Public vs authenticated pages** via `dash-auth-async` `BasicAuth`
+
+## Pages
+
+| Route | Access | What it shows |
+|------------|---------------|----------------------------------------|
+| `/` | public | Landing page + navigation |
+| `/live` | public | Live counter/clock streamed over WS |
+| `/private` | authenticated | Simulated 0→100% progress task over WS |
+
+## Run
+
+```bash
+pip install "dash-auth-async[quart]" # or: uv sync, from the repo root
+python examples/websocket_auth_quart/app.py
+```
+
+Open . The `/private` page prompts for login:
+
+- `admin` / `admin`
+- `viewer` / `viewer123`
+
+(Use `127.0.0.1` rather than `localhost`.)
+
+## Note on WebSocket-layer authentication
+
+Authentication is enforced at the **HTTP page level**: `/private` is a normal
+HTTP GET that is auth-checked, so an anonymous user cannot load the page or start
+its stream. This is what the public/private split demonstrates.
+
+However, `dash-auth-async`'s auth hook is registered as `@server.before_request`,
+which does **not** fire for WebSocket connections (Quart uses separate
+`before_websocket` / `websocket_connect` hooks). The WebSocket callback route
+(`/_dash-ws-callback`) is therefore **not independently auth-gated** — a
+hand-crafted raw WS connection would bypass the check. Closing this gap properly
+would mean adding auth via Dash's `websocket_connect` hook in the library, which
+is out of scope for this example.
diff --git a/examples/websocket_auth_quart/app.py b/examples/websocket_auth_quart/app.py
new file mode 100644
index 0000000..983e5f5
--- /dev/null
+++ b/examples/websocket_auth_quart/app.py
@@ -0,0 +1,53 @@
+"""Example: Dash + Quart backend with WebSocket streaming and public/private auth.
+
+Run:
+ python examples/websocket_auth_quart/app.py
+
+Then open http://127.0.0.1:8050/ in a browser.
+
+Credentials for the private page:
+ admin / admin
+ viewer / viewer123
+"""
+
+from dash import Dash, dcc, html, page_container
+
+from dash_auth_async import BasicAuth
+
+app = Dash(
+ __name__,
+ backend="quart",
+ use_pages=True,
+ websocket_callbacks=True,
+ suppress_callback_exceptions=True,
+)
+
+app.layout = html.Div(
+ [
+ html.Div(
+ [
+ dcc.Link("Home", href="/"),
+ dcc.Link("Live (public)", href="/live"),
+ dcc.Link("Private", href="/private"),
+ ],
+ style={
+ "display": "flex",
+ "gap": "1rem",
+ "background": "#eee",
+ "padding": "0.5rem 1rem",
+ },
+ ),
+ page_container,
+ ],
+ style={"display": "flex", "flexDirection": "column", "fontFamily": "sans-serif"},
+)
+
+BasicAuth(
+ app,
+ {"admin": "admin", "viewer": "viewer123"},
+ secret_key="example-secret-not-for-production",
+ public_routes=["/", "/live"],
+)
+
+if __name__ == "__main__":
+ app.run(host="127.0.0.1", port=8050, debug=True)
diff --git a/examples/websocket_auth_quart/pages/__init__.py b/examples/websocket_auth_quart/pages/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/examples/websocket_auth_quart/pages/home.py b/examples/websocket_auth_quart/pages/home.py
new file mode 100644
index 0000000..239d32f
--- /dev/null
+++ b/examples/websocket_auth_quart/pages/home.py
@@ -0,0 +1,23 @@
+"""Public landing page for the WebSocket + auth example."""
+
+from dash import html, register_page
+
+register_page(__name__, path="/", name="Home")
+
+layout = html.Div(
+ [
+ html.H1("Quart + WebSockets + dash-auth-async"),
+ html.P(
+ "This example runs on the Dash Quart backend with WebSocket "
+ "streaming callbacks, and demonstrates public vs authenticated "
+ "pages using dash-auth-async BasicAuth."
+ ),
+ html.Ul(
+ [
+ html.Li("'/live' is public — anyone can watch the live counter."),
+ html.Li("'/private' requires login (try admin / admin)."),
+ ]
+ ),
+ ],
+ style={"padding": "1rem"},
+)
diff --git a/examples/websocket_auth_quart/pages/live.py b/examples/websocket_auth_quart/pages/live.py
new file mode 100644
index 0000000..5365541
--- /dev/null
+++ b/examples/websocket_auth_quart/pages/live.py
@@ -0,0 +1,49 @@
+"""Public page: a live counter/clock streamed over the WebSocket.
+
+Uses public_callback so the callback is whitelisted by dash-auth-async even if a
+leg is served over HTTP. When routed over the WebSocket the whitelist is simply
+unused. The streaming pattern (async + set_props + is_shutdown) follows
+https://dash.plotly.com/websocket-callbacks.
+"""
+
+import asyncio
+from datetime import datetime
+
+import dash
+from dash import Input, Output, ctx, html, register_page, set_props
+
+from dash_auth_async import public_callback
+
+register_page(__name__, path="/live", name="Live")
+
+layout = html.Div(
+ [
+ html.H1("Live counter / clock (public)"),
+ html.P("Anyone can view this page and start the stream — no login required."),
+ html.Button("Start stream", id="counter-start"),
+ html.Div(
+ "Press start.",
+ id="counter-out",
+ style={"marginTop": "1rem", "fontSize": "1.5rem"},
+ ),
+ ],
+ style={"padding": "1rem"},
+)
+
+
+@public_callback(
+ Output("counter-out", "children"),
+ Input("counter-start", "n_clicks"),
+ prevent_initial_call=True,
+)
+async def stream_counter(_n_clicks):
+ ws = getattr(ctx, "websocket", None)
+ for i in range(1, 11):
+ if ws is not None and ws.is_shutdown:
+ return dash.no_update
+ set_props(
+ "counter-out",
+ {"children": f"Tick {i}/10 — {datetime.now():%H:%M:%S}"},
+ )
+ await asyncio.sleep(1)
+ return "Done."
diff --git a/examples/websocket_auth_quart/pages/private.py b/examples/websocket_auth_quart/pages/private.py
new file mode 100644
index 0000000..4d6e679
--- /dev/null
+++ b/examples/websocket_auth_quart/pages/private.py
@@ -0,0 +1,44 @@
+"""Private page: a simulated long-running task streamed over the WebSocket.
+
+This page is NOT in public_routes, so loading it requires BasicAuth login. The
+callback is a plain @callback: the page is only reachable when authenticated and
+no user-group gating is required. Streaming follows the same async + set_props +
+is_shutdown pattern as the public page.
+"""
+
+import asyncio
+
+import dash
+from dash import Input, Output, callback, ctx, html, register_page, set_props
+
+register_page(__name__, path="/private", name="Private")
+
+layout = html.Div(
+ [
+ html.H1("Simulated task (private)"),
+ html.P("You are logged in — this page sits behind BasicAuth."),
+ html.Button("Run task", id="task-start"),
+ html.Div(
+ html.Progress(id="task-bar", value="0", max="100"),
+ style={"marginTop": "1rem"},
+ ),
+ html.Div("Idle.", id="task-status", style={"marginTop": "1rem"}),
+ ],
+ style={"padding": "1rem"},
+)
+
+
+@callback(
+ Output("task-status", "children"),
+ Input("task-start", "n_clicks"),
+ prevent_initial_call=True,
+)
+async def run_task(_n_clicks):
+ ws = getattr(ctx, "websocket", None)
+ for pct in range(0, 101, 10):
+ if ws is not None and ws.is_shutdown:
+ return dash.no_update
+ set_props("task-bar", {"value": str(pct)})
+ set_props("task-status", {"children": f"Working… {pct}%"})
+ await asyncio.sleep(0.5)
+ return "Complete!"
diff --git a/pyproject.toml b/pyproject.toml
index 9ab916b..6e6266e 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project]
name = "dash-auth-async"
-version = "1.2.1"
+version = "1.3.0"
description = "Dash Authorization Package."
readme = "README.md"
requires-python = ">=3.10"
@@ -66,7 +66,13 @@ dev = [
"ruff>=0.15.16",
"pytest-cov>=7.1.0",
"websocket-client>=1.9.0",
- "requests[security]>=2.34.2"
+ "requests[security]>=2.34.2",
+ "fastapi",
+ "starlette",
+ "uvicorn",
+ "httpx>=0.23.0",
+ "itsdangerous",
+ "python-multipart"
]
[tool.uv]
@@ -96,6 +102,13 @@ preview = true
select = ["FAST","I","D","DOC","PL","UP","PERF","RUF"]
[tool.ruff.lint.per-file-ignores]
+# Examples are runnable demos, not shipped library code (setuptools only
+# packages dash_auth_async*). Like tests, they read better without forced
+# docstring coverage on every demo page/callback.
+"examples/**" = [
+ "D", # demo scripts don't need module/function docstrings
+ "DOC", # ...nor docstring sections (Returns, etc.)
+]
# Tests are self-describing and have different idioms than library code:
"tests/**" = [
"D", # don't require docstrings on every test/fixture/module
@@ -114,5 +127,10 @@ select = ["FAST","I","D","DOC","PL","UP","PERF","RUF"]
quote-style = "double"
indent-style = "space"
+[tool.ruff.lint.isort]
+# Keep aliased imports on the same line as their siblings so a single
+# optional-dependency import guard stays one statement (avoids PLW0717).
+combine-as-imports = true
+
[tool.ruff.lint.pydocstyle]
convention = "google"
\ No newline at end of file
diff --git a/tests/conftest.py b/tests/conftest.py
index 9440ac8..9a4338b 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -38,6 +38,26 @@ def _stop_quart_gracefully(runner) -> bool:
return not runner.thread.is_alive()
+def _stop_fastapi_gracefully(runner) -> bool:
+ """Shut down a fastapi-backend dash test server via uvicorn's should_exit.
+
+ Dash's FastAPI backend stores the uvicorn Server on the Dash app as
+ ``_uvicorn_server`` when run threaded (_fastapi.py:384). Setting
+ ``should_exit`` lets uvicorn's serve loop return so the thread exits
+ cleanly instead of being killed mid-flight.
+
+ Returns True if the server thread exited; False means fall back to the
+ original kill-based stop.
+ """
+ dash_app = getattr(runner, "_app", None)
+ server = getattr(dash_app, "_uvicorn_server", None)
+ if server is None:
+ return False
+ server.should_exit = True
+ runner.thread.join(timeout=runner.stop_timeout)
+ return not runner.thread.is_alive()
+
+
_original_init = _runners.BaseDashRunner.__init__
@@ -61,15 +81,15 @@ def _init_with_ipv4_host(
_original_stop = _runners.ThreadedRunner.stop
-def _stop_with_graceful_quart(self: Any) -> Any:
- if _stop_quart_gracefully(self):
+def _stop_with_graceful_async(self: Any) -> Any:
+ if _stop_quart_gracefully(self) or _stop_fastapi_gracefully(self):
self._app = None
self.started = False
return
return _original_stop(self)
-_runners.ThreadedRunner.stop = _stop_with_graceful_quart # type: ignore
+_runners.ThreadedRunner.stop = _stop_with_graceful_async # type: ignore
@pytest.fixture(autouse=True)
diff --git a/tests/integration/test_basic_auth_integration_auth_func.py b/tests/integration/test_basic_auth_integration_auth_func.py
index 4631ece..7e8722b 100644
--- a/tests/integration/test_basic_auth_integration_auth_func.py
+++ b/tests/integration/test_basic_auth_integration_auth_func.py
@@ -30,7 +30,7 @@ def auth_function(username, password):
},
],
)
-def test_ba002_basic_auth_login_flow(dash_br, dash_thread_server, kwargs):
+def test_ba004_basic_auth_login_flow(dash_br, dash_thread_server, kwargs):
app = Dash(__name__, **kwargs)
app.layout = html.Div(
[dcc.Input(id="input", value="initial value"), html.Div(id="output")]
@@ -104,7 +104,7 @@ def both_no_auth_func_or_dict(dash_br, dash_thread_server, **kwargs):
},
],
)
-def test_ba003_basic_auth_login_flow(dash_br, dash_thread_server, kwargs):
+def test_ba005_basic_auth_login_flow(dash_br, dash_thread_server, kwargs):
with pytest.raises(ValueError):
both_dict_and_func(dash_br, dash_thread_server, **kwargs)
with pytest.raises(ValueError):
diff --git a/tests/integration/test_basic_auth_integration_auth_func_fastapi.py b/tests/integration/test_basic_auth_integration_auth_func_fastapi.py
new file mode 100644
index 0000000..5f34036
--- /dev/null
+++ b/tests/integration/test_basic_auth_integration_auth_func_fastapi.py
@@ -0,0 +1,100 @@
+import pytest
+import requests
+from dash import Dash, Input, Output, dcc, html
+
+from dash_auth_async import basic_auth
+
+pytest.importorskip("fastapi", reason="FastAPI extra dependencies are not installed")
+
+TEST_USERS = {
+ "valid": [["hello", "world"], ["hello2", "wo:rld"]],
+ "invalid": [["hello", "password"]],
+}
+
+
+def auth_function(username, password):
+ return [username, password] in TEST_USERS["valid"]
+
+
+@pytest.mark.parametrize(
+ "kwargs",
+ [
+ {},
+ {"url_base_pathname": "/app/"},
+ {"url_base_pathname": "/sub/app/"},
+ {
+ "routes_pathname_prefix": "/app/",
+ "requests_pathname_prefix": "/app/",
+ },
+ ],
+)
+def test_ba004_basic_auth_login_flow(dash_br, dash_thread_server, kwargs):
+ app = Dash(__name__, backend="fastapi", **kwargs)
+ app.layout = html.Div(
+ [dcc.Input(id="input", value="initial value"), html.Div(id="output")]
+ )
+
+ @app.callback(Output("output", "children"), Input("input", "value"))
+ def update_output(new_value):
+ return new_value
+
+ basic_auth.BasicAuth(app, auth_func=auth_function)
+
+ dash_thread_server(app)
+ path_prefix = (
+ app.config.get("url_base_pathname", "")
+ or app.config.get("requests_pathname_prefix", "")
+ or app.config.get("routes_pathname_prefix", "")
+ )
+ base_url = dash_thread_server.url + path_prefix
+
+ def test_failed_views(url):
+ assert requests.get(url).status_code == 401
+ assert requests.get(url.strip("/") + "/_dash-layout").status_code == 401
+
+ test_failed_views(base_url)
+
+ for user, password in TEST_USERS["invalid"]:
+ test_failed_views(base_url.replace("//", f"//{user}:{password}@"))
+
+ for user, password in TEST_USERS["valid"]:
+ dash_br.driver.get(base_url.replace("//", f"//{user}:{password}@"))
+ dash_br.driver.get(base_url)
+ dash_br.wait_for_text_to_equal("#output", "initial value")
+
+
+def both_dict_and_func(dash_br, dash_thread_server, **kwargs):
+ app = Dash(__name__, backend="fastapi", **kwargs)
+ app.layout = html.Div(
+ [dcc.Input(id="input", value="initial value"), html.Div(id="output")]
+ )
+ basic_auth.BasicAuth(app, TEST_USERS["valid"], auth_func=auth_function)
+ return True
+
+
+def both_no_auth_func_or_dict(dash_br, dash_thread_server, **kwargs):
+ app = Dash(__name__, backend="fastapi", **kwargs)
+ app.layout = html.Div(
+ [dcc.Input(id="input", value="initial value"), html.Div(id="output")]
+ )
+ basic_auth.BasicAuth(app)
+ return True
+
+
+@pytest.mark.parametrize(
+ "kwargs",
+ [
+ {},
+ {"url_base_pathname": "/app/"},
+ {"url_base_pathname": "/sub/app/"},
+ {
+ "routes_pathname_prefix": "/app/",
+ "requests_pathname_prefix": "/app/",
+ },
+ ],
+)
+def test_ba005_basic_auth_login_flow(dash_br, dash_thread_server, kwargs):
+ with pytest.raises(ValueError):
+ both_dict_and_func(dash_br, dash_thread_server, **kwargs)
+ with pytest.raises(ValueError):
+ both_no_auth_func_or_dict(dash_br, dash_thread_server, **kwargs)
diff --git a/tests/integration/test_basic_auth_integration_auth_func_quart.py b/tests/integration/test_basic_auth_integration_auth_func_quart.py
index ac87026..d2c493a 100644
--- a/tests/integration/test_basic_auth_integration_auth_func_quart.py
+++ b/tests/integration/test_basic_auth_integration_auth_func_quart.py
@@ -32,7 +32,7 @@ def auth_function(username, password):
},
],
)
-def test_ba002_basic_auth_login_flow(dash_br, dash_thread_server, kwargs):
+def test_ba004_basic_auth_login_flow(dash_br, dash_thread_server, kwargs):
app = Dash(__name__, backend="quart", **kwargs)
app.layout = html.Div(
[dcc.Input(id="input", value="initial value"), html.Div(id="output")]
@@ -106,7 +106,7 @@ def both_no_auth_func_or_dict(dash_br, dash_thread_server, **kwargs):
},
],
)
-def test_ba003_basic_auth_login_flow(dash_br, dash_thread_server, kwargs):
+def test_ba005_basic_auth_login_flow(dash_br, dash_thread_server, kwargs):
with pytest.raises(ValueError):
both_dict_and_func(dash_br, dash_thread_server, **kwargs)
with pytest.raises(ValueError):
diff --git a/tests/integration/test_basic_auth_integration_fastapi.py b/tests/integration/test_basic_auth_integration_fastapi.py
new file mode 100644
index 0000000..b8c419b
--- /dev/null
+++ b/tests/integration/test_basic_auth_integration_fastapi.py
@@ -0,0 +1,120 @@
+import pytest
+import requests
+from dash import Dash, Input, Output, dcc, html
+
+from dash_auth_async import BasicAuth, add_public_routes, protected
+
+pytest.importorskip("fastapi", reason="FastAPI extra dependencies are not installed")
+
+TEST_USERS = {
+ "valid": [["hello", "world"], ["hello2", "wo:rld"]],
+ "invalid": [["hello", "password"]],
+}
+
+
+@pytest.mark.parametrize(
+ "kwargs",
+ [
+ {},
+ {"url_base_pathname": "/app/"},
+ {"url_base_pathname": "/sub/app/"},
+ {
+ "routes_pathname_prefix": "/app/",
+ "requests_pathname_prefix": "/app/",
+ },
+ ],
+)
+def test_ba001_basic_auth_login_flow(dash_br, dash_thread_server, kwargs):
+ app = Dash(__name__, backend="fastapi", **kwargs)
+ app.layout = html.Div(
+ [dcc.Input(id="input", value="initial value"), html.Div(id="output")]
+ )
+
+ @app.callback(Output("output", "children"), Input("input", "value"))
+ def update_output(new_value):
+ return new_value
+
+ BasicAuth(app, TEST_USERS["valid"], public_routes=["/home"])
+ add_public_routes(app, ["/user//public"])
+
+ dash_thread_server(app)
+ path_prefix = (
+ app.config.get("url_base_pathname", "")
+ or app.config.get("requests_pathname_prefix", "")
+ or app.config.get("routes_pathname_prefix", "")
+ )
+ base_url = dash_thread_server.url + path_prefix
+
+ def test_failed_views(url):
+ assert requests.get(url).status_code == 401
+
+ def test_successful_views(url):
+ assert requests.get(url.rstrip("/") + "/_dash-layout").status_code == 200
+ assert requests.get(url.rstrip("/") + "/home").status_code == 200
+ assert requests.get(url.rstrip("/") + "/user/john123/public").status_code == 200
+
+ test_failed_views(base_url)
+ test_successful_views(base_url)
+
+ for user, password in TEST_USERS["invalid"]:
+ test_failed_views(base_url.replace("//", f"//{user}:{password}@"))
+ test_successful_views(base_url.replace("//", f"//{user}:{password}@"))
+
+ for user, password in TEST_USERS["valid"]:
+ dash_br.driver.get(base_url.replace("//", f"//{user}:{password}@"))
+ dash_br.driver.get(base_url)
+ dash_br.wait_for_text_to_equal("#output", "initial value")
+
+
+@pytest.mark.parametrize(
+ "kwargs",
+ [
+ {},
+ {"url_base_pathname": "/app/"},
+ {"url_base_pathname": "/sub/app/"},
+ {
+ "routes_pathname_prefix": "/app/",
+ "requests_pathname_prefix": "/app/",
+ },
+ ],
+)
+def test_ba002_basic_auth_groups(dash_br, dash_thread_server, kwargs):
+ app = Dash(__name__, backend="fastapi", **kwargs)
+ app.layout = html.Div(
+ [dcc.Input(id="input", value="initial value"), html.Div(id="output")]
+ )
+
+ @app.callback(
+ Output("output", "children"),
+ Input("input", "value"),
+ groups=["admin"],
+ )
+ @protected(
+ unauthenticated_output="unauthenticated",
+ missing_permissions_output="forbidden",
+ groups=["admin"],
+ )
+ def update_output(new_value):
+ return new_value
+
+ BasicAuth(
+ app,
+ TEST_USERS["valid"],
+ public_routes=["/home"],
+ user_groups={"hello": ["admin"]},
+ secret_key="Test!",
+ )
+
+ dash_thread_server(app)
+ path_prefix = (
+ app.config.get("url_base_pathname", "")
+ or app.config.get("requests_pathname_prefix", "")
+ or app.config.get("routes_pathname_prefix", "")
+ )
+ base_url = dash_thread_server.url + path_prefix
+
+ for user, password in TEST_USERS["valid"]:
+ dash_br.driver.get(base_url.replace("//", f"//{user}:{password}@"))
+ dash_br.driver.get(base_url)
+ expected = "initial value" if user == "hello" else "forbidden"
+ dash_br.wait_for_text_to_equal("#output", expected)
diff --git a/tests/integration/test_oidc_auth_fastapi.py b/tests/integration/test_oidc_auth_fastapi.py
new file mode 100644
index 0000000..0bd02a6
--- /dev/null
+++ b/tests/integration/test_oidc_auth_fastapi.py
@@ -0,0 +1,240 @@
+from unittest.mock import patch
+
+import pytest
+import requests
+from dash import Dash, Input, Output, dcc, html
+
+from dash_auth_async import OIDCAuth, protected_callback
+
+pytest.importorskip("fastapi", reason="FastAPI extra dependencies are not installed")
+
+from starlette.responses import RedirectResponse
+
+_OAUTH_APP = "authlib.integrations.starlette_client.apps.StarletteOAuth2App"
+_METADATA_URL = "https://idp2.com/oidc/2/.well-known/openid-configuration"
+
+
+async def valid_authorize_redirect(self, request, redirect_uri, *args, **kwargs):
+ return RedirectResponse("/" + redirect_uri.split("/", maxsplit=3)[-1])
+
+
+async def invalid_authorize_redirect(self, request, redirect_uri, *args, **kwargs):
+ base_url = "/" + redirect_uri.split("/", maxsplit=3)[-1]
+ return RedirectResponse(
+ f"{base_url}?error=Unauthorized&error_description=something went wrong"
+ )
+
+
+async def valid_authorize_access_token(self, request, *args, **kwargs):
+ return {
+ "userinfo": {"email": "a.b@mail.com", "groups": ["viewer", "editor"]},
+ "refresh_token": "ABCDEF",
+ }
+
+
+@pytest.mark.parametrize(
+ "kwargs",
+ [
+ {},
+ {"url_base_pathname": "/app/"},
+ {"url_base_pathname": "/sub/app/"},
+ {
+ "routes_pathname_prefix": "/app/",
+ "requests_pathname_prefix": "/app/",
+ },
+ ],
+)
+@patch(f"{_OAUTH_APP}.authorize_redirect", valid_authorize_redirect)
+@patch(f"{_OAUTH_APP}.authorize_access_token", valid_authorize_access_token)
+def test_oaf001_oidc_auth_login_flow_success(dash_br, dash_thread_server, kwargs):
+ app = Dash(__name__, backend="fastapi", **kwargs)
+ app.layout = html.Div(
+ [
+ dcc.Input(id="input", value="initial value"),
+ html.Div(id="output1"),
+ html.Div(id="output2"),
+ html.Div("static", id="output3"),
+ html.Div("static", id="output4"),
+ html.Div("not static", id="output5"),
+ ]
+ )
+
+ @app.callback(Output("output1", "children"), Input("input", "value"))
+ def update_output1(new_value):
+ return new_value
+
+ @protected_callback(
+ Output("output2", "children"),
+ Input("input", "value"),
+ groups=["editor"],
+ check_type="one_of",
+ )
+ def update_output2(new_value):
+ return new_value
+
+ @protected_callback(
+ Output("output3", "children"),
+ Input("input", "value"),
+ groups=["admin"],
+ check_type="one_of",
+ )
+ def update_output3(new_value):
+ return new_value
+
+ @protected_callback(
+ Output("output4", "children"),
+ Input("input", "value"),
+ groups=["viewer"],
+ check_type="none_of",
+ )
+ def update_output4(new_value):
+ return new_value
+
+ @protected_callback(
+ Output("output5", "children"),
+ Input("input", "value"),
+ groups=["viewer", "editor"],
+ check_type="all_of",
+ )
+ def update_output5(new_value):
+ return new_value
+
+ oidc = OIDCAuth(app, secret_key="Test")
+ oidc.register_provider(
+ "oidc",
+ token_endpoint_auth_method="client_secret_post",
+ client_id="",
+ client_secret="",
+ server_metadata_url=_METADATA_URL,
+ )
+ dash_thread_server(app)
+ path_prefix = (
+ app.config.get("url_base_pathname", "")
+ or app.config.get("requests_pathname_prefix", "")
+ or app.config.get("routes_pathname_prefix", "")
+ )
+ base_url = dash_thread_server.url + path_prefix
+
+ assert requests.get(base_url).status_code == 200
+
+ dash_br.driver.get(base_url)
+ dash_br.wait_for_text_to_equal("#output1", "initial value")
+ dash_br.wait_for_text_to_equal("#output2", "initial value")
+ dash_br.wait_for_text_to_equal("#output3", "static")
+ dash_br.wait_for_text_to_equal("#output4", "static")
+ dash_br.wait_for_text_to_equal("#output5", "initial value")
+
+
+@pytest.mark.parametrize(
+ "kwargs",
+ [
+ {},
+ {"url_base_pathname": "/app/"},
+ {"url_base_pathname": "/sub/app/"},
+ {
+ "routes_pathname_prefix": "/app/",
+ "requests_pathname_prefix": "/app/",
+ },
+ ],
+)
+@patch(f"{_OAUTH_APP}.authorize_redirect", invalid_authorize_redirect)
+def test_oaf002_oidc_auth_login_fail(dash_thread_server, kwargs):
+ app = Dash(__name__, backend="fastapi", **kwargs)
+ app.layout = html.Div(
+ [dcc.Input(id="input", value="initial value"), html.Div(id="output")]
+ )
+
+ @app.callback(Output("output", "children"), Input("input", "value"))
+ def update_output(new_value):
+ return new_value
+
+ oidc = OIDCAuth(app, public_routes=["/public"], secret_key="Test")
+ oidc.register_provider(
+ "oidc",
+ token_endpoint_auth_method="client_secret_post",
+ client_id="",
+ client_secret="",
+ server_metadata_url=_METADATA_URL,
+ )
+ dash_thread_server(app)
+ path_prefix = (
+ app.config.get("url_base_pathname", "")
+ or app.config.get("requests_pathname_prefix", "")
+ or app.config.get("routes_pathname_prefix", "")
+ )
+ base_url = dash_thread_server.url + path_prefix
+
+ def test_unauthorized(url):
+ r = requests.get(url)
+ assert r.status_code == 401
+ assert r.text == "Unauthorized: something went wrong"
+
+ def test_authorized(url):
+ assert requests.get(url).status_code == 200
+
+ test_unauthorized(base_url)
+ test_authorized(base_url.rstrip("/") + "/public")
+
+
+@pytest.mark.parametrize(
+ "kwargs",
+ [
+ {},
+ {"url_base_pathname": "/app/"},
+ {"url_base_pathname": "/sub/app/"},
+ {
+ "routes_pathname_prefix": "/app/",
+ "requests_pathname_prefix": "/app/",
+ },
+ ],
+)
+@patch(f"{_OAUTH_APP}.authorize_redirect", valid_authorize_redirect)
+@patch(f"{_OAUTH_APP}.authorize_access_token", valid_authorize_access_token)
+def test_oaf003_oidc_auth_login_several_idp(dash_br, dash_thread_server, kwargs):
+ app = Dash(__name__, backend="fastapi", **kwargs)
+ app.layout = html.Div(
+ [
+ dcc.Input(id="input", value="initial value"),
+ html.Div(id="output1"),
+ ]
+ )
+
+ @app.callback(Output("output1", "children"), Input("input", "value"))
+ def update_output1(new_value):
+ return new_value
+
+ oidc = OIDCAuth(app, secret_key="Test")
+ oidc.register_provider(
+ "idp1",
+ token_endpoint_auth_method="client_secret_post",
+ client_id="",
+ client_secret="",
+ server_metadata_url=_METADATA_URL,
+ )
+ oidc.register_provider(
+ "idp2",
+ token_endpoint_auth_method="client_secret_post",
+ client_id="",
+ client_secret="",
+ server_metadata_url=_METADATA_URL,
+ )
+
+ dash_thread_server(app)
+ path_prefix = (
+ app.config.get("url_base_pathname", "")
+ or app.config.get("requests_pathname_prefix", "")
+ or app.config.get("routes_pathname_prefix", "")
+ )
+ base_url = dash_thread_server.url
+ base_url_prefix = (base_url + path_prefix).rstrip("/")
+ assert requests.get(base_url).status_code == 400
+ assert requests.get(base_url_prefix).status_code == 400
+
+ assert requests.get(base_url + "/oidc/idp1/login").status_code == 200
+ assert requests.get(base_url + "/oidc/logout").status_code == 200
+ assert requests.get(base_url).status_code == 400
+ assert requests.get(base_url + "/oidc/idp2/login").status_code == 200
+
+ dash_br.driver.get(base_url + "/oidc/idp2/login")
+ dash_br.driver.get(base_url_prefix)
+ dash_br.wait_for_text_to_equal("#output1", "initial value")
diff --git a/tests/integration/test_oidc_auth_fastapi_wiring.py b/tests/integration/test_oidc_auth_fastapi_wiring.py
new file mode 100644
index 0000000..280dc0d
--- /dev/null
+++ b/tests/integration/test_oidc_auth_fastapi_wiring.py
@@ -0,0 +1,85 @@
+"""OIDCAuth construction/wiring on a FastAPI-backed Dash app (no browser)."""
+
+import asyncio
+
+import pytest
+from dash import Dash, dcc, html
+
+from dash_auth_async import OIDCAuth
+from dash_auth_async.oidc_auth import get_oauth
+
+pytest.importorskip("fastapi", reason="FastAPI extra dependencies are not installed")
+
+_METADATA_URL = "https://idp2.com/oidc/2/.well-known/openid-configuration"
+
+
+def _make_oidc_app():
+ app = Dash(__name__, backend="fastapi")
+ app.layout = html.Div(
+ [dcc.Input(id="input", value="initial value"), html.Div(id="output")]
+ )
+ oidc = OIDCAuth(app, secret_key="Test")
+ oidc.register_provider(
+ "idp",
+ token_endpoint_auth_method="client_secret_post",
+ client_id="",
+ client_secret="",
+ server_metadata_url=_METADATA_URL,
+ )
+ return app, oidc
+
+
+def test_fastapi_backend_uses_starlette_oauth_registry():
+ from authlib.integrations.starlette_client import (
+ OAuth as StarletteOAuth,
+ StarletteOAuth2App,
+ )
+
+ app, oidc = _make_oidc_app()
+ assert isinstance(oidc.oauth, StarletteOAuth)
+ assert app.server.state.dash_auth_oauth is oidc.oauth
+ assert isinstance(oidc.get_oauth_client("idp"), StarletteOAuth2App)
+
+
+def test_oidc_routes_registered_with_translated_idp_placeholder():
+ app, _ = _make_oidc_app()
+ paths = {route.path for route in app.server.routes if hasattr(route, "path")}
+ assert "/oidc/{idp}/login" in paths
+ assert "/oidc/{idp}/callback" in paths
+ assert "/oidc/logout" in paths
+ names = {route.name for route in app.server.routes if hasattr(route, "name")}
+ assert {"oidc_login", "oidc_logout", "oidc_callback"} <= names
+
+
+def test_get_oauth_finds_state_registry():
+ app, oidc = _make_oidc_app()
+ assert get_oauth(app) is oidc.oauth
+
+
+def test_callback_unknown_idp_returns_400():
+ from starlette.requests import Request
+
+ from dash_auth_async.backends import _current_request_var
+
+ _, oidc = _make_oidc_app()
+ scope = {
+ "type": "http",
+ "method": "GET",
+ "path": "/oidc/nope/callback",
+ "headers": [],
+ "query_string": b"",
+ }
+ request = Request(scope)
+
+ async def run():
+ # The merged async callback resolves the request from the ContextVar
+ # (set by the auth middleware in a live request) rather than a param.
+ token = _current_request_var.set(request)
+ try:
+ response = await oidc._callback_async("nope")
+ finally:
+ _current_request_var.reset(token)
+ assert response.status_code == 400
+ assert b"not a valid registered idp" in response.body
+
+ asyncio.run(run())
diff --git a/tests/integration/test_oidc__auth_quart_wiring.py b/tests/integration/test_oidc_auth_quart_wiring.py
similarity index 94%
rename from tests/integration/test_oidc__auth_quart_wiring.py
rename to tests/integration/test_oidc_auth_quart_wiring.py
index bc050ad..43fb509 100644
--- a/tests/integration/test_oidc__auth_quart_wiring.py
+++ b/tests/integration/test_oidc_auth_quart_wiring.py
@@ -3,7 +3,7 @@
import asyncio
import pytest
-from dash import Dash
+from dash import Dash, dcc, html
from dash_auth_async import OIDCAuth
from dash_auth_async.oidc_auth import get_oauth
@@ -16,6 +16,9 @@
def _make_oidc_app():
app = Dash(__name__, backend="quart")
+ app.layout = html.Div(
+ [dcc.Input(id="input", value="initial value"), html.Div(id="output")]
+ )
oidc = OIDCAuth(app, secret_key="Test")
oidc.register_provider(
"idp",
diff --git a/tests/integration/test_oidc_state_csrf.py b/tests/integration/test_oidc_state_csrf.py
new file mode 100644
index 0000000..30bed0a
--- /dev/null
+++ b/tests/integration/test_oidc_state_csrf.py
@@ -0,0 +1,63 @@
+"""OIDC OAuth state/CSRF validation, driven through real authlib (no mock IDP).
+
+The browser OIDC tests patch ``authorize_redirect``/``authorize_access_token``,
+so authlib's anti-CSRF state check never actually runs there. These tests drive
+the *real* authlib state path on a Flask backend via its test client to lock the
+invariant against future authlib changes: ``/login`` stores the generated state
+in the session, and ``/callback`` presented with a tampered or missing state must
+be rejected with 401 (the ``except OAuthError`` branch in ``OIDCAuth.callback``),
+never silently accepted.
+
+authlib validates state before any token-endpoint call, so registering the
+provider with explicit endpoints keeps the whole flow offline.
+"""
+
+from urllib.parse import parse_qs, urlparse
+
+from dash import Dash, html
+
+from dash_auth_async import OIDCAuth
+
+_AUTHORIZE_URL = "https://idp.example/authorize"
+_TOKEN_URL = "https://idp.example/token"
+
+
+def _make_client():
+ app = Dash(__name__)
+ app.layout = html.Div("state-csrf") # Dash validates layout on first request
+ oidc = OIDCAuth(app, secret_key="state-csrf-secret")
+ oidc.register_provider(
+ "idp",
+ client_id="client-id",
+ client_secret="client-secret",
+ authorize_url=_AUTHORIZE_URL,
+ access_token_url=_TOKEN_URL,
+ )
+ # The Flask test client persists the session cookie across requests, so the
+ # state stored at /login is presented back at /callback automatically.
+ return app.server.test_client()
+
+
+def _login_and_capture_state(client) -> str:
+ resp = client.get("/oidc/idp/login")
+ assert resp.status_code == 302
+ query = parse_qs(urlparse(resp.headers["Location"]).query)
+ # authlib generated a state and stored it in the session before redirecting.
+ assert "state" in query
+ return query["state"][0]
+
+
+def test_oidc_callback_rejects_tampered_state():
+ client = _make_client()
+ real_state = _login_and_capture_state(client)
+
+ resp = client.get(f"/oidc/idp/callback?code=fake-code&state={real_state}-tampered")
+ assert resp.status_code == 401
+
+
+def test_oidc_callback_rejects_missing_state():
+ client = _make_client()
+ _login_and_capture_state(client)
+
+ resp = client.get("/oidc/idp/callback?code=fake-code")
+ assert resp.status_code == 401
diff --git a/tests/integration/test_protected_callback_websocket_integration_fastapi.py b/tests/integration/test_protected_callback_websocket_integration_fastapi.py
new file mode 100644
index 0000000..f860e2b
--- /dev/null
+++ b/tests/integration/test_protected_callback_websocket_integration_fastapi.py
@@ -0,0 +1,92 @@
+"""End-to-end integration tests for a *protected WebSocket* callback on FastAPI.
+
+FastAPI sibling of test_protected_callback_websocket_integration_quart.py. The
+callback opts into the WebSocket transport (``websocket=True``) and streams
+updates with ``set_props``; the final visible text is streamed, not returned, so
+it can only arrive over the socket. The auth gate must wrap it correctly -- an
+under-privileged user gets the fallback and the stream body never runs.
+"""
+
+import asyncio
+
+import dash
+import pytest
+from dash import Dash, Input, Output, html, set_props
+
+from dash_auth_async import BasicAuth, protected_callback
+
+pytest.importorskip("fastapi", reason="FastAPI extra dependencies are not installed")
+
+TEST_USERS = {"hello": "world", "hello2": "wo:rld"}
+USER_GROUPS = {"hello": ["admin"]}
+
+
+def _build_app() -> Dash:
+ # websocket_callbacks=True makes the *client* open the socket; per-callback
+ # websocket=True then routes this callback over it.
+ app = Dash(__name__, backend="fastapi", websocket_callbacks=True)
+ app.layout = html.Div(
+ [
+ html.Button("Start stream", id="ws-start"),
+ html.Div("idle", id="ws-out"),
+ ]
+ )
+
+ @protected_callback(
+ Output("ws-out", "children"),
+ Input("ws-start", "n_clicks"),
+ groups=["admin"],
+ missing_permissions_output="forbidden",
+ prevent_initial_call=True,
+ websocket=True,
+ )
+ async def stream_ws(_n_clicks):
+ ws = getattr(dash.ctx, "websocket", None)
+ for i in range(1, 4):
+ if ws is not None and ws.is_shutdown:
+ return dash.no_update
+ set_props("ws-out", {"children": f"tick {i}/3"})
+ await asyncio.sleep(0)
+ # Final visible text is streamed, not returned: proves the WebSocket push.
+ set_props("ws-out", {"children": "streamed"})
+ return dash.no_update
+
+ return app
+
+
+def _login(dash_br, base_url, username, password):
+ dash_br.driver.get(base_url.replace("//", f"//{username}:{password}@"))
+ dash_br.driver.get(base_url)
+
+
+def test_pcwf001_authorized_protected_websocket_callback_streams_set_props(
+ dash_br, dash_thread_server
+):
+ """An authorised user sees the value pushed via ``set_props`` over the socket."""
+ app = _build_app()
+ BasicAuth(app, TEST_USERS, user_groups=USER_GROUPS, secret_key="Test!")
+
+ dash_thread_server(app)
+ base_url = dash_thread_server.url
+ _login(dash_br, base_url, "hello", "world")
+
+ dash_br.wait_for_text_to_equal("#ws-out", "idle")
+ dash_br.find_element("#ws-start").click()
+ dash_br.wait_for_text_to_equal("#ws-out", "streamed")
+
+
+def test_pcwf002_missing_permissions_protected_websocket_callback_emits_fallback(
+ dash_br, dash_thread_server
+):
+ """An authenticated user without the group gets the fallback, never the stream."""
+ app = _build_app()
+ BasicAuth(app, TEST_USERS, user_groups=USER_GROUPS, secret_key="Test!")
+
+ dash_thread_server(app)
+ base_url = dash_thread_server.url
+ # "hello2" authenticates but has no groups -> admin gate rejects it.
+ _login(dash_br, base_url, "hello2", "wo:rld")
+
+ dash_br.wait_for_text_to_equal("#ws-out", "idle")
+ dash_br.find_element("#ws-start").click()
+ dash_br.wait_for_text_to_equal("#ws-out", "forbidden")
diff --git a/tests/integration/test_websocket_reconnect_on_login_fastapi.py b/tests/integration/test_websocket_reconnect_on_login_fastapi.py
new file mode 100644
index 0000000..d2d06f2
--- /dev/null
+++ b/tests/integration/test_websocket_reconnect_on_login_fastapi.py
@@ -0,0 +1,103 @@
+"""Regression: logging in must close the same browser's stale pre-login socket.
+
+A WebSocket's session is frozen at the handshake cookie. A socket opened before
+login is therefore permanently unauthenticated, and Dash's renderer only
+reconnects on a socket *close* -- so without intervention the first protected
+callback over that socket is rejected (4401) and silently dropped, and the user
+must click a second time (the rejection is what finally triggers the reconnect).
+
+Design A closes that gap from the server: a pre-login ``dac_client`` cookie ties
+an anonymous socket back to the browser, and the moment that browser
+authenticates over HTTP the server closes its stale anonymous socket. The
+renderer's existing auto-reconnect then dials a fresh, authenticated handshake
+*before* the first click.
+
+This test drives the raw socket (no browser) and asserts that proactive close.
+FastAPI sibling of the Quart variant.
+"""
+
+import pytest
+from dash import Dash, Input, Output, callback, html
+
+from dash_auth_async import BasicAuth
+
+pytest.importorskip("fastapi", reason="FastAPI extra dependencies are not installed")
+requests = pytest.importorskip("requests", reason="requests is not installed")
+websocket = pytest.importorskip(
+ "websocket", reason="websocket-client (the 'websocket' module) is not installed"
+)
+
+
+def _build_app() -> Dash:
+ app = Dash(__name__, backend="fastapi", websocket_callbacks=True)
+ app.layout = html.Div(
+ [html.Button("p", id="priv-in"), html.Div("idle", id="priv-out")]
+ )
+
+ @callback(
+ Output("priv-out", "children"),
+ Input("priv-in", "n_clicks"),
+ prevent_initial_call=True,
+ websocket=True,
+ )
+ def private(_n):
+ return "TOP-SECRET-USER-DATA"
+
+ BasicAuth(app, {"hello": "world"}, secret_key="Test!")
+ return app
+
+
+def _socket_was_closed(conn, timeout=6) -> bool:
+ """Return True if the server closed ``conn`` within ``timeout`` seconds.
+
+ websocket-client surfaces a server close as an empty ``recv()`` or a
+ ``WebSocketConnectionClosedException``; a still-open idle socket raises a
+ timeout instead.
+ """
+ conn.settimeout(timeout)
+ try:
+ return conn.recv() in {"", b"", None}
+ except websocket.WebSocketConnectionClosedException:
+ return True
+ except Exception:
+ return False
+
+
+def test_login_closes_stale_anonymous_socket(dash_thread_server):
+ app = _build_app()
+ dash_thread_server(app)
+ base = dash_thread_server.url
+ ws_url = base.replace("http://", "ws://") + "/_dash-ws-callback"
+
+ session = requests.Session()
+
+ # First anonymous contact must mint a pre-login client-id cookie -- even on
+ # the BasicAuth 401 challenge -- so the socket can be tied to this browser.
+ session.get(base, timeout=8)
+ client_id = session.cookies.get("dac_client")
+ assert client_id, "expected a pre-login 'dac_client' cookie on first contact"
+
+ # Open an anonymous socket carrying only dac_client (no session): this models
+ # the SharedWorker socket opened on a public page before login.
+ conn = websocket.create_connection(
+ ws_url,
+ header=[f"Origin: {base}", f"Cookie: dac_client={client_id}"],
+ timeout=8,
+ suppress_origin=True,
+ )
+ try:
+ # The same browser now authenticates over HTTP (same dac_client cookie).
+ resp = session.get(base, auth=("hello", "world"), timeout=8)
+ assert resp.status_code == 200, resp.status_code
+
+ # The server must proactively close the stale anonymous socket so the
+ # renderer reconnects authenticated before the user's first click.
+ assert _socket_was_closed(conn), (
+ "stale pre-login socket was not closed after the same browser "
+ "logged in -- the first protected click would be dropped"
+ )
+ finally:
+ try:
+ conn.close()
+ except Exception:
+ pass
diff --git a/tests/integration/test_websocket_reconnect_on_login_quart.py b/tests/integration/test_websocket_reconnect_on_login_quart.py
new file mode 100644
index 0000000..7b58380
--- /dev/null
+++ b/tests/integration/test_websocket_reconnect_on_login_quart.py
@@ -0,0 +1,87 @@
+"""Regression: logging in must close the same browser's stale pre-login socket.
+
+Quart sibling of test_websocket_reconnect_on_login_fastapi.py -- see that file
+for the full rationale. A socket's session is frozen at its handshake cookie, so
+a pre-login socket stays unauthenticated; Design A closes it the moment the same
+browser (tied by the pre-login ``dac_client`` cookie) authenticates over HTTP,
+so the renderer reconnects authenticated before the first click.
+"""
+
+import pytest
+from dash import Dash, Input, Output, callback, html
+
+from dash_auth_async import BasicAuth
+
+pytest.importorskip("quart", reason="Quart extra dependencies are not installed")
+requests = pytest.importorskip("requests", reason="requests is not installed")
+websocket = pytest.importorskip(
+ "websocket", reason="websocket-client (the 'websocket' module) is not installed"
+)
+
+
+def _build_app() -> Dash:
+ app = Dash(__name__, backend="quart", websocket_callbacks=True)
+ app.layout = html.Div(
+ [html.Button("p", id="priv-in"), html.Div("idle", id="priv-out")]
+ )
+
+ @callback(
+ Output("priv-out", "children"),
+ Input("priv-in", "n_clicks"),
+ prevent_initial_call=True,
+ websocket=True,
+ )
+ def private(_n):
+ return "TOP-SECRET-USER-DATA"
+
+ BasicAuth(app, {"hello": "world"}, secret_key="Test!")
+ return app
+
+
+def _socket_was_closed(conn, timeout=6) -> bool:
+ """Return True if the server closed ``conn`` within ``timeout`` seconds."""
+ conn.settimeout(timeout)
+ try:
+ return conn.recv() in {"", b"", None}
+ except websocket.WebSocketConnectionClosedException:
+ return True
+ except Exception:
+ return False
+
+
+def test_login_closes_stale_anonymous_socket(dash_thread_server):
+ app = _build_app()
+ dash_thread_server(app)
+ base = dash_thread_server.url
+ ws_url = base.replace("http://", "ws://") + "/_dash-ws-callback"
+
+ session = requests.Session()
+
+ # First anonymous contact must mint a pre-login client-id cookie -- even on
+ # the BasicAuth 401 challenge -- so the socket can be tied to this browser.
+ session.get(base, timeout=8)
+ client_id = session.cookies.get("dac_client")
+ assert client_id, "expected a pre-login 'dac_client' cookie on first contact"
+
+ # Open an anonymous socket carrying only dac_client (no session).
+ conn = websocket.create_connection(
+ ws_url,
+ header=[f"Origin: {base}", f"Cookie: dac_client={client_id}"],
+ timeout=8,
+ suppress_origin=True,
+ )
+ try:
+ # The same browser now authenticates over HTTP (same dac_client cookie).
+ resp = session.get(base, auth=("hello", "world"), timeout=8)
+ assert resp.status_code == 200, resp.status_code
+
+ # The server must proactively close the stale anonymous socket.
+ assert _socket_was_closed(conn), (
+ "stale pre-login socket was not closed after the same browser "
+ "logged in -- the first protected click would be dropped"
+ )
+ finally:
+ try:
+ conn.close()
+ except Exception:
+ pass
diff --git a/tests/integration/test_websocket_security_fastapi.py b/tests/integration/test_websocket_security_fastapi.py
new file mode 100644
index 0000000..f50aecc
--- /dev/null
+++ b/tests/integration/test_websocket_security_fastapi.py
@@ -0,0 +1,184 @@
+"""Security: the FastAPI websocket callback endpoint must enforce auth.
+
+Drives the raw socket with websocket-client (no browser) and asserts an
+unauthenticated client cannot invoke a private callback, a public_callback still
+streams unauthenticated, and an authenticated-but-under-privileged user gets the
+fallback rather than the protected payload. FastAPI sibling of
+test_websocket_security_quart.py.
+"""
+
+import json
+
+import pytest
+from dash import Dash, Input, Output, callback, html
+
+from dash_auth_async import BasicAuth, protected_callback, public_callback
+
+# Guard the optional, non-extra test deps *before* importing them. fastapi is
+# checked first so the non-async matrix jobs skip cleanly without needing the
+# socket client at all.
+pytest.importorskip("fastapi", reason="FastAPI extra dependencies are not installed")
+requests = pytest.importorskip("requests", reason="requests is not installed")
+websocket = pytest.importorskip( # websocket-client (synchronous)
+ "websocket",
+ reason="websocket-client (the 'websocket' module) is not installed",
+)
+
+
+def _build_app() -> Dash:
+ app = Dash(__name__, backend="fastapi", websocket_callbacks=True)
+ app.layout = html.Div(
+ [
+ html.Button("p", id="priv-in"),
+ html.Div("idle", id="priv-out"),
+ html.Button("u", id="pub-in"),
+ html.Div("idle", id="pub-out"),
+ ]
+ )
+
+ @callback(
+ Output("priv-out", "children"),
+ Input("priv-in", "n_clicks"),
+ prevent_initial_call=True,
+ websocket=True,
+ )
+ def private(_n):
+ return "TOP-SECRET-USER-DATA"
+
+ @public_callback(
+ Output("pub-out", "children"),
+ Input("pub-in", "n_clicks"),
+ prevent_initial_call=True,
+ websocket=True,
+ )
+ async def public(_n):
+ return "PUBLIC-OK"
+
+ BasicAuth(app, {"hello": "world"}, secret_key="Test!")
+ return app
+
+
+def _build_app_with_protected_admin_callback() -> Dash:
+ """App whose private callback is group-gated to ``admin`` over the socket."""
+ app = Dash(__name__, backend="fastapi", websocket_callbacks=True)
+ app.layout = html.Div(
+ [
+ html.Button("p", id="priv-in"),
+ html.Div("idle", id="priv-out"),
+ ]
+ )
+
+ @protected_callback(
+ Output("priv-out", "children"),
+ Input("priv-in", "n_clicks"),
+ groups=["admin"],
+ missing_permissions_output="forbidden",
+ prevent_initial_call=True,
+ websocket=True,
+ )
+ async def private(_n):
+ return "TOP-SECRET-ADMIN-DATA"
+
+ BasicAuth(
+ app,
+ {"admin": "pw", "viewer": "pw"},
+ user_groups={"admin": ["admin"]}, # "viewer" authenticates with no groups
+ secret_key="Test!",
+ )
+ return app
+
+
+def _login_cookie_header(base_url, username, password) -> str:
+ """Authenticate over HTTP and return the session ``Cookie`` header value.
+
+ The auth hook runs ``is_authorized`` on this request, which stashes
+ ``session["user"]`` (with groups) and sets the session cookie -- the same
+ cookie the browser would send at the WS handshake.
+ """
+ resp = requests.get(base_url, auth=(username, password), timeout=8)
+ assert resp.status_code == 200, resp.status_code
+ return "; ".join(f"{c.name}={c.value}" for c in resp.cookies)
+
+
+def _send_callback_request(ws_url, origin, output, comp_id, in_id, cookie=None):
+ header = [f"Origin: {origin}"]
+ if cookie:
+ header.append(f"Cookie: {cookie}")
+ # suppress_origin: FastAPI/Uvicorn rejects a duplicate ``Origin`` header with
+ # HTTP 400 (Quart is lenient), and websocket-client injects its own by
+ # default -- suppress it so only our explicit Origin reaches the server.
+ conn = websocket.create_connection(
+ ws_url, header=header, timeout=8, suppress_origin=True
+ )
+ try:
+ conn.send(
+ json.dumps(
+ {
+ "type": "callback_request",
+ "requestId": "1",
+ "rendererId": "r1",
+ "payload": {
+ "output": output,
+ "outputs": {"id": comp_id, "property": "children"},
+ "inputs": [{"id": in_id, "property": "n_clicks", "value": 1}],
+ "changedPropIds": [f"{in_id}.n_clicks"],
+ "state": [],
+ },
+ }
+ )
+ )
+ frames = []
+ for _ in range(5):
+ try:
+ frames.append(str(conn.recv()))
+ except Exception:
+ break
+ return "".join(frames)
+ finally:
+ try:
+ conn.close()
+ except Exception:
+ pass
+
+
+def test_unauthenticated_ws_cannot_invoke_private_callback(dash_thread_server):
+ app = _build_app()
+ dash_thread_server(app)
+ base = dash_thread_server.url
+ ws_url = base.replace("http://", "ws://") + "/_dash-ws-callback"
+
+ received = _send_callback_request(
+ ws_url, base, "priv-out.children", "priv-out", "priv-in"
+ )
+ assert "TOP-SECRET-USER-DATA" not in received
+
+
+def test_unauthenticated_ws_can_invoke_public_callback(dash_thread_server):
+ app = _build_app()
+ dash_thread_server(app)
+ base = dash_thread_server.url
+ ws_url = base.replace("http://", "ws://") + "/_dash-ws-callback"
+
+ received = _send_callback_request(
+ ws_url, base, "pub-out.children", "pub-out", "pub-in"
+ )
+ assert "PUBLIC-OK" in received
+
+
+def test_authenticated_wrong_group_ws_gets_fallback_not_secret(dash_thread_server):
+ """Authenticated-but-under-privileged over the raw socket: the group gate
+ renders ``missing_permissions_output`` and never leaks the admin payload.
+ """
+ app = _build_app_with_protected_admin_callback()
+ dash_thread_server(app)
+ base = dash_thread_server.url
+ ws_url = base.replace("http://", "ws://") + "/_dash-ws-callback"
+
+ # "viewer" authenticates (cookie set) but lacks the "admin" group.
+ cookie = _login_cookie_header(base, "viewer", "pw")
+
+ received = _send_callback_request(
+ ws_url, base, "priv-out.children", "priv-out", "priv-in", cookie=cookie
+ )
+ assert "TOP-SECRET-ADMIN-DATA" not in received
+ assert "forbidden" in received
diff --git a/tests/unit/test_backends.py b/tests/unit/test_backends.py
index 84466b6..3d66d7e 100644
--- a/tests/unit/test_backends.py
+++ b/tests/unit/test_backends.py
@@ -1,9 +1,11 @@
from dash import Dash
+from dash_auth_async import backends
from dash_auth_async.backends import (
FlaskBackend,
detect_backend,
get_active_backend,
+ set_active_backend,
)
@@ -16,6 +18,21 @@ def test_active_backend_defaults_to_flask():
assert isinstance(get_active_backend(), FlaskBackend)
+def test_set_active_backend_overrides_default():
+ sentinel = FlaskBackend()
+ set_active_backend(sentinel)
+ # The process-global helper returns exactly what Auth.__init__ registered,
+ # not a freshly detected backend — this is the cache public_routes reads.
+ assert get_active_backend() is sentinel
+
+
+def test_default_backend_is_not_constructed_eagerly_at_import():
+ # B2: the Flask fallback is built lazily inside get_active_backend(), not
+ # as a module-level _DEFAULT_BACKEND at import, so Flask isn't cemented as
+ # the default at module load in a non-Flask process.
+ assert not hasattr(backends, "_DEFAULT_BACKEND")
+
+
def test_flask_backend_url_for_and_redirect():
app = Dash(__name__)
server = app.server
@@ -30,3 +47,28 @@ def target():
response = backend.redirect("/target")
assert response.status_code == 302
assert response.headers["Location"] == "/target"
+
+
+def test_flask_backend_new_method_defaults():
+ import flask
+
+ from dash_auth_async.backends import FlaskBackend
+
+ backend = FlaskBackend()
+
+ # coerce_response is pass-through on Flask
+ sentinel = ("body", 401, {"X": "y"})
+ assert backend.coerce_response(sentinel) is sentinel
+
+ # setup_session sets secret_key; session_configured reflects it
+ server = flask.Flask(__name__)
+ assert backend.session_configured(server) is False
+ backend.setup_session(server, "Test!")
+ assert server.secret_key == "Test!"
+ assert backend.session_configured(server) is True
+
+ # make_oauth returns a flask_client OAuth bound to the server
+ from authlib.integrations.flask_client import OAuth as FlaskOAuth
+
+ oauth = backend.make_oauth(server)
+ assert isinstance(oauth, FlaskOAuth)
diff --git a/tests/unit/test_backends_fastapi.py b/tests/unit/test_backends_fastapi.py
new file mode 100644
index 0000000..716458a
--- /dev/null
+++ b/tests/unit/test_backends_fastapi.py
@@ -0,0 +1,369 @@
+import pytest
+from dash import Dash
+
+pytest.importorskip("fastapi", reason="FastAPI extra dependencies are not installed")
+
+from starlette.requests import Request
+
+from dash_auth_async.backends import (
+ FastAPIBackend,
+ _current_request_var,
+ detect_backend,
+ get_active_backend,
+ set_active_backend,
+)
+
+
+def _bare_request(path="/", session=None):
+ scope = {
+ "type": "http",
+ "method": "GET",
+ "path": path,
+ "headers": [],
+ "query_string": b"",
+ }
+ if session is not None:
+ scope["session"] = session
+ return Request(scope)
+
+
+def test_detect_backend_fastapi():
+ app = Dash(__name__, backend="fastapi")
+ assert isinstance(detect_backend(app.server), FastAPIBackend)
+
+
+def test_active_backend_roundtrip():
+ backend = FastAPIBackend()
+ set_active_backend(backend)
+ assert get_active_backend() is backend
+
+
+def test_contextvar_set_reset_and_request_context():
+ backend = FastAPIBackend()
+ assert backend.has_request_context() is False
+
+ req = _bare_request()
+ token = _current_request_var.set(req)
+ try:
+ assert backend.has_request_context() is True
+ assert backend.request is req
+ finally:
+ _current_request_var.reset(token)
+ assert backend.has_request_context() is False
+
+
+def test_session_without_middleware_raises_runtimeerror():
+ backend = FastAPIBackend()
+ req = _bare_request() # no "session" in scope
+ token = _current_request_var.set(req)
+ try:
+ with pytest.raises(RuntimeError):
+ _ = backend.session
+ finally:
+ _current_request_var.reset(token)
+
+
+def test_session_off_request_raises_runtimeerror():
+ # request is None outside any request context. The scope-based check
+ # must surface RuntimeError (a caught, "not authenticated" signal), not
+ # an AttributeError on None — and never relies on a -O-stripped assert.
+ backend = FastAPIBackend()
+ assert backend.request is None
+ with pytest.raises(RuntimeError):
+ _ = backend.session
+
+
+def test_session_present_returns_mapping():
+ backend = FastAPIBackend()
+ req = _bare_request(session={"user": {"email": "a.b@mail.com"}})
+ token = _current_request_var.set(req)
+ try:
+ assert backend.session["user"]["email"] == "a.b@mail.com"
+ finally:
+ _current_request_var.reset(token)
+
+
+def test_coerce_response_tuple_str_and_response():
+ from starlette.responses import Response as StarletteResponse
+
+ backend = FastAPIBackend()
+
+ # Tuples carry status/headers and stay plain text (e.g. the 401 challenge).
+ resp = backend.coerce_response(
+ ("Login Required", 401, {"WWW-Authenticate": 'Basic realm="x"'})
+ )
+ assert resp.status_code == 401
+ assert resp.headers["WWW-Authenticate"] == 'Basic realm="x"'
+ assert resp.body == b"Login Required"
+ assert resp.media_type == "text/plain"
+
+ # A bare string is treated as HTML, matching Flask/Quart str returns (e.g.
+ # the OIDC logout page) so the browser renders rather than shows the markup.
+ resp2 = backend.coerce_response("hello")
+ assert resp2.status_code == 200
+ assert resp2.body == b"hello"
+ assert resp2.media_type == "text/html"
+
+ passthrough = StarletteResponse(content="x", status_code=204)
+ assert backend.coerce_response(passthrough) is passthrough
+
+
+def test_fastapi_backend_url_for_and_redirect():
+ from fastapi import FastAPI
+ from fastapi.testclient import TestClient
+
+ app = FastAPI()
+ backend = FastAPIBackend()
+
+ @app.get("/target", name="target")
+ async def target():
+ return {"ok": True}
+
+ @app.get("/probe")
+ async def probe(request: Request):
+ token = _current_request_var.set(request)
+ try:
+ return {
+ "url": backend.url_for("target"),
+ "https_url": backend.url_for("target", _external=True, _scheme="https"),
+ "redirect_loc": backend.redirect("/target").headers["location"],
+ "host": backend.current_host(),
+ }
+ finally:
+ _current_request_var.reset(token)
+
+ client = TestClient(app)
+ data = client.get("/probe").json()
+ assert data["url"].endswith("/target")
+ assert data["https_url"].startswith("https://")
+ assert data["redirect_loc"] == "/target"
+ assert data["host"] # non-empty netloc
+
+
+def _build_app_with_auth(decide, needs_body):
+ """A FastAPI app whose only middleware is the auth hook, plus an echo
+ route that proves the body survives middleware body-consumption."""
+ from fastapi import FastAPI
+
+ app = FastAPI()
+ backend = FastAPIBackend()
+
+ @app.post("/_dash-update-component")
+ async def echo(request: Request):
+ body = await request.json()
+ return {"seen": body, "had_context": backend.has_request_context()}
+
+ @app.get("/open")
+ async def open_route():
+ return {"ok": True}
+
+ backend.register_auth_hook(app, needs_body, decide)
+ return app
+
+
+def test_auth_hook_allows_when_decide_returns_none():
+ from fastapi.testclient import TestClient
+
+ calls = []
+
+ def decide(path, body):
+ calls.append((path, body))
+
+ app = _build_app_with_auth(
+ decide, needs_body=lambda p: p == "/_dash-update-component"
+ )
+ client = TestClient(app)
+
+ r = client.post("/_dash-update-component", json={"output": "x", "inputs": []})
+ assert r.status_code == 200
+ # Body was replayed: the downstream route still parsed it.
+ assert r.json()["seen"] == {"output": "x", "inputs": []}
+ assert r.json()["had_context"] is True
+ # decide saw the parsed body for the callback route.
+ assert calls == [("/_dash-update-component", {"output": "x", "inputs": []})]
+
+
+def test_auth_hook_short_circuits_with_tuple():
+ from fastapi.testclient import TestClient
+
+ def decide(path, body):
+ return ("Login Required", 401, {"WWW-Authenticate": 'Basic realm="x"'})
+
+ app = _build_app_with_auth(decide, needs_body=lambda p: False)
+ client = TestClient(app)
+
+ r = client.get("/open")
+ assert r.status_code == 401
+ assert r.headers["WWW-Authenticate"] == 'Basic realm="x"'
+ assert r.text == "Login Required"
+
+
+def test_auth_hook_awaits_coroutine_results():
+ from fastapi.testclient import TestClient
+ from starlette.responses import PlainTextResponse
+
+ async def decide(path, body):
+ return PlainTextResponse("async-block", status_code=403)
+
+ app = _build_app_with_auth(decide, needs_body=lambda p: False)
+ client = TestClient(app)
+
+ r = client.get("/open")
+ assert r.status_code == 403
+ assert r.text == "async-block"
+
+
+def test_auth_hook_unparseable_body_is_treated_as_none():
+ # Fail-closed path: malformed JSON on a needs_body route must reach decide
+ # as None (the `except Exception: body = None` branch), not raise a 500.
+ from fastapi.testclient import TestClient
+
+ seen = []
+
+ def decide(path, body):
+ seen.append(body)
+ return ("Unauthorized", 401) if body is None else None
+
+ app = _build_app_with_auth(
+ decide, needs_body=lambda p: p == "/_dash-update-component"
+ )
+ client = TestClient(app)
+
+ r = client.post(
+ "/_dash-update-component",
+ content="{ not valid json",
+ headers={"content-type": "application/json"},
+ )
+ assert seen == [None]
+ assert r.status_code == 401
+ assert r.text == "Unauthorized"
+
+
+def test_auth_hook_short_circuits_after_parsing_body():
+ # The other short-circuit branch: decide returns non-None *when needs_body
+ # is True*, so the body is parsed first and then the request is blocked.
+ from fastapi.testclient import TestClient
+
+ seen = []
+
+ def decide(path, body):
+ seen.append(body)
+ return ("Login Required", 401)
+
+ app = _build_app_with_auth(decide, needs_body=lambda p: True)
+ client = TestClient(app)
+
+ r = client.post("/_dash-update-component", json={"output": "x", "inputs": []})
+ assert r.status_code == 401
+ assert r.text == "Login Required"
+ # decide saw the parsed body — it ran after body parsing, not before.
+ assert seen == [{"output": "x", "inputs": []}]
+
+
+def test_downstream_receive_emits_disconnect_after_body():
+ # The replayed receive must deliver the cached body once, then signal
+ # http.disconnect — an app that polls receive() after the body (to detect
+ # disconnect) must not get the same body event forever.
+ import asyncio
+
+ from fastapi import FastAPI
+
+ app = FastAPI()
+ backend = FastAPIBackend()
+ backend.register_auth_hook(
+ app, needs_body=lambda p: True, decide=lambda path, body: None
+ )
+ # add_middleware prepends; our auth middleware is the only/outermost one.
+ auth_middleware_cls = app.user_middleware[0].cls
+
+ received = []
+
+ async def inner_app(scope, receive, send):
+ received.append(await receive()) # cached body
+ received.append(await receive()) # past the body → disconnect
+
+ middleware = auth_middleware_cls(inner_app)
+
+ scope = {
+ "type": "http",
+ "method": "POST",
+ "path": "/_dash-update-component",
+ "headers": [(b"content-type", b"application/json")],
+ "query_string": b"",
+ }
+
+ async def receive():
+ return {
+ "type": "http.request",
+ "body": b'{"output": "x", "inputs": []}',
+ "more_body": False,
+ }
+
+ async def send(_message):
+ pass
+
+ async def drive():
+ await middleware(scope, receive, send)
+
+ asyncio.run(drive())
+
+ assert received[0]["type"] == "http.request"
+ assert received[0]["body"] == b'{"output": "x", "inputs": []}'
+ assert received[1] == {"type": "http.disconnect"}
+
+
+def test_setup_session_adds_session_middleware_once():
+ from fastapi import FastAPI
+ from starlette.middleware.sessions import SessionMiddleware
+
+ backend = FastAPIBackend()
+ app = FastAPI()
+
+ assert backend.session_configured(app) is False
+
+ backend.setup_session(app, "Test!")
+ assert backend.session_configured(app) is True
+ count = sum(1 for m in app.user_middleware if m.cls is SessionMiddleware)
+ assert count == 1
+
+ # Calling again must not add a second SessionMiddleware.
+ backend.setup_session(app, "Test!")
+ count = sum(1 for m in app.user_middleware if m.cls is SessionMiddleware)
+ assert count == 1
+
+
+def test_setup_session_wires_secure_session_to_https_only():
+ from fastapi import FastAPI
+ from starlette.middleware.sessions import SessionMiddleware
+
+ backend = FastAPIBackend()
+
+ insecure = FastAPI()
+ backend.setup_session(insecure, "Test!")
+ sm = next(m for m in insecure.user_middleware if m.cls is SessionMiddleware)
+ assert sm.kwargs.get("https_only") is False
+
+ secure = FastAPI()
+ backend.setup_session(secure, "Test!", secure_session=True)
+ sm = next(m for m in secure.user_middleware if m.cls is SessionMiddleware)
+ assert sm.kwargs.get("https_only") is True
+
+
+def test_setup_session_noop_without_secret_key():
+ from fastapi import FastAPI
+
+ backend = FastAPIBackend()
+ app = FastAPI()
+ backend.setup_session(app, None)
+ assert backend.session_configured(app) is False
+
+
+def test_config_store_read_roundtrip_via_state():
+ from fastapi import FastAPI
+
+ backend = FastAPIBackend()
+ app = FastAPI() # FastAPI has no .config, only .state
+
+ assert backend.read_config(app, "PUBLIC_ROUTES", "fallback") == "fallback"
+ backend.store_config(app, "PUBLIC_ROUTES", ["/home"])
+ assert backend.read_config(app, "PUBLIC_ROUTES") == ["/home"]
diff --git a/tests/unit/test_group_protection_fastapi.py b/tests/unit/test_group_protection_fastapi.py
new file mode 100644
index 0000000..b376716
--- /dev/null
+++ b/tests/unit/test_group_protection_fastapi.py
@@ -0,0 +1,78 @@
+import asyncio
+import inspect
+
+import pytest
+
+pytest.importorskip("fastapi", reason="FastAPI extra dependencies are not installed")
+
+from starlette.requests import Request
+
+from dash_auth_async import check_groups, list_groups, protected
+from dash_auth_async.backends import (
+ FastAPIBackend,
+ _current_request_var,
+ set_active_backend,
+)
+
+
+def _request_with_session(session):
+ scope = {
+ "type": "http",
+ "method": "GET",
+ "path": "/",
+ "headers": [],
+ "query_string": b"",
+ }
+ if session is not None:
+ scope["session"] = session
+ return Request(scope)
+
+
+def test_gp_list_groups_fastapi():
+ set_active_backend(FastAPIBackend())
+ req = _request_with_session(
+ {"user": {"email": "a.b@mail.com", "groups": ["default"]}}
+ )
+ token = _current_request_var.set(req)
+ try:
+ assert list_groups() == ["default"]
+ assert check_groups(["default"]) is True
+ assert check_groups(["other"]) is False
+ finally:
+ _current_request_var.reset(token)
+
+
+def test_gp_no_session_returns_none():
+ set_active_backend(FastAPIBackend())
+ req = _request_with_session(None) # no SessionMiddleware -> session raises
+ token = _current_request_var.set(req)
+ try:
+ assert list_groups() is None
+ assert check_groups(["default"]) is None
+ finally:
+ _current_request_var.reset(token)
+
+
+def test_gp_async_protected_unauthenticated_without_session_user():
+ """An async ``protected`` wrapper with no logged-in user (empty session)
+ short-circuits to ``unauthenticated_output`` over the FastAPI backend, while
+ staying a coroutine so Dash keeps it on the async dispatch path.
+ """
+ set_active_backend(FastAPIBackend())
+ req = _request_with_session({}) # session available, but no "user"
+ token = _current_request_var.set(req)
+ try:
+
+ async def func():
+ return "success"
+
+ wrapped = protected(
+ unauthenticated_output="unauthenticated",
+ missing_permissions_output="forbidden",
+ groups=["admin"],
+ )(func)
+
+ assert inspect.iscoroutinefunction(wrapped)
+ assert asyncio.run(wrapped()) == "unauthenticated"
+ finally:
+ _current_request_var.reset(token)
diff --git a/tests/unit/test_group_protection_quart.py b/tests/unit/test_group_protection_quart.py
index f910346..a83039a 100644
--- a/tests/unit/test_group_protection_quart.py
+++ b/tests/unit/test_group_protection_quart.py
@@ -1,15 +1,15 @@
import asyncio
+import inspect
import pytest
-from dash_auth_async import check_groups, list_groups
+from dash_auth_async import check_groups, list_groups, protected
pytest.importorskip("quart", reason="Quart extra dependencies are not installed")
def test_gp004_list_groups_quart():
- from quart import Quart
- from quart import session as quart_session
+ from quart import Quart, session as quart_session
from dash_auth_async.backends import QuartBackend, set_active_backend
@@ -27,3 +27,34 @@ async def run():
assert check_groups(["other"]) is False
asyncio.run(run())
+
+
+def test_gp_async_protected_unauthenticated_without_session_user():
+ """An async ``protected`` wrapper with no logged-in user short-circuits to
+ ``unauthenticated_output`` over the Quart backend, while staying a coroutine
+ so Dash keeps it on the async dispatch path.
+ """
+ from quart import Quart
+
+ from dash_auth_async.backends import QuartBackend, set_active_backend
+
+ app = Quart(__name__)
+ app.secret_key = "Test!"
+
+ async def func():
+ return "success"
+
+ async def run():
+ async with app.test_request_context("/", method="GET"): # ty: ignore[invalid-context-manager]
+ # No session["user"] -> unauthenticated.
+ set_active_backend(QuartBackend())
+ wrapped = protected(
+ unauthenticated_output="unauthenticated",
+ missing_permissions_output="forbidden",
+ groups=["admin"],
+ )(func)
+
+ assert inspect.iscoroutinefunction(wrapped)
+ assert await wrapped() == "unauthenticated"
+
+ asyncio.run(run())
diff --git a/tests/unit/test_protected_callback_async.py b/tests/unit/test_protected_callback_async.py
index 5f9d0bb..f624490 100644
--- a/tests/unit/test_protected_callback_async.py
+++ b/tests/unit/test_protected_callback_async.py
@@ -80,6 +80,32 @@ async def func():
assert asyncio.run(wrapped()) == "forbidden"
+def test_gp008b_async_protected_emits_unauthenticated_output_without_session_user():
+ """An async target with no session user gets the unauthenticated output.
+
+ The third gate branch (alongside authorized and missing-permissions): when
+ ``_current_user`` resolves to ``None`` the wrapper must short-circuit to
+ ``unauthenticated_output`` and never await the inner coroutine -- while still
+ staying a coroutine function so Dash keeps it on the async dispatch path.
+ """
+
+ async def func():
+ return "success"
+
+ app = Flask(__name__)
+ app.secret_key = "Test!"
+ with app.test_request_context("/", method="GET"):
+ # No session["user"] -> unauthenticated.
+ wrapped = protected(
+ unauthenticated_output="unauthenticated",
+ missing_permissions_output="forbidden",
+ groups=["admin"],
+ )(func)
+
+ assert inspect.iscoroutinefunction(wrapped)
+ assert asyncio.run(wrapped()) == "unauthenticated"
+
+
# --------------------------------------------------------------------------- #
# Level 3: protected_callback registers on the correct Dash dispatch path
# --------------------------------------------------------------------------- #
diff --git a/tests/unit/test_public_routes.py b/tests/unit/test_public_routes.py
new file mode 100644
index 0000000..0f14256
--- /dev/null
+++ b/tests/unit/test_public_routes.py
@@ -0,0 +1,45 @@
+"""Unit tests for the public-route/callback registration helpers."""
+
+import pytest
+from dash import Dash
+
+from dash_auth_async import backends
+from dash_auth_async.public_routes import (
+ add_public_routes,
+ get_public_callbacks,
+ get_public_routes,
+)
+
+
+def test_public_helpers_resolve_backend_from_app_server_not_global_fallback():
+ """The public-route helpers must resolve the backend from ``app.server``,
+ not the process-global ``get_active_backend()`` (which falls back to
+ ``FlaskBackend``). On a FastAPI app that fallback's ``store_config`` would
+ write to the nonexistent ``server.config``; ``detect_backend(app.server)``
+ routes through ``server.state`` instead.
+
+ Regression guard for the FastAPI public-route registration path: with no
+ ``Auth`` having set the active backend, the global is the Flask fallback,
+ so a helper that trusted it would ``AttributeError`` here.
+ """
+ pytest.importorskip(
+ "fastapi", reason="FastAPI extra dependencies are not installed"
+ )
+
+ # Leave the active backend unset -> get_active_backend() is the Flask
+ # fallback, the exact condition that exposed the bug.
+ backends._active_backend = None
+
+ app = Dash(__name__, backend="fastapi")
+
+ # Would raise AttributeError (no server.config) under the Flask fallback.
+ add_public_routes(app, ["/login"])
+
+ # Stored on (and read back from) the FastAPI server.state, not server.config.
+ assert get_public_routes(app) is app.server.state.PUBLIC_ROUTES
+ assert any(
+ rule.rule == "/login" for rule in get_public_routes(app).map.iter_rules()
+ )
+
+ # The callbacks reader returns its default through the same backend path.
+ assert get_public_callbacks(app) == []
diff --git a/tests/unit/test_websocket_auth.py b/tests/unit/test_websocket_auth.py
index 6817c2e..213fb59 100644
--- a/tests/unit/test_websocket_auth.py
+++ b/tests/unit/test_websocket_auth.py
@@ -1,5 +1,6 @@
"""Unit tests for the WebSocket auth primitives."""
+import asyncio
import contextvars
from concurrent.futures import ThreadPoolExecutor
@@ -130,28 +131,130 @@ class _Server:
"""Weak-referenceable stand-in for an app.server (the registry key)."""
+class _DashAppStub:
+ """Minimal Dash-app double exposing the idempotent ``_setup_server`` the hook
+ calls to migrate ``callback_map`` lazily on the first WS ``callback_request``.
+ """
+
+ def __init__(self) -> None:
+ self.setup_calls = 0
+
+ def _setup_server(self) -> None:
+ self.setup_calls += 1
+
+
class _RecordingAuth:
"""Auth double that records the calls the hook routes to it."""
def __init__(self) -> None:
self.calls: list = []
+ self.app = _DashAppStub()
def authorize_ws(self, payload, user) -> bool:
self.calls.append((payload, user))
return True
+# --------------------------------------------------------------------------- #
+# Connection registry: the websocket_connect hook tracks sockets so a later
+# login can retire the browser's stale anonymous one. It must not accumulate --
+# authenticated sockets are never retired (so never tracked), and a browser's
+# reconnects (one SharedWorker socket per browser) must not pile up.
+# --------------------------------------------------------------------------- #
+class _IdentityBackend:
+ """Backend double whose ``ws_identity`` returns a fixed (server, user)."""
+
+ def __init__(self, user) -> None:
+ self._user = user
+
+ def ws_identity(self, _ws):
+ return object(), self._user
+
+
+class _FakeWS:
+ """Minimal socket double exposing the cookies and an awaitable close."""
+
+ def __init__(self, client_id) -> None:
+ self.cookies = {"dac_client": client_id} if client_id else {}
+ self.close_code = None
+
+ async def close(self, code=None) -> None:
+ self.close_code = code
+
+
+def _tracked_count(client_id) -> int:
+ """Sockets tracked for ``client_id``, agnostic to the registry's shape."""
+ from dash_auth_async.websocket_auth import _WS_BY_CLIENT
+
+ entry = _WS_BY_CLIENT.get(client_id)
+ if entry is None:
+ return 0
+ return len(entry) if isinstance(entry, (set, list, dict)) else 1
+
+
+def _run_connect_hook(ws) -> None:
+ from dash_auth_async.websocket_auth import _ws_connect_hook
+
+ async def _run():
+ _ws_connect_hook(ws)
+
+ asyncio.run(_run())
+
+
+def _use_backend(monkeypatch, user) -> None:
+ """Point the connect hook at a backend double whose identity returns ``user``."""
+ monkeypatch.setattr(
+ "dash_auth_async.websocket_auth.get_active_backend",
+ lambda: _IdentityBackend(user),
+ )
+
+
+def test_authenticated_handshake_is_not_tracked(monkeypatch):
+ """A socket that handshakes already authenticated is never retired, so the
+ registry must not hold it (else authenticated sockets leak unboundedly).
+ """
+ from dash_auth_async.websocket_auth import _WS_BY_CLIENT
+
+ _WS_BY_CLIENT.clear()
+ _use_backend(monkeypatch, {"email": "a@b.c", "groups": []})
+
+ _run_connect_hook(_FakeWS("client-1"))
+
+ assert _tracked_count("client-1") == 0
+
+
+def test_reconnect_does_not_accumulate_entries(monkeypatch):
+ """A browser has one SharedWorker socket; a reconnect replaces the prior
+ anonymous entry rather than stacking, so tracking stays at one per browser.
+ """
+ from dash_auth_async.websocket_auth import _WS_BY_CLIENT
+
+ _WS_BY_CLIENT.clear()
+ _use_backend(monkeypatch, None) # anonymous handshakes
+
+ _run_connect_hook(_FakeWS("client-2"))
+ _run_connect_hook(_FakeWS("client-2"))
+
+ assert _tracked_count("client-2") == 1
+
+
def test_ws_hook_resolves_auth_for_the_current_app(monkeypatch):
"""With two dash-auth-async apps in the process, the hook consults only the
Auth registered for ``quart.current_app`` -- not some other app's Auth.
"""
quart = pytest.importorskip("quart")
+ from dash_auth_async.backends import QuartBackend, set_active_backend
from dash_auth_async.websocket_auth import (
_AUTH_BY_SERVER,
_WS_AUTH_USER,
_ws_message_hook,
)
+ # The hook resolves identity via the active backend; this is the Quart path
+ # (ws_identity reads the quart.current_app/quart.session monkeypatched below).
+ # The autouse reset_active_backend fixture clears it after the test.
+ set_active_backend(QuartBackend())
+
server_a, server_b = _Server(), _Server()
auth_a, auth_b = _RecordingAuth(), _RecordingAuth()
_AUTH_BY_SERVER[server_a] = auth_a
@@ -174,6 +277,10 @@ def _get_current_object(self):
# Only app B's Auth was consulted, with app B's session user.
assert auth_b.calls == [({"output": "x.children"}, user)]
assert auth_a.calls == []
+ # The hook migrated only app B's callback_map (lazy, idempotent), and
+ # never touched app A's.
+ assert auth_b.app.setup_calls == 1
+ assert auth_a.app.setup_calls == 0
# The resolved user was stashed for the worker.
assert _WS_AUTH_USER.get() == user
finally:
diff --git a/uv.lock b/uv.lock
index 9e56cdf..d90e9a7 100644
--- a/uv.lock
+++ b/uv.lock
@@ -596,7 +596,7 @@ testing = [
[[package]]
name = "dash-auth-async"
-version = "1.2.1"
+version = "1.3.0"
source = { editable = "." }
dependencies = [
{ name = "authlib" },
@@ -622,13 +622,19 @@ dev = [
{ name = "asgiref" },
{ name = "creosote" },
{ name = "dash", extra = ["testing"] },
+ { name = "fastapi" },
+ { name = "httpx" },
+ { name = "itsdangerous" },
{ name = "pre-commit" },
{ name = "pytest" },
{ name = "pytest-cov" },
+ { name = "python-multipart" },
{ name = "requests" },
{ name = "ruff" },
{ name = "setuptools" },
+ { name = "starlette" },
{ name = "ty" },
+ { name = "uvicorn" },
{ name = "websocket-client" },
]
@@ -650,13 +656,19 @@ dev = [
{ name = "asgiref", specifier = ">=3.11.1" },
{ name = "creosote", specifier = ">=5.2.0" },
{ name = "dash", extras = ["testing"], specifier = ">=4.2.0" },
+ { name = "fastapi" },
+ { name = "httpx", specifier = ">=0.23.0" },
+ { name = "itsdangerous" },
{ name = "pre-commit", specifier = ">=3.5.0" },
{ name = "pytest", specifier = ">=8.3.5" },
{ name = "pytest-cov", specifier = ">=7.1.0" },
+ { name = "python-multipart" },
{ name = "requests", extras = ["security"], specifier = ">=2.34.2" },
{ name = "ruff", specifier = ">=0.15.16" },
{ name = "setuptools", specifier = ">=79.0.1" },
+ { name = "starlette" },
{ name = "ty", specifier = ">=0.0.46" },
+ { name = "uvicorn" },
{ name = "websocket-client", specifier = ">=1.9.0" },
]
@@ -1767,6 +1779,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/0b/d7/1959b9648791274998a9c3526f6d0ec8fd2233e4d4acce81bbae76b44b2a/python_dotenv-1.2.2-py3-none-any.whl", hash = "sha256:1d8214789a24de455a8b8bd8ae6fe3c6b69a5e3d64aa8a8e5d68e694bbcb285a", size = 22101, upload-time = "2026-03-01T16:00:25.09Z" },
]
+[[package]]
+name = "python-multipart"
+version = "0.0.32"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/5b/42/55c32bb9b12693c092ad250a0e82edb5b31ddeda6eb772de5f308b3804ad/python_multipart-0.0.32.tar.gz", hash = "sha256:be54b7f3fa167bb83e4fcd936b887b708f4e57fe75911c02aebf53efaf8d938e", size = 46881, upload-time = "2026-06-04T16:18:58.647Z" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/e1/04/e8135ebd1ad02c56ec633277529b2602ff99ff634be76cdba5744cf554fd/python_multipart-0.0.32-py3-none-any.whl", hash = "sha256:ff6d3f776f16878c894e52e107296ffc890e913c611b1a4ec6c44e2821fe2e23", size = 30042, upload-time = "2026-06-04T16:18:57.319Z" },
+]
+
[[package]]
name = "pyyaml"
version = "6.0.3"