diff --git a/src/auth/__init__.py b/src/auth/__init__.py index 00e8c39..c395994 100644 --- a/src/auth/__init__.py +++ b/src/auth/__init__.py @@ -87,6 +87,7 @@ async def current_superuser( UserRead, UserUpdate, ) +from .service import AuthenticationService, get_auth_service # noqa: E402 __all__ = [ "fastapi_users", @@ -106,4 +107,6 @@ async def current_superuser( "UserUpdate", "TokenResponse", "MessageResponse", + "AuthenticationService", + "get_auth_service", ] diff --git a/src/auth/service.py b/src/auth/service.py new file mode 100644 index 0000000..777ec19 --- /dev/null +++ b/src/auth/service.py @@ -0,0 +1,140 @@ +from fastapi import Depends, Request +from fastapi.security import OAuth2PasswordRequestForm +from fastapi_users.authentication import JWTStrategy + +from src.audit import AuditService, get_audit_service +from src.audit.schemas import AuditAction, AuditResult +from src.http.utils import extract_client_info +from src.auth.backend import ( + RefreshTokenManager, + get_jwt_strategy, + get_refresh_token_manager, +) +from src.auth.manager import UserManager, get_user_manager +from src.auth.schemas import TokenResponse +from src.shared.errors import ErrorCode +from src.exceptions import BusinessException + + +class AuthenticationService: + def __init__( + self, + user_manager: UserManager, + strategy: JWTStrategy, + refresh_manager: RefreshTokenManager, + audit_service: AuditService, + ): + self.user_manager = user_manager + self.strategy = strategy + self.refresh_manager = refresh_manager + self.audit_service = audit_service + + async def login( + self, credentials: OAuth2PasswordRequestForm, request: Request + ) -> TokenResponse: + user_agent, ip = extract_client_info(request) + + user = await self.user_manager.authenticate(credentials) + if not user or not user.is_active: + await self.audit_service.log( + action=AuditAction.LOGIN, + result=AuditResult.FAILURE, + user_agent=user_agent, + ip=ip, + extra={"username": credentials.username}, + ) + raise BusinessException( + ErrorCode.AUTH_INVALID_CREDENTIALS, "Invalid credentials" + ) + + access_token = await self.strategy.write_token(user) + refresh_token = await self.refresh_manager.create_refresh_token( + user.id, user_agent + ) + + await self.audit_service.log( + action=AuditAction.LOGIN, + result=AuditResult.SUCCESS, + actor_id=user.id, + user_agent=user_agent, + ip=ip, + ) + + return TokenResponse( + access_token=access_token, + refresh_token=refresh_token, + token_type="Bearer", + ) + + async def refresh(self, refresh_token: str, request: Request) -> TokenResponse: + user_agent, ip = extract_client_info(request) + + user_id = await self.refresh_manager.verify_refresh_token(refresh_token) + + if not user_id: + await self.audit_service.log( + action=AuditAction.REFRESH, + result=AuditResult.FAILURE, + user_agent=user_agent, + ip=ip, + ) + raise BusinessException(ErrorCode.AUTH_TOKEN_INVALID, "Invalid token") + + user = await self.user_manager.get(user_id) + if not user or not user.is_active: + await self.audit_service.log( + action=AuditAction.REFRESH, + result=AuditResult.FAILURE, + actor_id=user_id, + user_agent=user_agent, + ip=ip, + ) + raise BusinessException(ErrorCode.USER_INACTIVE, "User inactive") + + access_token = await self.strategy.write_token(user) + new_refresh_token = await self.refresh_manager.create_refresh_token( + user.id, user_agent + ) + await self.refresh_manager.revoke_token(refresh_token) + + await self.audit_service.log( + action=AuditAction.REFRESH, + result=AuditResult.SUCCESS, + actor_id=user.id, + user_agent=user_agent, + ip=ip, + ) + + return TokenResponse(access_token=access_token, refresh_token=new_refresh_token) + + async def logout(self, refresh_token: str, request: Request) -> None: + user_agent, ip = extract_client_info(request) + user_id = await self.refresh_manager.verify_refresh_token(refresh_token) + + if not user_id: + await self.audit_service.log( + action=AuditAction.LOGOUT, + result=AuditResult.FAILURE, + user_agent=user_agent, + ip=ip, + ) + raise BusinessException(ErrorCode.AUTH_TOKEN_INVALID, "Invalid token") + + await self.refresh_manager.revoke_token(refresh_token) + + await self.audit_service.log( + action=AuditAction.LOGOUT, + result=AuditResult.SUCCESS, + actor_id=user_id, + user_agent=user_agent, + ip=ip, + ) + + +async def get_auth_service( + user_manager: UserManager = Depends(get_user_manager), + strategy: JWTStrategy = Depends(get_jwt_strategy), + refresh_manager: RefreshTokenManager = Depends(get_refresh_token_manager), + audit_service: AuditService = Depends(get_audit_service), +) -> AuthenticationService: + return AuthenticationService(user_manager, strategy, refresh_manager, audit_service) diff --git a/src/http/routers/auth.py b/src/http/routers/auth.py index e38291e..12e5841 100644 --- a/src/http/routers/auth.py +++ b/src/http/routers/auth.py @@ -1,16 +1,8 @@ from fastapi import APIRouter, Depends, Request from fastapi.security import OAuth2PasswordRequestForm -from src.audit import AuditService, get_audit_service -from src.audit.schemas import AuditAction, AuditResult -from src.http.utils import extract_client_info from src.auth import fastapi_users -from src.auth.backend import ( - RefreshTokenManager, - get_jwt_strategy, - get_refresh_token_manager, -) -from src.auth.manager import UserManager, get_user_manager +from src.auth.service import AuthenticationService, get_auth_service from src.auth.schemas import ( MessageResponse, TokenResponse, @@ -18,8 +10,6 @@ UserRead, UserUpdate, ) -from src.shared.errors import ErrorCode -from src.exceptions import BusinessException router = APIRouter() @@ -52,120 +42,25 @@ async def login( request: Request, credentials: OAuth2PasswordRequestForm = Depends(), - user_manager: UserManager = Depends(get_user_manager), - strategy=Depends(get_jwt_strategy), - refresh_manager: RefreshTokenManager = Depends(get_refresh_token_manager), - audit_service: AuditService = Depends(get_audit_service), + auth_service: AuthenticationService = Depends(get_auth_service), ) -> TokenResponse: - user_agent, ip = extract_client_info(request) - - user = await user_manager.authenticate(credentials) - if not user or not user.is_active: - await audit_service.log( - action=AuditAction.LOGIN, - result=AuditResult.FAILURE, - user_agent=user_agent, - ip=ip, - extra={"username": credentials.username}, - ) - raise BusinessException( - ErrorCode.AUTH_INVALID_CREDENTIALS, "Invalid credentials" - ) - - access_token = await strategy.write_token(user) - - refresh_token = await refresh_manager.create_refresh_token(user.id, user_agent) - - await audit_service.log( - action=AuditAction.LOGIN, - result=AuditResult.SUCCESS, - actor_id=user.id, - user_agent=user_agent, - ip=ip, - ) - - return TokenResponse( - access_token=access_token, - refresh_token=refresh_token, - token_type="Bearer", - ) + return await auth_service.login(credentials, request) @router.post("/jwt/refresh") async def refresh_jwt( request: Request, refresh_token: str, - user_manager: UserManager = Depends(get_user_manager), - strategy=Depends(get_jwt_strategy), - refresh_manager: RefreshTokenManager = Depends(get_refresh_token_manager), - audit_service: AuditService = Depends(get_audit_service), + auth_service: AuthenticationService = Depends(get_auth_service), ) -> TokenResponse: - user_agent, ip = extract_client_info(request) - - user_id = await refresh_manager.verify_refresh_token(refresh_token) - - if not user_id: - await audit_service.log( - action=AuditAction.REFRESH, - result=AuditResult.FAILURE, - user_agent=user_agent, - ip=ip, - ) - raise BusinessException(ErrorCode.AUTH_TOKEN_INVALID, "Invalid token") - - user = await user_manager.get(user_id) - if not user or not user.is_active: - await audit_service.log( - action=AuditAction.REFRESH, - result=AuditResult.FAILURE, - actor_id=user_id, - user_agent=user_agent, - ip=ip, - ) - raise BusinessException(ErrorCode.USER_INACTIVE, "User inactive") - - access_token = await strategy.write_token(user) - new_refresh_token = await refresh_manager.create_refresh_token(user.id, user_agent) - await refresh_manager.revoke_token(refresh_token) - - await audit_service.log( - action=AuditAction.REFRESH, - result=AuditResult.SUCCESS, - actor_id=user.id, - user_agent=user_agent, - ip=ip, - ) - - return TokenResponse(access_token=access_token, refresh_token=new_refresh_token) + return await auth_service.refresh(refresh_token, request) @router.post("/jwt/logout") async def logout( request: Request, refresh_token: str, - refresh_manager: RefreshTokenManager = Depends(get_refresh_token_manager), - audit_service: AuditService = Depends(get_audit_service), + auth_service: AuthenticationService = Depends(get_auth_service), ) -> MessageResponse: - user_agent, ip = extract_client_info(request) - user_id = await refresh_manager.verify_refresh_token(refresh_token) - - if not user_id: - await audit_service.log( - action=AuditAction.LOGOUT, - result=AuditResult.FAILURE, - user_agent=user_agent, - ip=ip, - ) - raise BusinessException(ErrorCode.AUTH_TOKEN_INVALID, "Invalid token") - - await refresh_manager.revoke_token(refresh_token) - - await audit_service.log( - action=AuditAction.LOGOUT, - result=AuditResult.SUCCESS, - actor_id=user_id, - user_agent=user_agent, - ip=ip, - ) - + await auth_service.logout(refresh_token, request) return MessageResponse(detail="Successfully logged out") diff --git a/tests/auth/test_service.py b/tests/auth/test_service.py new file mode 100644 index 0000000..2a086a2 --- /dev/null +++ b/tests/auth/test_service.py @@ -0,0 +1,288 @@ +import pytest +from unittest.mock import AsyncMock, MagicMock +from fastapi import Request +from fastapi.security import OAuth2PasswordRequestForm + +from src.auth.service import AuthenticationService +from src.auth.schemas import TokenResponse +from src.audit.schemas import AuditAction, AuditResult +from src.shared.errors import ErrorCode +from src.exceptions import BusinessException + + +@pytest.fixture +def mock_user_manager(): + manager = AsyncMock() + return manager + + +@pytest.fixture +def mock_jwt_strategy(): + strategy = AsyncMock() + return strategy + + +@pytest.fixture +def mock_refresh_manager(): + manager = AsyncMock() + return manager + + +@pytest.fixture +def mock_audit_service(): + service = AsyncMock() + return service + + +@pytest.fixture +def mock_request(): + request = MagicMock(spec=Request) + request.headers = {"user-agent": "test-agent"} + request.client = MagicMock() + request.client.host = "127.0.0.1" + return request + + +@pytest.fixture +def auth_service( + mock_user_manager, mock_jwt_strategy, mock_refresh_manager, mock_audit_service +): + return AuthenticationService( + user_manager=mock_user_manager, + strategy=mock_jwt_strategy, + refresh_manager=mock_refresh_manager, + audit_service=mock_audit_service, + ) + + +@pytest.fixture +def mock_user(): + user = MagicMock() + user.id = 1 + user.is_active = True + return user + + +@pytest.fixture +def mock_credentials(): + credentials = MagicMock(spec=OAuth2PasswordRequestForm) + credentials.username = "testuser" + credentials.password = "testpass" + return credentials + + +@pytest.mark.asyncio +async def test_login_success( + auth_service, + mock_user_manager, + mock_jwt_strategy, + mock_refresh_manager, + mock_audit_service, + mock_request, + mock_credentials, + mock_user, +): + mock_user_manager.authenticate.return_value = mock_user + mock_jwt_strategy.write_token.return_value = "access_token_123" + mock_refresh_manager.create_refresh_token.return_value = "refresh_token_123" + + result = await auth_service.login(mock_credentials, mock_request) + + assert isinstance(result, TokenResponse) + assert result.access_token == "access_token_123" + assert result.refresh_token == "refresh_token_123" + assert result.token_type == "Bearer" + + mock_user_manager.authenticate.assert_called_once_with(mock_credentials) + mock_jwt_strategy.write_token.assert_called_once_with(mock_user) + mock_refresh_manager.create_refresh_token.assert_called_once_with(1, "test-agent") + + assert mock_audit_service.log.call_count == 1 + call_args = mock_audit_service.log.call_args.kwargs + assert call_args["action"] == AuditAction.LOGIN + assert call_args["result"] == AuditResult.SUCCESS + assert call_args["actor_id"] == 1 + + +@pytest.mark.asyncio +async def test_login_invalid_credentials( + auth_service, + mock_user_manager, + mock_audit_service, + mock_request, + mock_credentials, +): + mock_user_manager.authenticate.return_value = None + + with pytest.raises(BusinessException) as exc_info: + await auth_service.login(mock_credentials, mock_request) + + assert exc_info.value.code == ErrorCode.AUTH_INVALID_CREDENTIALS + + assert mock_audit_service.log.call_count == 1 + call_args = mock_audit_service.log.call_args.kwargs + assert call_args["action"] == AuditAction.LOGIN + assert call_args["result"] == AuditResult.FAILURE + assert call_args["extra"]["username"] == "testuser" + + +@pytest.mark.asyncio +async def test_login_inactive_user( + auth_service, + mock_user_manager, + mock_audit_service, + mock_request, + mock_credentials, + mock_user, +): + mock_user.is_active = False + mock_user_manager.authenticate.return_value = mock_user + + with pytest.raises(BusinessException) as exc_info: + await auth_service.login(mock_credentials, mock_request) + + assert exc_info.value.code == ErrorCode.AUTH_INVALID_CREDENTIALS + + assert mock_audit_service.log.call_count == 1 + call_args = mock_audit_service.log.call_args.kwargs + assert call_args["action"] == AuditAction.LOGIN + assert call_args["result"] == AuditResult.FAILURE + + +@pytest.mark.asyncio +async def test_refresh_success( + auth_service, + mock_user_manager, + mock_jwt_strategy, + mock_refresh_manager, + mock_audit_service, + mock_request, + mock_user, +): + mock_refresh_manager.verify_refresh_token.return_value = 1 + mock_user_manager.get.return_value = mock_user + mock_jwt_strategy.write_token.return_value = "new_access_token" + mock_refresh_manager.create_refresh_token.return_value = "new_refresh_token" + + result = await auth_service.refresh("old_refresh_token", mock_request) + + assert isinstance(result, TokenResponse) + assert result.access_token == "new_access_token" + assert result.refresh_token == "new_refresh_token" + + mock_refresh_manager.verify_refresh_token.assert_called_once_with( + "old_refresh_token" + ) + mock_user_manager.get.assert_called_once_with(1) + mock_jwt_strategy.write_token.assert_called_once_with(mock_user) + mock_refresh_manager.revoke_token.assert_called_once_with("old_refresh_token") + + assert mock_audit_service.log.call_count == 1 + call_args = mock_audit_service.log.call_args.kwargs + assert call_args["action"] == AuditAction.REFRESH + assert call_args["result"] == AuditResult.SUCCESS + + +@pytest.mark.asyncio +async def test_refresh_invalid_token( + auth_service, + mock_refresh_manager, + mock_audit_service, + mock_request, +): + mock_refresh_manager.verify_refresh_token.return_value = None + + with pytest.raises(BusinessException) as exc_info: + await auth_service.refresh("invalid_token", mock_request) + + assert exc_info.value.code == ErrorCode.AUTH_TOKEN_INVALID + + assert mock_audit_service.log.call_count == 1 + call_args = mock_audit_service.log.call_args.kwargs + assert call_args["action"] == AuditAction.REFRESH + assert call_args["result"] == AuditResult.FAILURE + + +@pytest.mark.asyncio +async def test_refresh_inactive_user( + auth_service, + mock_user_manager, + mock_refresh_manager, + mock_audit_service, + mock_request, + mock_user, +): + mock_refresh_manager.verify_refresh_token.return_value = 1 + mock_user.is_active = False + mock_user_manager.get.return_value = mock_user + + with pytest.raises(BusinessException) as exc_info: + await auth_service.refresh("valid_token", mock_request) + + assert exc_info.value.code == ErrorCode.USER_INACTIVE + + assert mock_audit_service.log.call_count == 1 + call_args = mock_audit_service.log.call_args.kwargs + assert call_args["action"] == AuditAction.REFRESH + assert call_args["result"] == AuditResult.FAILURE + assert call_args["actor_id"] == 1 + + +@pytest.mark.asyncio +async def test_refresh_user_not_found( + auth_service, + mock_user_manager, + mock_refresh_manager, + mock_audit_service, + mock_request, +): + mock_refresh_manager.verify_refresh_token.return_value = 1 + mock_user_manager.get.return_value = None + + with pytest.raises(BusinessException) as exc_info: + await auth_service.refresh("valid_token", mock_request) + + assert exc_info.value.code == ErrorCode.USER_INACTIVE + + +@pytest.mark.asyncio +async def test_logout_success( + auth_service, + mock_refresh_manager, + mock_audit_service, + mock_request, +): + mock_refresh_manager.verify_refresh_token.return_value = 1 + + result = await auth_service.logout("refresh_token", mock_request) + + assert result is None + + mock_refresh_manager.verify_refresh_token.assert_called_once_with("refresh_token") + mock_refresh_manager.revoke_token.assert_called_once_with("refresh_token") + + assert mock_audit_service.log.call_count == 1 + call_args = mock_audit_service.log.call_args.kwargs + assert call_args["action"] == AuditAction.LOGOUT + assert call_args["result"] == AuditResult.SUCCESS + assert call_args["actor_id"] == 1 + + +@pytest.mark.asyncio +async def test_logout_invalid_token( + auth_service, + mock_refresh_manager, + mock_audit_service, + mock_request, +): + mock_refresh_manager.verify_refresh_token.return_value = None + + with pytest.raises(BusinessException) as exc_info: + await auth_service.logout("invalid_token", mock_request) + + assert exc_info.value.code == ErrorCode.AUTH_TOKEN_INVALID + + assert mock_audit_service.log.call_count == 1 + call_args = mock_audit_service.log.call_args.kwargs + assert call_args["action"] == AuditAction.LOGOUT + assert call_args["result"] == AuditResult.FAILURE