diff --git a/.env-example b/.env-example index da373d2..e284bf1 100644 --- a/.env-example +++ b/.env-example @@ -1,18 +1,22 @@ # PostgreSQL -# (Currently hardcoded in docker-compose.yaml) +POSTGRES_USER=mlsec2 +POSTGRES_PASSWORD=mlsec2_pw -# MinIO -MINIO_ROOT_USER=minioadmin -MINIO_ROOT_PASSWORD=minioadmin +# MinIO Object Storage +MINIO_ROOT_USER=mlsec2 +MINIO_ROOT_PASSWORD=mlsec2_pw -# Docker Configuration -DOCKER_GID=999 +# RabbitMQ +RABBITMQ_USER=mlsec2 +RABBITMQ_PASSWORD=mlsec2_pw # Gateway Secret -# (Currently hardcoded in docker-compose.yaml, could be moved here if needed) -GATEWAY_SECRET=welovemarcus +GATEWAY_SECRET=mlsec2_pw + +# Docker Configuration +DOCKER_GID=999 -# VirusTotal API key (for attack similarity evaluation) +# VirusTotal API key (for attack behavioral evaluation) VIRUSTOTAL_API_KEY= # Email MFA (SMTP) configuration diff --git a/.github/workflows/api_ci.yaml b/.github/workflows/api_ci.yaml index 38caca6..1653289 100644 --- a/.github/workflows/api_ci.yaml +++ b/.github/workflows/api_ci.yaml @@ -20,7 +20,7 @@ jobs: --health-retries 5 env: - DATABASE_URL: postgresql://postgres:password123@localhost:5433/mlsec_test + DATABASE_URL: postgresql://mlsec2:mlsec2_pw@localhost:5433/mlsec_test REDIS_URL: redis://localhost:6379/0 steps: @@ -33,7 +33,7 @@ jobs: - name: Wait for test database readiness run: | for i in {1..30}; do - if docker exec test-postgres-db pg_isready -U postgres -d mlsec_test; then + if docker exec test-postgres-db pg_isready -U mlsec2 -d mlsec_test; then exit 0 fi sleep 2 diff --git a/.github/workflows/worker_ci.yaml b/.github/workflows/worker_ci.yaml index 8227200..a33ca2e 100644 --- a/.github/workflows/worker_ci.yaml +++ b/.github/workflows/worker_ci.yaml @@ -9,11 +9,11 @@ jobs: runs-on: ubuntu-latest env: - DATABASE_URL: postgresql://postgres:password123@localhost:5433/mlsec_test + DATABASE_URL: postgresql://mlsec2:mlsec2_pw@localhost:5433/mlsec_test REDIS_URL: redis://localhost:6379/0 MINIO_ENDPOINT: localhost:9000 - MINIO_ACCESS_KEY: minioadmin - MINIO_SECRET_KEY: minioadmin + MINIO_ACCESS_KEY: mlsec2 + MINIO_SECRET_KEY: mlsec2_pw CELERY_BROKER_URL: redis://localhost:6379/1 CELERY_RESULT_BACKEND: redis://localhost:6379/2 @@ -39,7 +39,7 @@ jobs: - name: Wait for test database readiness run: | for i in {1..30}; do - if docker exec test-postgres-db pg_isready -U postgres -d mlsec_test; then + if docker exec test-postgres-db pg_isready -U mlsec2 -d mlsec_test; then exit 0 fi sleep 2 diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..0c2e987 --- /dev/null +++ b/Makefile @@ -0,0 +1,23 @@ +.PHONY: up down build ps logs worker-count + +# Parse num_workers from config.yaml +NUM_WORKERS=$(shell awk '/worker:/ {found=1} found && /num_workers:/ {print $$2; exit}' config.yaml) + +up: + @echo "Starting platform with $(NUM_WORKERS) workers..." + docker compose up --scale worker=$(NUM_WORKERS) + +down: + docker compose down + +build: + docker compose build + +ps: + docker compose ps + +logs: + docker compose logs -f + +worker-count: + @echo $(NUM_WORKERS) diff --git a/config.yaml b/config.yaml index f67f8e6..9743559 100644 --- a/config.yaml +++ b/config.yaml @@ -1,5 +1,5 @@ worker: - num_workers: 4 # number of concurrent Celery worker processes (maps to --concurrency) + num_workers: 4 # number of concurrent Worker containers (Total Possible Competitor Containers = num_workers*batch_size) defense_job: mem_limit: "1g" nano_cpus: 1000000000 @@ -8,19 +8,19 @@ worker: max_uncompressed_size_mb: 1024 evaluation: requests_timeout_seconds: 5 - batch_size: 4 + batch_size: 2 defense_max_ram: 1024 # MB - soft RAM threshold; sample marked evaded and container restarted if exceeded defense_max_time: 5000 # ms - per-sample time limit; exceeded = evaded defense_max_timeout: 20000 # ms - forced restart threshold (must be >= defense_max_time) defense_max_restarts: 3 # max container restarts before error state stats_sampling_rate: 25 # Number of samples evaluated before checking container stats again (This has a massive impact on total evaluation time) heuristic_validation: - enable_heuristic_validation: true + enable_heuristic_validation: false heurval_malware_fpr_minimum: 0.0 heurval_malware_tpr_minimum: 0.30 heurval_goodware_fpr_minimum: 0.0 heurval_goodware_tpr_minimum: 0.30 - reject_heurval_failures: true + reject_heurval_failures: false source: # Resource limits max_zip_size_mb: 512 @@ -41,9 +41,8 @@ worker: cleanup_pulled_images: true # Remove pulled images after evaluation (Docker Hub sources only) minio: bucket_name: "mlsec-submissions" - access_key: "mlsec_minio_admin" - secret_key: "mlsec_minio_password_change_in_production" attack: + skip_seeding: true # true = skip template seeding and all behavioral checks check_similarity: false # false = skip evaluation, accept all validated attacks reject_dissimilar_attacks: false # only applies when check_similarity=true # true = reject if score < minimum_attack_similarity @@ -52,7 +51,10 @@ worker: max_zip_size_mb: 100 sandbox_backend: "virustotal" # "virustotal" | "local" cache_persistence_duration: 300 # seconds of inactivity before clearing the sample cache + cache_max_size_gb: 10 # Maximum size of local sample cache before pruning application: login_code: 'ABC' + defense_submission_cooldown: 30 # seconds between defense submissions per user; 0 = no cooldown + attack_submission_cooldown: 30 # seconds between attack submissions per user; 0 = no cooldown email_mfa_enabled: true diff --git a/docker-compose.prod.yaml b/docker-compose.prod.yaml deleted file mode 100644 index 958699a..0000000 --- a/docker-compose.prod.yaml +++ /dev/null @@ -1,169 +0,0 @@ -services: - postgres: - image: postgres:18 - container_name: postgres-db - restart: unless-stopped - environment: - POSTGRES_USER: postgres - POSTGRES_PASSWORD: password123 - POSTGRES_DB: mlsec - volumes: - - pgdata:/var/lib/postgresql - - ./services/postgres/init:/docker-entrypoint-initdb.d - networks: - - backend_net - - minio: - image: minio/minio:latest - container_name: mlsec-minio - restart: unless-stopped - environment: - MINIO_ROOT_USER: ${MINIO_ROOT_USER:-mlsec_minio_admin} - MINIO_ROOT_PASSWORD: ${MINIO_ROOT_PASSWORD:-mlsec_minio_password_change_in_production} - command: server /data --console-address ":9001" - volumes: - - miniodata:/data - networks: - - backend_net - - rabbitmq: - image: rabbitmq:3-management - container_name: rabbitmq - restart: unless-stopped - environment: - RABBITMQ_DEFAULT_USER: mlsec - RABBITMQ_DEFAULT_PASS: mlsec - volumes: - - rabbitmqdata:/var/lib/rabbitmq - networks: - - backend_net - - redis: - image: redis:7-alpine - container_name: mlsec-redis - restart: unless-stopped - command: redis-server --appendonly yes - volumes: - - redisdata:/data - networks: - - backend_net - - api: - build: - context: ./services/api - dockerfile: Dockerfile - container_name: mlsec-api - restart: unless-stopped - environment: - DATABASE_URL: postgresql://postgres:password123@postgres:5432/mlsec - CELERY_BROKER_URL: amqp://mlsec:mlsec@rabbitmq:5672// - REDIS_URL: redis://redis:6379/0 - MINIO_ENDPOINT: minio:9000 - MINIO_ACCESS_KEY: ${MINIO_ROOT_USER:-mlsec_minio_admin} - MINIO_SECRET_KEY: ${MINIO_ROOT_PASSWORD:-mlsec_minio_password_change_in_production} - MINIO_SECURE: "false" - # Production-specific settings - ADMIN_LOCALHOST_ONLY: "false" - ADMIN_ALLOWED_HOSTS: '["mlsec2.com", "api:8000"]' - ADMIN_TRUSTED_PROXY_HOSTS: '["127.0.0.1", "::1", "172.16.0.0/12", "172.17.0.0/12", "172.18.0.0/12", "172.19.0.0/12", "172.20.0.0/12", "172.21.0.0/12", "172.22.0.0/12"]' - ADMIN_ALLOWED_NETWORKS: '["172.16.0.0/12"]' - CORS_ALLOW_ORIGINS: '["http://mlsec2.com", "http://localhost:4321"]' - CORS_ALLOW_ORIGIN_REGEX: '^https?://(mlsec2\.com|localhost)(:\d+)?$' - volumes: - - ./config.yaml:/app/config.yaml:ro - networks: - - backend_net - depends_on: - - postgres - - rabbitmq - - redis - - minio - - worker: - build: - context: ./services/worker - dockerfile: Dockerfile - container_name: mlsec-worker - restart: unless-stopped - environment: - DATABASE_URL: postgresql://postgres:password123@postgres:5432/mlsec - CELERY_BROKER_URL: amqp://mlsec:mlsec@rabbitmq:5672// - CELERY_DEFAULT_QUEUE: mlsec - REDIS_URL: redis://redis:6379/0 - MINIO_ENDPOINT: minio:9000 - MINIO_ACCESS_KEY: ${MINIO_ROOT_USER:-mlsec_minio_admin} - MINIO_SECRET_KEY: ${MINIO_ROOT_PASSWORD:-mlsec_minio_password_change_in_production} - MINIO_SECURE: "false" - VIRUSTOTAL_API_KEY: ${VIRUSTOTAL_API_KEY:-} - CELERY_CONCURRENCY: "${WORKER_CONCURRENCY:-4}" - volumes: - - /var/run/docker.sock:/var/run/docker.sock - - ./config.yaml:/app/config.yaml:ro - - worker_cache:/app/cache - group_add: - - "${DOCKER_GID:-999}" - networks: - - backend_net - - defense_net - depends_on: - - postgres - - rabbitmq - - redis - - minio - - frontend: - build: - context: ./services/frontend - dockerfile: Dockerfile.prod - container_name: mlsec-frontend - restart: unless-stopped - environment: - PUBLIC_API_URL: "http://mlsec2.com" - API_INTERNAL_URL: "http://api:8000" - networks: - - backend_net - depends_on: - - api - - nginx: - image: nginx:alpine - container_name: mlsec-nginx - restart: unless-stopped - ports: - - "80:80" - volumes: - - ./services/nginx/nginx.conf:/etc/nginx/conf.d/default.conf:ro - networks: - - backend_net - depends_on: - - api - - frontend - - # Gateway to control traffic for student containers - mlsec-gateway: - build: - context: ./services/gateway - dockerfile: Dockerfile - container_name: mlsec-gateway - restart: unless-stopped - cap_add: - - NET_ADMIN - sysctls: - - net.ipv4.ip_forward=1 - networks: - - defense_net - -networks: - backend_net: - name: backend_net - driver: bridge - defense_net: - name: defense_net - driver: bridge - -volumes: - pgdata: - rabbitmqdata: - redisdata: - miniodata: - worker_cache: diff --git a/docker-compose.yaml b/docker-compose.yaml index ea9aebf..03522d4 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -4,8 +4,8 @@ services: image: postgres:18 container_name: postgres-db environment: - POSTGRES_USER: postgres - POSTGRES_PASSWORD: password123 + POSTGRES_USER: ${POSTGRES_USER:-mlsec2} + POSTGRES_PASSWORD: ${POSTGRES_PASSWORD:-mlsec2_pw} POSTGRES_DB: mlsec ports: - "5432:5432" @@ -20,8 +20,8 @@ services: image: postgres:18 container_name: test-postgres-db environment: - POSTGRES_USER: postgres - POSTGRES_PASSWORD: password123 + POSTGRES_USER: ${POSTGRES_USER:-mlsec2} + POSTGRES_PASSWORD: ${POSTGRES_PASSWORD:-mlsec2_pw} POSTGRES_DB: mlsec_test ports: - "5433:5432" @@ -35,8 +35,8 @@ services: image: minio/minio:latest container_name: mlsec-minio environment: - MINIO_ROOT_USER: ${MINIO_ROOT_USER:-mlsec_minio_admin} - MINIO_ROOT_PASSWORD: ${MINIO_ROOT_PASSWORD:-mlsec_minio_password_change_in_production} + MINIO_ROOT_USER: ${MINIO_ROOT_USER:-mlsec2} + MINIO_ROOT_PASSWORD: ${MINIO_ROOT_PASSWORD:-mlsec2_pw} command: server /data --console-address ":9001" ports: - "9000:9000" # API @@ -55,8 +55,8 @@ services: image: rabbitmq:3-management container_name: rabbitmq environment: - RABBITMQ_DEFAULT_USER: mlsec - RABBITMQ_DEFAULT_PASS: mlsec + RABBITMQ_DEFAULT_USER: ${RABBITMQ_USER:-mlsec2} + RABBITMQ_DEFAULT_PASS: ${RABBITMQ_PASSWORD:-mlsec2_pw} ports: - "5672:5672" # AMQP - "15672:15672" # Management UI @@ -88,12 +88,12 @@ services: dockerfile: Dockerfile container_name: mlsec-api environment: - DATABASE_URL: postgresql://postgres:password123@postgres:5432/mlsec - CELERY_BROKER_URL: amqp://mlsec:mlsec@rabbitmq:5672// + DATABASE_URL: postgresql://${POSTGRES_USER:-mlsec2}:${POSTGRES_PASSWORD:-mlsec2_pw}@postgres:5432/mlsec + CELERY_BROKER_URL: amqp://${RABBITMQ_USER:-mlsec2}:${RABBITMQ_PASSWORD:-mlsec2_pw}@rabbitmq:5672// REDIS_URL: redis://redis:6379/0 MINIO_ENDPOINT: minio:9000 - MINIO_ACCESS_KEY: ${MINIO_ROOT_USER:-mlsec_minio_admin} - MINIO_SECRET_KEY: ${MINIO_ROOT_PASSWORD:-mlsec_minio_password_change_in_production} + MINIO_ACCESS_KEY: ${MINIO_ROOT_USER:-mlsec2} + MINIO_SECRET_KEY: ${MINIO_ROOT_PASSWORD:-mlsec2_pw} MINIO_SECURE: "false" ADMIN_ALLOWED_HOSTS: '["172.22.0.1"]' # Docker bridge gateway for localhost SSH tunnel access ADMIN_TRUSTED_PROXY_HOSTS: '["127.0.0.1", "::1", "172.16.0.0/12"]' # Trust nginx on any Docker bridge subnet @@ -127,16 +127,16 @@ services: context: ./services/worker dockerfile: Dockerfile environment: - DATABASE_URL: postgresql://postgres:password123@postgres:5432/mlsec - CELERY_BROKER_URL: amqp://mlsec:mlsec@rabbitmq:5672// + DATABASE_URL: postgresql://${POSTGRES_USER:-mlsec2}:${POSTGRES_PASSWORD:-mlsec2_pw}@postgres:5432/mlsec + CELERY_BROKER_URL: amqp://${RABBITMQ_USER:-mlsec2}:${RABBITMQ_PASSWORD:-mlsec2_pw}@rabbitmq:5672// CELERY_DEFAULT_QUEUE: mlsec REDIS_URL: redis://redis:6379/0 MINIO_ENDPOINT: minio:9000 - MINIO_ACCESS_KEY: ${MINIO_ROOT_USER:-mlsec_minio_admin} - MINIO_SECRET_KEY: ${MINIO_ROOT_PASSWORD:-mlsec_minio_password_change_in_production} + MINIO_ACCESS_KEY: ${MINIO_ROOT_USER:-mlsec2} + MINIO_SECRET_KEY: ${MINIO_ROOT_PASSWORD:-mlsec2_pw} MINIO_SECURE: "false" VIRUSTOTAL_API_KEY: ${VIRUSTOTAL_API_KEY:-} - CELERY_CONCURRENCY: "${WORKER_CONCURRENCY:-4}" + CELERY_CONCURRENCY: "${WORKER_CONCURRENCY:-1}" # Forces use of a single thread per Worker container volumes: - /var/run/docker.sock:/var/run/docker.sock - ./config.yaml:/app/config.yaml:ro @@ -174,14 +174,25 @@ services: container_name: mlsec-nginx ports: - "80:80" + - "443:443" volumes: - ./services/nginx/nginx.conf:/etc/nginx/conf.d/default.conf:ro + - certbot-conf:/etc/letsencrypt:ro + - certbot-www:/var/www/certbot:ro networks: - backend_net depends_on: - api - frontend + certbot: + image: certbot/certbot + container_name: mlsec-certbot + volumes: + - certbot-conf:/etc/letsencrypt + - certbot-www:/var/www/certbot + entrypoint: "/bin/sh -c 'trap exit TERM; while :; do certbot renew --quiet; sleep 12h & wait $${!}; done'" + # Adminer (Can remove in production if we dont need it) adminer: image: adminer:latest @@ -209,5 +220,7 @@ volumes: redisdata: miniodata: worker_cache: + certbot-conf: + certbot-www: diff --git a/services/api/core/admin.py b/services/api/core/admin.py index 2053fdc..62ad2b8 100644 --- a/services/api/core/admin.py +++ b/services/api/core/admin.py @@ -149,7 +149,6 @@ def _hash_token(token: str) -> str: def require_admin_origin(request: Request, *, require_present: bool = True) -> None: """Ensure Origin/Referer points to localhost (and is present if required).""" - settings = get_settings() origin = request.headers.get("origin") referer = request.headers.get("referer") @@ -168,22 +167,18 @@ def _origin_host(value: str) -> str | None: if origin: origin_host = _origin_host(origin) - if not _is_loopback_host(origin_host) and not _is_in_allowed_hosts( - origin_host, settings.admin_allowed_hosts - ): + if not _is_loopback_host(origin_host): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, - detail="Admin actions require localhost or allowed origin", + detail="Admin actions require localhost origin", ) if referer: referer_host = _origin_host(referer) - if not _is_loopback_host(referer_host) and not _is_in_allowed_hosts( - referer_host, settings.admin_allowed_hosts - ): + if not _is_loopback_host(referer_host): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, - detail="Admin actions require localhost or allowed origin", + detail="Admin actions require localhost origin", ) diff --git a/services/api/core/config.py b/services/api/core/config.py index 16a6e96..fd87b78 100644 --- a/services/api/core/config.py +++ b/services/api/core/config.py @@ -16,14 +16,16 @@ class MinIOConfig(BaseModel): endpoint: str = "minio:9000" - access_key: str = "minioadmin" - secret_key: str = "minioadmin" + access_key: str = "mlsec2" + secret_key: str = "mlsec2_pw" bucket_name: str = "mlsec-submissions" secure: bool = False class ApplicationConfig(BaseModel): join_code: str | None = None + defense_submission_cooldown: int = 0 + attack_submission_cooldown: int = 0 email_mfa_enabled: bool | None = None @@ -53,6 +55,8 @@ def get_config() -> AppConfig: minio=MinIOConfig(**minio_data), application=ApplicationConfig( join_code=join_code, + defense_submission_cooldown=int(app_data.get("defense_submission_cooldown", 0)), + attack_submission_cooldown=int(app_data.get("attack_submission_cooldown", 0)), email_mfa_enabled=email_mfa_enabled, ), ) diff --git a/services/api/core/database.py b/services/api/core/database.py index 48c1cb0..579a731 100644 --- a/services/api/core/database.py +++ b/services/api/core/database.py @@ -15,7 +15,7 @@ from core.settings import get_settings # TODO: Change this to .env for resolving password, port, etc. -DEFAULT_DATABASE_URL = "postgresql://postgres:password123@postgres:5432/mlsec" +DEFAULT_DATABASE_URL = "postgresql://mlsec2:mlsec2_pw@postgres:5432/mlsec" logger = logging.getLogger(__name__) diff --git a/services/api/core/settings.py b/services/api/core/settings.py index dbf2125..ef7493e 100644 --- a/services/api/core/settings.py +++ b/services/api/core/settings.py @@ -26,8 +26,8 @@ class Settings(BaseSettings): # MinIO object storage minio_endpoint: str = "minio:9000" - minio_access_key: str = "minioadmin" - minio_secret_key: str = "minioadmin" + minio_access_key: str = "mlsec2" + minio_secret_key: str = "mlsec2_pw" minio_secure: bool = False minio_bucket_name: str = "mlsec-submissions" diff --git a/services/api/core/storage.py b/services/api/core/storage.py index 570709e..3cba640 100644 --- a/services/api/core/storage.py +++ b/services/api/core/storage.py @@ -21,8 +21,8 @@ def get_minio_client() -> Minio: """Singleton MinIO client factory.""" cfg = get_config().minio http_client = urllib3.PoolManager( - timeout=urllib3.Timeout(connect=5, read=30), - retries=urllib3.Retry(total=0), + timeout=urllib3.Timeout(connect=5, read=600), + retries=urllib3.Retry(total=2), ) return Minio( endpoint=cfg.endpoint, diff --git a/services/api/core/submission_control.py b/services/api/core/submission_control.py index 5e9e60f..9cabaa6 100644 --- a/services/api/core/submission_control.py +++ b/services/api/core/submission_control.py @@ -2,6 +2,7 @@ from __future__ import annotations +import math from dataclasses import dataclass from datetime import datetime, timezone @@ -91,7 +92,12 @@ def set_manual_closed( closed: bool, updated_by: str | None, ) -> SubmissionControl: - """Toggle the manual close flag and return the updated control state.""" + """Toggle the manual close flag and return the updated control state. + + When opening submissions (closed=False), also clears a lapsed scheduled + close time so the submission window is actually open afterwards. + """ + now = _utcnow() row = ( db.execute( text( @@ -100,6 +106,11 @@ def set_manual_closed( VALUES (1, :manual_closed, :updated_at, :updated_by) ON CONFLICT (id) DO UPDATE SET manual_closed = EXCLUDED.manual_closed, + close_at = CASE + WHEN NOT :manual_closed AND submission_control.close_at <= :now + THEN NULL + ELSE submission_control.close_at + END, updated_at = EXCLUDED.updated_at, updated_by = EXCLUDED.updated_by RETURNING manual_closed, close_at, updated_at, updated_by @@ -107,8 +118,9 @@ def set_manual_closed( ), { "manual_closed": closed, - "updated_at": _utcnow(), + "updated_at": now, "updated_by": updated_by, + "now": now, }, ) .mappings() @@ -176,3 +188,54 @@ def ensure_submissions_open(db: Session) -> None: status_code=status.HTTP_403_FORBIDDEN, detail="Submissions are closed (deadline passed)", ) + + +def get_cooldown_remaining( + db: Session, + *, + user_id: str, + submission_type: str, + cooldown_seconds: int, +) -> int | None: + """Return remaining cooldown seconds, or None if no cooldown is active.""" + if cooldown_seconds <= 0: + return None + row = db.execute( + text( + """ + SELECT MAX(created_at) + FROM submissions + WHERE user_id = :user_id + AND submission_type = :submission_type + AND deleted_at IS NULL + """ + ), + {"user_id": user_id, "submission_type": submission_type}, + ).fetchone() + if row is None or row[0] is None: + return None + last_submitted = _as_utc(row[0]) + elapsed = (_utcnow() - last_submitted).total_seconds() + remaining = cooldown_seconds - elapsed + return math.ceil(remaining) if remaining > 0 else None + + +def check_cooldown( + db: Session, + *, + user_id: str, + submission_type: str, + cooldown_seconds: int, +) -> None: + """Raise HTTP 429 if the user submitted within the cooldown window.""" + remaining = get_cooldown_remaining( + db, + user_id=user_id, + submission_type=submission_type, + cooldown_seconds=cooldown_seconds, + ) + if remaining: + raise HTTPException( + status_code=status.HTTP_429_TOO_MANY_REQUESTS, + detail=f"Please wait {remaining} seconds before submitting again.", + ) diff --git a/services/api/core/submissions.py b/services/api/core/submissions.py index 6769b26..7081d46 100644 --- a/services/api/core/submissions.py +++ b/services/api/core/submissions.py @@ -45,32 +45,23 @@ def require_submission_of_type( def validate_docker_image_format(image: str) -> None: """ - Validate Docker image URL/name format. + Validate Docker image reference format. Accepts formats like: + - nginx - nginx:latest - - user/repo:tag - - hub.docker.com/r/user/image - - registry.io/project/image:tag + - user/image:tag + - registry.io/user/image:tag Raises: HTTPException(400): If format is invalid """ image = image.strip() - if not image: + pattern = r"^[a-zA-Z0-9][a-zA-Z0-9._\-/]*(:[a-zA-Z0-9._\-]+)?$" + if not re.match(pattern, image): raise HTTPException( - status_code=400, detail="Docker image cannot be empty") - - # Basic validation - detailed validation happens in worker - # Just check for obviously bad patterns - if " " in image: - raise HTTPException( - status_code=400, detail="Docker image name cannot contain spaces" - ) - - if image.startswith("-") or image.endswith("-"): - raise HTTPException( - status_code=400, detail="Docker image name cannot start or end with dash" + status_code=400, + detail="Invalid Docker image format. Expected: image, image:tag, user/image:tag, or registry/path:tag", ) @@ -83,11 +74,11 @@ def validate_github_url_format(url: str) -> None: Raises: HTTPException(400): If format is invalid """ - pattern = r"^https://github\.com/[\w-]+/[\w-]+(\.git)?$" + pattern = r"^https://github\.com/[\w-]+/[\w-]+(/tree/(?!.*\.\.)([\w.\-/]+))?(\.git)?$" if not re.match(pattern, url.strip()): raise HTTPException( status_code=400, - detail="Invalid GitHub URL format. Must be https://github.com/username/repository", + detail="Invalid GitHub URL format. Expected: https://github.com/username/repository or https://github.com/username/repository/tree/branch", ) diff --git a/services/api/main.py b/services/api/main.py index 93ce861..b5a0583 100644 --- a/services/api/main.py +++ b/services/api/main.py @@ -68,8 +68,7 @@ def create_app() -> FastAPI: app.include_router(queue_router) app.include_router(submissions_router, prefix="/api") app.include_router(leaderboard_router) - app.include_router(admin_router) # this needs to be fixed but I think this will fix our problems for now - app.include_router(admin_router, prefix="/api") + app.include_router(admin_router) return app diff --git a/services/api/routers/admin.py b/services/api/routers/admin.py index 5074fec..0d802e9 100644 --- a/services/api/routers/admin.py +++ b/services/api/routers/admin.py @@ -1,5 +1,6 @@ from __future__ import annotations +import csv import io import json import logging @@ -8,6 +9,7 @@ from uuid import UUID, uuid4 from fastapi import APIRouter, Depends, HTTPException, Query, Request, UploadFile, status +from fastapi.responses import StreamingResponse from sqlalchemy import text from sqlalchemy.orm import Session @@ -53,6 +55,9 @@ AdminEvaluationPairRecord, AdminSubmissionEvaluationsResponse, AdminActivateSubmissionResponse, + JobDetailResponse, + JobDetailSubmission, + JobDetailEvalRun, ) router = APIRouter( @@ -461,6 +466,191 @@ def get_recent_jobs( return AdminJobLogsResponse(count=len(items), items=items) +@router.get("/logs/jobs/{job_id}/detail", response_model=JobDetailResponse) +def get_job_detail( + job_id: str, + _: AuthenticatedUser = Depends(require_admin_user), + db: Session = Depends(get_db), +) -> JobDetailResponse: + """Return extended detail for a single job record.""" + row = ( + db.execute( + text( + """ + SELECT id, job_type, status, requested_by_user_id, payload, created_at, updated_at + FROM jobs + WHERE id = :id + """ + ), + {"id": job_id}, + ) + .mappings() + .fetchone() + ) + if row is None: + raise HTTPException(status_code=404, detail="Job not found") + + job = AdminJobLogRecord(**row) + payload = row["payload"] or {} + submission: JobDetailSubmission | None = None + eval_runs: list[JobDetailEvalRun] = [] + + if job.job_type == "D": + sub_id = payload.get("defense_submission_id") + if sub_id: + sub_row = ( + db.execute( + text( + """ + SELECT s.id, s.version, s.display_name, s.status, + d.source_type + FROM submissions s + LEFT JOIN defense_submission_details d ON d.submission_id = s.id + WHERE s.id = :id + """ + ), + {"id": sub_id}, + ) + .mappings() + .fetchone() + ) + heurval_done: int | None = None + heurval_total: int | None = None + hv_row = ( + db.execute( + text( + """ + SELECT hr.sample_set_id, + COUNT(hfr.id) AS done, + (SELECT COUNT(*) FROM heurval_samples hs + WHERE hs.sample_set_id = hr.sample_set_id) AS total + FROM heurval_results hr + LEFT JOIN heurval_file_results hfr ON hfr.heurval_result_id = hr.id + WHERE hr.defense_submission_id = :sub_id + GROUP BY hr.sample_set_id + LIMIT 1 + """ + ), + {"sub_id": sub_id}, + ) + .mappings() + .fetchone() + ) + if hv_row: + heurval_done = int(hv_row["done"]) + heurval_total = int(hv_row["total"]) + if sub_row: + submission = JobDetailSubmission( + submission_id=str(sub_row["id"]), + version=sub_row["version"], + display_name=sub_row["display_name"], + status=sub_row["status"], + source_type=sub_row["source_type"], + heurval_done=heurval_done, + heurval_total=heurval_total, + ) + run_rows = ( + db.execute( + text( + """ + SELECT er.id, + er.attack_submission_id, + er.status, + er.duration_ms, + COUNT(efr.id) AS files_done, + (SELECT COUNT(*) FROM attack_files af + WHERE af.attack_submission_id = er.attack_submission_id) AS files_total + FROM evaluation_runs er + LEFT JOIN evaluation_file_results efr ON efr.evaluation_run_id = er.id + WHERE er.defense_submission_id = :id + GROUP BY er.id, er.attack_submission_id, er.status, er.duration_ms + ORDER BY er.created_at DESC + LIMIT 10 + """ + ), + {"id": sub_id}, + ) + .mappings() + .fetchall() + ) + eval_runs = [ + JobDetailEvalRun( + id=str(r["id"]), + counterpart_id=str(r["attack_submission_id"]), + status=r["status"], + duration_ms=r["duration_ms"], + files_done=int(r["files_done"]), + files_total=int(r["files_total"]), + ) + for r in run_rows + ] + + elif job.job_type == "A": + sub_id = payload.get("attack_submission_id") + if sub_id: + sub_row = ( + db.execute( + text( + """ + SELECT s.id, s.version, s.display_name, s.status, + a.file_count + FROM submissions s + LEFT JOIN attack_submission_details a ON a.submission_id = s.id + WHERE s.id = :id + """ + ), + {"id": sub_id}, + ) + .mappings() + .fetchone() + ) + if sub_row: + submission = JobDetailSubmission( + submission_id=str(sub_row["id"]), + version=sub_row["version"], + display_name=sub_row["display_name"], + status=sub_row["status"], + file_count=sub_row["file_count"], + ) + run_rows = ( + db.execute( + text( + """ + SELECT er.id, + er.defense_submission_id, + er.status, + er.duration_ms, + COUNT(efr.id) AS files_done, + (SELECT COUNT(*) FROM attack_files af + WHERE af.attack_submission_id = er.attack_submission_id) AS files_total + FROM evaluation_runs er + LEFT JOIN evaluation_file_results efr ON efr.evaluation_run_id = er.id + WHERE er.attack_submission_id = :id + GROUP BY er.id, er.defense_submission_id, er.status, er.duration_ms + ORDER BY er.created_at DESC + LIMIT 10 + """ + ), + {"id": sub_id}, + ) + .mappings() + .fetchall() + ) + eval_runs = [ + JobDetailEvalRun( + id=str(r["id"]), + counterpart_id=str(r["defense_submission_id"]), + status=r["status"], + duration_ms=r["duration_ms"], + files_done=int(r["files_done"]), + files_total=int(r["files_total"]), + ) + for r in run_rows + ] + + return JobDetailResponse(job=job, submission=submission, evaluation_runs=eval_runs) + + @router.get("/logs/evaluations", response_model=AdminEvaluationLogsResponse) def get_recent_evaluations( _: AuthenticatedUser = Depends(require_admin_user), @@ -1521,7 +1711,7 @@ def activate_submission( sub_row = db.execute( text(""" - SELECT id, user_id, submission_type + SELECT id, user_id, submission_type, status FROM submissions WHERE id = CAST(:sid AS uuid) AND deleted_at IS NULL """), @@ -1530,6 +1720,13 @@ def activate_submission( if sub_row is None: raise HTTPException(status_code=404, detail="Submission not found") + status: str = sub_row["status"] + if status not in ("validated", "evaluated"): + raise HTTPException( + status_code=409, + detail="Submission must be validated or evaluated before it can be set as active", + ) + user_id: str = str(sub_row["user_id"]) submission_type: str = sub_row["submission_type"] @@ -1560,6 +1757,27 @@ def activate_submission( ) db.commit() + # Setting to active mimics an initial submission + from routers.queue import _insert_job, _publish_task + from schemas.jobs import JobType + + if submission_type == "defense": + j_type = JobType.DEFENSE + payload = {"defense_submission_id": submission_id} + else: + j_type = JobType.ATTACK + payload = {"attack_submission_id": submission_id} + + job_id = _insert_job( + db=db, + job_type=j_type.value, + payload=payload, + requested_by_user_id=current_user.user_id, + ) + _publish_task(job_type=j_type, job_id=job_id, payload=payload) + + logger.info(f"Enqueued {submission_type} job {job_id} after admin set active") + client_ip, user_agent = _request_meta(request) log_audit_event( event_type=event_type, @@ -1599,3 +1817,235 @@ def activate_submission( message=str(exc), ) raise + + +# --------------------------------------------------------------------------- +# CSV exports +# --------------------------------------------------------------------------- + +def _csv_response(rows: list[list], filename: str) -> StreamingResponse: + """Build a StreamingResponse from a 2-D list of CSV rows.""" + buf = io.StringIO() + writer = csv.writer(buf) + writer.writerows(rows) + buf.seek(0) + return StreamingResponse( + iter([buf.getvalue()]), + media_type="text/csv", + headers={"Content-Disposition": f'attachment; filename="{filename}"'}, + ) + + +def _submission_label(username: str, display_name: str | None, version: str) -> str: + return f"{username} / {display_name or version}" + + +@router.get("/export/scores/all") +def export_all_evaluation_scores( + _: AuthenticatedUser = Depends(require_admin_user), + db: Session = Depends(get_db), +) -> StreamingResponse: + """Download confusion-matrix CSV (TP, FP, FN, TN) for all active submission pairs.""" + axis_rows = db.execute( + text(""" + SELECT u.username, s.display_name, s.version, s.id::text, a.submission_type + FROM active_submissions a + JOIN submissions s ON s.id = a.submission_id + JOIN users u ON u.id = a.user_id + WHERE u.disabled_at IS NULL AND s.deleted_at IS NULL + ORDER BY a.submission_type, u.username + """) + ).fetchall() + + attackers = [r for r in axis_rows if r[4] == "attack"] + defenders = [r for r in axis_rows if r[4] == "defense"] + + if not attackers or not defenders: + return _csv_response( + [["No active submission pairs available."]], + "evaluation_scores_all.csv", + ) + + attack_ids = [r[3] for r in attackers] + defense_ids = [r[3] for r in defenders] + + file_rows = db.execute( + text(""" + SELECT eps.defense_submission_id::text, + eps.attack_submission_id::text, + af.is_malware, + efr.model_output + FROM evaluation_pair_scores eps + JOIN evaluation_runs er ON er.id = eps.latest_evaluation_run_id + JOIN evaluation_file_results efr ON efr.evaluation_run_id = er.id + JOIN attack_files af ON af.id = efr.attack_file_id + WHERE eps.defense_submission_id::text = ANY(:def_ids) + AND eps.attack_submission_id::text = ANY(:atk_ids) + AND eps.latest_evaluation_run_id IS NOT NULL + AND efr.model_output IS NOT NULL + AND af.is_malware IS NOT NULL + """), + {"def_ids": defense_ids, "atk_ids": attack_ids}, + ).fetchall() + + confusion: dict[tuple[str, str], dict[str, int]] = {} + for fr in file_rows: + key = (fr[0], fr[1]) + if key not in confusion: + confusion[key] = {"tp": 0, "fp": 0, "fn": 0, "tn": 0} + if fr[3] == 1 and fr[2]: confusion[key]["tp"] += 1 + elif fr[3] == 1 and not fr[2]: confusion[key]["fp"] += 1 + elif fr[3] == 0 and fr[2]: confusion[key]["fn"] += 1 + elif fr[3] == 0 and not fr[2]: confusion[key]["tn"] += 1 + + header = ["Defense \\ Attack"] + [_submission_label(r[0], r[1], r[2]) for r in attackers] + data_rows: list[list] = [header] + for d in defenders: + did = d[3] + cells = [] + for a in attackers: + c = confusion.get((did, a[3])) + cells.append(f"({c['tp']},{c['fp']},{c['fn']},{c['tn']})" if c else "") + data_rows.append([_submission_label(d[0], d[1], d[2])] + cells) + + return _csv_response(data_rows, "evaluation_scores_all.csv") + + +@router.get("/export/scores/individual") +def export_individual_evaluation_scores( + defense_submission_id: UUID = Query(...), + attack_submission_id: UUID = Query(...), + _: AuthenticatedUser = Depends(require_admin_user), + db: Session = Depends(get_db), +) -> StreamingResponse: + """Download per-file model output for a single defense/attack pair.""" + rows = db.execute( + text(""" + SELECT af.filename, efr.model_output, af.is_malware + FROM evaluation_pair_scores eps + JOIN evaluation_runs er ON er.id = eps.latest_evaluation_run_id + JOIN evaluation_file_results efr ON efr.evaluation_run_id = er.id + JOIN attack_files af ON af.id = efr.attack_file_id + WHERE eps.defense_submission_id = :def_id + AND eps.attack_submission_id = :atk_id + AND eps.latest_evaluation_run_id IS NOT NULL + ORDER BY af.filename + """), + {"def_id": str(defense_submission_id), "atk_id": str(attack_submission_id)}, + ).fetchall() + + if not rows: + return _csv_response( + [["No evaluation data found for this pair."]], + "evaluation_scores_individual.csv", + ) + + filenames = [r[0] or "unknown" for r in rows] + outputs = [str(r[1]) if r[1] is not None else "" for r in rows] + ground = [str(r[2]) if r[2] is not None else "" for r in rows] + + return _csv_response( + [ + [""] + filenames, + ["Model Output"] + outputs, + ["Is Malware (ground truth)"] + ground, + ], + "evaluation_scores_individual.csv", + ) + + +@router.get("/export/validation-scores") +def export_validation_scores( + _: AuthenticatedUser = Depends(require_admin_user), + db: Session = Depends(get_db), +) -> StreamingResponse: + """Download heuristic-validation result grid: defenses vs validation samples.""" + rows = db.execute( + text(""" + SELECT u.username, s.display_name, s.version, s.id::text, + hs.filename, hfr.model_output + FROM heurval_results hr + JOIN submissions s ON s.id = hr.defense_submission_id + JOIN users u ON u.id = s.user_id + JOIN heurval_file_results hfr ON hfr.heurval_result_id = hr.id + JOIN heurval_samples hs ON hs.id = hfr.sample_id + WHERE s.deleted_at IS NULL + ORDER BY s.created_at, hs.filename + """) + ).fetchall() + + if not rows: + return _csv_response([["No validation scores available."]], "validation_scores.csv") + + sub_order: list[str] = [] + sub_labels: dict[str, str] = {} + cells: dict[str, dict[str, int | None]] = {} + all_files: set[str] = set() + + for r in rows: + sid = r[3] + if sid not in sub_labels: + sub_order.append(sid) + sub_labels[sid] = _submission_label(r[0], r[1], r[2]) + cells[sid] = {} + cells[sid][r[4]] = r[5] + all_files.add(r[4]) + + sorted_files = sorted(all_files) + data_rows: list[list] = [["Defense"] + sorted_files] + for sid in sub_order: + row_cells = [str(cells[sid].get(f, "")) for f in sorted_files] + data_rows.append([sub_labels[sid]] + row_cells) + + return _csv_response(data_rows, "validation_scores.csv") + + +@router.get("/export/behavioral-analysis") +def export_behavioral_analysis( + _: AuthenticatedUser = Depends(require_admin_user), + db: Session = Depends(get_db), +) -> StreamingResponse: + """Download behavioral analysis status grid: attacks vs template files.""" + rows = db.execute( + text(""" + SELECT u.username, s.display_name, s.version, s.id::text, + orig.filename AS template_filename, + af.behavior_status + FROM submissions s + JOIN users u ON u.id = s.user_id + JOIN attack_files af ON af.attack_submission_id = s.id + JOIN attack_files orig ON orig.id = af.original_file_id + WHERE s.submission_type = 'attack' + AND s.deleted_at IS NULL + AND af.original_file_id IS NOT NULL + ORDER BY s.created_at, orig.filename + """) + ).fetchall() + + if not rows: + return _csv_response( + [["No behavioral analysis data available."]], + "behavioral_analysis.csv", + ) + + sub_order: list[str] = [] + sub_labels: dict[str, str] = {} + cells: dict[str, dict[str, str]] = {} + all_files: set[str] = set() + + for r in rows: + sid = r[3] + if sid not in sub_labels: + sub_order.append(sid) + sub_labels[sid] = _submission_label(r[0], r[1], r[2]) + cells[sid] = {} + cells[sid][r[4]] = r[5] or "" + all_files.add(r[4]) + + sorted_files = sorted(all_files) + data_rows: list[list] = [["Attack"] + sorted_files] + for sid in sub_order: + row_cells = [cells[sid].get(f, "") for f in sorted_files] + data_rows.append([sub_labels[sid]] + row_cells) + + return _csv_response(data_rows, "behavioral_analysis.csv") diff --git a/services/api/routers/submissions.py b/services/api/routers/submissions.py index 60df912..b1d0378 100644 --- a/services/api/routers/submissions.py +++ b/services/api/routers/submissions.py @@ -24,13 +24,15 @@ validate_github_url_format, validate_semver_format, ) -from core.submission_control import ensure_submissions_open +from core.config import get_config +from core.submission_control import check_cooldown, ensure_submissions_open, get_cooldown_remaining from routers.queue import _insert_job, _publish_task from schemas.jobs import JobType from schemas.submissions import ( CreateDefenseDockerRequest, CreateDefenseGitHubRequest, SetActiveResponse, + SubmissionDetailResponse, SubmissionListItem, SubmissionHistoryResponse, SubmissionResponse, @@ -106,6 +108,7 @@ def create_defense_docker( """ # Enforce admin-controlled submission window before accepting new work. ensure_submissions_open(db) + check_cooldown(db, user_id=str(current_user.user_id), submission_type='defense', cooldown_seconds=get_config().application.defense_submission_cooldown) # 1. Validate format validate_docker_image_format(req.docker_image) validate_semver_format(req.version) @@ -198,6 +201,7 @@ def create_defense_github( """ # Enforce admin-controlled submission window before accepting new work. ensure_submissions_open(db) + check_cooldown(db, user_id=str(current_user.user_id), submission_type='defense', cooldown_seconds=get_config().application.defense_submission_cooldown) # 1. Validate format validate_github_url_format(req.git_repo) validate_semver_format(req.version) @@ -292,6 +296,7 @@ async def create_defense_zip( """ # Enforce admin-controlled submission window before accepting new work. ensure_submissions_open(db) + check_cooldown(db, user_id=str(current_user.user_id), submission_type='defense', cooldown_seconds=get_config().application.defense_submission_cooldown) settings = get_settings() # 1. Validate file @@ -448,6 +453,7 @@ async def create_attack_zip( """ # Enforce admin-controlled submission window before accepting new work. ensure_submissions_open(db) + check_cooldown(db, user_id=str(current_user.user_id), submission_type='attack', cooldown_seconds=get_config().application.attack_submission_cooldown) settings = get_settings() # 1. Validate file @@ -720,7 +726,98 @@ def set_active_submission( except Exception: logger.warning("Failed to publish leaderboard update after active submission change") + # Setting to active mimics an initial submission + from routers.queue import _insert_job, _publish_task + from schemas.jobs import JobType + + if sub_type == "defense": + j_type = JobType.DEFENSE + payload = {"defense_submission_id": submission_id} + else: + j_type = JobType.ATTACK + payload = {"attack_submission_id": submission_id} + + job_id = _insert_job( + db=db, + job_type=j_type.value, + payload=payload, + requested_by_user_id=current_user.user_id, + ) + _publish_task(job_type=j_type, job_id=job_id, payload=payload) + + logger.info(f"Enqueued {sub_type} job {job_id} after setting active") + return SetActiveResponse( submission_id=submission_id, submission_type=sub_type, ) + + +@router.get("/{submission_id}/detail", response_model=SubmissionDetailResponse) +def get_submission_detail( + submission_id: str, + current_user: AuthenticatedUser = Depends(get_authenticated_user), + db: Session = Depends(get_db), +) -> SubmissionDetailResponse: + """Return source metadata for a submission belonging to the authenticated user.""" + row = db.execute( + text( + """ + SELECT + s.id, + s.submission_type, + s.created_at, + d.source_type, + d.sha256, + d.docker_image, + d.git_repo, + a.zip_sha256 + FROM submissions s + LEFT JOIN defense_submission_details d ON d.submission_id = s.id + LEFT JOIN attack_submission_details a ON a.submission_id = s.id + WHERE s.id = :id + AND s.user_id = :user_id + AND s.deleted_at IS NULL + """ + ), + {"id": submission_id, "user_id": str(current_user.user_id)}, + ).mappings().fetchone() + + if row is None: + raise HTTPException(status_code=404, detail="Submission not found") + + return SubmissionDetailResponse( + submission_id=str(row["id"]), + created_at=row["created_at"].isoformat(), + source_type=row["source_type"], + sha256=row["sha256"] or row["zip_sha256"], + docker_image=row["docker_image"], + git_repo=row["git_repo"], + ) + + +@router.get("/cooldown") +def get_submission_cooldown( + current_user: AuthenticatedUser = Depends(get_authenticated_user), + db: Session = Depends(get_db), +) -> dict: + """Return remaining cooldown seconds for each submission type. + + A null value means no cooldown is currently active. + """ + cfg = get_config() + user_id = str(current_user.user_id) + return { + "defense_remaining_seconds": get_cooldown_remaining( + db, + user_id=user_id, + submission_type="defense", + cooldown_seconds=cfg.application.defense_submission_cooldown, + ), + "attack_remaining_seconds": get_cooldown_remaining( + db, + user_id=user_id, + submission_type="attack", + cooldown_seconds=cfg.application.attack_submission_cooldown, + ), + } diff --git a/services/api/schemas/admin.py b/services/api/schemas/admin.py index 07d7a5a..ea76a00 100644 --- a/services/api/schemas/admin.py +++ b/services/api/schemas/admin.py @@ -38,6 +38,32 @@ class AdminJobLogsResponse(BaseModel): items: list[AdminJobLogRecord] +class JobDetailSubmission(BaseModel): + submission_id: str + version: str + display_name: str | None + status: str + source_type: str | None = None + file_count: int | None = None + heurval_done: int | None = None + heurval_total: int | None = None + + +class JobDetailEvalRun(BaseModel): + id: str + counterpart_id: str + status: str | None + duration_ms: int | None + files_done: int | None = None + files_total: int | None = None + + +class JobDetailResponse(BaseModel): + job: AdminJobLogRecord + submission: JobDetailSubmission | None + evaluation_runs: list[JobDetailEvalRun] + + class AdminEvaluationLogRecord(BaseModel): id: UUID defense_submission_id: UUID diff --git a/services/api/schemas/submissions.py b/services/api/schemas/submissions.py index ea736a2..964cb17 100644 --- a/services/api/schemas/submissions.py +++ b/services/api/schemas/submissions.py @@ -11,7 +11,7 @@ class CreateDefenseDockerRequest(BaseModel): """Request schema for Docker Hub defense submission.""" - docker_image: str = Field(..., min_length=1, max_length=500) + docker_image: str = Field(..., pattern=r"^[a-zA-Z0-9][a-zA-Z0-9._\-/]*(:[a-zA-Z0-9._\-]+)?$", max_length=500) version: str = Field(..., pattern=r"^\d+\.\d+\.\d+$") display_name: str | None = Field(None, max_length=200) @@ -25,17 +25,20 @@ def validate_docker_format(cls, v: str) -> str: class CreateDefenseGitHubRequest(BaseModel): """Request schema for GitHub repository defense submission.""" - git_repo: str = Field(..., pattern=r"^https://github\.com/[\w-]+/[\w-]+") + git_repo: str = Field(..., max_length=500) version: str = Field(..., pattern=r"^\d+\.\d+\.\d+$") display_name: str | None = Field(None, max_length=200) @field_validator("git_repo") @classmethod def validate_github_url(cls, v: str) -> str: - """Strip whitespace and .git suffix from GitHub URL.""" + """Strip whitespace, remove .git suffix, and enforce GitHub URL format.""" + import re v = v.strip() if v.endswith(".git"): v = v[:-4] + if not re.match(r"^https://github\.com/[\w-]+/[\w-]+(/tree/(?!.*\.\.)([\w.\-/]+))?$", v): + raise ValueError("Provided path is incorrectly formatted") return v @@ -113,3 +116,14 @@ class SubmissionHistoryResponse(BaseModel): total: int limit: int offset: int + + +class SubmissionDetailResponse(BaseModel): + """Source metadata for a single submission, fetched on demand.""" + + submission_id: str + created_at: str + source_type: str | None + sha256: str | None + docker_image: str | None + git_repo: str | None diff --git a/services/api/tests/conftest.py b/services/api/tests/conftest.py index eba47fb..2137f05 100644 --- a/services/api/tests/conftest.py +++ b/services/api/tests/conftest.py @@ -5,11 +5,17 @@ from sqlalchemy.orm import sessionmaker from fastapi.testclient import TestClient +import os +os.environ["CELERY_BROKER_URL"] = "memory://" +os.environ["CELERY_DEFAULT_QUEUE"] = "mlsec" + from main import app from core.database import Base, get_db +from core.celery_app import get_celery +from unittest.mock import MagicMock -TEST_DB_URL = "postgresql://postgres:password123@localhost:5433/mlsec_test" +TEST_DB_URL = os.getenv("DATABASE_URL", "postgresql://mlsec2:mlsec2_pw@localhost:5433/mlsec_test") engine = create_engine(TEST_DB_URL) TestingSessionLocal = sessionmaker(bind=engine) @@ -94,6 +100,14 @@ def override_get_db(): app.dependency_overrides.pop(get_db, None) +@pytest.fixture(autouse=True) +def mock_celery(monkeypatch): + """Mock Celery to prevent physical broker connections during tests.""" + mock_app = MagicMock() + monkeypatch.setattr("core.celery_app.get_celery", lambda: mock_app) + return mock_app + + @pytest.fixture() def fake_redis(): """Provide fake Redis client for testing without external dependency.""" diff --git a/services/api/tests/test_audit.py b/services/api/tests/test_audit.py new file mode 100644 index 0000000..7482626 --- /dev/null +++ b/services/api/tests/test_audit.py @@ -0,0 +1,119 @@ +"""Tests for core/audit.py.""" + +from __future__ import annotations + +from unittest.mock import MagicMock, call +from uuid import UUID, uuid4 + +import pytest + +import core.audit as audit_module +from core.audit import log_audit_event + + +def _make_engine_mock(): + mock_conn = MagicMock() + mock_ctx = MagicMock() + mock_ctx.__enter__ = MagicMock(return_value=mock_conn) + mock_ctx.__exit__ = MagicMock(return_value=False) + mock_engine = MagicMock() + mock_engine.begin.return_value = mock_ctx + return mock_engine, mock_conn + + +def test_log_audit_event_minimal(monkeypatch): + mock_engine, mock_conn = _make_engine_mock() + monkeypatch.setattr(audit_module, "get_engine", lambda: mock_engine) + + log_audit_event(event_type="test.event") + + mock_conn.execute.assert_called_once() + + +def test_log_audit_event_all_fields(monkeypatch): + mock_engine, mock_conn = _make_engine_mock() + monkeypatch.setattr(audit_module, "get_engine", lambda: mock_engine) + + user_id = uuid4() + log_audit_event( + event_type="auth.login", + user_id=user_id, + email="user@example.com", + ip_address="127.0.0.1", + user_agent="TestAgent/1.0", + success=True, + message="Login successful", + metadata={"key": "value", "count": 42}, + ) + + mock_conn.execute.assert_called_once() + _, params = mock_conn.execute.call_args[0] + assert params["event_type"] == "auth.login" + assert params["user_id"] == str(user_id) + assert params["email"] == "user@example.com" + assert params["ip_address"] == "127.0.0.1" + assert params["user_agent"] == "TestAgent/1.0" + assert params["success"] is True + assert params["message"] == "Login successful" + assert '"key": "value"' in params["metadata"] + assert '"count": 42' in params["metadata"] + + +def test_log_audit_event_none_user_id(monkeypatch): + mock_engine, mock_conn = _make_engine_mock() + monkeypatch.setattr(audit_module, "get_engine", lambda: mock_engine) + + log_audit_event(event_type="anon.event", user_id=None) + + _, params = mock_conn.execute.call_args[0] + assert params["user_id"] is None + + +def test_log_audit_event_none_metadata_serializes_as_none(monkeypatch): + mock_engine, mock_conn = _make_engine_mock() + monkeypatch.setattr(audit_module, "get_engine", lambda: mock_engine) + + log_audit_event(event_type="test.event", metadata=None) + + _, params = mock_conn.execute.call_args[0] + assert params["metadata"] is None + + +def test_log_audit_event_swallows_engine_exception(monkeypatch): + def _bad_engine(): + raise RuntimeError("DB is down") + + monkeypatch.setattr(audit_module, "get_engine", _bad_engine) + + result = log_audit_event(event_type="test.event", message="should not crash") + + assert result is None + + +def test_log_audit_event_swallows_execute_exception(monkeypatch): + mock_conn = MagicMock() + mock_conn.execute.side_effect = Exception("execute failed") + mock_ctx = MagicMock() + mock_ctx.__enter__ = MagicMock(return_value=mock_conn) + mock_ctx.__exit__ = MagicMock(return_value=False) + mock_engine = MagicMock() + mock_engine.begin.return_value = mock_ctx + monkeypatch.setattr(audit_module, "get_engine", lambda: mock_engine) + + result = log_audit_event(event_type="test.event") + + assert result is None + + +def test_log_audit_event_metadata_dict_is_json_serialized(monkeypatch): + mock_engine, mock_conn = _make_engine_mock() + monkeypatch.setattr(audit_module, "get_engine", lambda: mock_engine) + + metadata = {"action": "disable", "target_user": "abc-123"} + log_audit_event(event_type="admin.user_disabled", metadata=metadata) + + _, params = mock_conn.execute.call_args[0] + import json + parsed = json.loads(params["metadata"]) + assert parsed["action"] == "disable" + assert parsed["target_user"] == "abc-123" diff --git a/services/api/tests/test_config.py b/services/api/tests/test_config.py new file mode 100644 index 0000000..237e3fb --- /dev/null +++ b/services/api/tests/test_config.py @@ -0,0 +1,159 @@ +"""Tests for core/config.py.""" + +from __future__ import annotations + +import pytest + +import core.config as config_module +from core.config import AppConfig, get_config + + +@pytest.fixture(autouse=True) +def clear_cache(): + config_module.get_config.cache_clear() + yield + config_module.get_config.cache_clear() + + +def test_get_config_no_file_returns_defaults(monkeypatch, tmp_path): + missing = tmp_path / "nonexistent.yaml" + monkeypatch.setattr(config_module, "Path", lambda p: missing) + + config = get_config() + + assert config.minio.endpoint == "minio:9000" + assert config.minio.access_key == "mlsec2" + assert config.minio.secret_key == "mlsec2_pw" + assert config.minio.bucket_name == "mlsec-submissions" + assert config.minio.secure is False + assert config.application.join_code is None + assert config.application.defense_submission_cooldown == 0 + assert config.application.attack_submission_cooldown == 0 + + +def test_get_config_valid_yaml(monkeypatch, tmp_path): + config_file = tmp_path / "config.yaml" + config_file.write_text( + "worker:\n" + " minio:\n" + " endpoint: custom-minio:9001\n" + " access_key: mykey\n" + " secret_key: mysecret\n" + " bucket_name: my-bucket\n" + " secure: true\n" + "application:\n" + " join_code: secret123\n" + " defense_submission_cooldown: 300\n" + " attack_submission_cooldown: 600\n" + ) + monkeypatch.setattr(config_module, "Path", lambda p: config_file) + + config = get_config() + + assert config.minio.endpoint == "custom-minio:9001" + assert config.minio.access_key == "mykey" + assert config.minio.secret_key == "mysecret" + assert config.minio.bucket_name == "my-bucket" + assert config.minio.secure is True + assert config.application.join_code == "secret123" + assert config.application.defense_submission_cooldown == 300 + assert config.application.attack_submission_cooldown == 600 + + +def test_get_config_login_code_fallback(monkeypatch, tmp_path): + config_file = tmp_path / "config.yaml" + config_file.write_text( + "worker:\n" + " minio: {}\n" + "application:\n" + " login_code: fallback_code\n" + ) + monkeypatch.setattr(config_module, "Path", lambda p: config_file) + + config = get_config() + + assert config.application.join_code == "fallback_code" + + +def test_get_config_join_code_takes_precedence_over_login_code(monkeypatch, tmp_path): + config_file = tmp_path / "config.yaml" + config_file.write_text( + "worker:\n" + " minio: {}\n" + "application:\n" + " join_code: primary\n" + " login_code: fallback\n" + ) + monkeypatch.setattr(config_module, "Path", lambda p: config_file) + + config = get_config() + + assert config.application.join_code == "primary" + + +def test_get_config_malformed_yaml_returns_defaults(monkeypatch, tmp_path): + config_file = tmp_path / "config.yaml" + config_file.write_text("[invalid yaml {{{") + monkeypatch.setattr(config_module, "Path", lambda p: config_file) + + config = get_config() + + assert isinstance(config, AppConfig) + assert config.minio.endpoint == "minio:9000" + assert config.application.defense_submission_cooldown == 0 + + +def test_get_config_empty_yaml_returns_defaults(monkeypatch, tmp_path): + config_file = tmp_path / "config.yaml" + config_file.write_text("") + monkeypatch.setattr(config_module, "Path", lambda p: config_file) + + config = get_config() + + assert isinstance(config, AppConfig) + assert config.minio.endpoint == "minio:9000" + assert config.application.join_code is None + + +def test_get_config_partial_application_section(monkeypatch, tmp_path): + config_file = tmp_path / "config.yaml" + config_file.write_text( + "worker:\n" + " minio:\n" + " endpoint: minio:9000\n" + "application:\n" + " defense_submission_cooldown: 120\n" + ) + monkeypatch.setattr(config_module, "Path", lambda p: config_file) + + config = get_config() + + assert config.application.defense_submission_cooldown == 120 + assert config.application.attack_submission_cooldown == 0 + assert config.application.join_code is None + + +def test_get_config_no_application_section(monkeypatch, tmp_path): + config_file = tmp_path / "config.yaml" + config_file.write_text( + "worker:\n" + " minio:\n" + " endpoint: custom:9000\n" + ) + monkeypatch.setattr(config_module, "Path", lambda p: config_file) + + config = get_config() + + assert config.application.join_code is None + assert config.application.defense_submission_cooldown == 0 + + +def test_get_config_is_cached(monkeypatch, tmp_path): + config_file = tmp_path / "config.yaml" + config_file.write_text("") + monkeypatch.setattr(config_module, "Path", lambda p: config_file) + + first = get_config() + second = get_config() + + assert first is second diff --git a/services/api/tests/test_core_admin.py b/services/api/tests/test_core_admin.py new file mode 100644 index 0000000..ecec69a --- /dev/null +++ b/services/api/tests/test_core_admin.py @@ -0,0 +1,390 @@ +"""Tests for core/admin.py.""" + +from __future__ import annotations + +import hashlib +import secrets +from datetime import datetime, timedelta, timezone +from types import SimpleNamespace +from uuid import uuid4 + +import pytest +from fastapi import HTTPException +from sqlalchemy import text + +from core.admin import ( + _hosts_match, + _is_from_trusted_proxy, + _is_in_allowed_hosts, + _is_in_allowed_networks, + _is_loopback_host, + consume_admin_action_token, + issue_admin_action_token, + require_admin_origin, + require_localhost_request, + require_admin_action_token, +) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +class _FakeSettings: + admin_localhost_only = True + admin_trusted_proxy_hosts = [] + admin_forwarded_for_header = "x-forwarded-for" + admin_allowed_hosts = [] + admin_allowed_networks = [] + admin_action_token_ttl_minutes = 5 + + +def _make_request(host="127.0.0.1", headers=None, client=True): + if client: + client_obj = SimpleNamespace(host=host) + else: + client_obj = None + return SimpleNamespace(client=client_obj, headers=headers or {}) + + +def _create_user_and_session(db_session): + uid = uuid4().hex[:8] + user_row = db_session.execute( + text( + "INSERT INTO users (username, email, is_admin) " + "VALUES (:username, :email, true) RETURNING id" + ), + {"username": f"admin_{uid}", "email": f"admin_{uid}@test.com"}, + ).fetchone() + user_id = str(user_row[0]) + + token = f"test-session-{uuid4()}" + token_hash = hashlib.sha256(token.encode()).hexdigest() + now = datetime.now(timezone.utc) + session_row = db_session.execute( + text( + "INSERT INTO user_sessions (user_id, token_hash, expires_at, last_seen_at) " + "VALUES (:user_id, :token_hash, :expires_at, :last_seen_at) RETURNING id" + ), + { + "user_id": user_id, + "token_hash": token_hash, + "expires_at": now + timedelta(hours=2), + "last_seen_at": now, + }, + ).fetchone() + session_id = str(session_row[0]) + db_session.flush() + return user_id, session_id + + +# --------------------------------------------------------------------------- +# _is_loopback_host +# --------------------------------------------------------------------------- + + +class TestIsLoopbackHost: + def test_localhost_string(self): + assert _is_loopback_host("localhost") is True + + def test_ipv4_loopback(self): + assert _is_loopback_host("127.0.0.1") is True + + def test_ipv4_loopback_alt(self): + assert _is_loopback_host("127.0.0.2") is True + + def test_ipv6_loopback(self): + assert _is_loopback_host("::1") is True + + def test_private_ip_is_not_loopback(self): + assert _is_loopback_host("192.168.1.1") is False + + def test_public_ip_is_not_loopback(self): + assert _is_loopback_host("8.8.8.8") is False + + def test_none_returns_false(self): + assert _is_loopback_host(None) is False + + def test_invalid_string_returns_false(self): + assert _is_loopback_host("not-an-ip") is False + + def test_ipv4_mapped_loopback(self): + assert _is_loopback_host("::ffff:127.0.0.1") is True + + +# --------------------------------------------------------------------------- +# _hosts_match +# --------------------------------------------------------------------------- + + +class TestHostsMatch: + def test_identical_strings(self): + assert _hosts_match("localhost", "localhost") is True + + def test_case_insensitive(self): + assert _hosts_match("LOCALHOST", "localhost") is True + + def test_same_ip_different_format(self): + assert _hosts_match("127.0.0.1", "127.0.0.1") is True + + def test_different_hosts(self): + assert _hosts_match("192.168.1.1", "192.168.1.2") is False + + def test_hostname_vs_ip(self): + assert _hosts_match("example.com", "127.0.0.1") is False + + +# --------------------------------------------------------------------------- +# _is_from_trusted_proxy +# --------------------------------------------------------------------------- + + +class TestIsFromTrustedProxy: + def test_exact_match(self): + assert _is_from_trusted_proxy("10.0.0.1", ["10.0.0.1"]) is True + + def test_cidr_match(self): + assert _is_from_trusted_proxy("10.0.0.5", ["10.0.0.0/24"]) is True + + def test_no_match(self): + assert _is_from_trusted_proxy("172.16.0.1", ["10.0.0.0/8"]) is False + + def test_none_host_returns_false(self): + assert _is_from_trusted_proxy(None, ["10.0.0.1"]) is False + + def test_empty_list_returns_false(self): + assert _is_from_trusted_proxy("127.0.0.1", []) is False + + +# --------------------------------------------------------------------------- +# _is_in_allowed_hosts +# --------------------------------------------------------------------------- + + +class TestIsInAllowedHosts: + def test_matching_host(self): + assert _is_in_allowed_hosts("10.1.2.3", ["10.1.2.3"]) is True + + def test_non_matching_host(self): + assert _is_in_allowed_hosts("10.1.2.4", ["10.1.2.3"]) is False + + def test_none_host_returns_false(self): + assert _is_in_allowed_hosts(None, ["10.1.2.3"]) is False + + def test_empty_list_returns_false(self): + assert _is_in_allowed_hosts("10.1.2.3", []) is False + + +# --------------------------------------------------------------------------- +# _is_in_allowed_networks +# --------------------------------------------------------------------------- + + +class TestIsInAllowedNetworks: + def test_ip_in_cidr(self): + assert _is_in_allowed_networks("10.0.0.50", ["10.0.0.0/24"]) is True + + def test_ip_outside_cidr(self): + assert _is_in_allowed_networks("10.0.1.1", ["10.0.0.0/24"]) is False + + def test_none_host_returns_false(self): + assert _is_in_allowed_networks(None, ["10.0.0.0/24"]) is False + + def test_invalid_network_skipped(self): + assert _is_in_allowed_networks("10.0.0.1", ["not-a-network"]) is False + + +# --------------------------------------------------------------------------- +# require_localhost_request +# --------------------------------------------------------------------------- + + +class TestRequireLocalhostRequest: + def test_loopback_host_allowed(self, monkeypatch): + monkeypatch.setattr("core.admin.get_settings", _FakeSettings) + request = _make_request(host="127.0.0.1") + require_localhost_request(request) + + def test_non_loopback_raises_403(self, monkeypatch): + monkeypatch.setattr("core.admin.get_settings", _FakeSettings) + request = _make_request(host="192.168.1.100") + with pytest.raises(HTTPException) as exc_info: + require_localhost_request(request) + assert exc_info.value.status_code == 403 + + def test_non_loopback_in_allowed_hosts_passes(self, monkeypatch): + class _Settings(_FakeSettings): + admin_allowed_hosts = ["10.0.0.5"] + + monkeypatch.setattr("core.admin.get_settings", _Settings) + request = _make_request(host="10.0.0.5") + require_localhost_request(request) + + def test_non_loopback_in_allowed_networks_passes(self, monkeypatch): + class _Settings(_FakeSettings): + admin_allowed_networks = ["10.0.0.0/24"] + + monkeypatch.setattr("core.admin.get_settings", _Settings) + request = _make_request(host="10.0.0.42") + require_localhost_request(request) + + def test_localhost_only_false_skips_check(self, monkeypatch): + class _Settings(_FakeSettings): + admin_localhost_only = False + + monkeypatch.setattr("core.admin.get_settings", _Settings) + request = _make_request(host="8.8.8.8") + require_localhost_request(request) + + def test_ipv6_loopback_allowed(self, monkeypatch): + monkeypatch.setattr("core.admin.get_settings", _FakeSettings) + request = _make_request(host="::1") + require_localhost_request(request) + + +# --------------------------------------------------------------------------- +# require_admin_origin +# --------------------------------------------------------------------------- + + +class TestRequireAdminOrigin: + def test_no_origin_or_referer_with_require_present_raises(self): + request = _make_request(headers={}) + with pytest.raises(HTTPException) as exc_info: + require_admin_origin(request, require_present=True) + assert exc_info.value.status_code == 403 + + def test_no_origin_or_referer_with_require_present_false_passes(self): + request = _make_request(headers={}) + require_admin_origin(request, require_present=False) + + def test_localhost_origin_passes(self): + request = _make_request(headers={"origin": "http://localhost:4321"}) + require_admin_origin(request) + + def test_non_localhost_origin_raises(self): + request = _make_request(headers={"origin": "https://evil.com"}) + with pytest.raises(HTTPException) as exc_info: + require_admin_origin(request) + assert exc_info.value.status_code == 403 + + def test_localhost_referer_passes(self): + request = _make_request(headers={"referer": "http://localhost/admin"}) + require_admin_origin(request) + + def test_non_localhost_referer_raises(self): + request = _make_request(headers={"referer": "https://evil.com/path"}) + with pytest.raises(HTTPException) as exc_info: + require_admin_origin(request) + assert exc_info.value.status_code == 403 + + def test_127_origin_passes(self): + request = _make_request(headers={"origin": "http://127.0.0.1:3000"}) + require_admin_origin(request) + + +# --------------------------------------------------------------------------- +# issue_admin_action_token / require_admin_action_token / consume_admin_action_token +# --------------------------------------------------------------------------- + + +class TestAdminActionToken: + def test_issue_returns_token_and_expiry(self, db_session): + _, session_id = _create_user_and_session(db_session) + + token, expires_at = issue_admin_action_token(db_session, session_id=session_id) + + assert isinstance(token, str) + assert len(token) > 0 + assert isinstance(expires_at, datetime) + assert expires_at > datetime.now(timezone.utc) + + def test_issued_token_is_hashed_in_db(self, db_session): + _, session_id = _create_user_and_session(db_session) + + token, _ = issue_admin_action_token(db_session, session_id=session_id) + + expected_hash = hashlib.sha256(token.encode()).hexdigest() + row = db_session.execute( + text( + "SELECT token_hash FROM admin_action_tokens WHERE session_id = :session_id" + ), + {"session_id": session_id}, + ).fetchone() + assert row is not None + assert row[0] == expected_hash + + def test_require_valid_token_returns_token(self, db_session): + _, session_id = _create_user_and_session(db_session) + token, _ = issue_admin_action_token(db_session, session_id=session_id) + request = _make_request(headers={"x-admin-action": token}) + + result = require_admin_action_token(request, db=db_session, session_id=session_id) + + assert result == token + + def test_require_missing_header_raises_403(self, db_session): + _, session_id = _create_user_and_session(db_session) + request = _make_request(headers={}) + + with pytest.raises(HTTPException) as exc_info: + require_admin_action_token(request, db=db_session, session_id=session_id) + assert exc_info.value.status_code == 403 + + def test_require_wrong_token_raises_403(self, db_session): + _, session_id = _create_user_and_session(db_session) + issue_admin_action_token(db_session, session_id=session_id) + request = _make_request(headers={"x-admin-action": "wrong-token"}) + + with pytest.raises(HTTPException) as exc_info: + require_admin_action_token(request, db=db_session, session_id=session_id) + assert exc_info.value.status_code == 403 + + def test_require_expired_token_raises_403(self, db_session): + _, session_id = _create_user_and_session(db_session) + raw_token = secrets.token_urlsafe(32) + token_hash = hashlib.sha256(raw_token.encode()).hexdigest() + past = datetime.now(timezone.utc) - timedelta(minutes=10) + db_session.execute( + text( + "INSERT INTO admin_action_tokens (session_id, token_hash, expires_at) " + "VALUES (:session_id, :token_hash, :expires_at)" + ), + {"session_id": session_id, "token_hash": token_hash, "expires_at": past}, + ) + db_session.flush() + request = _make_request(headers={"x-admin-action": raw_token}) + + with pytest.raises(HTTPException) as exc_info: + require_admin_action_token(request, db=db_session, session_id=session_id) + assert exc_info.value.status_code == 403 + + def test_consume_removes_token(self, db_session): + _, session_id = _create_user_and_session(db_session) + token, _ = issue_admin_action_token(db_session, session_id=session_id) + + consume_admin_action_token(db_session, session_id=session_id, token=token) + db_session.flush() + + row = db_session.execute( + text( + "SELECT token_hash FROM admin_action_tokens WHERE session_id = :session_id" + ), + {"session_id": session_id}, + ).fetchone() + assert row is None + + def test_issue_replaces_existing_token(self, db_session): + _, session_id = _create_user_and_session(db_session) + token1, _ = issue_admin_action_token(db_session, session_id=session_id) + token2, _ = issue_admin_action_token(db_session, session_id=session_id) + + count_row = db_session.execute( + text( + "SELECT COUNT(*) FROM admin_action_tokens WHERE session_id = :session_id" + ), + {"session_id": session_id}, + ).fetchone() + assert count_row[0] == 1 + assert token1 != token2 diff --git a/services/api/tests/test_storage.py b/services/api/tests/test_storage.py new file mode 100644 index 0000000..660c39e --- /dev/null +++ b/services/api/tests/test_storage.py @@ -0,0 +1,189 @@ +"""Tests for core/storage.py.""" + +from __future__ import annotations + +import hashlib +import io +from unittest.mock import MagicMock, patch + +import pytest +from minio.error import S3Error + +import core.storage as storage_module +from core.storage import ( + delete_object, + ensure_bucket_exists, + upload_attack_template, + upload_attack_zip, + upload_defense_zip, + upload_heurval_sample, + upload_heurval_set_zip, +) + + +def _make_s3_error(): + fake_response = MagicMock() + return S3Error(fake_response, "TestError", "something went wrong", "/", "req-1", "host-1") + + +@pytest.fixture(autouse=True) +def mock_storage(monkeypatch): + mock_client = MagicMock() + mock_config = MagicMock() + mock_config.minio.bucket_name = "test-bucket" + monkeypatch.setattr(storage_module, "get_minio_client", lambda: mock_client) + monkeypatch.setattr(storage_module, "get_config", lambda: mock_config) + return mock_client + + +class TestUploadDefenseZip: + def test_returns_correct_keys(self, mock_storage): + content = b"fake zip content" + user_id = "user-123" + submission_id = "sub-456" + + result = upload_defense_zip(io.BytesIO(content), user_id, submission_id) + + assert result["object_key"] == f"defense/{user_id}/{submission_id}.zip" + assert result["sha256"] == hashlib.sha256(content).hexdigest() + assert result["size_bytes"] == len(content) + + def test_calls_put_object_with_correct_args(self, mock_storage): + content = b"zip data" + upload_defense_zip(io.BytesIO(content), "uid", "sid") + + mock_storage.put_object.assert_called_once() + call_kwargs = mock_storage.put_object.call_args.kwargs + assert call_kwargs["bucket_name"] == "test-bucket" + assert call_kwargs["object_name"] == "defense/uid/sid.zip" + assert call_kwargs["length"] == len(content) + assert call_kwargs["content_type"] == "application/zip" + + def test_raises_s3_error_on_failure(self, mock_storage): + mock_storage.put_object.side_effect = _make_s3_error() + with pytest.raises(S3Error): + upload_defense_zip(io.BytesIO(b"data"), "uid", "sid") + + +class TestUploadAttackZip: + def test_returns_correct_keys(self, mock_storage): + content = b"attack zip" + user_id = "user-abc" + submission_id = "sub-xyz" + + result = upload_attack_zip(io.BytesIO(content), user_id, submission_id) + + assert result["object_key"] == f"attack/{user_id}/{submission_id}.zip" + assert result["sha256"] == hashlib.sha256(content).hexdigest() + assert result["size_bytes"] == len(content) + + def test_calls_put_object_with_correct_args(self, mock_storage): + content = b"atk data" + upload_attack_zip(io.BytesIO(content), "uid", "sid") + + call_kwargs = mock_storage.put_object.call_args.kwargs + assert call_kwargs["bucket_name"] == "test-bucket" + assert call_kwargs["object_name"] == "attack/uid/sid.zip" + + def test_raises_s3_error_on_failure(self, mock_storage): + mock_storage.put_object.side_effect = _make_s3_error() + with pytest.raises(S3Error): + upload_attack_zip(io.BytesIO(b"data"), "uid", "sid") + + +class TestUploadAttackTemplate: + def test_returns_correct_keys(self, mock_storage): + content = b"template zip content" + template_id = "tmpl-001" + + result = upload_attack_template(content, template_id) + + assert result["object_key"] == f"template/{template_id}.zip" + assert result["sha256"] == hashlib.sha256(content).hexdigest() + assert result["size_bytes"] == len(content) + + def test_sha256_matches_actual_hash(self, mock_storage): + content = b"hello world" + result = upload_attack_template(content, "t1") + expected = hashlib.sha256(content).hexdigest() + assert result["sha256"] == expected + + def test_raises_s3_error_on_failure(self, mock_storage): + mock_storage.put_object.side_effect = _make_s3_error() + with pytest.raises(S3Error): + upload_attack_template(b"data", "tmpl-id") + + +class TestUploadHeurvalSample: + def test_returns_correct_keys(self, mock_storage): + content = b"sample content" + set_id = "set-001" + label = "malware" + filename = "evil.exe" + + result = upload_heurval_sample(content, set_id, label, filename) + + assert result["object_key"] == f"heurval/{set_id}/{label}/{filename}" + assert result["sha256"] == hashlib.sha256(content).hexdigest() + assert result["size_bytes"] == len(content) + + def test_uses_basename_for_safety(self, mock_storage): + content = b"data" + result = upload_heurval_sample( + content, "s1", "goodware", "/some/nested/path/file.exe" + ) + assert result["object_key"] == "heurval/s1/goodware/file.exe" + + def test_raises_s3_error_on_failure(self, mock_storage): + mock_storage.put_object.side_effect = _make_s3_error() + with pytest.raises(S3Error): + upload_heurval_sample(b"data", "set", "malware", "file.exe") + + +class TestUploadHeurvalSetZip: + def test_returns_correct_keys(self, mock_storage): + content = b"heurval zip" + set_id = "hvset-001" + + result = upload_heurval_set_zip(content, set_id) + + assert result["object_key"] == f"heurval/{set_id}/samples.zip" + assert result["sha256"] == hashlib.sha256(content).hexdigest() + assert result["size_bytes"] == len(content) + + def test_raises_s3_error_on_failure(self, mock_storage): + mock_storage.put_object.side_effect = _make_s3_error() + with pytest.raises(S3Error): + upload_heurval_set_zip(b"data", "set-id") + + +class TestDeleteObject: + def test_calls_remove_object_with_correct_args(self, mock_storage): + object_key = "defense/user-1/sub-1.zip" + delete_object(object_key) + mock_storage.remove_object.assert_called_once_with( + bucket_name="test-bucket", + object_name=object_key, + ) + + def test_raises_s3_error_on_failure(self, mock_storage): + mock_storage.remove_object.side_effect = _make_s3_error() + with pytest.raises(S3Error): + delete_object("some/key.zip") + + +class TestEnsureBucketExists: + def test_creates_bucket_when_not_exists(self, mock_storage): + mock_storage.bucket_exists.return_value = False + ensure_bucket_exists() + mock_storage.make_bucket.assert_called_once_with("test-bucket") + + def test_skips_make_bucket_when_already_exists(self, mock_storage): + mock_storage.bucket_exists.return_value = True + ensure_bucket_exists() + mock_storage.make_bucket.assert_not_called() + + def test_raises_s3_error_when_bucket_check_fails(self, mock_storage): + mock_storage.bucket_exists.side_effect = _make_s3_error() + with pytest.raises(S3Error): + ensure_bucket_exists() diff --git a/services/api/tests/test_submission_control.py b/services/api/tests/test_submission_control.py new file mode 100644 index 0000000..e1d470e --- /dev/null +++ b/services/api/tests/test_submission_control.py @@ -0,0 +1,343 @@ +"""Tests for core/submission_control.py.""" + +from __future__ import annotations + +from datetime import datetime, timedelta, timezone +from uuid import uuid4 + +import pytest +from fastapi import HTTPException +from sqlalchemy import text + +from core.submission_control import ( + SubmissionControl, + _as_utc, + check_cooldown, + ensure_submissions_open, + get_cooldown_remaining, + get_submission_control, + set_close_at, + set_manual_closed, +) + + +# --------------------------------------------------------------------------- +# Pure-Python unit tests (no DB) +# --------------------------------------------------------------------------- + + +class TestSubmissionControlIsClosed: + def test_open_by_default(self): + sc = SubmissionControl( + manual_closed=False, + close_at=None, + updated_at=None, + updated_by=None, + ) + assert sc.is_closed() is False + + def test_manual_closed_true(self): + sc = SubmissionControl( + manual_closed=True, + close_at=None, + updated_at=None, + updated_by=None, + ) + assert sc.is_closed() is True + + def test_close_at_in_past(self): + past = datetime.now(timezone.utc) - timedelta(minutes=5) + sc = SubmissionControl( + manual_closed=False, + close_at=past, + updated_at=None, + updated_by=None, + ) + assert sc.is_closed() is True + + def test_close_at_in_future(self): + future = datetime.now(timezone.utc) + timedelta(hours=1) + sc = SubmissionControl( + manual_closed=False, + close_at=future, + updated_at=None, + updated_by=None, + ) + assert sc.is_closed() is False + + def test_both_manual_and_past_close_at(self): + past = datetime.now(timezone.utc) - timedelta(seconds=1) + sc = SubmissionControl( + manual_closed=True, + close_at=past, + updated_at=None, + updated_by=None, + ) + assert sc.is_closed() is True + + def test_is_closed_accepts_explicit_now(self): + future = datetime.now(timezone.utc) + timedelta(hours=1) + sc = SubmissionControl( + manual_closed=False, + close_at=future, + updated_at=None, + updated_by=None, + ) + now_way_ahead = future + timedelta(hours=2) + assert sc.is_closed(now=now_way_ahead) is True + + +class TestAsUtc: + def test_none_returns_none(self): + assert _as_utc(None) is None + + def test_naive_datetime_gets_utc_tzinfo(self): + naive = datetime(2024, 6, 1, 12, 0, 0) + result = _as_utc(naive) + assert result.tzinfo is timezone.utc + assert result.replace(tzinfo=None) == naive + + def test_utc_datetime_unchanged(self): + utc_dt = datetime(2024, 6, 1, 12, 0, 0, tzinfo=timezone.utc) + result = _as_utc(utc_dt) + assert result == utc_dt + + def test_non_utc_aware_datetime_converted(self): + from datetime import timezone as tz + eastern = timezone(timedelta(hours=-5)) + dt = datetime(2024, 6, 1, 7, 0, 0, tzinfo=eastern) + result = _as_utc(dt) + assert result.tzinfo == timezone.utc + assert result.hour == 12 + + +# --------------------------------------------------------------------------- +# DB-backed integration tests +# --------------------------------------------------------------------------- + + +def _create_user(db_session, *, username: str = None, email: str = None) -> str: + username = username or f"sc_user_{uuid4().hex[:8]}" + email = email or f"{uuid4().hex[:8]}@test.com" + row = db_session.execute( + text( + "INSERT INTO users (username, email, is_admin) " + "VALUES (:username, :email, false) RETURNING id" + ), + {"username": username, "email": email}, + ).fetchone() + return str(row[0]) + + +class TestGetSubmissionControl: + def test_returns_defaults_when_no_row(self, db_session): + sc = get_submission_control(db_session) + assert sc.manual_closed is False + assert sc.close_at is None + + def test_returns_correct_state_when_row_exists(self, db_session): + future = datetime.now(timezone.utc) + timedelta(hours=2) + user_id = _create_user(db_session) + db_session.execute( + text( + "INSERT INTO submission_control (id, manual_closed, close_at, updated_by) " + "VALUES (1, true, :close_at, :updated_by) " + "ON CONFLICT (id) DO UPDATE " + "SET manual_closed = EXCLUDED.manual_closed, " + " close_at = EXCLUDED.close_at, " + " updated_by = EXCLUDED.updated_by" + ), + {"close_at": future, "updated_by": user_id}, + ) + db_session.flush() + + sc = get_submission_control(db_session) + + assert sc.manual_closed is True + assert sc.close_at is not None + assert sc.close_at > datetime.now(timezone.utc) + assert sc.updated_by == user_id + + +class TestSetManualClosed: + def test_set_manual_closed_true(self, db_session): + user_id = _create_user(db_session) + sc = set_manual_closed(db_session, closed=True, updated_by=user_id) + assert sc.manual_closed is True + assert sc.updated_by == user_id + + def test_set_manual_closed_false(self, db_session): + user_id = _create_user(db_session) + set_manual_closed(db_session, closed=True, updated_by=user_id) + sc = set_manual_closed(db_session, closed=False, updated_by=user_id) + assert sc.manual_closed is False + + def test_open_clears_lapsed_close_at(self, db_session): + user_id = _create_user(db_session) + past = datetime.now(timezone.utc) - timedelta(hours=1) + set_close_at(db_session, close_at=past, updated_by=user_id) + sc = set_manual_closed(db_session, closed=False, updated_by=user_id) + assert sc.close_at is None + + def test_open_preserves_future_close_at(self, db_session): + user_id = _create_user(db_session) + future = datetime.now(timezone.utc) + timedelta(hours=2) + set_close_at(db_session, close_at=future, updated_by=user_id) + sc = set_manual_closed(db_session, closed=False, updated_by=user_id) + assert sc.close_at is not None + + +class TestSetCloseAt: + def test_set_future_close_at(self, db_session): + user_id = _create_user(db_session) + future = datetime.now(timezone.utc) + timedelta(hours=3) + sc = set_close_at(db_session, close_at=future, updated_by=user_id) + assert sc.close_at is not None + assert sc.close_at > datetime.now(timezone.utc) + + def test_clear_close_at(self, db_session): + user_id = _create_user(db_session) + future = datetime.now(timezone.utc) + timedelta(hours=3) + set_close_at(db_session, close_at=future, updated_by=user_id) + sc = set_close_at(db_session, close_at=None, updated_by=user_id) + assert sc.close_at is None + + +class TestEnsureSubmissionsOpen: + def test_does_not_raise_when_open(self, db_session): + ensure_submissions_open(db_session) + + def test_raises_403_when_manual_closed(self, db_session): + user_id = _create_user(db_session) + set_manual_closed(db_session, closed=True, updated_by=user_id) + with pytest.raises(HTTPException) as exc_info: + ensure_submissions_open(db_session) + assert exc_info.value.status_code == 403 + assert "administrator" in exc_info.value.detail + + def test_raises_403_when_deadline_passed(self, db_session): + user_id = _create_user(db_session) + past = datetime.now(timezone.utc) - timedelta(minutes=1) + set_close_at(db_session, close_at=past, updated_by=user_id) + with pytest.raises(HTTPException) as exc_info: + ensure_submissions_open(db_session) + assert exc_info.value.status_code == 403 + assert "deadline" in exc_info.value.detail + + +class TestGetCooldownRemaining: + def test_returns_none_when_cooldown_zero(self, db_session): + user_id = _create_user(db_session) + result = get_cooldown_remaining( + db_session, + user_id=user_id, + submission_type="defense", + cooldown_seconds=0, + ) + assert result is None + + def test_returns_none_when_no_prior_submissions(self, db_session): + user_id = _create_user(db_session) + result = get_cooldown_remaining( + db_session, + user_id=user_id, + submission_type="defense", + cooldown_seconds=3600, + ) + assert result is None + + def test_returns_remaining_when_within_cooldown(self, db_session): + user_id = _create_user(db_session) + just_now = datetime.now(timezone.utc) - timedelta(seconds=10) + db_session.execute( + text( + "INSERT INTO submissions (user_id, submission_type, version, status, created_at) " + "VALUES (:user_id, 'defense', '1.0.0', 'submitted', :created_at)" + ), + {"user_id": user_id, "created_at": just_now}, + ) + db_session.flush() + + result = get_cooldown_remaining( + db_session, + user_id=user_id, + submission_type="defense", + cooldown_seconds=3600, + ) + assert result is not None + assert result > 0 + assert result <= 3600 + + def test_returns_none_when_cooldown_expired(self, db_session): + user_id = _create_user(db_session) + long_ago = datetime.now(timezone.utc) - timedelta(hours=2) + db_session.execute( + text( + "INSERT INTO submissions (user_id, submission_type, version, status, created_at) " + "VALUES (:user_id, 'defense', '1.0.0', 'submitted', :created_at)" + ), + {"user_id": user_id, "created_at": long_ago}, + ) + db_session.flush() + + result = get_cooldown_remaining( + db_session, + user_id=user_id, + submission_type="defense", + cooldown_seconds=60, + ) + assert result is None + + def test_ignores_deleted_submissions(self, db_session): + user_id = _create_user(db_session) + just_now = datetime.now(timezone.utc) - timedelta(seconds=10) + db_session.execute( + text( + "INSERT INTO submissions " + "(user_id, submission_type, version, status, created_at, deleted_at) " + "VALUES (:user_id, 'defense', '1.0.0', 'submitted', :created_at, :deleted_at)" + ), + {"user_id": user_id, "created_at": just_now, "deleted_at": just_now}, + ) + db_session.flush() + + result = get_cooldown_remaining( + db_session, + user_id=user_id, + submission_type="defense", + cooldown_seconds=3600, + ) + assert result is None + + +class TestCheckCooldown: + def test_does_not_raise_when_no_cooldown(self, db_session): + user_id = _create_user(db_session) + check_cooldown( + db_session, + user_id=user_id, + submission_type="attack", + cooldown_seconds=0, + ) + + def test_raises_429_when_within_cooldown(self, db_session): + user_id = _create_user(db_session) + just_now = datetime.now(timezone.utc) - timedelta(seconds=5) + db_session.execute( + text( + "INSERT INTO submissions (user_id, submission_type, version, status, created_at) " + "VALUES (:user_id, 'attack', '1.0.0', 'submitted', :created_at)" + ), + {"user_id": user_id, "created_at": just_now}, + ) + db_session.flush() + + with pytest.raises(HTTPException) as exc_info: + check_cooldown( + db_session, + user_id=user_id, + submission_type="attack", + cooldown_seconds=3600, + ) + assert exc_info.value.status_code == 429 + assert "wait" in exc_info.value.detail.lower() diff --git a/services/api/tests/test_submissions.py b/services/api/tests/test_submissions.py index 27e9996..ebddba6 100644 --- a/services/api/tests/test_submissions.py +++ b/services/api/tests/test_submissions.py @@ -296,6 +296,71 @@ def test_create_defense_github_strips_git_extension(self, client, db_session, mo assert response.status_code == 201 + def test_create_defense_github_with_branch(self, client, db_session, monkeypatch): + """Test that a URL with /tree/ is accepted and stored verbatim.""" + from routers import submissions as submissions_module + + fake = _FakeCelery() + monkeypatch.setattr(submissions_module, "_publish_task", + lambda **kwargs: fake.send_task("", kwargs)) + + user_id = _create_user(db_session) + token = _create_session_token(db_session, user_id=user_id) + + response = client.post( + "/api/submissions/defense/github", + json={ + "git_repo": "https://github.com/user/repo/tree/my-branch", + "version": "1.0.0", + }, + headers=_make_auth_headers(token), + ) + + assert response.status_code == 201 + details_row = db_session.execute( + text("SELECT git_repo FROM defense_submission_details WHERE submission_id = :id"), + {"id": response.json()["submission_id"]}, + ).fetchone() + assert details_row[0] == "https://github.com/user/repo/tree/my-branch" + + def test_create_defense_github_with_slash_branch(self, client, db_session, monkeypatch): + """Test that a URL with a multi-segment branch (feature/foo) is accepted.""" + from routers import submissions as submissions_module + + fake = _FakeCelery() + monkeypatch.setattr(submissions_module, "_publish_task", + lambda **kwargs: fake.send_task("", kwargs)) + + user_id = _create_user(db_session) + token = _create_session_token(db_session, user_id=user_id) + + response = client.post( + "/api/submissions/defense/github", + json={ + "git_repo": "https://github.com/user/repo/tree/feature/my-feature", + "version": "1.0.0", + }, + headers=_make_auth_headers(token), + ) + + assert response.status_code == 201 + + def test_create_defense_github_invalid_branch_empty(self, client, db_session): + """Test that /tree/ with no branch name is rejected.""" + user_id = _create_user(db_session) + token = _create_session_token(db_session, user_id=user_id) + + response = client.post( + "/api/submissions/defense/github", + json={ + "git_repo": "https://github.com/user/repo/tree/", + "version": "1.0.0", + }, + headers=_make_auth_headers(token), + ) + + assert response.status_code == 422 + # ============================================================================ # Defense ZIP Submission Tests @@ -524,6 +589,8 @@ def test_validate_github_url_format_valid(self): validate_github_url_format("https://github.com/user/repo") validate_github_url_format("https://github.com/user-name/repo-name") validate_github_url_format("https://github.com/user/repo.git") + validate_github_url_format("https://github.com/user/repo/tree/my-branch") + validate_github_url_format("https://github.com/user/repo/tree/feature/foo") def test_validate_github_url_format_invalid(self): """Test GitHub URL validation with invalid URLs.""" @@ -541,6 +608,10 @@ def test_validate_github_url_format_invalid(self): validate_github_url_format("github.com/user/repo") assert exc.value.status_code == 400 + with pytest.raises(HTTPException) as exc: + validate_github_url_format("https://github.com/user/repo/tree/") + assert exc.value.status_code == 400 + def test_validate_docker_image_format_valid(self): """Test Docker image format validation.""" from core.submissions import validate_docker_image_format diff --git a/services/frontend/Dockerfile.prod b/services/frontend/Dockerfile.prod deleted file mode 100644 index 4e4f50e..0000000 --- a/services/frontend/Dockerfile.prod +++ /dev/null @@ -1,25 +0,0 @@ -FROM node:20-alpine AS build - -WORKDIR /app - -COPY package.json package-lock.json* ./ -RUN npm install - -COPY . . -RUN npm run build - -FROM node:20-alpine AS runtime - -WORKDIR /app - -COPY --from=build /app/dist ./dist -COPY --from=build /app/package.json ./package.json -COPY --from=build /app/node_modules ./node_modules - -ENV HOST=0.0.0.0 -ENV PORT=4321 -ENV NODE_ENV=production - -EXPOSE 4321 - -CMD ["node", "./dist/server/entry.mjs"] diff --git a/services/frontend/astro.config.mjs b/services/frontend/astro.config.mjs index dbc42a9..8e11199 100644 --- a/services/frontend/astro.config.mjs +++ b/services/frontend/astro.config.mjs @@ -3,16 +3,12 @@ import { defineConfig } from 'astro/config'; import react from '@astrojs/react'; import tailwind from '@astrojs/tailwind'; -import node from '@astrojs/node'; const apiTarget = process.env.API_INTERNAL_URL || 'http://127.0.0.1:8000'; // https://astro.build/config export default defineConfig({ output: 'server', - adapter: node({ - mode: 'standalone', - }), integrations: [react(), tailwind()], vite: { server: { diff --git a/services/frontend/package.json b/services/frontend/package.json index 70f81d8..f34cd5b 100644 --- a/services/frontend/package.json +++ b/services/frontend/package.json @@ -9,7 +9,6 @@ "astro": "astro" }, "dependencies": { - "@astrojs/node": "^9.0.0", "@astrojs/react": "^4.4.2", "@tailwindcss/vite": "^4.1.18", "@types/react": "^19.2.10", diff --git a/services/frontend/public/favicon.svg b/services/frontend/public/favicon.svg index f157bd1..71d77f3 100644 --- a/services/frontend/public/favicon.svg +++ b/services/frontend/public/favicon.svg @@ -1,9 +1,251 @@ - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/services/frontend/public/mlsec-title.svg b/services/frontend/public/mlsec-title.svg new file mode 100644 index 0000000..8d27f1d --- /dev/null +++ b/services/frontend/public/mlsec-title.svg @@ -0,0 +1,243 @@ + + + + + + + diff --git a/services/frontend/src/components/Attack_submit.astro b/services/frontend/src/components/Attack_submit.astro index 732523b..f2c1b53 100644 --- a/services/frontend/src/components/Attack_submit.astro +++ b/services/frontend/src/components/Attack_submit.astro @@ -1,9 +1,11 @@ --- --- -
+

Attack Submission

+ +
@@ -90,6 +92,74 @@ } }); + function _extractDetail(detail: unknown): string { + if (typeof detail === 'string') return detail; + if (Array.isArray(detail) && detail.length > 0) { + return (detail as { msg?: string }[]).map(e => e.msg ?? String(e)).join('; '); + } + return 'An unexpected error occurred.'; + } + + function _attachDismiss(el: HTMLElement) { + const btn = document.createElement('button'); + btn.type = 'button'; + btn.textContent = '×'; + btn.className = 'ml-3 text-current opacity-60 hover:opacity-100 leading-none flex-shrink-0'; + btn.setAttribute('aria-label', 'Dismiss'); + btn.addEventListener('click', () => { + el.className = 'hidden text-sm rounded-lg px-3 py-2'; + el.textContent = ''; + }); + el.appendChild(btn); + } + + (function initAttackCooldown() { + const cooldownEl = document.getElementById('attack-cooldown') as HTMLParagraphElement; + let timer: ReturnType | null = null; + + function startCountdown(seconds: number) { + let remaining = seconds; + if (timer) clearInterval(timer); + + function tick() { + if (remaining <= 0) { + cooldownEl.classList.add('hidden'); + cooldownEl.textContent = ''; + if (timer) clearInterval(timer); + return; + } + const mm = String(Math.floor(remaining / 60)).padStart(2, '0'); + const ss = String(remaining % 60).padStart(2, '0'); + cooldownEl.textContent = `You can submit again in ${mm}:${ss}`; + cooldownEl.classList.remove('hidden'); + remaining--; + } + + tick(); + timer = setInterval(tick, 1000); + } + + async function checkCooldown() { + try { + const res = await fetch('/api/submissions/cooldown'); + if (!res.ok) return; + const data = await res.json(); + const secs: number | null = data.attack_remaining_seconds; + if (secs && secs > 0) { + startCountdown(secs); + } + } catch {} + } + + checkCooldown(); + + document.addEventListener('submission-created', (e: Event) => { + if ((e as CustomEvent).detail?.type === 'attack') { + checkCooldown(); + } + }); + })(); + document.getElementById('attack-form')!.addEventListener('submit', async (e) => { e.preventDefault(); @@ -102,7 +172,8 @@ if (!file) { feedback.textContent = 'Please select a ZIP file.'; - feedback.className = 'text-sm rounded-lg px-3 py-2 bg-red-50 text-red-700 border border-red-200'; + feedback.className = 'flex items-center justify-between text-sm rounded-lg px-3 py-2 bg-red-50 text-red-700 border border-red-200'; + _attachDismiss(feedback); return; } @@ -118,10 +189,11 @@ const res = await fetch('/api/submissions/attack/zip', { method: 'POST', body: fd }); const data = await res.json(); - if (!res.ok) throw new Error(data.detail ?? `Error ${res.status}`); + if (!res.ok) throw new Error(_extractDetail(data.detail) || `Error ${res.status}`); feedback.textContent = `Submitted (ID: ${data.submission_id})`; - feedback.className = 'text-sm rounded-lg px-3 py-2 bg-green-50 text-green-700 border border-green-200'; + feedback.className = 'flex items-center justify-between text-sm rounded-lg px-3 py-2 bg-green-50 text-green-700 border border-green-200'; + _attachDismiss(feedback); (e.target as HTMLFormElement).reset(); attackIdleView.classList.remove('hidden'); attackChosenView.classList.add('hidden'); @@ -130,7 +202,8 @@ } catch (err: unknown) { feedback.textContent = (err instanceof Error ? err.message : null) ?? 'Submission failed.'; - feedback.className = 'text-sm rounded-lg px-3 py-2 bg-red-50 text-red-700 border border-red-200'; + feedback.className = 'flex items-center justify-between text-sm rounded-lg px-3 py-2 bg-red-50 text-red-700 border border-red-200'; + _attachDismiss(feedback); } finally { btn.disabled = false; btn.textContent = 'Submit Attack'; diff --git a/services/frontend/src/components/Defense_submit.astro b/services/frontend/src/components/Defense_submit.astro index 313150e..0ff006e 100644 --- a/services/frontend/src/components/Defense_submit.astro +++ b/services/frontend/src/components/Defense_submit.astro @@ -1,8 +1,10 @@ --- --- -
-

Model Submission

+
+

Defense Submission

+ + @@ -95,7 +97,7 @@ type="submit" class="w-full bg-primary hover:bg-secondary text-white font-semibold py-2.5 rounded-lg transition duration-200" > - Submit Model + Submit Defense @@ -146,6 +148,74 @@ diff --git a/services/frontend/src/components/EvaluationMatrix.tsx b/services/frontend/src/components/EvaluationMatrix.tsx index 64e2c82..eceea82 100644 --- a/services/frontend/src/components/EvaluationMatrix.tsx +++ b/services/frontend/src/components/EvaluationMatrix.tsx @@ -24,28 +24,28 @@ interface LeaderboardData { /** * Maps a score (0.0 to 1.0) to an RGB color. - * 0% = red rgb(220, 50, 50) - * 50% = white rgb(255, 255, 255) - * 100% = green rgb(50, 175, 80) + * 0% = orange rgb(254, 179, 56) + * 50% = white rgb(230, 230, 230) + * 100% = blue rgb(2, 81, 150) */ function scoreToColor(score: number): string { const s = Math.max(0, Math.min(1, score)); if (s <= 0.5) { const t = s / 0.5; - const r = Math.round(220 + (255 - 220) * t); - const g = Math.round(50 + (255 - 50) * t); - const b = Math.round(50 + (255 - 50) * t); + const r = Math.round(254 + (230 - 254) * t); + const g = Math.round(179 + (230 - 179) * t); + const b = Math.round(56 + (230 - 56) * t); return `rgb(${r},${g},${b})`; } const t = (s - 0.5) / 0.5; - const r = Math.round(255 + (50 - 255) * t); - const g = Math.round(255 + (175 - 255) * t); - const b = Math.round(255 + (80 - 255) * t); + const r = Math.round(230 + (2 - 230) * t); + const g = Math.round(230 + (81 - 230) * t); + const b = Math.round(230 + (150 - 230) * t); return `rgb(${r},${g},${b})`; } function textColorForScore(score: number): string { - return score < 0.35 || score > 0.65 ? '#ffffff' : '#374151'; + return score > 0.70 ? '#ffffff' : '#374151'; } export default function EvaluationMatrix() { @@ -82,10 +82,10 @@ export default function EvaluationMatrix() { const { attackers, defenders, scores } = data; - if (attackers.length === 0 && defenders.length === 0) { + if (attackers.length === 0 || defenders.length === 0) { return (

- No active submissions yet. Once participants activate a submission, the matrix will appear here. + The matrix will appear once there is at least one active attack submission and one active defense submission.

); } @@ -126,47 +126,77 @@ export default function EvaluationMatrix() {
)} -
- - +
+
+
+ + {/* Row 1: Attack axis label. Spans name col + all score cols (no separate corner cell). */} + + + )} + + {/* Row 2: Column headers. Defense label starts here and spans down through all defender rows. */} - - {attackers.map(atk => ( + ))} - - + {/* Defender rows. Defense label column is covered by the rowSpan above. */} {defenders.map((def, di) => ( - - + {attackers.map(atk => { const key = `${atk.submission_id}/${def.submission_id}`; const entry = scores[key]; const bgColor = entry && showGradient ? scoreToColor(entry.score) : undefined; const fgColor = entry && showGradient ? textColorForScore(entry.score) : '#374151'; - return (
+ {attackers.length > 0 && ( + + Attack +
- Defense \ Attack + + + Defense + + {attackers.map((atk, ai) => ( -
{atk.username}
- {atk.display_name && ( -
{atk.display_name}
- )} -
v{atk.version}
+
+
{atk.username}
+ {atk.display_name && ( +
{atk.display_name}
+ )} +
v{atk.version}
+
-
{def.username}
- {def.display_name && ( -
{def.display_name}
- )} -
v{def.version}
+
+
+
{def.username}
+ {def.display_name && ( +
{def.display_name}
+ )} +
v{def.version}
+
+
); diff --git a/services/frontend/src/components/Footer.astro b/services/frontend/src/components/Footer.astro index 31072ef..34f1c49 100644 --- a/services/frontend/src/components/Footer.astro +++ b/services/frontend/src/components/Footer.astro @@ -3,9 +3,7 @@ const currentYear = new Date().getFullYear(); ---
-
- Copyright © {currentYear} -  |  - All rights reserved +
+ Copyright © {currentYear}  |  All rights reserved
\ No newline at end of file diff --git a/services/frontend/src/components/LoginModal.astro b/services/frontend/src/components/LoginModal.astro index 4584f63..e48f926 100644 --- a/services/frontend/src/components/LoginModal.astro +++ b/services/frontend/src/components/LoginModal.astro @@ -203,6 +203,13 @@ openModal(tab); }); + window.addEventListener('pageshow', (e) => { + if (e.persisted) { + (document.getElementById('login-form') as HTMLFormElement)?.reset(); + (document.getElementById('register-form') as HTMLFormElement)?.reset(); + } + }); + closeBtn.addEventListener('click', closeModal); modal.addEventListener('click', (e) => { if (e.target === modal) closeModal(); }); document.addEventListener('keydown', (e) => { @@ -270,6 +277,9 @@ return; } + (document.getElementById("login-form") as HTMLFormElement).reset(); + (document.getElementById("register-form") as HTMLFormElement).reset(); + window.location.href = "/submission"; if (data.authenticated) { window.location.href = "/submission"; return; diff --git a/services/frontend/src/components/Navbar.astro b/services/frontend/src/components/Navbar.astro index 16bdc39..89b510e 100644 --- a/services/frontend/src/components/Navbar.astro +++ b/services/frontend/src/components/Navbar.astro @@ -9,45 +9,103 @@ import Account from './Account.astro'; const session = await getSession(Astro.request); --- -