diff --git a/src/api/auth.py b/src/api/auth.py index f8b2b53..d2247cc 100644 --- a/src/api/auth.py +++ b/src/api/auth.py @@ -12,7 +12,6 @@ ) from src.auth.manager import UserManager, get_user_manager from src.auth.schemas import ( - AccessTokenResponse, MessageResponse, TokenResponse, UserCreate, @@ -100,7 +99,7 @@ async def refresh_jwt( strategy=Depends(get_jwt_strategy), refresh_manager: RefreshTokenManager = Depends(get_refresh_token_manager), audit_service: AuditService = Depends(get_audit_service), -) -> AccessTokenResponse: +) -> TokenResponse: user_agent, ip = extract_client_info(request) user_id = await refresh_manager.verify_refresh_token(refresh_token) @@ -126,6 +125,8 @@ async def refresh_jwt( raise BusinessException(ErrorCode.USER_INACTIVE, "User inactive") access_token = await strategy.write_token(user) + new_refresh_token = await refresh_manager.create_refresh_token(user.id, user_agent) + await refresh_manager.revoke_token(refresh_token) await audit_service.log( action=AuditAction.REFRESH, @@ -135,7 +136,7 @@ async def refresh_jwt( ip=ip, ) - return AccessTokenResponse(access_token=access_token, token_type="Bearer") + return TokenResponse(access_token=access_token, refresh_token=new_refresh_token) @router.post("/jwt/logout") diff --git a/src/auth/__init__.py b/src/auth/__init__.py index 8fc62a6..20f2201 100644 --- a/src/auth/__init__.py +++ b/src/auth/__init__.py @@ -81,7 +81,6 @@ async def current_superuser( require_roles, ) from .schemas import ( # noqa: E402 - AccessTokenResponse, MessageResponse, TokenResponse, UserCreate, @@ -106,6 +105,5 @@ async def current_superuser( "UserCreate", "UserUpdate", "TokenResponse", - "AccessTokenResponse", "MessageResponse", ] diff --git a/src/auth/schemas.py b/src/auth/schemas.py index 26b1e3e..3ab9211 100644 --- a/src/auth/schemas.py +++ b/src/auth/schemas.py @@ -22,10 +22,5 @@ class TokenResponse(BaseModel): token_type: Literal["Bearer"] = "Bearer" -class AccessTokenResponse(BaseModel): - access_token: str - token_type: Literal["Bearer"] = "Bearer" - - class MessageResponse(BaseModel): detail: str diff --git a/tests/integration/test_auth_router.py b/tests/integration/test_auth_router.py index 8a3032b..89fe0e7 100644 --- a/tests/integration/test_auth_router.py +++ b/tests/integration/test_auth_router.py @@ -57,7 +57,30 @@ async def test_refresh_token_success(test_client, test_user): data = response.json() assert "access_token" in data assert data["token_type"] == "Bearer" - assert "refresh_token" not in data + assert "refresh_token" in data + + +async def test_refresh_token_rotation(test_client, test_user): + login_response = await test_client.post( + "/auth/jwt/login", + data={"username": "test@example.com", "password": "testpassword123"}, + ) + old_refresh_token = login_response.json()["refresh_token"] + + refresh_response = await test_client.post( + "/auth/jwt/refresh", params={"refresh_token": old_refresh_token} + ) + new_refresh_token = refresh_response.json()["refresh_token"] + + old_token_response = await test_client.post( + "/auth/jwt/refresh", params={"refresh_token": old_refresh_token} + ) + assert old_token_response.status_code == 401 + + new_token_response = await test_client.post( + "/auth/jwt/refresh", params={"refresh_token": new_refresh_token} + ) + assert new_token_response.status_code == 200 async def test_refresh_token_invalid(test_client):