Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 5 additions & 76 deletions src/auth/__init__.py
Original file line number Diff line number Diff line change
@@ -1,93 +1,22 @@
from fastapi import Depends, Request
from fastapi_users import FastAPIUsers

from src.audit import AuditService, get_audit_service
from src.audit.dependencies import generate_request_id
from src.audit.schemas import AuditAction, AuditResult
from src.http.utils import extract_client_info
from src.auth.models import User

from .backend import auth_backend
from .manager import get_user_manager

fastapi_users = FastAPIUsers[User, int](
get_user_manager=get_user_manager,
auth_backends=[auth_backend],
)

_current_user_base = fastapi_users.current_user(active=True)
_current_superuser_base = fastapi_users.current_user(active=True, superuser=True)


async def current_user(
request: Request,
user: User = Depends(_current_user_base),
audit_service: AuditService = Depends(get_audit_service),
request_id: str = Depends(generate_request_id),
) -> User:
user_agent, ip = extract_client_info(request)

await audit_service.log(
action=AuditAction.PROTECTED_RESOURCE_ACCESS,
result=AuditResult.SUCCESS,
actor_id=user.id,
request_id=request_id,
user_agent=user_agent,
ip=ip,
extra={
"method": request.method,
"path": request.url.path,
},
)

return user


async def current_superuser(
request: Request,
user: User = Depends(_current_superuser_base),
audit_service: AuditService = Depends(get_audit_service),
request_id: str = Depends(generate_request_id),
) -> User:
user_agent, ip = extract_client_info(request)

await audit_service.log(
action=AuditAction.PROTECTED_RESOURCE_ACCESS,
result=AuditResult.SUCCESS,
actor_id=user.id,
request_id=request_id,
user_agent=user_agent,
ip=ip,
extra={
"method": request.method,
"path": request.url.path,
"superuser": True,
},
)

return user


from .models import ( # noqa: E402
OAuthAccount,
from .dependencies import current_superuser, current_user, fastapi_users
from .models import OAuthAccount, User
from .rbac import (
Permission,
Role,
RolePermission,
UserRole,
)
from .rbac import ( # noqa: E402
owner_or_perm,
require_permissions,
require_roles,
)
from .schemas import ( # noqa: E402
from .schemas import (
MessageResponse,
TokenResponse,
UserCreate,
UserRead,
UserUpdate,
)
from .service import AuthenticationService, get_auth_service # noqa: E402
from .service import AuthenticationService, get_auth_service

__all__ = [
"fastapi_users",
Expand Down
68 changes: 68 additions & 0 deletions src/auth/dependencies.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from fastapi import Depends, Request
from fastapi_users import FastAPIUsers

from src.audit import AuditService, get_audit_service
from src.audit.dependencies import generate_request_id
from src.audit.schemas import AuditAction, AuditResult
from src.auth.models import User
from src.http.utils import extract_client_info

from .backend import auth_backend
from .manager import get_user_manager

fastapi_users = FastAPIUsers[User, int](
get_user_manager=get_user_manager,
auth_backends=[auth_backend],
)

_current_user_base = fastapi_users.current_user(active=True)
_current_superuser_base = fastapi_users.current_user(active=True, superuser=True)


async def current_user(
request: Request,
user: User = Depends(_current_user_base),
audit_service: AuditService = Depends(get_audit_service),
request_id: str = Depends(generate_request_id),
) -> User:
user_agent, ip = extract_client_info(request)

await audit_service.log(
action=AuditAction.PROTECTED_RESOURCE_ACCESS,
result=AuditResult.SUCCESS,
actor_id=user.id,
request_id=request_id,
user_agent=user_agent,
ip=ip,
extra={
"method": request.method,
"path": request.url.path,
},
)

return user


async def current_superuser(
request: Request,
user: User = Depends(_current_superuser_base),
audit_service: AuditService = Depends(get_audit_service),
request_id: str = Depends(generate_request_id),
) -> User:
user_agent, ip = extract_client_info(request)

await audit_service.log(
action=AuditAction.PROTECTED_RESOURCE_ACCESS,
result=AuditResult.SUCCESS,
actor_id=user.id,
request_id=request_id,
user_agent=user_agent,
ip=ip,
extra={
"method": request.method,
"path": request.url.path,
"superuser": True,
},
)

return user
41 changes: 0 additions & 41 deletions src/auth/models.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from datetime import UTC, datetime

from sqlmodel import Field, SQLModel

from src.mixins import TimestampMixin
Expand Down Expand Up @@ -30,42 +28,3 @@ class OAuthAccount(SQLModel, table=True):
refresh_token: str | None = Field(default=None, max_length=1024, nullable=True)
account_id: str = Field(max_length=320, nullable=False, index=True)
account_email: str = Field(max_length=320, nullable=False)


class UserRole(SQLModel, table=True):
__tablename__ = "user_roles"

user_id: int = Field(foreign_key="users.id", primary_key=True)
role_id: int = Field(foreign_key="roles.id", primary_key=True)
assigned_at: datetime = Field(
default_factory=lambda: datetime.now(UTC),
)


class RolePermission(SQLModel, table=True):
__tablename__ = "role_permissions"

role_id: int = Field(foreign_key="roles.id", primary_key=True)
permission_id: int = Field(foreign_key="permissions.id", primary_key=True)
assigned_at: datetime = Field(
default_factory=lambda: datetime.now(UTC),
)


class Role(SQLModel, TimestampMixin, table=True):
__tablename__ = "roles"

id: int | None = Field(default=None, primary_key=True)
name: str = Field(unique=True, index=True, max_length=100)
description: str | None = Field(default=None, max_length=255)
is_system: bool = Field(default=False)


class Permission(SQLModel, TimestampMixin, table=True):
__tablename__ = "permissions"

id: int | None = Field(default=None, primary_key=True)
code: str = Field(unique=True, index=True, max_length=150)
name: str = Field(max_length=100)
description: str | None = Field(default=None, max_length=255)
module: str = Field(max_length=100)
24 changes: 24 additions & 0 deletions src/auth/rbac/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
"""RBAC (Role-Based Access Control) submodule for authorization."""

from .dependencies import (
get_permission_service,
owner_or_perm,
require_permissions,
require_roles,
)
from .models import Permission, Role, RolePermission, UserRole
from .repository import PermissionRepository
from .service import PermissionService

__all__ = [
"PermissionRepository",
"PermissionService",
"get_permission_service",
"require_permissions",
"require_roles",
"owner_or_perm",
"Role",
"Permission",
"UserRole",
"RolePermission",
]
136 changes: 5 additions & 131 deletions src/auth/rbac.py → src/auth/rbac/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,147 +2,21 @@
from typing import Literal

from fastapi import Depends, Request
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession

from src.audit import AuditService, generate_request_id, get_audit_service
from src.audit.schemas import AuditAction, AuditResult
from src.http.utils import extract_client_info
from src.auth.models import Permission, Role, RolePermission, User, UserRole
from src.auth.dependencies import current_user
from src.auth.exceptions import (
InsufficientPermissionException,
InsufficientRoleException,
)
from src.auth.models import User
from src.http.utils import extract_client_info
from src.session import get_session

from . import current_user


class PermissionRepository:
def __init__(self, session: AsyncSession):
self.session = session

async def get_user_permissions(self, user_id: int) -> set[str]:
stmt = (
select(Permission.code)
.distinct()
.join(RolePermission, Permission.id == RolePermission.permission_id)
.join(UserRole, RolePermission.role_id == UserRole.role_id)
.where(UserRole.user_id == user_id)
)
result = await self.session.execute(stmt)
return set(result.scalars().all())

async def get_user_roles(self, user_id: int) -> set[str]:
stmt = (
select(Role.name)
.distinct()
.join(UserRole, Role.id == UserRole.role_id)
.where(UserRole.user_id == user_id)
)
result = await self.session.execute(stmt)
return set(result.scalars().all())


class PermissionService:
def __init__(self, repository: PermissionRepository):
self.repository = repository

@staticmethod
def _split_permission(perm: str) -> tuple[str, str | None]:
"""Split permission into (module, action)."""
if ":" in perm:
module, action = perm.split(":", 1)
# Treat empty action as absent (not a wildcard).
return module, action if action != "" else None
return perm, None

def _match_permission(
self,
required_perm: str,
user_perm: str,
wildcard_support: bool,
) -> bool:
# Wildcard disabled: exact match only
if not wildcard_support:
return required_perm == user_perm

# "*": requires global permission, only user="*" satisfies
if required_perm == "*":
return user_perm == "*"
# user="*": global permission satisfies any requirement
if user_perm == "*":
return True

req_module, req_action = self._split_permission(required_perm)
user_module, user_action = self._split_permission(user_perm)

# Different modules: no match
if req_module != user_module:
return False

# required_perm must be "module:action" format (not bare "module")
if req_action is None:
return False

# user has "module" (full module access): matches any "module:action"
if user_action is None:
return True

# "module:*": module wildcard, matches any action
if req_action == "*":
return True

# Exact action match or user has "module:*"
return user_action == req_action or user_action == "*"

def _has_permission(
self,
required_perm: str,
user_perms: set[str],
wildcard_support: bool,
) -> bool:
return any(
self._match_permission(required_perm, user_perm, wildcard_support)
for user_perm in user_perms
)

async def check_permissions(
self,
user_id: int,
required_perms: Sequence[str],
match: Literal["all", "any"] = "all",
wildcard_support: bool = True,
) -> bool:
if match not in ("all", "any"):
raise ValueError("match must be 'all' or 'any'")

user_perms = await self.repository.get_user_permissions(user_id)

if match == "all":
return all(
self._has_permission(req, user_perms, wildcard_support)
for req in required_perms
)
return any(
self._has_permission(req, user_perms, wildcard_support)
for req in required_perms
)

async def check_roles(
self,
user_id: int,
required_roles: Sequence[str],
match: Literal["all", "any"] = "all",
) -> bool:
if match not in ("all", "any"):
raise ValueError("match must be 'all' or 'any'")

user_roles = await self.repository.get_user_roles(user_id)

if match == "all":
return all(req in user_roles for req in required_roles)
return any(req in user_roles for req in required_roles)
from .repository import PermissionRepository
from .service import PermissionService


async def get_permission_service(
Expand Down
Loading
Loading