diff --git a/models/src/agent_control_models/server.py b/models/src/agent_control_models/server.py index d0e61172..64c18ff2 100644 --- a/models/src/agent_control_models/server.py +++ b/models/src/agent_control_models/server.py @@ -311,12 +311,10 @@ class GetControlResponse(BaseModel): id: int = Field(..., description="Control ID") name: str = Field(..., description="Control name") - data: ControlDefinition | UnrenderedTemplateControl | None = Field( - None, + data: ControlDefinition | UnrenderedTemplateControl = Field( description=( "Control configuration data. A ControlDefinition for raw/rendered " - "controls, an UnrenderedTemplateControl for unrendered templates, " - "or None if not yet configured." + "controls or an UnrenderedTemplateControl for unrendered templates." ), ) diff --git a/sdks/python/src/agent_control/__init__.py b/sdks/python/src/agent_control/__init__.py index 1475049c..0a8614c4 100644 --- a/sdks/python/src/agent_control/__init__.py +++ b/sdks/python/src/agent_control/__init__.py @@ -1055,7 +1055,7 @@ async def get_control( Dictionary containing: - id: Control ID - name: Control name - - data: Control definition or None if not configured + - data: Control definition or unrendered template control data Raises: httpx.HTTPError: If request fails or control not found diff --git a/sdks/python/src/agent_control/controls.py b/sdks/python/src/agent_control/controls.py index 26bc75e1..0c7999f2 100644 --- a/sdks/python/src/agent_control/controls.py +++ b/sdks/python/src/agent_control/controls.py @@ -107,7 +107,7 @@ async def get_control( Dictionary containing: - id: Control ID - name: Control name - - data: Control definition (condition, action, scope, etc.) or None if not configured + - data: Control definition or unrendered template control data Raises: httpx.HTTPError: If request fails @@ -117,8 +117,7 @@ async def get_control( async with AgentControlClient() as client: control = await get_control(client, control_id=5) print(f"Control: {control['name']}") - if control['data']: - print(f"Execution: {control['data']['execution']}") + print(f"Enabled: {control['data']['enabled']}") """ response = await client.http_client.get(f"/api/v1/controls/{control_id}") response.raise_for_status() diff --git a/sdks/typescript/src/generated/funcs/controls-update-metadata.ts b/sdks/typescript/src/generated/funcs/controls-update-metadata.ts index a9536f6f..161bcacf 100644 --- a/sdks/typescript/src/generated/funcs/controls-update-metadata.ts +++ b/sdks/typescript/src/generated/funcs/controls-update-metadata.ts @@ -48,7 +48,7 @@ import { Result } from "../types/fp.js"; * Raises: * HTTPException 404: Control not found * HTTPException 409: New name conflicts with existing control - * HTTPException 422: Cannot update enabled status (control has no data configured) + * HTTPException 422: Cannot update metadata for corrupted control data * HTTPException 500: Database error during update */ export function controlsUpdateMetadata( diff --git a/sdks/typescript/src/generated/models/get-control-response.ts b/sdks/typescript/src/generated/models/get-control-response.ts index bf67a1c9..8e65e936 100644 --- a/sdks/typescript/src/generated/models/get-control-response.ts +++ b/sdks/typescript/src/generated/models/get-control-response.ts @@ -18,7 +18,7 @@ import { } from "./unrendered-template-control.js"; /** - * Control configuration data. A ControlDefinition for raw/rendered controls, an UnrenderedTemplateControl for unrendered templates, or None if not yet configured. + * Control configuration data. A ControlDefinition for raw/rendered controls or an UnrenderedTemplateControl for unrendered templates. */ export type GetControlResponseData = | ControlDefinitionOutput @@ -29,9 +29,9 @@ export type GetControlResponseData = */ export type GetControlResponse = { /** - * Control configuration data. A ControlDefinition for raw/rendered controls, an UnrenderedTemplateControl for unrendered templates, or None if not yet configured. + * Control configuration data. A ControlDefinition for raw/rendered controls or an UnrenderedTemplateControl for unrendered templates. */ - data?: ControlDefinitionOutput | UnrenderedTemplateControl | null | undefined; + data: ControlDefinitionOutput | UnrenderedTemplateControl; /** * Control ID */ @@ -66,14 +66,10 @@ export const GetControlResponse$inboundSchema: z.ZodMiniType< GetControlResponse, unknown > = z.object({ - data: z.optional( - z.nullable( - smartUnion([ - ControlDefinitionOutput$inboundSchema, - UnrenderedTemplateControl$inboundSchema, - ]), - ), - ), + data: smartUnion([ + ControlDefinitionOutput$inboundSchema, + UnrenderedTemplateControl$inboundSchema, + ]), id: types.number(), name: types.string(), }); diff --git a/sdks/typescript/src/generated/sdk/controls.ts b/sdks/typescript/src/generated/sdk/controls.ts index cba949c7..e8218d17 100644 --- a/sdks/typescript/src/generated/sdk/controls.ts +++ b/sdks/typescript/src/generated/sdk/controls.ts @@ -223,7 +223,7 @@ export class Controls extends ClientSDK { * Raises: * HTTPException 404: Control not found * HTTPException 409: New name conflicts with existing control - * HTTPException 422: Cannot update enabled status (control has no data configured) + * HTTPException 422: Cannot update metadata for corrupted control data * HTTPException 500: Database error during update */ async updateMetadata( diff --git a/server/alembic/versions/c1e9f9c4a1d2_control_versions_and_soft_delete_legacy_controls.py b/server/alembic/versions/c1e9f9c4a1d2_control_versions_and_soft_delete_legacy_controls.py new file mode 100644 index 00000000..cb9a09ed --- /dev/null +++ b/server/alembic/versions/c1e9f9c4a1d2_control_versions_and_soft_delete_legacy_controls.py @@ -0,0 +1,227 @@ +"""add control versions and soft-delete unusable legacy controls + +Revision ID: c1e9f9c4a1d2 +Revises: 5f2b5f4e1a90 +Create Date: 2026-04-15 12:00:00.000000 + +""" + +from __future__ import annotations + +import datetime as dt +import logging +from typing import Any + +import sqlalchemy as sa +from alembic import op +from pydantic import ValidationError +from sqlalchemy import inspect +from sqlalchemy.dialects import postgresql + +from agent_control_models import ControlDefinition, UnrenderedTemplateControl + +# revision identifiers, used by Alembic. +revision = "c1e9f9c4a1d2" +down_revision = "5f2b5f4e1a90" +branch_labels = None +depends_on = None + +_logger = logging.getLogger("alembic.runtime.migration") + +_BACKFILL_NOTE = "Backfilled from existing control" + + +def _classify_control_payload(data: Any) -> tuple[bool, str | None]: + """Return whether a legacy control payload is still usable.""" + if data == {}: + return False, "empty payload" + if not isinstance(data, dict): + return False, "invalid control payload" + + try: + UnrenderedTemplateControl.model_validate(data) + except ValidationError: + pass + else: + return True, None + + try: + ControlDefinition.model_validate(data) + except ValidationError: + return False, "invalid control payload" + + return True, None + + +def _snapshot_payload( + *, + name: str, + data: Any, + deleted_at: dt.datetime | None, +) -> dict[str, Any]: + """Build the JSON snapshot persisted in control_versions.""" + return { + "name": name, + "data": data, + "deleted_at": deleted_at.isoformat() if deleted_at is not None else None, + "cloned_control_id": None, + } + + +def upgrade() -> None: + op.add_column("controls", sa.Column("deleted_at", sa.DateTime(timezone=True), nullable=True)) + op.drop_constraint("controls_name_key", "controls", type_="unique") + op.create_index( + "idx_controls_name_active", + "controls", + ["name"], + unique=True, + postgresql_where=sa.text("deleted_at IS NULL"), + ) + + op.create_table( + "control_versions", + sa.Column("id", sa.Integer(), autoincrement=True, nullable=False), + sa.Column("control_id", sa.Integer(), nullable=False), + sa.Column("version_num", sa.Integer(), nullable=False), + sa.Column("event_type", sa.String(length=255), nullable=False), + sa.Column("snapshot", postgresql.JSONB(astext_type=sa.Text()), nullable=False), + sa.Column("note", sa.Text(), nullable=True), + sa.Column( + "created_at", + sa.DateTime(timezone=True), + server_default=sa.text("CURRENT_TIMESTAMP"), + nullable=False, + ), + sa.ForeignKeyConstraint(["control_id"], ["controls.id"]), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint( + "control_id", + "version_num", + name="uq_control_versions_control_version", + ), + ) + op.create_index( + "idx_control_versions_control_created", + "control_versions", + ["control_id", sa.literal_column("created_at DESC")], + unique=False, + ) + + bind = op.get_bind() + db_inspector = inspect(bind) + + controls = sa.table( + "controls", + sa.column("id", sa.Integer()), + sa.column("name", sa.String()), + sa.column("data", postgresql.JSONB(astext_type=sa.Text())), + sa.column("deleted_at", sa.DateTime(timezone=True)), + ) + control_versions = sa.table( + "control_versions", + sa.column("control_id", sa.Integer()), + sa.column("version_num", sa.Integer()), + sa.column("event_type", sa.String()), + sa.column("snapshot", postgresql.JSONB(astext_type=sa.Text())), + sa.column("note", sa.Text()), + ) + policy_controls = sa.table( + "policy_controls", + sa.column("policy_id", sa.Integer()), + sa.column("control_id", sa.Integer()), + ) + agent_controls = sa.table( + "agent_controls", + sa.column("agent_name", sa.String()), + sa.column("control_id", sa.Integer()), + ) + + store_publications = None + if db_inspector.has_table("control_stores_controls"): + store_publications = sa.table( + "control_stores_controls", + sa.column("store_id", sa.Integer()), + sa.column("control_id", sa.Integer()), + ) + + rows = bind.execute( + sa.select( + controls.c.id, + controls.c.name, + controls.c.data, + ).order_by(controls.c.id) + ).mappings() + + auto_deleted_controls: list[str] = [] + for row in rows: + control_id = int(row["id"]) + control_name = str(row["name"]) + control_data = row["data"] + usable, reason = _classify_control_payload(control_data) + + bind.execute( + sa.insert(control_versions).values( + control_id=control_id, + version_num=1, + event_type="migration_backfill", + snapshot=_snapshot_payload( + name=control_name, + data=control_data, + deleted_at=None, + ), + note=_BACKFILL_NOTE, + ) + ) + + if usable: + continue + + if store_publications is not None: + bind.execute( + sa.delete(store_publications).where( + store_publications.c.control_id == control_id + ) + ) + bind.execute( + sa.delete(policy_controls).where(policy_controls.c.control_id == control_id) + ) + bind.execute( + sa.delete(agent_controls).where(agent_controls.c.control_id == control_id) + ) + + deleted_at = dt.datetime.now(dt.UTC) + bind.execute( + sa.update(controls) + .where(controls.c.id == control_id) + .values(deleted_at=deleted_at) + ) + bind.execute( + sa.insert(control_versions).values( + control_id=control_id, + version_num=2, + event_type="migration_autodelete", + snapshot=_snapshot_payload( + name=control_name, + data=control_data, + deleted_at=deleted_at, + ), + note=f"Auto-soft-deleted during migration: {reason}", + ) + ) + auto_deleted_controls.append(f"{control_id}:{control_name}") + + if auto_deleted_controls: + _logger.warning( + "Auto-soft-deleted %d unusable controls during migration: %s", + len(auto_deleted_controls), + ", ".join(auto_deleted_controls), + ) + + +def downgrade() -> None: + op.drop_index("idx_control_versions_control_created", table_name="control_versions") + op.drop_table("control_versions") + op.drop_index("idx_controls_name_active", table_name="controls") + op.create_unique_constraint("controls_name_key", "controls", ["name"]) + op.drop_column("controls", "deleted_at") diff --git a/server/src/agent_control_server/endpoints/agents.py b/server/src/agent_control_server/endpoints/agents.py index 9ec98a77..8b9b9f40 100644 --- a/server/src/agent_control_server/endpoints/agents.py +++ b/server/src/agent_control_server/endpoints/agents.py @@ -114,9 +114,6 @@ def _validate_controls_for_agent(agent: Agent, controls: list[Control]) -> list[ agent_evaluators = {e.name: e for e in (agent_data.evaluators or [])} for control in controls: - if not control.data: - continue - # Skip unrendered template controls — they have no evaluators to validate. if ( isinstance(control.data, dict) @@ -399,6 +396,7 @@ async def list_agents( ) .join(Control, all_associations.c.control_id == Control.id) .where( + Control.deleted_at.is_(None), or_( Control.data["enabled"].astext == "true", ~Control.data.has_key("enabled"), @@ -1248,7 +1246,9 @@ async def add_agent_control( """Associate a control directly with an agent (idempotent).""" agent = await _get_agent_or_404(agent_name, db) - control_result = await db.execute(select(Control).where(Control.id == control_id)) + control_result = await db.execute( + select(Control).where(Control.id == control_id, Control.deleted_at.is_(None)) + ) control: Control | None = control_result.scalars().first() if control is None: raise NotFoundError( @@ -1314,7 +1314,9 @@ async def remove_agent_control( """Remove a direct control association from an agent (idempotent).""" agent = await _get_agent_or_404(agent_name, db) - control_result = await db.execute(select(Control.id).where(Control.id == control_id)) + control_result = await db.execute( + select(Control.id).where(Control.id == control_id, Control.deleted_at.is_(None)) + ) if control_result.first() is None: raise NotFoundError( error_code=ErrorCode.CONTROL_NOT_FOUND, diff --git a/server/src/agent_control_server/endpoints/controls.py b/server/src/agent_control_server/endpoints/controls.py index 16421477..fe3ec161 100644 --- a/server/src/agent_control_server/endpoints/controls.py +++ b/server/src/agent_control_server/endpoints/controls.py @@ -1,3 +1,6 @@ +import datetime as dt +from typing import cast + from agent_control_engine import list_evaluators from agent_control_models import ControlDefinition, TemplateControlInput, UnrenderedTemplateControl from agent_control_models.errors import ErrorCode, ValidationErrorItem @@ -27,6 +30,7 @@ from sqlalchemy import Integer, String, delete, func, literal, or_, select, union_all from sqlalchemy.exc import IntegrityError from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.sql import Select from ..auth import require_admin_key from ..db import get_async_db @@ -67,6 +71,38 @@ _logger = get_logger(__name__) +def _select_active_control(control_id: int) -> Select[tuple[Control]]: + """Return a query for an active control row by ID.""" + return select(Control).where(Control.id == control_id, Control.deleted_at.is_(None)) + + +def _select_active_control_name( + name: str, + *, + exclude_control_id: int | None = None, +) -> Select[tuple[int]]: + """Return a query for active controls matching the provided name.""" + stmt = select(Control.id).where(Control.name == name, Control.deleted_at.is_(None)) + if exclude_control_id is not None: + stmt = stmt.where(Control.id != exclude_control_id) + return stmt + + +async def _get_active_control_or_404(control_id: int, db: AsyncSession) -> Control: + """Load an active control or raise CONTROL_NOT_FOUND.""" + res = await db.execute(_select_active_control(control_id)) + control = cast(Control | None, res.scalars().first()) + if control is None: + raise NotFoundError( + error_code=ErrorCode.CONTROL_NOT_FOUND, + detail=f"Control with ID '{control_id}' not found", + resource="Control", + resource_id=str(control_id), + hint="Verify the control ID is correct and the control has been created.", + ) + return control + + def _serialize_control_data( control_data: ControlDefinition | UnrenderedTemplateControl, ) -> dict[str, object]: @@ -463,7 +499,7 @@ async def create_control( HTTPException 500: Database error during creation """ # Uniqueness check - existing = await db.execute(select(Control.id).where(Control.name == request.name)) + existing = await db.execute(_select_active_control_name(request.name)) if existing.first() is not None: raise ConflictError( error_code=ErrorCode.CONTROL_NAME_CONFLICT, @@ -539,35 +575,12 @@ async def get_control( Raises: HTTPException 404: Control not found """ - res = await db.execute(select(Control).where(Control.id == control_id)) - control = res.scalars().first() - if control is None: - raise NotFoundError( - error_code=ErrorCode.CONTROL_NOT_FOUND, - detail=f"Control with ID '{control_id}' not found", - resource="Control", - resource_id=str(control_id), - hint="Verify the control ID is correct and the control has been created.", - ) - - # Parse data if present and non-empty - control_data: ControlDefinition | UnrenderedTemplateControl | None = None - if control.data: - try: - control_data = _parse_stored_control_data( - control.data, - control_name=control.name, - control_id=control_id, - ) - except Exception: - # Data exists but is corrupted - log and return None - _logger.warning( - "Control '%s' (id=%s) has corrupted data that failed validation", - control.name, - control_id, - exc_info=True, - ) - control_data = None + control = await _get_active_control_or_404(control_id, db) + control_data = _parse_stored_control_data( + control.data, + control_name=control.name, + control_id=control_id, + ) return GetControlResponse( id=control.id, @@ -602,16 +615,7 @@ async def get_control_data( HTTPException 404: Control not found HTTPException 422: Control data is corrupted """ - res = await db.execute(select(Control).where(Control.id == control_id)) - control = res.scalars().first() - if control is None: - raise NotFoundError( - error_code=ErrorCode.CONTROL_NOT_FOUND, - detail=f"Control with ID '{control_id}' not found", - resource="Control", - resource_id=str(control_id), - hint="Verify the control ID is correct and the control has been created.", - ) + control = await _get_active_control_or_404(control_id, db) control_data = _parse_stored_control_data( control.data, control_name=control.name, @@ -650,16 +654,7 @@ async def set_control_data( HTTPException 404: Control not found HTTPException 500: Database error during update """ - res = await db.execute(select(Control).where(Control.id == control_id)) - control = res.scalars().first() - if control is None: - raise NotFoundError( - error_code=ErrorCode.CONTROL_NOT_FOUND, - detail=f"Control with ID '{control_id}' not found", - resource="Control", - resource_id=str(control_id), - hint="Verify the control ID is correct and the control has been created.", - ) + control = await _get_active_control_or_404(control_id, db) control_def = await _materialize_control_input( request.data, @@ -756,7 +751,7 @@ async def list_controls( Example: GET /controls?limit=10&enabled=true&step_type=tool """ - query = select(Control).order_by(Control.id.desc()) + query = select(Control).where(Control.deleted_at.is_(None)).order_by(Control.id.desc()) # Apply cursor if cursor is not None: @@ -821,7 +816,7 @@ async def list_controls( controls = list(result.scalars().all()) # Get total count (with same filters, but without cursor/limit) - total_query = select(func.count()).select_from(Control) + total_query = select(func.count()).select_from(Control).where(Control.deleted_at.is_(None)) if name is not None: total_query = total_query.where( Control.name.ilike(f"%{escape_like_pattern(name)}%", escape="\\") @@ -990,17 +985,7 @@ async def delete_control( HTTPException 409: Control is in use (and force=false) HTTPException 500: Database error during deletion """ - # Find the control - result = await db.execute(select(Control).where(Control.id == control_id)) - control = result.scalars().first() - if control is None: - raise NotFoundError( - error_code=ErrorCode.CONTROL_NOT_FOUND, - detail=f"Control with ID '{control_id}' not found", - resource="Control", - resource_id=str(control_id), - hint="Verify the control ID is correct and the control has been created.", - ) + control = await _get_active_control_or_404(control_id, db) # Check for associations with policies and direct agent links. policy_assoc_query = select( @@ -1072,15 +1057,15 @@ async def delete_control( len(dissociated_from_agents), ) - # Delete the control - await db.delete(control) + # Tombstone the control so backfilled version history remains referentially intact. + control.deleted_at = dt.datetime.now(dt.UTC) try: await db.commit() - _logger.info(f"Deleted control '{control.name}' ({control_id})") + _logger.info("Soft-deleted control '%s' (%s)", control.name, control_id) except Exception: await db.rollback() _logger.error( - f"Failed to delete control '{control.name}' ({control_id})", + f"Failed to soft-delete control '{control.name}' ({control_id})", exc_info=True, ) raise DatabaseError( @@ -1127,20 +1112,15 @@ async def patch_control( Raises: HTTPException 404: Control not found HTTPException 409: New name conflicts with existing control - HTTPException 422: Cannot update enabled status (control has no data configured) + HTTPException 422: Cannot update metadata for corrupted control data HTTPException 500: Database error during update """ - # Find the control - result = await db.execute(select(Control).where(Control.id == control_id)) - control = result.scalars().first() - if control is None: - raise NotFoundError( - error_code=ErrorCode.CONTROL_NOT_FOUND, - detail=f"Control with ID '{control_id}' not found", - resource="Control", - resource_id=str(control_id), - hint="Verify the control ID is correct and the control has been created.", - ) + control = await _get_active_control_or_404(control_id, db) + parsed_control = _parse_stored_control_data( + control.data, + control_name=control.name, + control_id=control_id, + ) # Track if anything changed updated = False @@ -1149,7 +1129,7 @@ async def patch_control( if request.name is not None and request.name != control.name: # Check for name collision existing = await db.execute( - select(Control.id).where(Control.name == request.name) + _select_active_control_name(request.name, exclude_control_id=control_id) ) if existing.first() is not None: raise ConflictError( @@ -1165,26 +1145,7 @@ async def patch_control( # Update enabled status if provided current_enabled: bool | None = None if request.enabled is not None: - if not control.data: - raise APIValidationError( - error_code=ErrorCode.VALIDATION_ERROR, - detail=( - f"Cannot update enabled status: control '{control.name}' " - "has no data configured" - ), - resource="Control", - hint=f"Use PUT /{control_id}/data to configure the control first.", - errors=[ - ValidationErrorItem( - resource="Control", - field="enabled", - code="no_data_configured", - message="Control must have data configured before enabling/disabling", - ) - ], - ) - - if _is_unrendered_template(control.data): + if isinstance(parsed_control, UnrenderedTemplateControl): if request.enabled: raise APIValidationError( error_code=ErrorCode.VALIDATION_ERROR, @@ -1213,48 +1174,14 @@ async def patch_control( # enabled=False on an unrendered template is a no-op (already false). current_enabled = False else: - try: - ctrl_def = ControlDefinition.model_validate(control.data) - if ctrl_def.enabled != request.enabled: - new_data = dict(control.data) - new_data["enabled"] = request.enabled - control.data = new_data - updated = True - current_enabled = request.enabled if updated else ctrl_def.enabled - except ValidationError: - _logger.error( - "Control '%s' (%s) has corrupted data in patch request", - control.name, - control_id, - exc_info=True, - ) - raise APIValidationError( - error_code=ErrorCode.CORRUPTED_DATA, - detail=f"Control '{control.name}' has corrupted data", - resource="Control", - hint="Update the control data using PUT /{control_id}/data.", - errors=[ - ValidationErrorItem( - resource="Control", - field="data", - code="corrupted_data", - message=_CORRUPTED_CONTROL_DATA_MESSAGE, - ) - ], - ) - elif control.data: - # Get current enabled status for response - if _is_unrendered_template(control.data): - current_enabled = False - else: - try: - ctrl_def = ControlDefinition.model_validate(control.data) - current_enabled = ctrl_def.enabled - except ValidationError: - _logger.warning( - "Control '%s' has invalid data, using default", - control.name, - ) + if parsed_control.enabled != request.enabled: + new_data = dict(control.data) + new_data["enabled"] = request.enabled + control.data = new_data + updated = True + current_enabled = request.enabled if updated else parsed_control.enabled + else: + current_enabled = parsed_control.enabled # Commit if anything changed if updated: diff --git a/server/src/agent_control_server/endpoints/policies.py b/server/src/agent_control_server/endpoints/policies.py index dd242d14..2ab8bc33 100644 --- a/server/src/agent_control_server/endpoints/policies.py +++ b/server/src/agent_control_server/endpoints/policies.py @@ -117,7 +117,9 @@ async def add_control_to_policy( hint="Verify the policy ID is correct and the policy has been created.", ) - ctl_res = await db.execute(select(Control).where(Control.id == control_id)) + ctl_res = await db.execute( + select(Control).where(Control.id == control_id, Control.deleted_at.is_(None)) + ) control = ctl_res.scalars().first() if control is None: raise NotFoundError( @@ -200,7 +202,9 @@ async def remove_control_from_policy( hint="Verify the policy ID is correct and the policy has been created.", ) - ctl_res = await db.execute(select(Control).where(Control.id == control_id)) + ctl_res = await db.execute( + select(Control).where(Control.id == control_id, Control.deleted_at.is_(None)) + ) control = ctl_res.scalars().first() if control is None: raise NotFoundError( @@ -272,9 +276,9 @@ async def list_policy_controls( ) rows = await db.execute( - select(policy_controls.c.control_id).where( - policy_controls.c.policy_id == policy_id - ) + select(policy_controls.c.control_id) + .join(Control, Control.id == policy_controls.c.control_id) + .where(policy_controls.c.policy_id == policy_id, Control.deleted_at.is_(None)) ) control_ids = [r[0] for r in rows.fetchall()] return GetPolicyControlsResponse(control_ids=control_ids) diff --git a/server/src/agent_control_server/models.py b/server/src/agent_control_server/models.py index 5da55730..583e6181 100644 --- a/server/src/agent_control_server/models.py +++ b/server/src/agent_control_server/models.py @@ -14,6 +14,8 @@ Integer, String, Table, + Text, + UniqueConstraint, text, ) from sqlalchemy.dialects.postgresql import JSONB @@ -71,13 +73,25 @@ class Policy(Base): class Control(Base): __tablename__ = "controls" + __table_args__ = ( + Index( + "idx_controls_name_active", + "name", + unique=True, + postgresql_where=text("deleted_at IS NULL"), + sqlite_where=text("deleted_at IS NULL"), + ), + ) id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) - name: Mapped[str] = mapped_column(String(255), nullable=False, unique=True) + name: Mapped[str] = mapped_column(String(255), nullable=False) # JSONB payload describing control specifics data: Mapped[dict[str, Any]] = mapped_column( JSONB, server_default=text("'{}'::jsonb"), nullable=False ) + deleted_at: Mapped[dt.datetime | None] = mapped_column( + DateTime(timezone=True), nullable=True + ) # Many-to-many backref: Control <> Policy policies: Mapped[list["Policy"]] = relationship( "Policy", secondary=lambda: policy_controls, back_populates="controls" @@ -88,6 +102,26 @@ class Control(Base): ) +class ControlVersion(Base): + __tablename__ = "control_versions" + __table_args__ = ( + UniqueConstraint("control_id", "version_num", name="uq_control_versions_control_version"), + Index("idx_control_versions_control_created", "control_id", text("created_at DESC")), + ) + + id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) + control_id: Mapped[int] = mapped_column( + Integer, ForeignKey("controls.id"), nullable=False + ) + version_num: Mapped[int] = mapped_column(Integer, nullable=False) + event_type: Mapped[str] = mapped_column(String(255), nullable=False) + snapshot: Mapped[dict[str, Any]] = mapped_column(JSONB, nullable=False) + note: Mapped[str | None] = mapped_column(Text, nullable=True) + created_at: Mapped[dt.datetime] = mapped_column( + DateTime(timezone=True), server_default=text("CURRENT_TIMESTAMP"), nullable=False + ) + + class Agent(Base): __tablename__ = "agents" __table_args__ = ( diff --git a/server/src/agent_control_server/services/controls.py b/server/src/agent_control_server/services/controls.py index 7042972b..9806f81a 100644 --- a/server/src/agent_control_server/services/controls.py +++ b/server/src/agent_control_server/services/controls.py @@ -60,6 +60,7 @@ async def _list_db_controls_for_agent( stmt = ( select(Control) .join(control_ids_subquery, Control.id == control_ids_subquery.c.control_id) + .where(Control.deleted_at.is_(None)) .order_by(Control.id.desc()) ) @@ -155,7 +156,7 @@ async def list_controls_for_policy(policy_id: int, db: AsyncSession) -> list[Con stmt = ( select(Control) .join(policy_controls, Control.id == policy_controls.c.control_id) - .where(policy_controls.c.policy_id == policy_id) + .where(policy_controls.c.policy_id == policy_id, Control.deleted_at.is_(None)) ) result = await db.execute(stmt) return list(result.scalars().unique().all()) diff --git a/server/tests/test_agents_additional.py b/server/tests/test_agents_additional.py index dc5124fb..1f0d9cc4 100644 --- a/server/tests/test_agents_additional.py +++ b/server/tests/test_agents_additional.py @@ -733,8 +733,8 @@ async def mock_db_missing_policy() -> AsyncGenerator[AsyncSession, None]: assert resp.json()["error_code"] == "POLICY_NOT_FOUND" -def test_set_agent_policy_skips_controls_without_data(client: TestClient) -> None: - # Given: an agent and a policy with a control that has no data configured +def test_set_agent_policy_rejects_controls_without_data(client: TestClient) -> None: + # Given: an agent and a policy with a control that has invalid empty data agent_name, _ = _init_agent(client) policy_id = _create_policy(client) control_id = _insert_unconfigured_control() @@ -744,9 +744,11 @@ def test_set_agent_policy_skips_controls_without_data(client: TestClient) -> Non # When: assigning the policy to the agent resp = client.post(f"/api/v1/agents/{agent_name}/policy/{policy_id}") - # Then: assignment succeeds because empty data is ignored during validation - assert resp.status_code == 200 - assert resp.json()["success"] is True + # Then: assignment is rejected because the stored control data is corrupted + assert resp.status_code == 400 + body = resp.json() + assert body["error_code"] == "POLICY_CONTROL_INCOMPATIBLE" + assert any("corrupted data" in err.get("message", "").lower() for err in body["errors"]) def test_set_agent_policy_rejects_controls_without_evaluator_name(client: TestClient) -> None: diff --git a/server/tests/test_control_phase0_alembic_migration.py b/server/tests/test_control_phase0_alembic_migration.py new file mode 100644 index 00000000..3c5553cb --- /dev/null +++ b/server/tests/test_control_phase0_alembic_migration.py @@ -0,0 +1,318 @@ +"""Alembic coverage for Phase 0 control cleanup and audit backfill.""" + +from __future__ import annotations + +import json +import uuid +from copy import deepcopy +from pathlib import Path +from typing import Any + +import pytest +from alembic import command +from alembic.config import Config +from sqlalchemy import create_engine, text +from sqlalchemy.engine import Engine, make_url + +from agent_control_server.config import db_config + +from .utils import VALID_CONTROL_PAYLOAD + +SERVER_DIR = Path(__file__).resolve().parents[1] +PRE_MIGRATION_REVISION = "5f2b5f4e1a90" +MIGRATION_REVISION = "c1e9f9c4a1d2" +_BASE_DB_URL = make_url(db_config.get_url()) + +pytestmark = pytest.mark.skipif( + _BASE_DB_URL.get_backend_name() != "postgresql", + reason="Phase 0 Alembic migration tests require PostgreSQL.", +) + + +def _unrendered_template_payload() -> dict[str, Any]: + return { + "template": { + "description": "Regex denial template", + "parameters": { + "pattern": { + "type": "regex_re2", + "label": "Pattern", + }, + }, + "definition_template": { + "description": "Template-backed control", + "execution": "server", + "scope": {"step_types": ["llm"], "stages": ["pre"]}, + "condition": { + "selector": {"path": "input"}, + "evaluator": { + "name": "regex", + "config": {"pattern": {"$param": "pattern"}}, + }, + }, + "action": {"decision": "deny"}, + }, + }, + "template_values": {}, + } + + +def _insert_control(engine: Engine, *, name: str, data: Any) -> int: + with engine.begin() as conn: + return int( + conn.execute( + text( + """ + INSERT INTO controls (name, data) + VALUES (:name, CAST(:data AS JSONB)) + RETURNING id + """ + ), + {"name": name, "data": json.dumps(data)}, + ).scalar_one() + ) + + +def _insert_policy(engine: Engine, *, name: str) -> int: + with engine.begin() as conn: + return int( + conn.execute( + text( + """ + INSERT INTO policies (name) + VALUES (:name) + RETURNING id + """ + ), + {"name": name}, + ).scalar_one() + ) + + +def _insert_agent(engine: Engine, *, name: str) -> None: + with engine.begin() as conn: + conn.execute( + text( + """ + INSERT INTO agents (name, data) + VALUES (:name, '{}'::jsonb) + """ + ), + {"name": name}, + ) + + +def _associate_policy_control(engine: Engine, *, policy_id: int, control_id: int) -> None: + with engine.begin() as conn: + conn.execute( + text( + """ + INSERT INTO policy_controls (policy_id, control_id) + VALUES (:policy_id, :control_id) + """ + ), + {"policy_id": policy_id, "control_id": control_id}, + ) + + +def _associate_agent_control(engine: Engine, *, agent_name: str, control_id: int) -> None: + with engine.begin() as conn: + conn.execute( + text( + """ + INSERT INTO agent_controls (agent_name, control_id) + VALUES (:agent_name, :control_id) + """ + ), + {"agent_name": agent_name, "control_id": control_id}, + ) + + +def _fetch_control(engine: Engine, control_id: int) -> dict[str, Any]: + with engine.begin() as conn: + row = conn.execute( + text( + """ + SELECT id, name, data, deleted_at + FROM controls + WHERE id = :id + """ + ), + {"id": control_id}, + ).mappings().one() + return dict(row) + + +def _fetch_versions(engine: Engine, control_id: int) -> list[dict[str, Any]]: + with engine.begin() as conn: + rows = conn.execute( + text( + """ + SELECT version_num, event_type, snapshot, note + FROM control_versions + WHERE control_id = :control_id + ORDER BY version_num + """ + ), + {"control_id": control_id}, + ).mappings() + return [dict(row) for row in rows] + + +def _policy_control_count(engine: Engine, control_id: int) -> int: + with engine.begin() as conn: + return int( + conn.execute( + text( + "SELECT COUNT(*) FROM policy_controls WHERE control_id = :control_id" + ), + {"control_id": control_id}, + ).scalar_one() + ) + + +def _agent_control_count(engine: Engine, control_id: int) -> int: + with engine.begin() as conn: + return int( + conn.execute( + text( + "SELECT COUNT(*) FROM agent_controls WHERE control_id = :control_id" + ), + {"control_id": control_id}, + ).scalar_one() + ) + + +@pytest.fixture +def temp_db_url() -> str: + temp_db_name = f"agent_control_phase0_{uuid.uuid4().hex[:12]}" + admin_url = _BASE_DB_URL.set(database="postgres").render_as_string(hide_password=False) + target_url = _BASE_DB_URL.set(database=temp_db_name).render_as_string(hide_password=False) + + admin_engine = create_engine(admin_url, isolation_level="AUTOCOMMIT") + with admin_engine.connect() as conn: + conn.execute(text(f'CREATE DATABASE "{temp_db_name}"')) + admin_engine.dispose() + + try: + yield target_url + finally: + cleanup_engine = create_engine(admin_url, isolation_level="AUTOCOMMIT") + with cleanup_engine.connect() as conn: + conn.execute( + text( + """ + SELECT pg_terminate_backend(pid) + FROM pg_stat_activity + WHERE datname = :db_name AND pid <> pg_backend_pid() + """ + ), + {"db_name": temp_db_name}, + ) + conn.execute(text(f'DROP DATABASE IF EXISTS "{temp_db_name}"')) + cleanup_engine.dispose() + + +@pytest.fixture +def alembic_config(temp_db_url: str) -> Config: + cfg = Config(str(SERVER_DIR / "alembic.ini")) + cfg.set_main_option("script_location", str(SERVER_DIR / "alembic")) + cfg.set_main_option("sqlalchemy.url", temp_db_url) + return cfg + + +@pytest.fixture +def temp_engine(temp_db_url: str) -> Engine: + engine = create_engine(temp_db_url, future=True) + try: + yield engine + finally: + engine.dispose() + + +@pytest.fixture +def upgrade_to(alembic_config: Config): + def _upgrade(revision: str, *, sql: bool = False) -> None: + command.upgrade(alembic_config, revision, sql=sql) + + return _upgrade + + +def test_upgrade_backfills_versions_for_usable_controls( + upgrade_to, + temp_engine: Engine, +) -> None: + upgrade_to(PRE_MIGRATION_REVISION) + + rendered_id = _insert_control( + temp_engine, + name="rendered-control", + data=deepcopy(VALID_CONTROL_PAYLOAD), + ) + unrendered_id = _insert_control( + temp_engine, + name="unrendered-control", + data=_unrendered_template_payload(), + ) + + upgrade_to(MIGRATION_REVISION) + + for control_id, expected_name in ( + (rendered_id, "rendered-control"), + (unrendered_id, "unrendered-control"), + ): + control = _fetch_control(temp_engine, control_id) + versions = _fetch_versions(temp_engine, control_id) + + assert control["deleted_at"] is None + assert len(versions) == 1 + assert versions[0]["version_num"] == 1 + assert versions[0]["event_type"] == "migration_backfill" + assert versions[0]["note"] == "Backfilled from existing control" + assert versions[0]["snapshot"]["name"] == expected_name + assert versions[0]["snapshot"]["deleted_at"] is None + assert versions[0]["snapshot"]["cloned_control_id"] is None + + +def test_upgrade_soft_deletes_unusable_controls_and_removes_associations( + upgrade_to, + temp_engine: Engine, +) -> None: + upgrade_to(PRE_MIGRATION_REVISION) + + empty_id = _insert_control(temp_engine, name="empty-control", data={}) + corrupted_id = _insert_control(temp_engine, name="corrupted-control", data={"bad": "data"}) + + policy_id = _insert_policy(temp_engine, name="policy-phase0") + agent_name = f"agent-{uuid.uuid4().hex[:12]}" + _insert_agent(temp_engine, name=agent_name) + _associate_policy_control(temp_engine, policy_id=policy_id, control_id=empty_id) + _associate_agent_control(temp_engine, agent_name=agent_name, control_id=corrupted_id) + + upgrade_to(MIGRATION_REVISION) + + empty_control = _fetch_control(temp_engine, empty_id) + corrupted_control = _fetch_control(temp_engine, corrupted_id) + empty_versions = _fetch_versions(temp_engine, empty_id) + corrupted_versions = _fetch_versions(temp_engine, corrupted_id) + + assert empty_control["deleted_at"] is not None + assert corrupted_control["deleted_at"] is not None + assert _policy_control_count(temp_engine, empty_id) == 0 + assert _agent_control_count(temp_engine, corrupted_id) == 0 + + assert [version["event_type"] for version in empty_versions] == [ + "migration_backfill", + "migration_autodelete", + ] + assert [version["event_type"] for version in corrupted_versions] == [ + "migration_backfill", + "migration_autodelete", + ] + assert empty_versions[1]["note"] == "Auto-soft-deleted during migration: empty payload" + assert ( + corrupted_versions[1]["note"] + == "Auto-soft-deleted during migration: invalid control payload" + ) + assert empty_versions[1]["snapshot"]["deleted_at"] is not None + assert corrupted_versions[1]["snapshot"]["deleted_at"] is not None diff --git a/server/tests/test_controls.py b/server/tests/test_controls.py index 87b364cb..eac24890 100644 --- a/server/tests/test_controls.py +++ b/server/tests/test_controls.py @@ -95,13 +95,14 @@ def test_create_control_without_data_returns_422(client: TestClient) -> None: def test_get_control_data_initially_unconfigured(client: TestClient) -> None: - # Given: a legacy control row with no data set + # Given: an invalid legacy control row with an empty payload control_id = create_unconfigured_control() # When: fetching its data resp = client.get(f"/api/v1/controls/{control_id}/data") - # Then: 422 because empty data is not a valid ControlDefinition (RFC 7807 format) + # Then: 422 because empty data is invalid stored control data assert resp.status_code == 422 response_data = resp.json() + assert response_data["error_code"] == "CORRUPTED_DATA" assert "invalid data" in response_data.get("detail", "").lower() @@ -274,20 +275,19 @@ def test_create_control_duplicate_name_409(client: TestClient) -> None: def test_get_control_returns_metadata(client: TestClient) -> None: - """Test GET /controls/{id} returns id, name, and None data for legacy rows.""" - # Given: a legacy control with a specific name and no configured data + """Test GET /controls/{id} rejects active controls with invalid stored data.""" + # Given: a legacy control with a specific name and an invalid empty payload name = f"test-control-{uuid.uuid4()}" control_id = create_unconfigured_control(name) # When: fetching the control get_resp = client.get(f"/api/v1/controls/{control_id}") - # Then: returns id, name, and data (None for legacy unconfigured rows) - assert get_resp.status_code == 200 + # Then: the API reports corrupted stored data instead of returning null data + assert get_resp.status_code == 422 body = get_resp.json() - assert body["id"] == control_id - assert body["name"] == name - assert body["data"] is None # Not configured yet + assert body["error_code"] == "CORRUPTED_DATA" + assert "invalid data" in body["detail"].lower() def test_get_control_with_data(client: TestClient) -> None: diff --git a/server/tests/test_controls_additional.py b/server/tests/test_controls_additional.py index bafbab67..d47cc8aa 100644 --- a/server/tests/test_controls_additional.py +++ b/server/tests/test_controls_additional.py @@ -88,7 +88,11 @@ async def mock_db_integrity_error() -> AsyncGenerator[AsyncSession, None]: def test_patch_control_rename_integrity_error_returns_conflict(client: TestClient) -> None: """DB uniqueness violations during rename should be surfaced as 409 conflicts.""" - control_obj = SimpleNamespace(id=1, name="old-control", data={}) + control_obj = SimpleNamespace( + id=1, + name="old-control", + data=deepcopy(VALID_CONTROL_PAYLOAD), + ) async def mock_db_integrity_error() -> AsyncGenerator[AsyncSession, None]: mock_session = AsyncMock(spec=AsyncSession) @@ -215,18 +219,19 @@ def test_list_controls_filters_and_pagination(client: TestClient) -> None: assert page2["controls"][0]["id"] != first_id -def test_patch_control_enabled_requires_data(client: TestClient) -> None: - # Given: a control without configured data +def test_patch_control_enabled_with_invalid_data_returns_corrupted_data( + client: TestClient, +) -> None: + # Given: a control with an invalid empty payload control_id, _ = _insert_unconfigured_control() - # When: toggling enabled without data + # When: toggling enabled resp = client.patch(f"/api/v1/controls/{control_id}", json={"enabled": False}) - # Then: validation error + # Then: corrupted-data validation is returned assert resp.status_code == 422 data = resp.json() - assert data["error_code"] == "VALIDATION_ERROR" - assert any(err.get("code") == "no_data_configured" for err in data.get("errors", [])) + assert data["error_code"] == "CORRUPTED_DATA" def test_patch_control_rename_conflict(client: TestClient) -> None: @@ -297,7 +302,10 @@ def test_patch_control_legacy_name_preserved_when_name_omitted( text( "INSERT INTO controls (name, data) VALUES (:name, CAST(:data AS JSONB))" ), - {"name": "legacy control name", "data": json.dumps({})}, + { + "name": "legacy control name", + "data": json.dumps(VALID_CONTROL_PAYLOAD), + }, ) row = conn.execute( text("SELECT id FROM controls WHERE name = 'legacy control name'") @@ -589,10 +597,14 @@ def test_delete_control_force_dissociates(client: TestClient) -> None: assert list_resp.status_code == 200 assert control_id not in list_resp.json()["control_ids"] + # And: the deleted control is hidden from active lookups + get_resp = client.get(f"/api/v1/controls/{control_id}") + assert get_resp.status_code == 404 + def test_delete_control_force_dissociates_direct_agent_links(client: TestClient) -> None: # Given: a control directly associated with an agent - control_id, _ = _create_control(client) + control_id, control_name = _create_control(client) _set_control_data(client, control_id, deepcopy(VALID_CONTROL_PAYLOAD)) agent_name = f"agent-{uuid.uuid4().hex[:12]}" @@ -615,10 +627,49 @@ def test_delete_control_force_dissociates_direct_agent_links(client: TestClient) assert body.get("dissociated_from_policies", []) == [] assert body.get("dissociated_from_agents", []) == [agent_name] + # And: the deleted control no longer appears in list results + list_resp = client.get("/api/v1/controls", params={"name": control_name}) + assert list_resp.status_code == 200 + assert list_resp.json()["controls"] == [] + assert list_resp.json()["pagination"]["total"] == 0 + + +def test_create_control_allows_reusing_soft_deleted_name(client: TestClient) -> None: + # Given: a control name that has been soft-deleted + name = f"control-{uuid.uuid4()}" + original_id, _ = _create_control(client, name=name) + + delete_resp = client.delete(f"/api/v1/controls/{original_id}", params={"force": True}) + assert delete_resp.status_code == 200 + + # When: creating a new control with the same name + recreate_resp = client.put("/api/v1/controls", json={"name": name, "data": VALID_CONTROL_PAYLOAD}) + + # Then: creation succeeds because uniqueness only applies to active rows + assert recreate_resp.status_code == 200, recreate_resp.text + assert recreate_resp.json()["control_id"] != original_id + -def test_get_control_corrupted_data_returns_none(client: TestClient) -> None: +def test_patch_control_rename_allows_soft_deleted_name(client: TestClient) -> None: + # Given: a soft-deleted control name and a separate active control + deleted_name = f"control-{uuid.uuid4()}" + deleted_id, _ = _create_control(client, name=deleted_name) + delete_resp = client.delete(f"/api/v1/controls/{deleted_id}", params={"force": True}) + assert delete_resp.status_code == 200 + + control_id, _ = _create_control(client) + + # When: renaming the active control to the deleted control's name + resp = client.patch(f"/api/v1/controls/{control_id}", json={"name": deleted_name}) + + # Then: rename succeeds + assert resp.status_code == 200, resp.text + assert resp.json()["name"] == deleted_name + + +def test_get_control_corrupted_data_returns_422(client: TestClient) -> None: # Given: a control with corrupted data in DB - control_id, control_name = _create_control(client) + control_id, _ = _create_control(client) with engine.begin() as conn: conn.execute( text("UPDATE controls SET data = CAST(:data AS JSONB) WHERE id = :id"), @@ -628,11 +679,10 @@ def test_get_control_corrupted_data_returns_none(client: TestClient) -> None: # When: fetching the control resp = client.get(f"/api/v1/controls/{control_id}") - # Then: data is None but the control metadata is intact - assert resp.status_code == 200 + # Then: corrupted-data validation is returned + assert resp.status_code == 422 body = resp.json() - assert body["name"] == control_name - assert body["data"] is None + assert body["error_code"] == "CORRUPTED_DATA" def test_get_control_data_corrupted_returns_422(client: TestClient) -> None: @@ -669,8 +719,7 @@ def test_patch_control_enabled_with_corrupted_data(client: TestClient) -> None: assert resp.status_code == 422 body = resp.json() assert body["error_code"] == "CORRUPTED_DATA" - assert body["errors"][0]["message"] == "Stored control data is corrupted and cannot be parsed." - assert "ValidationError" not in body["errors"][0]["message"] + assert "ValidationError" not in resp.text def test_set_control_data_agent_scoped_agent_not_found(client: TestClient) -> None: @@ -991,7 +1040,7 @@ def test_patch_control_enabled_preserves_extra_fields(client: TestClient) -> Non assert control.data.get("custom_meta") == {"source": "unit-test"} -def test_patch_control_rename_with_corrupted_data_returns_enabled_none( +def test_patch_control_rename_with_corrupted_data_returns_422( client: TestClient, ) -> None: # Given: a control with corrupted data in DB @@ -1008,6 +1057,6 @@ def test_patch_control_rename_with_corrupted_data_returns_enabled_none( json={"name": f"{control_name}-renamed"}, ) - # Then: rename succeeds and enabled is omitted (None) - assert resp.status_code == 200 - assert resp.json()["enabled"] is None + # Then: corrupted-data validation is returned + assert resp.status_code == 422 + assert resp.json()["error_code"] == "CORRUPTED_DATA" diff --git a/server/tests/test_policies.py b/server/tests/test_policies.py index e961a989..623142b7 100644 --- a/server/tests/test_policies.py +++ b/server/tests/test_policies.py @@ -166,6 +166,36 @@ def test_policy_remove_control_missing_control_returns_404(client: TestClient) - assert resp.json()["error_code"] == "CONTROL_NOT_FOUND" +def test_policy_add_soft_deleted_control_returns_404(client: TestClient) -> None: + # Given: an existing policy and a control that has been soft-deleted + policy_id = _create_policy(client) + control_id = _create_control(client) + delete_resp = client.delete(f"/api/v1/controls/{control_id}") + assert delete_resp.status_code == 200 + + # When: adding the soft-deleted control to the policy + resp = client.post(f"/api/v1/policies/{policy_id}/controls/{control_id}") + + # Then: the deleted control is treated as not found + assert resp.status_code == 404 + assert resp.json()["error_code"] == "CONTROL_NOT_FOUND" + + +def test_policy_remove_soft_deleted_control_returns_404(client: TestClient) -> None: + # Given: an existing policy and a control that has been soft-deleted + policy_id = _create_policy(client) + control_id = _create_control(client) + delete_resp = client.delete(f"/api/v1/controls/{control_id}") + assert delete_resp.status_code == 200 + + # When: removing the soft-deleted control from the policy + resp = client.delete(f"/api/v1/policies/{policy_id}/controls/{control_id}") + + # Then: the deleted control is treated as not found + assert resp.status_code == 404 + assert resp.json()["error_code"] == "CONTROL_NOT_FOUND" + + def test_policy_add_control_db_error_returns_500( app: FastAPI, client: TestClient ) -> None: diff --git a/server/tests/test_policy_integration.py b/server/tests/test_policy_integration.py index f3f28f34..efdb5526 100644 --- a/server/tests/test_policy_integration.py +++ b/server/tests/test_policy_integration.py @@ -508,6 +508,32 @@ def test_agent_control_endpoints_return_404_for_missing_resources(client: TestCl assert resp.json()["error_code"] == "CONTROL_NOT_FOUND" +def test_agent_add_soft_deleted_control_returns_404(client: TestClient) -> None: + """Direct add should reject a soft-deleted control as not found.""" + agent_name, _ = _create_agent(client) + control_id = _create_control(client) + delete_resp = client.delete(f"/api/v1/controls/{control_id}") + assert delete_resp.status_code == 200 + + resp = client.post(f"/api/v1/agents/{agent_name}/controls/{control_id}") + + assert resp.status_code == 404 + assert resp.json()["error_code"] == "CONTROL_NOT_FOUND" + + +def test_agent_remove_soft_deleted_control_returns_404(client: TestClient) -> None: + """Direct remove should reject a soft-deleted control as not found.""" + agent_name, _ = _create_agent(client) + control_id = _create_control(client) + delete_resp = client.delete(f"/api/v1/controls/{control_id}") + assert delete_resp.status_code == 200 + + resp = client.delete(f"/api/v1/agents/{agent_name}/controls/{control_id}") + + assert resp.status_code == 404 + assert resp.json()["error_code"] == "CONTROL_NOT_FOUND" + + def test_agent_gets_controls_from_direct_associations(client: TestClient) -> None: """Agent should see controls directly associated with it.""" agent_name, _ = _create_agent(client) diff --git a/server/tests/test_services_controls.py b/server/tests/test_services_controls.py index b7798a5b..ec6fa12e 100644 --- a/server/tests/test_services_controls.py +++ b/server/tests/test_services_controls.py @@ -1,5 +1,6 @@ from __future__ import annotations +import datetime as dt import uuid from copy import deepcopy @@ -76,6 +77,36 @@ async def test_list_controls_for_policy_returns_controls(async_db) -> None: assert names == {control_a.name, control_b.name} +@pytest.mark.asyncio +async def test_list_controls_for_policy_excludes_deleted_controls(async_db) -> None: + # Given: a policy associated with one active and one soft-deleted control + policy = Policy(name=f"policy-{uuid.uuid4()}") + active_control = Control(name=f"control-{uuid.uuid4()}", data=VALID_CONTROL_PAYLOAD) + deleted_control = Control( + name=f"deleted-control-{uuid.uuid4()}", + data=VALID_CONTROL_PAYLOAD, + deleted_at=dt.datetime.now(dt.UTC), + ) + async_db.add_all([policy, active_control, deleted_control]) + await async_db.flush() + + await async_db.execute( + insert(policy_controls).values( + [ + {"policy_id": policy.id, "control_id": active_control.id}, + {"policy_id": policy.id, "control_id": deleted_control.id}, + ] + ) + ) + await async_db.commit() + + # When: listing controls for the policy + controls = await list_controls_for_policy(policy.id, async_db) + + # Then: only active controls are returned + assert [control.id for control in controls] == [active_control.id] + + @pytest.mark.asyncio async def test_list_controls_for_agent_returns_controls(async_db) -> None: # Given: an agent associated with one policy control and one direct control @@ -111,6 +142,41 @@ async def test_list_controls_for_agent_returns_controls(async_db) -> None: assert ids == sorted(ids, reverse=True) +@pytest.mark.asyncio +async def test_list_controls_for_agent_excludes_deleted_controls(async_db) -> None: + # Given: an agent associated with one active and one soft-deleted control + active_control = Control(name=f"active-control-{uuid.uuid4()}", data=VALID_CONTROL_PAYLOAD) + deleted_control = Control( + name=f"deleted-control-{uuid.uuid4()}", + data=VALID_CONTROL_PAYLOAD, + deleted_at=dt.datetime.now(dt.UTC), + ) + agent = Agent(name=f"agent-{uuid.uuid4()}", data={}) + async_db.add_all([active_control, deleted_control, agent]) + await async_db.flush() + + await async_db.execute( + insert(agent_controls).values( + [ + {"agent_name": agent.name, "control_id": active_control.id}, + {"agent_name": agent.name, "control_id": deleted_control.id}, + ] + ) + ) + await async_db.commit() + + # When: listing controls for the agent + controls = await list_controls_for_agent( + agent.name, + async_db, + rendered_state="all", + enabled_state="all", + ) + + # Then: soft-deleted controls are excluded + assert [control.id for control in controls] == [active_control.id] + + @pytest.mark.asyncio async def test_list_controls_for_agent_filters_by_rendered_and_enabled_state(async_db) -> None: # Given: an agent with active, disabled, and unrendered associated controls