diff --git a/src/auth/__init__.py b/src/auth/__init__.py index c395994..4d16f31 100644 --- a/src/auth/__init__.py +++ b/src/auth/__init__.py @@ -1,93 +1,22 @@ -from fastapi import Depends, Request -from fastapi_users import FastAPIUsers - -from src.audit import AuditService, get_audit_service -from src.audit.dependencies import generate_request_id -from src.audit.schemas import AuditAction, AuditResult -from src.http.utils import extract_client_info -from src.auth.models import User - -from .backend import auth_backend -from .manager import get_user_manager - -fastapi_users = FastAPIUsers[User, int]( - get_user_manager=get_user_manager, - auth_backends=[auth_backend], -) - -_current_user_base = fastapi_users.current_user(active=True) -_current_superuser_base = fastapi_users.current_user(active=True, superuser=True) - - -async def current_user( - request: Request, - user: User = Depends(_current_user_base), - audit_service: AuditService = Depends(get_audit_service), - request_id: str = Depends(generate_request_id), -) -> User: - user_agent, ip = extract_client_info(request) - - await audit_service.log( - action=AuditAction.PROTECTED_RESOURCE_ACCESS, - result=AuditResult.SUCCESS, - actor_id=user.id, - request_id=request_id, - user_agent=user_agent, - ip=ip, - extra={ - "method": request.method, - "path": request.url.path, - }, - ) - - return user - - -async def current_superuser( - request: Request, - user: User = Depends(_current_superuser_base), - audit_service: AuditService = Depends(get_audit_service), - request_id: str = Depends(generate_request_id), -) -> User: - user_agent, ip = extract_client_info(request) - - await audit_service.log( - action=AuditAction.PROTECTED_RESOURCE_ACCESS, - result=AuditResult.SUCCESS, - actor_id=user.id, - request_id=request_id, - user_agent=user_agent, - ip=ip, - extra={ - "method": request.method, - "path": request.url.path, - "superuser": True, - }, - ) - - return user - - -from .models import ( # noqa: E402 - OAuthAccount, +from .dependencies import current_superuser, current_user, fastapi_users +from .models import OAuthAccount, User +from .rbac import ( Permission, Role, RolePermission, UserRole, -) -from .rbac import ( # noqa: E402 owner_or_perm, require_permissions, require_roles, ) -from .schemas import ( # noqa: E402 +from .schemas import ( MessageResponse, TokenResponse, UserCreate, UserRead, UserUpdate, ) -from .service import AuthenticationService, get_auth_service # noqa: E402 +from .service import AuthenticationService, get_auth_service __all__ = [ "fastapi_users", diff --git a/src/auth/dependencies.py b/src/auth/dependencies.py new file mode 100644 index 0000000..81d8983 --- /dev/null +++ b/src/auth/dependencies.py @@ -0,0 +1,68 @@ +from fastapi import Depends, Request +from fastapi_users import FastAPIUsers + +from src.audit import AuditService, get_audit_service +from src.audit.dependencies import generate_request_id +from src.audit.schemas import AuditAction, AuditResult +from src.auth.models import User +from src.http.utils import extract_client_info + +from .backend import auth_backend +from .manager import get_user_manager + +fastapi_users = FastAPIUsers[User, int]( + get_user_manager=get_user_manager, + auth_backends=[auth_backend], +) + +_current_user_base = fastapi_users.current_user(active=True) +_current_superuser_base = fastapi_users.current_user(active=True, superuser=True) + + +async def current_user( + request: Request, + user: User = Depends(_current_user_base), + audit_service: AuditService = Depends(get_audit_service), + request_id: str = Depends(generate_request_id), +) -> User: + user_agent, ip = extract_client_info(request) + + await audit_service.log( + action=AuditAction.PROTECTED_RESOURCE_ACCESS, + result=AuditResult.SUCCESS, + actor_id=user.id, + request_id=request_id, + user_agent=user_agent, + ip=ip, + extra={ + "method": request.method, + "path": request.url.path, + }, + ) + + return user + + +async def current_superuser( + request: Request, + user: User = Depends(_current_superuser_base), + audit_service: AuditService = Depends(get_audit_service), + request_id: str = Depends(generate_request_id), +) -> User: + user_agent, ip = extract_client_info(request) + + await audit_service.log( + action=AuditAction.PROTECTED_RESOURCE_ACCESS, + result=AuditResult.SUCCESS, + actor_id=user.id, + request_id=request_id, + user_agent=user_agent, + ip=ip, + extra={ + "method": request.method, + "path": request.url.path, + "superuser": True, + }, + ) + + return user diff --git a/src/auth/models.py b/src/auth/models.py index 8b7c301..9af115d 100644 --- a/src/auth/models.py +++ b/src/auth/models.py @@ -1,5 +1,3 @@ -from datetime import UTC, datetime - from sqlmodel import Field, SQLModel from src.mixins import TimestampMixin @@ -30,42 +28,3 @@ class OAuthAccount(SQLModel, table=True): refresh_token: str | None = Field(default=None, max_length=1024, nullable=True) account_id: str = Field(max_length=320, nullable=False, index=True) account_email: str = Field(max_length=320, nullable=False) - - -class UserRole(SQLModel, table=True): - __tablename__ = "user_roles" - - user_id: int = Field(foreign_key="users.id", primary_key=True) - role_id: int = Field(foreign_key="roles.id", primary_key=True) - assigned_at: datetime = Field( - default_factory=lambda: datetime.now(UTC), - ) - - -class RolePermission(SQLModel, table=True): - __tablename__ = "role_permissions" - - role_id: int = Field(foreign_key="roles.id", primary_key=True) - permission_id: int = Field(foreign_key="permissions.id", primary_key=True) - assigned_at: datetime = Field( - default_factory=lambda: datetime.now(UTC), - ) - - -class Role(SQLModel, TimestampMixin, table=True): - __tablename__ = "roles" - - id: int | None = Field(default=None, primary_key=True) - name: str = Field(unique=True, index=True, max_length=100) - description: str | None = Field(default=None, max_length=255) - is_system: bool = Field(default=False) - - -class Permission(SQLModel, TimestampMixin, table=True): - __tablename__ = "permissions" - - id: int | None = Field(default=None, primary_key=True) - code: str = Field(unique=True, index=True, max_length=150) - name: str = Field(max_length=100) - description: str | None = Field(default=None, max_length=255) - module: str = Field(max_length=100) diff --git a/src/auth/rbac/__init__.py b/src/auth/rbac/__init__.py new file mode 100644 index 0000000..5664e8c --- /dev/null +++ b/src/auth/rbac/__init__.py @@ -0,0 +1,24 @@ +"""RBAC (Role-Based Access Control) submodule for authorization.""" + +from .dependencies import ( + get_permission_service, + owner_or_perm, + require_permissions, + require_roles, +) +from .models import Permission, Role, RolePermission, UserRole +from .repository import PermissionRepository +from .service import PermissionService + +__all__ = [ + "PermissionRepository", + "PermissionService", + "get_permission_service", + "require_permissions", + "require_roles", + "owner_or_perm", + "Role", + "Permission", + "UserRole", + "RolePermission", +] diff --git a/src/auth/rbac.py b/src/auth/rbac/dependencies.py similarity index 62% rename from src/auth/rbac.py rename to src/auth/rbac/dependencies.py index be781b2..abe09a4 100644 --- a/src/auth/rbac.py +++ b/src/auth/rbac/dependencies.py @@ -2,147 +2,21 @@ from typing import Literal from fastapi import Depends, Request -from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from src.audit import AuditService, generate_request_id, get_audit_service from src.audit.schemas import AuditAction, AuditResult -from src.http.utils import extract_client_info -from src.auth.models import Permission, Role, RolePermission, User, UserRole +from src.auth.dependencies import current_user from src.auth.exceptions import ( InsufficientPermissionException, InsufficientRoleException, ) +from src.auth.models import User +from src.http.utils import extract_client_info from src.session import get_session -from . import current_user - - -class PermissionRepository: - def __init__(self, session: AsyncSession): - self.session = session - - async def get_user_permissions(self, user_id: int) -> set[str]: - stmt = ( - select(Permission.code) - .distinct() - .join(RolePermission, Permission.id == RolePermission.permission_id) - .join(UserRole, RolePermission.role_id == UserRole.role_id) - .where(UserRole.user_id == user_id) - ) - result = await self.session.execute(stmt) - return set(result.scalars().all()) - - async def get_user_roles(self, user_id: int) -> set[str]: - stmt = ( - select(Role.name) - .distinct() - .join(UserRole, Role.id == UserRole.role_id) - .where(UserRole.user_id == user_id) - ) - result = await self.session.execute(stmt) - return set(result.scalars().all()) - - -class PermissionService: - def __init__(self, repository: PermissionRepository): - self.repository = repository - - @staticmethod - def _split_permission(perm: str) -> tuple[str, str | None]: - """Split permission into (module, action).""" - if ":" in perm: - module, action = perm.split(":", 1) - # Treat empty action as absent (not a wildcard). - return module, action if action != "" else None - return perm, None - - def _match_permission( - self, - required_perm: str, - user_perm: str, - wildcard_support: bool, - ) -> bool: - # Wildcard disabled: exact match only - if not wildcard_support: - return required_perm == user_perm - - # "*": requires global permission, only user="*" satisfies - if required_perm == "*": - return user_perm == "*" - # user="*": global permission satisfies any requirement - if user_perm == "*": - return True - - req_module, req_action = self._split_permission(required_perm) - user_module, user_action = self._split_permission(user_perm) - - # Different modules: no match - if req_module != user_module: - return False - - # required_perm must be "module:action" format (not bare "module") - if req_action is None: - return False - - # user has "module" (full module access): matches any "module:action" - if user_action is None: - return True - - # "module:*": module wildcard, matches any action - if req_action == "*": - return True - - # Exact action match or user has "module:*" - return user_action == req_action or user_action == "*" - - def _has_permission( - self, - required_perm: str, - user_perms: set[str], - wildcard_support: bool, - ) -> bool: - return any( - self._match_permission(required_perm, user_perm, wildcard_support) - for user_perm in user_perms - ) - - async def check_permissions( - self, - user_id: int, - required_perms: Sequence[str], - match: Literal["all", "any"] = "all", - wildcard_support: bool = True, - ) -> bool: - if match not in ("all", "any"): - raise ValueError("match must be 'all' or 'any'") - - user_perms = await self.repository.get_user_permissions(user_id) - - if match == "all": - return all( - self._has_permission(req, user_perms, wildcard_support) - for req in required_perms - ) - return any( - self._has_permission(req, user_perms, wildcard_support) - for req in required_perms - ) - - async def check_roles( - self, - user_id: int, - required_roles: Sequence[str], - match: Literal["all", "any"] = "all", - ) -> bool: - if match not in ("all", "any"): - raise ValueError("match must be 'all' or 'any'") - - user_roles = await self.repository.get_user_roles(user_id) - - if match == "all": - return all(req in user_roles for req in required_roles) - return any(req in user_roles for req in required_roles) +from .repository import PermissionRepository +from .service import PermissionService async def get_permission_service( diff --git a/src/auth/rbac/models.py b/src/auth/rbac/models.py new file mode 100644 index 0000000..13bc85c --- /dev/null +++ b/src/auth/rbac/models.py @@ -0,0 +1,44 @@ +from datetime import UTC, datetime + +from sqlmodel import Field, SQLModel + +from src.mixins import TimestampMixin + + +class UserRole(SQLModel, table=True): + __tablename__ = "user_roles" + + user_id: int = Field(foreign_key="users.id", primary_key=True) + role_id: int = Field(foreign_key="roles.id", primary_key=True) + assigned_at: datetime = Field( + default_factory=lambda: datetime.now(UTC), + ) + + +class RolePermission(SQLModel, table=True): + __tablename__ = "role_permissions" + + role_id: int = Field(foreign_key="roles.id", primary_key=True) + permission_id: int = Field(foreign_key="permissions.id", primary_key=True) + assigned_at: datetime = Field( + default_factory=lambda: datetime.now(UTC), + ) + + +class Role(SQLModel, TimestampMixin, table=True): + __tablename__ = "roles" + + id: int | None = Field(default=None, primary_key=True) + name: str = Field(unique=True, index=True, max_length=100) + description: str | None = Field(default=None, max_length=255) + is_system: bool = Field(default=False) + + +class Permission(SQLModel, TimestampMixin, table=True): + __tablename__ = "permissions" + + id: int | None = Field(default=None, primary_key=True) + code: str = Field(unique=True, index=True, max_length=150) + name: str = Field(max_length=100) + description: str | None = Field(default=None, max_length=255) + module: str = Field(max_length=100) diff --git a/src/auth/rbac/repository.py b/src/auth/rbac/repository.py new file mode 100644 index 0000000..0f88acd --- /dev/null +++ b/src/auth/rbac/repository.py @@ -0,0 +1,30 @@ +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from .models import Permission, Role, RolePermission, UserRole + + +class PermissionRepository: + def __init__(self, session: AsyncSession): + self.session = session + + async def get_user_permissions(self, user_id: int) -> set[str]: + stmt = ( + select(Permission.code) + .distinct() + .join(RolePermission, Permission.id == RolePermission.permission_id) + .join(UserRole, RolePermission.role_id == UserRole.role_id) + .where(UserRole.user_id == user_id) + ) + result = await self.session.execute(stmt) + return set(result.scalars().all()) + + async def get_user_roles(self, user_id: int) -> set[str]: + stmt = ( + select(Role.name) + .distinct() + .join(UserRole, Role.id == UserRole.role_id) + .where(UserRole.user_id == user_id) + ) + result = await self.session.execute(stmt) + return set(result.scalars().all()) diff --git a/src/auth/rbac/service.py b/src/auth/rbac/service.py new file mode 100644 index 0000000..0633bf8 --- /dev/null +++ b/src/auth/rbac/service.py @@ -0,0 +1,105 @@ +from collections.abc import Sequence +from typing import Literal + +from .repository import PermissionRepository + + +class PermissionService: + def __init__(self, repository: PermissionRepository): + self.repository = repository + + @staticmethod + def _split_permission(perm: str) -> tuple[str, str | None]: + """Split permission into (module, action).""" + if ":" in perm: + module, action = perm.split(":", 1) + # Treat empty action as absent (not a wildcard). + return module, action if action != "" else None + return perm, None + + def _match_permission( + self, + required_perm: str, + user_perm: str, + wildcard_support: bool, + ) -> bool: + # Wildcard disabled: exact match only + if not wildcard_support: + return required_perm == user_perm + + # "*": requires global permission, only user="*" satisfies + if required_perm == "*": + return user_perm == "*" + # user="*": global permission satisfies any requirement + if user_perm == "*": + return True + + req_module, req_action = self._split_permission(required_perm) + user_module, user_action = self._split_permission(user_perm) + + # Different modules: no match + if req_module != user_module: + return False + + # required_perm must be "module:action" format (not bare "module") + if req_action is None: + return False + + # user has "module" (full module access): matches any "module:action" + if user_action is None: + return True + + # "module:*": module wildcard, matches any action + if req_action == "*": + return True + + # Exact action match or user has "module:*" + return user_action == req_action or user_action == "*" + + def _has_permission( + self, + required_perm: str, + user_perms: set[str], + wildcard_support: bool, + ) -> bool: + return any( + self._match_permission(required_perm, user_perm, wildcard_support) + for user_perm in user_perms + ) + + async def check_permissions( + self, + user_id: int, + required_perms: Sequence[str], + match: Literal["all", "any"] = "all", + wildcard_support: bool = True, + ) -> bool: + if match not in ("all", "any"): + raise ValueError("match must be 'all' or 'any'") + + user_perms = await self.repository.get_user_permissions(user_id) + + if match == "all": + return all( + self._has_permission(req, user_perms, wildcard_support) + for req in required_perms + ) + return any( + self._has_permission(req, user_perms, wildcard_support) + for req in required_perms + ) + + async def check_roles( + self, + user_id: int, + required_roles: Sequence[str], + match: Literal["all", "any"] = "all", + ) -> bool: + if match not in ("all", "any"): + raise ValueError("match must be 'all' or 'any'") + + user_roles = await self.repository.get_user_roles(user_id) + + if match == "all": + return all(req in user_roles for req in required_roles) + return any(req in user_roles for req in required_roles) diff --git a/tests/auth/test_models.py b/tests/auth/test_models.py index 592987c..1b1fede 100644 --- a/tests/auth/test_models.py +++ b/tests/auth/test_models.py @@ -2,7 +2,7 @@ from sqlalchemy import select from sqlalchemy.exc import IntegrityError -from src.auth.models import Permission, Role, RolePermission, UserRole +from src.auth.rbac.models import Permission, Role, RolePermission, UserRole class TestRoleSchema: diff --git a/tests/integration/test_audit_rbac.py b/tests/integration/test_audit_rbac.py index f997d21..fec2878 100644 --- a/tests/integration/test_audit_rbac.py +++ b/tests/integration/test_audit_rbac.py @@ -5,9 +5,9 @@ from src.auth import require_permissions, require_roles from src.auth.models import User +from src.auth.rbac.models import Permission, RolePermission, UserRole from src.audit.schemas import AuditAction, AuditLog, AuditResult from src.shared.errors import ErrorCode -from src.auth.models import Permission, RolePermission, UserRole @pytest.fixture diff --git a/tests/integration/test_audit_request_tracing.py b/tests/integration/test_audit_request_tracing.py index d581a1a..3efff6a 100644 --- a/tests/integration/test_audit_request_tracing.py +++ b/tests/integration/test_audit_request_tracing.py @@ -117,7 +117,7 @@ async def test_rbac_audit_includes_request_id( test_app_with_protected_route, test_user, test_db ): from src.auth import require_permissions - from src.auth.models import Permission, RolePermission + from src.auth.rbac.models import Permission, RolePermission, UserRole from src.main import app async with test_db() as session: @@ -129,8 +129,6 @@ async def test_rbac_audit_includes_request_id( session.add(role_perm) await session.commit() - from src.auth.models import UserRole - user_role = UserRole(user_id=test_user.id, role_id=1) session.add(user_role) await session.commit() diff --git a/tests/integration/test_rbac.py b/tests/integration/test_rbac.py index 1b46c76..65de85a 100644 --- a/tests/integration/test_rbac.py +++ b/tests/integration/test_rbac.py @@ -4,8 +4,8 @@ from src.auth import owner_or_perm, require_permissions, require_roles from src.auth.models import User +from src.auth.rbac.models import Permission, RolePermission, UserRole from src.shared.errors import ErrorCode -from src.auth.models import Permission, RolePermission, UserRole @pytest.fixture