From 8879d4fb543927b74d75c1cf9734c07492ef7273 Mon Sep 17 00:00:00 2001 From: Joaquin Rodriguez Date: Wed, 11 Mar 2026 23:05:48 -0400 Subject: [PATCH 1/4] feat: support CastleDM DSNs with async driver injection Auto-inject async SQLAlchemy drivers for bare postgresql/mysql/sqlite URLs so acc castledm dsn output works directly as DATABASE_URL. Add MySQL dialect/session handling and test coverage to preserve read-only enforcement across backends. --- AGENTS.md | 3 +- README.md | 14 ++++++++- pyproject.toml | 1 + src/secure_sql_mcp/config.py | 23 +++++++++++++++ src/secure_sql_mcp/database.py | 4 +++ src/secure_sql_mcp/query_validator.py | 2 ++ tests/test_config.py | 39 ++++++++++++++++++++++++++ tests/test_mcp_interface.py | 23 +++++++++++++++ tests/test_query_validator_security.py | 19 +++++++++++++ 9 files changed, 126 insertions(+), 2 deletions(-) diff --git a/AGENTS.md b/AGENTS.md index e8da51f..89a4c15 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -11,13 +11,14 @@ Core package: `src/secure_sql_mcp` - `config.py` - Loads env config. + - Injects async driver suffixes for bare `DATABASE_URL` schemes (`postgresql://`, `mysql://`, `sqlite://`). - Parses `ALLOWED_POLICY_FILE` in strict `table:columns` format. - `query_validator.py` - SQL safety checks (read-only, single statement). - Strict table/column authorization checks. - `database.py` - Async SQLAlchemy access. - - Read-only session preparation and query timeout/row caps. + - Read-only session preparation and query timeout/row caps (PostgreSQL, MySQL, SQLite). - `server.py` - MCP tool surface (`query`, `list_tables`, `describe_table`). - User/agent-facing responses. diff --git a/README.md b/README.md index 0a65b3f..9fe77c9 100644 --- a/README.md +++ b/README.md @@ -63,12 +63,24 @@ The `--env-file` should point to a file containing `DATABASE_URL` and `ALLOWED_P | Variable | Required | Default | Description | |----------|----------|---------|-------------| -| `DATABASE_URL` | Yes | — | SQLAlchemy async URL (e.g. `sqlite+aiosqlite:///./example.db` or `postgresql+asyncpg://...`) | +| `DATABASE_URL` | Yes | — | Database URL. Bare `postgresql://`, `mysql://`, and `sqlite://` URLs are accepted and auto-upgraded to async drivers (`+asyncpg`, `+aiomysql`, `+aiosqlite`). | | `ALLOWED_POLICY_FILE` | Yes | — | Path to the policy file | | `MAX_ROWS` | No | 100 | Maximum rows returned per query (1–10000) | | `QUERY_TIMEOUT` | No | 30 | Query timeout in seconds (1–300) | | `LOG_LEVEL` | No | INFO | Logging level (DEBUG, INFO, WARNING, ERROR) | +## CastleDM DSN Integration + +`acc castledm dsn` emits driver-agnostic URLs (`postgresql://...` or `mysql://...`). `secure-sql-mcp` accepts those values directly in `DATABASE_URL` and injects the async driver suffix automatically. + +Example: + +```bash +DATABASE_URL="$(acc castledm dsn maha intake --env=dev --read-only)" \ +ALLOWED_POLICY_FILE=./policy/allowed_policy.txt \ +python -m secure_sql_mcp.server +``` + ## Policy File Format `allowed_policy.txt`: diff --git a/pyproject.toml b/pyproject.toml index 466da88..8881041 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,6 +15,7 @@ dependencies = [ "sqlglot", "sqlalchemy[asyncio]", "asyncpg", + "aiomysql", "aiosqlite", "pydantic-settings", ] diff --git a/src/secure_sql_mcp/config.py b/src/secure_sql_mcp/config.py index 34137d1..cf72a79 100644 --- a/src/secure_sql_mcp/config.py +++ b/src/secure_sql_mcp/config.py @@ -21,6 +21,29 @@ class Settings(BaseSettings): query_timeout: int = Field(default=30, alias="QUERY_TIMEOUT", ge=1, le=300) log_level: str = Field(default="INFO", alias="LOG_LEVEL") + @field_validator("database_url", mode="before") + @classmethod + def inject_async_driver(cls, value: Any) -> str: + """Ensure SQLAlchemy async URLs include an async driver suffix.""" + database_url = str(value).strip() + if "://" not in database_url: + return database_url + + scheme = database_url.split("://", 1)[0] + if "+" in scheme: + return database_url + + async_driver_map = { + "postgresql": "asyncpg", + "mysql": "aiomysql", + "sqlite": "aiosqlite", + } + driver = async_driver_map.get(scheme) + if driver is None: + return database_url + + return database_url.replace(f"{scheme}://", f"{scheme}+{driver}://", 1) + @model_validator(mode="after") def load_allowed_policy(self) -> Settings: """Load strict table:columns policy from file.""" diff --git a/src/secure_sql_mcp/database.py b/src/secure_sql_mcp/database.py index f799833..5b9125e 100644 --- a/src/secure_sql_mcp/database.py +++ b/src/secure_sql_mcp/database.py @@ -104,6 +104,10 @@ async def _prepare_read_only_session(self, conn: AsyncConnection) -> None: timeout_ms = int(self._settings.query_timeout) * 1000 await conn.execute(text("BEGIN READ ONLY")) await conn.execute(text(f"SET LOCAL statement_timeout = {timeout_ms}")) + elif self._settings.database_url.startswith("mysql"): + timeout_ms = int(self._settings.query_timeout) * 1000 + await conn.execute(text(f"SET SESSION MAX_EXECUTION_TIME = {timeout_ms}")) + await conn.execute(text("START TRANSACTION READ ONLY")) elif self._settings.database_url.startswith("sqlite"): await conn.execute(text("PRAGMA query_only = ON")) diff --git a/src/secure_sql_mcp/query_validator.py b/src/secure_sql_mcp/query_validator.py index dab4d73..090f5d6 100644 --- a/src/secure_sql_mcp/query_validator.py +++ b/src/secure_sql_mcp/query_validator.py @@ -239,6 +239,8 @@ def _validate_column_access( def _dialect(self) -> str | None: if self.settings.database_url.startswith("postgresql"): return "postgres" + if self.settings.database_url.startswith("mysql"): + return "mysql" if self.settings.database_url.startswith("sqlite"): return "sqlite" return None diff --git a/tests/test_config.py b/tests/test_config.py index 3c1b98f..6e2621f 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -11,6 +11,45 @@ from tests.conftest import write_policy +@pytest.mark.parametrize( + ("database_url", "expected_url"), + [ + ( + "postgresql://user:pass@localhost:5432/appdb", + "postgresql+asyncpg://user:pass@localhost:5432/appdb", + ), + ( + "mysql://user:pass@localhost:3306/appdb", + "mysql+aiomysql://user:pass@localhost:3306/appdb", + ), + ( + "sqlite:///./tmp.db", + "sqlite+aiosqlite:///./tmp.db", + ), + ( + "postgresql+asyncpg://user:pass@localhost:5432/appdb", + "postgresql+asyncpg://user:pass@localhost:5432/appdb", + ), + ( + "mysql+aiomysql://user:pass@localhost:3306/appdb", + "mysql+aiomysql://user:pass@localhost:3306/appdb", + ), + ], +) +def test_database_url_injects_or_preserves_async_driver( + tmp_path: Path, database_url: str, expected_url: str +) -> None: + policy_path = tmp_path / "policy.txt" + write_policy(policy_path, "customers:id\n") + settings = Settings.model_validate( + { + "DATABASE_URL": database_url, + "ALLOWED_POLICY_FILE": str(policy_path), + } + ) + assert settings.database_url == expected_url + + def test_policy_invalid_format_no_colon_raises(tmp_path: Path) -> None: policy_path = tmp_path / "policy.txt" write_policy(policy_path, "customers id email\n") diff --git a/tests/test_mcp_interface.py b/tests/test_mcp_interface.py index 96325f8..7b69098 100644 --- a/tests/test_mcp_interface.py +++ b/tests/test_mcp_interface.py @@ -4,6 +4,7 @@ import json import sqlite3 from pathlib import Path +from unittest.mock import AsyncMock import pytest from pydantic import ValidationError @@ -318,3 +319,25 @@ async def _raise_db_error(_: str) -> object: assert "describe_table" in response assert "supersecret" not in response assert "internal-db" not in response + + +def test_prepare_read_only_session_mysql_sets_timeout_and_read_only(tmp_path: Path) -> None: + policy_path = tmp_path / "allowed_policy.txt" + write_policy(policy_path, "customers:id\n") + settings = Settings.model_validate( + { + "DATABASE_URL": "mysql://user:pass@localhost:3306/appdb", + "ALLOWED_POLICY_FILE": str(policy_path), + "QUERY_TIMEOUT": 12, + } + ) + db = AsyncDatabase(settings) + fake_conn = AsyncMock() + + asyncio.run(db._prepare_read_only_session(fake_conn)) + + executed_sql = [str(call.args[0]) for call in fake_conn.execute.await_args_list] + assert executed_sql == [ + "SET SESSION MAX_EXECUTION_TIME = 12000", + "START TRANSACTION READ ONLY", + ] diff --git a/tests/test_query_validator_security.py b/tests/test_query_validator_security.py index 4465275..db734af 100644 --- a/tests/test_query_validator_security.py +++ b/tests/test_query_validator_security.py @@ -151,3 +151,22 @@ def test_validator_blocks_except_with_disallowed_table(validator: QueryValidator result = validator.validate_query("SELECT id FROM customers EXCEPT SELECT id FROM secrets") assert not result.ok assert "Access to table 'secrets' is restricted" in (result.error or "") + + +def test_validator_uses_mysql_dialect_for_mysql_url(tmp_path: Path) -> None: + policy_path = tmp_path / "allowed_policy.txt" + write_policy( + policy_path, + """ + customers:id,email + """, + ) + settings = Settings.model_validate( + { + "DATABASE_URL": "mysql://user:pass@localhost:3306/appdb", + "ALLOWED_POLICY_FILE": str(policy_path), + } + ) + validator = QueryValidator(settings) + + assert validator._dialect == "mysql" From e5b1f620d055b19a30b9f8d1ac829df6c7f789a3 Mon Sep 17 00:00:00 2001 From: Joaquin Rodriguez Date: Wed, 11 Mar 2026 23:09:02 -0400 Subject: [PATCH 2/4] docs: remove CastleDM-specific README section Keep the README provider-agnostic by removing the CastleDM DSN integration section while preserving the general DATABASE_URL behavior documentation. --- README.md | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/README.md b/README.md index 9fe77c9..2b49605 100644 --- a/README.md +++ b/README.md @@ -69,18 +69,6 @@ The `--env-file` should point to a file containing `DATABASE_URL` and `ALLOWED_P | `QUERY_TIMEOUT` | No | 30 | Query timeout in seconds (1–300) | | `LOG_LEVEL` | No | INFO | Logging level (DEBUG, INFO, WARNING, ERROR) | -## CastleDM DSN Integration - -`acc castledm dsn` emits driver-agnostic URLs (`postgresql://...` or `mysql://...`). `secure-sql-mcp` accepts those values directly in `DATABASE_URL` and injects the async driver suffix automatically. - -Example: - -```bash -DATABASE_URL="$(acc castledm dsn maha intake --env=dev --read-only)" \ -ALLOWED_POLICY_FILE=./policy/allowed_policy.txt \ -python -m secure_sql_mcp.server -``` - ## Policy File Format `allowed_policy.txt`: From a453b59e6be561795f08a850d0876e8ab93ace99 Mon Sep 17 00:00:00 2001 From: Joaquin Rodriguez Date: Fri, 13 Mar 2026 13:10:31 -0400 Subject: [PATCH 3/4] feat: add in-container OPA authz cutover Run OPA locally in the container and move authorization decisions to composed Rego policies so baseline constraints and ACL checks both gate access. This keeps enforcement fail-closed while preserving existing SQL safety semantics and tool contracts. --- .gitignore | 1 - Dockerfile | 16 +- README.md | 71 ++++++-- docker/entrypoint.sh | 18 ++ docker/wait_for_opa.py | 34 ++++ policy/allowed_policy.txt | 2 + policy/allowed_policy_castledm_test.txt | 7 + policy/data/acl.example.json | 14 ++ policy/rego/acl.rego | 57 ++++++ policy/rego/authz.rego | 18 ++ policy/rego/default_constraints.rego | 32 ++++ src/secure_sql_mcp/config.py | 86 +++++++++ src/secure_sql_mcp/opa_policy.py | 154 ++++++++++++++++ src/secure_sql_mcp/query_validator.py | 233 +++++++++++++++++++----- src/secure_sql_mcp/server.py | 34 +++- tests/test_config.py | 47 +++++ tests/test_mcp_interface.py | 8 +- tests/test_opa_policy.py | 95 ++++++++++ 18 files changed, 859 insertions(+), 68 deletions(-) create mode 100644 docker/entrypoint.sh create mode 100644 docker/wait_for_opa.py create mode 100644 policy/allowed_policy.txt create mode 100644 policy/allowed_policy_castledm_test.txt create mode 100644 policy/data/acl.example.json create mode 100644 policy/rego/acl.rego create mode 100644 policy/rego/authz.rego create mode 100644 policy/rego/default_constraints.rego create mode 100644 src/secure_sql_mcp/opa_policy.py create mode 100644 tests/test_opa_policy.py diff --git a/.gitignore b/.gitignore index ea53511..1c5bcc0 100644 --- a/.gitignore +++ b/.gitignore @@ -9,4 +9,3 @@ __pycache__/ dist/ build/ example.db -policy/ diff --git a/Dockerfile b/Dockerfile index 8f0ed61..89c691a 100644 --- a/Dockerfile +++ b/Dockerfile @@ -8,15 +8,27 @@ RUN python -m venv /opt/venv \ && /opt/venv/bin/pip install --no-cache-dir . \ && find /opt/venv -type d -name __pycache__ -exec rm -rf {} + 2>/dev/null; true +FROM openpolicyagent/opa:1.5.1-static AS opa + FROM python:3.12-slim ENV PYTHONDONTWRITEBYTECODE=1 \ PYTHONUNBUFFERED=1 \ - PATH="/opt/venv/bin:$PATH" + PATH="/opt/venv/bin:$PATH" \ + OPA_URL="http://127.0.0.1:8181" \ + OPA_DECISION_PATH="/v1/data/secure_sql/authz/decision" \ + OPA_TIMEOUT_MS="50" \ + OPA_FAIL_CLOSED="true" COPY --from=builder /opt/venv /opt/venv +COPY --from=opa /opa /usr/local/bin/opa +COPY policy /app/policy +COPY docker/entrypoint.sh /app/entrypoint.sh +COPY docker/wait_for_opa.py /app/wait_for_opa.py + +RUN chmod 0555 /usr/local/bin/opa /app/entrypoint.sh /app/wait_for_opa.py RUN useradd -r -s /usr/sbin/nologin appuser USER appuser -ENTRYPOINT ["python", "-m", "secure_sql_mcp.server"] +ENTRYPOINT ["/app/entrypoint.sh"] diff --git a/README.md b/README.md index 2b49605..1ac7f3c 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,7 @@ # Secure SQL MCP Server -Read-only SQL MCP server with strict table/column policy controls. +Read-only SQL MCP server with strict table/column policy controls, with OPA-based +authorization running inside the same container. [![CI](https://github.com/jrhuerta/secure-sql-mcp/actions/workflows/ci.yml/badge.svg)](https://github.com/jrhuerta/secure-sql-mcp/actions/workflows/ci.yml) [![GHCR](https://img.shields.io/badge/ghcr-jrhuerta%2Fsecure--sql--mcp-blue)](https://github.com/jrhuerta/secure-sql-mcp/pkgs/container/secure-sql-mcp) @@ -31,27 +32,33 @@ To use this server with Cursor, Claude Desktop, or other MCP clients, add it to **Claude Desktop** (`claude_desktop_config.json`): same structure under `mcpServers`. -The `--env-file` should point to a file containing `DATABASE_URL` and `ALLOWED_POLICY_FILE=/run/policy/allowed_policy.txt` (see Environment Variables below). The volume mounts the policy directory read-only. Pull the image first: `docker pull ghcr.io/jrhuerta/secure-sql-mcp:latest` +The `--env-file` should point to a file containing `DATABASE_URL` and +`ALLOWED_POLICY_FILE=/run/policy/allowed_policy.txt` (see Environment Variables below). +The volume mounts the policy directory read-only. Pull the image first: +`docker pull ghcr.io/jrhuerta/secure-sql-mcp:latest` ## Security Model - Database credentials stay server-side (env vars), never in prompts. - Only read queries are allowed. -- Policy is strict and file-based: - - one required file: `ALLOWED_POLICY_FILE` - - each line is `table:col1,col2,col3` or `table:*` +- OPA authorization runs in-process for the container image (local loopback only, no external port exposure). +- Policy is strict and deny-by-default: + - baseline constraints and ACL rules are evaluated by OPA + - ACL source can come from a native OPA data file or transformed legacy `ALLOWED_POLICY_FILE` - If a table/column is not explicitly allowed, it is blocked. ## Implemented Security Controls -- **Query shape enforcement** +- **OPA baseline constraints (`default_constraints`)** - Exactly one SQL statement is allowed per request. - Non-read operations are blocked (`INSERT`, `UPDATE`, `DELETE`, `DROP`, `ALTER`, `CREATE`, `TRUNCATE`, `GRANT`, `REVOKE`, `MERGE`, and related command expressions). -- **Strict access policy enforcement** + - Unqualified columns in multi-table queries are rejected under strict mode. +- **OPA ACL policy (`acl`)** - Deny-by-default for tables and columns. - Access checks apply across direct queries and composed queries (`JOIN`, `UNION`, subqueries, aliases). - - `SELECT *` is rejected unless the table policy is `table:*`. - - Unqualified columns in multi-table queries are rejected under strict mode. + - `SELECT *` is rejected unless ACL explicitly allows wildcard (`*`) for that table. +- **Composed authorization (`authz`)** + - Access is granted only when both `default_constraints` and `acl` allow. - **Runtime safety controls** - Query timeout and row cap are enforced server-side. - Row-cap truncation is explicit in response payloads. @@ -65,6 +72,11 @@ The `--env-file` should point to a file containing `DATABASE_URL` and `ALLOWED_P |----------|----------|---------|-------------| | `DATABASE_URL` | Yes | — | Database URL. Bare `postgresql://`, `mysql://`, and `sqlite://` URLs are accepted and auto-upgraded to async drivers (`+asyncpg`, `+aiomysql`, `+aiosqlite`). | | `ALLOWED_POLICY_FILE` | Yes | — | Path to the policy file | +| `OPA_URL` | No | `http://127.0.0.1:8181` in Docker image; unset otherwise | OPA base URL. When set, queries/tools are authorized via OPA. | +| `OPA_DECISION_PATH` | No | `/v1/data/secure_sql/authz/decision` | OPA decision endpoint path. | +| `OPA_TIMEOUT_MS` | No | `50` | OPA decision timeout in milliseconds. | +| `OPA_FAIL_CLOSED` | No | `true` | If `true`, OPA errors/timeouts block access. | +| `OPA_ACL_DATA_FILE` | No | unset | Optional JSON ACL file (`secure_sql.acl.tables`) preferred over transformed `ALLOWED_POLICY_FILE`. | | `MAX_ROWS` | No | 100 | Maximum rows returned per query (1–10000) | | `QUERY_TIMEOUT` | No | 30 | Query timeout in seconds (1–300) | | `LOG_LEVEL` | No | INFO | Logging level (DEBUG, INFO, WARNING, ERROR) | @@ -84,6 +96,18 @@ Rules: - `#` comments and blank lines are allowed. - Matching is case-insensitive. +## OPA Policy Layout + +- Rego bundle directory: `policy/rego/` + - `default_constraints.rego` + - `acl.rego` + - `authz.rego` +- Example ACL data file: `policy/data/acl.example.json` + +ACL source precedence at runtime: +1. If `OPA_ACL_DATA_FILE` is set, ACL input is loaded from that JSON file. +2. Otherwise, `ALLOWED_POLICY_FILE` is transformed into equivalent ACL input. + ## Agent Discoverability The MCP server exposes: @@ -115,12 +139,31 @@ QUERY_TIMEOUT=30 LOG_LEVEL=INFO EOF +# Optional when testing against an external/local OPA process outside the container: +# OPA_URL=http://127.0.0.1:8181 +# OPA_DECISION_PATH=/v1/data/secure_sql/authz/decision +# OPA_TIMEOUT_MS=50 +# OPA_FAIL_CLOSED=true + mkdir -p policy cat > policy/allowed_policy.txt <<'EOF' customers:id,email orders:* EOF +cat > policy/acl.json <<'EOF' +{ + "secure_sql": { + "acl": { + "tables": { + "customers": {"columns": ["id", "email"]}, + "orders": {"columns": ["*"]} + } + } + } +} +EOF + # Create tables for local testing (optional) sqlite3 example.db <<'SQL' CREATE TABLE IF NOT EXISTS customers (id INTEGER PRIMARY KEY, email TEXT NOT NULL, ssn TEXT); @@ -150,6 +193,7 @@ EOF cat > .env <<'EOF' DATABASE_URL=sqlite+aiosqlite:///./example.db ALLOWED_POLICY_FILE=/run/policy/allowed_policy.txt +OPA_ACL_DATA_FILE=/run/policy/acl.json MAX_ROWS=100 QUERY_TIMEOUT=30 LOG_LEVEL=INFO @@ -191,6 +235,7 @@ docker compose up --build - Avoid hardcoding credentials in shell history. - Mount policy files read-only (`:ro`) in Docker. - Keep `.env` and policy files out of version control. +- Keep OPA policy/data assets immutable in runtime containers. ## Dev Tooling @@ -221,6 +266,7 @@ What these suites validate: - strict deny-by-default table/column ACL checks, including join/union/subquery paths - protocol-level behavior over MCP stdio transport - timeout, row cap truncation, and non-leaky actionable DB error responses +- OPA fail-closed behavior and ACL source precedence ## CI Security Gate Expectations @@ -237,7 +283,7 @@ python -m pytest -q \ Recommended policy: - block merges on any failure in the security suites above -- require test updates when changing query validation, policy parsing, or MCP tool responses +- require test updates when changing query validation, OPA policy inputs, policy parsing, or MCP tool responses - keep security test fixtures deterministic (no shared state, no external DB dependency by default) ## Contributing @@ -258,8 +304,9 @@ Before merging security-sensitive changes, verify: - query validation still enforces exactly one statement per request - mutation/DDL/privilege SQL operations are blocked with actionable messaging -- table and column access remains deny-by-default against `ALLOWED_POLICY_FILE` -- `SELECT *` is rejected unless policy explicitly allows `table:*` +- table and column access remains deny-by-default against effective ACL source + (`OPA_ACL_DATA_FILE` when present, else transformed `ALLOWED_POLICY_FILE`) +- `SELECT *` is rejected unless ACL explicitly allows wildcard - multi-table queries still reject unqualified columns and enforce alias-aware ACLs - timeout and row-cap protections remain active and tested - DB error responses stay sanitized and do not expose credentials/internal connection details diff --git a/docker/entrypoint.sh b/docker/entrypoint.sh new file mode 100644 index 0000000..dfd48a9 --- /dev/null +++ b/docker/entrypoint.sh @@ -0,0 +1,18 @@ +#!/usr/bin/env sh +set -eu + +OPA_BUNDLE_DIR="${OPA_BUNDLE_DIR:-/app/policy}" +OPA_ADDR="${OPA_ADDR:-127.0.0.1:8181}" + +opa run --server --addr "$OPA_ADDR" --bundle "$OPA_BUNDLE_DIR" & +OPA_PID=$! + +cleanup() { + kill "$OPA_PID" 2>/dev/null || true +} + +trap cleanup INT TERM EXIT + +python /app/wait_for_opa.py + +exec python -m secure_sql_mcp.server diff --git a/docker/wait_for_opa.py b/docker/wait_for_opa.py new file mode 100644 index 0000000..fa03036 --- /dev/null +++ b/docker/wait_for_opa.py @@ -0,0 +1,34 @@ +"""Wait for OPA health without extra image dependencies. + +This project uses a Python-based readiness check instead of curl/wget so the +runtime image can stay minimal and self-contained. `python:3.12-slim` already +ships Python (required by the MCP server), while curl is not guaranteed. +""" + +from __future__ import annotations + +import os +import time +from urllib import request + + +def main() -> None: + opa_url = os.environ.get("OPA_URL", "http://127.0.0.1:8181").rstrip("/") + health_url = f"{opa_url}/health" + + deadline = time.time() + 15 + last_error: Exception | None = None + while time.time() < deadline: + try: + with request.urlopen(health_url, timeout=0.5) as resp: # noqa: S310 + if 200 <= resp.status < 300: + return + except Exception as exc: # noqa: BLE001 + last_error = exc + time.sleep(0.2) + + raise SystemExit(f"OPA failed health check: {last_error}") + + +if __name__ == "__main__": + main() diff --git a/policy/allowed_policy.txt b/policy/allowed_policy.txt new file mode 100644 index 0000000..eb05c8c --- /dev/null +++ b/policy/allowed_policy.txt @@ -0,0 +1,2 @@ +customers:id,email +orders:* diff --git a/policy/allowed_policy_castledm_test.txt b/policy/allowed_policy_castledm_test.txt new file mode 100644 index 0000000..7b32740 --- /dev/null +++ b/policy/allowed_policy_castledm_test.txt @@ -0,0 +1,7 @@ +tbl_alayamarket_demand_offers:id,visit_id,alayamarket_offer_id +tbl_alayamarket_demand_offers_audit:* +tbl_offers_responses:* +tbl_schedule_offer:* +tbl_schedule_offer_response:* +view_grouped_offer_responses:* +view_guid_computed_offer_responses:* diff --git a/policy/data/acl.example.json b/policy/data/acl.example.json new file mode 100644 index 0000000..9b644a2 --- /dev/null +++ b/policy/data/acl.example.json @@ -0,0 +1,14 @@ +{ + "secure_sql": { + "acl": { + "tables": { + "customers": { + "columns": ["id", "email"] + }, + "orders": { + "columns": ["*"] + } + } + } + } +} diff --git a/policy/rego/acl.rego b/policy/rego/acl.rego new file mode 100644 index 0000000..07217a7 --- /dev/null +++ b/policy/rego/acl.rego @@ -0,0 +1,57 @@ +package secure_sql.acl + +default allow := false + +acl_tables := object.get(object.get(input, "acl", {}), "tables", {}) + +is_query_tool if input.tool.name == "query" +is_list_tables_tool if input.tool.name == "list_tables" +is_describe_table_tool if input.tool.name == "describe_table" + +table_allowed(table) if object.get(acl_tables, table, null) != null + +allowed_columns(table) := object.get(object.get(acl_tables, table, {}), "columns", []) + +column_allowed(table, col) if { + allowed_columns(table)[_] == "*" +} + +column_allowed(table, col) if { + allowed_columns(table)[_] == col +} + +deny_reasons["table_restricted"] if { + is_query_tool + table := input.query.referenced_tables[_] + not table_allowed(table) +} + +deny_reasons["column_restricted"] if { + is_query_tool + table := object.keys(input.query.referenced_columns)[_] + col := input.query.referenced_columns[table][_] + not column_allowed(table, col) +} + +deny_reasons["star_not_allowed"] if { + is_query_tool + table := input.query.star_tables[_] + not column_allowed(table, "*") +} + +deny_reasons["table_restricted"] if { + is_describe_table_tool + not table_allowed(input.table) +} + +allow if is_list_tables_tool + +allow if { + is_describe_table_tool + table_allowed(input.table) +} + +allow if { + is_query_tool + count(deny_reasons) == 0 +} diff --git a/policy/rego/authz.rego b/policy/rego/authz.rego new file mode 100644 index 0000000..069e4a6 --- /dev/null +++ b/policy/rego/authz.rego @@ -0,0 +1,18 @@ +package secure_sql.authz + +import data.secure_sql.acl +import data.secure_sql.default_constraints + +deny_reasons[reason] if default_constraints.deny_reasons[reason] +deny_reasons[reason] if acl.deny_reasons[reason] + +allow if { + default_constraints.allow + acl.allow + count(deny_reasons) == 0 +} + +decision := { + "allow": allow, + "deny_reasons": [reason | deny_reasons[reason]], +} diff --git a/policy/rego/default_constraints.rego b/policy/rego/default_constraints.rego new file mode 100644 index 0000000..2285601 --- /dev/null +++ b/policy/rego/default_constraints.rego @@ -0,0 +1,32 @@ +package secure_sql.default_constraints + +default allow := false + +is_query_tool if input.tool.name == "query" + +deny_reasons["multiple_statements"] if { + is_query_tool + input.query.statement_count != 1 +} + +deny_reasons["disallowed_operation"] if { + is_query_tool + input.query.has_disallowed_operation +} + +deny_reasons["not_read_query"] if { + is_query_tool + not input.query.is_read_statement +} + +deny_reasons["unqualified_multi_table_column"] if { + is_query_tool + input.query.has_unqualified_multi_table_columns +} + +allow if not is_query_tool + +allow if { + is_query_tool + count(deny_reasons) == 0 +} diff --git a/src/secure_sql_mcp/config.py b/src/secure_sql_mcp/config.py index cf72a79..52677b3 100644 --- a/src/secure_sql_mcp/config.py +++ b/src/secure_sql_mcp/config.py @@ -2,6 +2,7 @@ from __future__ import annotations +import json from pathlib import Path from typing import Any @@ -17,6 +18,14 @@ class Settings(BaseSettings): database_url: str = Field(alias="DATABASE_URL") allowed_policy_file: str = Field(alias="ALLOWED_POLICY_FILE") allowed_policy: dict[str, set[str]] = Field(default_factory=dict) + effective_acl_policy: dict[str, set[str]] = Field(default_factory=dict) + opa_url: str | None = Field(default=None, alias="OPA_URL") + opa_decision_path: str = Field( + default="/v1/data/secure_sql/authz/decision", alias="OPA_DECISION_PATH" + ) + opa_timeout_ms: int = Field(default=50, alias="OPA_TIMEOUT_MS", ge=1, le=5000) + opa_fail_closed: bool = Field(default=True, alias="OPA_FAIL_CLOSED") + opa_acl_data_file: str | None = Field(default=None, alias="OPA_ACL_DATA_FILE") max_rows: int = Field(default=100, alias="MAX_ROWS", ge=1, le=10000) query_timeout: int = Field(default=30, alias="QUERY_TIMEOUT", ge=1, le=300) log_level: str = Field(default="INFO", alias="LOG_LEVEL") @@ -48,8 +57,19 @@ def inject_async_driver(cls, value: Any) -> str: def load_allowed_policy(self) -> Settings: """Load strict table:columns policy from file.""" self.allowed_policy = self._parse_allowed_policy_file(self.allowed_policy_file) + self.effective_acl_policy = self._load_effective_acl_policy( + self.allowed_policy, self.opa_acl_data_file + ) return self + @field_validator("opa_decision_path", mode="before") + @classmethod + def normalize_opa_decision_path(cls, value: Any) -> str: + path = str(value).strip() + if not path: + return "/v1/data/secure_sql/authz/decision" + return path if path.startswith("/") else f"/{path}" + @field_validator("log_level", mode="before") @classmethod def normalize_log_level(cls, value: Any) -> str: @@ -105,6 +125,72 @@ def _parse_allowed_policy_file(path: str) -> dict[str, set[str]]: raise ValueError("Allowed policy file is empty. Add at least one table rule.") return policy + @classmethod + def _load_effective_acl_policy( + cls, allowed_policy: dict[str, set[str]], opa_acl_data_file: str | None + ) -> dict[str, set[str]]: + """Load ACL from OPA-native data file when available, else fallback to legacy policy.""" + if not opa_acl_data_file: + return {table: set(columns) for table, columns in allowed_policy.items()} + + acl_path = Path(opa_acl_data_file).expanduser() + if not acl_path.exists(): + raise ValueError(f"OPA ACL data file does not exist: {acl_path}") + if not acl_path.is_file(): + raise ValueError(f"OPA ACL data path is not a file: {acl_path}") + + try: + payload = json.loads(acl_path.read_text(encoding="utf-8")) + except json.JSONDecodeError as exc: + raise ValueError(f"OPA ACL data file must be valid JSON: {exc.msg}") from exc + + tables_payload = cls._extract_opa_tables_payload(payload) + parsed: dict[str, set[str]] = {} + for raw_table, raw_rule in tables_payload.items(): + table = str(raw_table).strip().lower() + if not table: + continue + if not isinstance(raw_rule, dict): + raise ValueError(f"OPA ACL rule for table '{table}' must be an object.") + raw_columns = raw_rule.get("columns") + if not isinstance(raw_columns, list) or not raw_columns: + raise ValueError( + f"OPA ACL rule for table '{table}' must include non-empty 'columns' list." + ) + columns = {str(column).strip().lower() for column in raw_columns if str(column).strip()} + if not columns: + raise ValueError( + f"OPA ACL rule for table '{table}' includes no valid column entries." + ) + if "*" in columns and len(columns) > 1: + raise ValueError( + f"OPA ACL wildcard for table '{table}' must be used alone in 'columns'." + ) + parsed[table] = columns + + if not parsed: + raise ValueError("OPA ACL data file resolved to an empty ACL policy.") + return parsed + + @staticmethod + def _extract_opa_tables_payload(payload: Any) -> dict[str, Any]: + if not isinstance(payload, dict): + raise ValueError("OPA ACL data file root must be a JSON object.") + + if "tables" in payload and isinstance(payload["tables"], dict): + return payload["tables"] + + secure_sql = payload.get("secure_sql") + if not isinstance(secure_sql, dict): + raise ValueError("OPA ACL data must define either 'tables' or 'secure_sql.acl.tables'.") + acl = secure_sql.get("acl") + if not isinstance(acl, dict): + raise ValueError("OPA ACL data missing object at 'secure_sql.acl'.") + tables = acl.get("tables") + if not isinstance(tables, dict): + raise ValueError("OPA ACL data missing object at 'secure_sql.acl.tables'.") + return tables + def load_settings() -> Settings: """Load typed settings from environment variables.""" diff --git a/src/secure_sql_mcp/opa_policy.py b/src/secure_sql_mcp/opa_policy.py new file mode 100644 index 0000000..876e1d4 --- /dev/null +++ b/src/secure_sql_mcp/opa_policy.py @@ -0,0 +1,154 @@ +"""OPA policy evaluation helpers.""" + +from __future__ import annotations + +import asyncio +import json +from dataclasses import dataclass, field +from typing import Any +from urllib import error, request + +from secure_sql_mcp.config import Settings + + +@dataclass(slots=True) +class PolicyDecision: + """Normalized policy decision returned to callers.""" + + allow: bool + deny_reasons: list[str] = field(default_factory=list) + message: str | None = None + raw_result: dict[str, Any] | None = None + + +class OpaPolicyEngine: + """Evaluates policy decisions against a local OPA server.""" + + def __init__(self, settings: Settings) -> None: + self.settings = settings + + async def evaluate(self, payload: dict[str, Any]) -> PolicyDecision: + return await asyncio.to_thread(self._evaluate_sync, payload) + + def evaluate_sync(self, payload: dict[str, Any]) -> PolicyDecision: + return self._evaluate_sync(payload) + + def _evaluate_sync(self, payload: dict[str, Any]) -> PolicyDecision: + if not self.settings.opa_url: + return PolicyDecision( + allow=False, + deny_reasons=["opa_unconfigured"], + message=( + "Authorization service is not configured. " + "Please escalate to a human operator." + ), + ) + + endpoint = f"{self.settings.opa_url.rstrip('/')}{self.settings.opa_decision_path}" + body = json.dumps({"input": payload}).encode("utf-8") + req = request.Request( # noqa: S310 + endpoint, + data=body, + method="POST", + headers={"Content-Type": "application/json"}, + ) + + try: + with request.urlopen(req, timeout=self.settings.opa_timeout_ms / 1000) as response: # noqa: S310 + data = json.loads(response.read().decode("utf-8")) + except (error.URLError, TimeoutError, json.JSONDecodeError) as exc: + if self.settings.opa_fail_closed: + return PolicyDecision( + allow=False, + deny_reasons=["opa_unavailable"], + message=( + "Authorization service is unavailable. " + "Please retry or escalate to a human operator." + ), + ) + return PolicyDecision( + allow=True, + deny_reasons=[], + message=None, + raw_result={"warning": str(exc)}, + ) + + result = self._extract_result(data) + if result is None: + return PolicyDecision( + allow=not self.settings.opa_fail_closed, + deny_reasons=["opa_undefined"], + message=( + "Authorization decision is unavailable. " + "Please retry or escalate to a human operator." + ), + ) + + if isinstance(result, bool): + return PolicyDecision(allow=result, raw_result={"allow": result}) + + if not isinstance(result, dict): + return PolicyDecision( + allow=False, + deny_reasons=["opa_invalid_result"], + message="Authorization decision format is invalid.", + ) + + allow = bool(result.get("allow", False)) + deny_reasons = [str(reason) for reason in result.get("deny_reasons", [])] + message = result.get("message") + if message is not None: + message = str(message) + + if not allow and not message: + message = self._message_for_reasons(deny_reasons) + + return PolicyDecision( + allow=allow, + deny_reasons=deny_reasons, + message=message, + raw_result=result, + ) + + @staticmethod + def _extract_result(response_payload: dict[str, Any]) -> Any | None: + # OPA REST response shape: {"result": ...} + if "result" in response_payload: + return response_payload.get("result") + return None + + @staticmethod + def _message_for_reasons(deny_reasons: list[str]) -> str: + if "multiple_statements" in deny_reasons: + return ( + "Only a single SQL statement is allowed. " + "Please remove additional statements and try again." + ) + if "disallowed_operation" in deny_reasons: + return ( + "This server is configured for read-only access. " + "If you need to modify data, please escalate to a human operator." + ) + if "not_read_query" in deny_reasons: + return "Only read-only SELECT queries are allowed." + if "table_restricted" in deny_reasons: + return ( + "Access to one or more tables is restricted by the server access policy. " + "Please use list_tables/describe_table to view allowed targets." + ) + if "column_restricted" in deny_reasons: + return ( + "Access to one or more selected columns is restricted by policy. " + "Use describe_table to inspect allowed columns." + ) + if "star_not_allowed" in deny_reasons: + return ( + "SELECT * is not allowed under strict policy for one or more tables. " + "Please select explicit allowed columns." + ) + if "unqualified_multi_table_column" in deny_reasons: + return ( + "Unqualified column references are not allowed in multi-table queries " + "under strict mode." + ) + return "Query blocked by policy." diff --git a/src/secure_sql_mcp/query_validator.py b/src/secure_sql_mcp/query_validator.py index 090f5d6..79bece3 100644 --- a/src/secure_sql_mcp/query_validator.py +++ b/src/secure_sql_mcp/query_validator.py @@ -4,11 +4,13 @@ from collections import defaultdict from dataclasses import dataclass +from typing import Any import sqlglot from sqlglot import exp from secure_sql_mcp.config import Settings +from secure_sql_mcp.opa_policy import OpaPolicyEngine @dataclass(slots=True) @@ -39,11 +41,14 @@ class QueryValidator: exp.Command, ) - def __init__(self, settings: Settings) -> None: + def __init__(self, settings: Settings, policy_engine: OpaPolicyEngine | None = None) -> None: self.settings = settings + self.policy_engine = policy_engine or ( + OpaPolicyEngine(settings) if settings.opa_url else None + ) def validate_query(self, sql: str) -> ValidationResult: - """Validate SQL for single statement, read-only, and table ACL rules.""" + """Validate SQL and authorize according to configured policy backend.""" query = sql.strip() if not query: return ValidationResult(ok=False, error="Query is empty.") @@ -56,57 +61,114 @@ def validate_query(self, sql: str) -> ValidationResult: error="Could not parse the SQL query. Please check the syntax and try again.", ) - if len(statements) != 1: + if not statements or statements[0] is None: return ValidationResult( ok=False, - error=( - "Only a single SQL statement is allowed. " - "Please remove additional statements and try again." - ), + error="Could not parse the SQL query. Please check the syntax and try again.", ) statement = statements[0] - if statement is None: - return ValidationResult( - ok=False, - error="Could not parse the SQL query. Please check the syntax and try again.", - ) statement_type = statement.key.upper() if statement.key else "UNKNOWN" + statement_count = len(statements) + has_disallowed_operation = any( + stmt is not None and self._contains_disallowed_operation(stmt) for stmt in statements + ) + is_read_statement = statement_count == 1 and self._is_read_statement(statement) - if self._contains_disallowed_operation(statement): - return ValidationResult( - ok=False, - error=( - "This server is configured for read-only access. " - f"The operation '{statement_type}' is not permitted. " - "If you need to modify data, please escalate to a human operator." - ), - ) - - if not self._is_read_statement(statement): - return ValidationResult( - ok=False, - error=(f"Only read-only SELECT queries are allowed. Received '{statement_type}'."), - ) - - referenced_tables = self.extract_referenced_tables(statement) - table_policy = self._resolve_table_policy(referenced_tables) - if isinstance(table_policy, str): - return ValidationResult(ok=False, error=table_policy) - - columns_result = self.extract_referenced_columns(statement, referenced_tables) - if isinstance(columns_result, str): - return ValidationResult(ok=False, error=columns_result) + referenced_tables: list[str] = [] + referenced_columns: dict[str, set[str]] = {} + star_tables: set[str] = set() + has_unqualified_multi_table_columns = False + + if statement_count == 1: + referenced_tables = self.extract_referenced_tables(statement) + if self.policy_engine is None: + if has_disallowed_operation: + return ValidationResult( + ok=False, + error=( + "This server is configured for read-only access. " + f"The operation '{statement_type}' is not permitted. " + "If you need to modify data, please escalate to a human operator." + ), + ) + if not is_read_statement: + return ValidationResult( + ok=False, + error=( + "Only read-only SELECT queries are allowed. " + f"Received '{statement_type}'." + ), + ) + + table_policy = self._resolve_table_policy(referenced_tables) + if isinstance(table_policy, str): + return ValidationResult(ok=False, error=table_policy) + + columns_result = self.extract_referenced_columns(statement, referenced_tables) + if isinstance(columns_result, str): + return ValidationResult(ok=False, error=columns_result) + + referenced_columns, star_tables = columns_result + columns_error = self._validate_column_access( + table_policy, referenced_columns, star_tables + ) + if columns_error: + return ValidationResult(ok=False, error=columns_error) - referenced_columns, star_tables = columns_result - columns_error = self._validate_column_access(table_policy, referenced_columns, star_tables) - if columns_error: - return ValidationResult(ok=False, error=columns_error) + for table in referenced_tables: + access_error = self.table_access_error(table, table_policy=table_policy) + if access_error: + return ValidationResult(ok=False, error=access_error) + else: + referenced_columns, star_tables, has_unqualified_multi_table_columns = ( + self._extract_referenced_columns_relaxed(statement, referenced_tables) + ) - for table in referenced_tables: - access_error = self.table_access_error(table, table_policy=table_policy) - if access_error: - return ValidationResult(ok=False, error=access_error) + if self.policy_engine is None: + if statement_count != 1: + return ValidationResult( + ok=False, + error=( + "Only a single SQL statement is allowed. " + "Please remove additional statements and try again." + ), + ) + if has_disallowed_operation: + return ValidationResult( + ok=False, + error=( + "This server is configured for read-only access. " + f"The operation '{statement_type}' is not permitted. " + "If you need to modify data, please escalate to a human operator." + ), + ) + if not is_read_statement: + return ValidationResult( + ok=False, + error=( + "Only read-only SELECT queries are allowed. " + f"Received '{statement_type}'." + ), + ) + else: + decision = self.policy_engine.evaluate_sync( + self._build_query_policy_input( + sql=query, + statement_count=statement_count, + statement_type=statement_type.lower(), + has_disallowed_operation=has_disallowed_operation, + is_read_statement=is_read_statement, + referenced_tables=referenced_tables, + referenced_columns=referenced_columns, + star_tables=star_tables, + has_unqualified_multi_table_columns=has_unqualified_multi_table_columns, + ) + ) + if not decision.allow: + return ValidationResult( + ok=False, error=decision.message or "Query blocked by policy." + ) return ValidationResult( ok=True, @@ -191,7 +253,7 @@ def extract_referenced_columns( def _resolve_table_policy(self, tables: list[str]) -> dict[str, set[str]] | str: resolved: dict[str, set[str]] = {} - available = ", ".join(sorted(self.settings.allowed_policy)) + available = ", ".join(sorted(self.settings.effective_acl_policy)) for table in tables: policy_columns = self.lookup_table_policy(table) @@ -272,8 +334,8 @@ def lookup_table_policy(self, table_name: str) -> set[str] | None: normalized = table_name.lower() candidates = (normalized, normalized.split(".")[-1]) for candidate in candidates: - if candidate in self.settings.allowed_policy: - return set(self.settings.allowed_policy[candidate]) + if candidate in self.settings.effective_acl_policy: + return set(self.settings.effective_acl_policy[candidate]) return None def _build_alias_map(self, statement: exp.Expression) -> dict[str, str]: @@ -289,3 +351,80 @@ def _build_alias_map(self, statement: exp.Expression) -> dict[str, str]: alias_name = alias_expr.name.lower() alias_map[alias_name] = table_name return alias_map + + def _extract_referenced_columns_relaxed( + self, statement: exp.Expression, referenced_tables: list[str] + ) -> tuple[dict[str, set[str]], set[str], bool]: + alias_map = self._build_alias_map(statement) + columns_by_table: defaultdict[str, set[str]] = defaultdict(set) + unqualified_columns: set[str] = set() + star_tables: set[str] = set() + + for column in statement.find_all(exp.Column): + if isinstance(column.this, exp.Star): + continue + if not column.name: + continue + + col_name = column.name.lower() + qualifier = (column.table or "").lower() + if qualifier: + table_name = alias_map.get(qualifier, qualifier) + columns_by_table[table_name].add(col_name) + else: + unqualified_columns.add(col_name) + + has_unqualified_multi_table_columns = bool( + unqualified_columns and len(referenced_tables) > 1 + ) + if unqualified_columns and len(referenced_tables) == 1: + columns_by_table[referenced_tables[0]].update(unqualified_columns) + + for select in statement.find_all(exp.Select): + for expression in select.expressions: + if isinstance(expression, exp.Star): + if len(referenced_tables) == 1: + star_tables.add(referenced_tables[0]) + elif len(referenced_tables) > 1: + star_tables.update(referenced_tables) + elif isinstance(expression, exp.Column) and isinstance(expression.this, exp.Star): + qualifier = (expression.table or "").lower() + if qualifier: + star_tables.add(alias_map.get(qualifier, qualifier)) + + return dict(columns_by_table), star_tables, has_unqualified_multi_table_columns + + def _build_query_policy_input( + self, + *, + sql: str, + statement_count: int, + statement_type: str, + has_disallowed_operation: bool, + is_read_statement: bool, + referenced_tables: list[str], + referenced_columns: dict[str, set[str]], + star_tables: set[str], + has_unqualified_multi_table_columns: bool, + ) -> dict[str, Any]: + acl_tables = { + table: {"columns": sorted(columns)} + for table, columns in sorted(self.settings.effective_acl_policy.items()) + } + return { + "tool": {"name": "query"}, + "query": { + "raw_sql": sql, + "statement_count": statement_count, + "statement_type": statement_type, + "has_disallowed_operation": has_disallowed_operation, + "is_read_statement": is_read_statement, + "referenced_tables": referenced_tables, + "referenced_columns": { + table: sorted(columns) for table, columns in sorted(referenced_columns.items()) + }, + "star_tables": sorted(star_tables), + "has_unqualified_multi_table_columns": has_unqualified_multi_table_columns, + }, + "acl": {"tables": acl_tables}, + } diff --git a/src/secure_sql_mcp/server.py b/src/secure_sql_mcp/server.py index 14727d3..91bb47d 100644 --- a/src/secure_sql_mcp/server.py +++ b/src/secure_sql_mcp/server.py @@ -7,11 +7,13 @@ from collections.abc import AsyncIterator from contextlib import asynccontextmanager from dataclasses import dataclass +from typing import Any from mcp.server.fastmcp import FastMCP from secure_sql_mcp.config import Settings, load_settings from secure_sql_mcp.database import AsyncDatabase +from secure_sql_mcp.opa_policy import OpaPolicyEngine from secure_sql_mcp.query_validator import QueryValidator LOGGER = logging.getLogger(__name__) @@ -22,6 +24,7 @@ class AppState: settings: Settings db: AsyncDatabase validator: QueryValidator + policy_engine: OpaPolicyEngine | None STATE: AppState | None = None @@ -34,9 +37,10 @@ async def lifespan(_: FastMCP) -> AsyncIterator[None]: settings = load_settings() logging.basicConfig(level=settings.log_level) db = AsyncDatabase(settings) - validator = QueryValidator(settings) + policy_engine = OpaPolicyEngine(settings) if settings.opa_url else None + validator = QueryValidator(settings, policy_engine=policy_engine) await db.connect() - STATE = AppState(settings=settings, db=db, validator=validator) + STATE = AppState(settings=settings, db=db, validator=validator, policy_engine=policy_engine) LOGGER.info("secure-sql-mcp started") try: yield @@ -94,7 +98,14 @@ async def query(sql: str) -> str: async def list_tables() -> str: """List tables the agent is allowed to query, validating existence when possible.""" app = _state() - policy = app.settings.allowed_policy + if app.policy_engine is not None: + decision = await app.policy_engine.evaluate( + _build_tool_policy_input("list_tables", app.settings) + ) + if not decision.allow: + return decision.message or "Operation blocked by policy." + + policy = app.settings.effective_acl_policy policy_tables = sorted(policy) policy_set = {t.lower() for t in policy} discovered: list[str] = [] @@ -139,9 +150,16 @@ async def list_tables() -> str: async def describe_table(table: str) -> str: """Describe columns for an allowed table.""" app = _state() + if app.policy_engine is not None: + payload = _build_tool_policy_input("describe_table", app.settings) + payload["table"] = table.lower() + decision = await app.policy_engine.evaluate(payload) + if not decision.allow: + return decision.message or "Operation blocked by policy." + policy_columns = app.validator.lookup_table_policy(table) if policy_columns is None: - available_tables = ", ".join(sorted(app.settings.allowed_policy)) + available_tables = ", ".join(sorted(app.settings.effective_acl_policy)) return ( f"Access to table '{table}' is restricted by the server access policy. " f"Allowed tables are: {available_tables}. " @@ -178,5 +196,13 @@ def main() -> None: mcp.run(transport="stdio") +def _build_tool_policy_input(tool_name: str, settings: Settings) -> dict[str, Any]: + acl_tables = { + table: {"columns": sorted(columns)} + for table, columns in sorted(settings.effective_acl_policy.items()) + } + return {"tool": {"name": tool_name}, "acl": {"tables": acl_tables}} + + if __name__ == "__main__": main() diff --git a/tests/test_config.py b/tests/test_config.py index 6e2621f..d66b8f6 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -172,3 +172,50 @@ def test_normalize_log_level(tmp_path: Path) -> None: } ) assert settings.log_level == "DEBUG" + + +def test_opa_acl_data_file_preferred_over_allowed_policy(tmp_path: Path) -> None: + policy_path = tmp_path / "policy.txt" + write_policy(policy_path, "customers:id\n") + opa_acl_path = tmp_path / "acl.json" + opa_acl_path.write_text( + """ + { + "secure_sql": { + "acl": { + "tables": { + "orders": {"columns": ["*"]} + } + } + } + } + """, + encoding="utf-8", + ) + + settings = Settings.model_validate( + { + "DATABASE_URL": "sqlite+aiosqlite:///./tmp.db", + "ALLOWED_POLICY_FILE": str(policy_path), + "OPA_ACL_DATA_FILE": str(opa_acl_path), + } + ) + + assert settings.allowed_policy == {"customers": {"id"}} + assert settings.effective_acl_policy == {"orders": {"*"}} + + +def test_invalid_opa_acl_json_raises(tmp_path: Path) -> None: + policy_path = tmp_path / "policy.txt" + write_policy(policy_path, "customers:id\n") + opa_acl_path = tmp_path / "acl.json" + opa_acl_path.write_text("{not-json", encoding="utf-8") + + with pytest.raises(ValidationError): + Settings.model_validate( + { + "DATABASE_URL": "sqlite+aiosqlite:///./tmp.db", + "ALLOWED_POLICY_FILE": str(policy_path), + "OPA_ACL_DATA_FILE": str(opa_acl_path), + } + ) diff --git a/tests/test_mcp_interface.py b/tests/test_mcp_interface.py index 7b69098..7161d47 100644 --- a/tests/test_mcp_interface.py +++ b/tests/test_mcp_interface.py @@ -42,7 +42,9 @@ def app_state(tmp_path: Path): ) db = AsyncDatabase(settings) asyncio.run(db.connect()) - state = AppState(settings=settings, db=db, validator=QueryValidator(settings)) + state = AppState( + settings=settings, db=db, validator=QueryValidator(settings), policy_engine=None + ) mcp_server.STATE = state try: @@ -91,7 +93,9 @@ def limited_app_state(tmp_path: Path): ) db = AsyncDatabase(settings) asyncio.run(db.connect()) - state = AppState(settings=settings, db=db, validator=QueryValidator(settings)) + state = AppState( + settings=settings, db=db, validator=QueryValidator(settings), policy_engine=None + ) mcp_server.STATE = state try: diff --git a/tests/test_opa_policy.py b/tests/test_opa_policy.py new file mode 100644 index 0000000..0280316 --- /dev/null +++ b/tests/test_opa_policy.py @@ -0,0 +1,95 @@ +from __future__ import annotations + +import json +from pathlib import Path +from typing import Any, cast +from urllib import error + +from secure_sql_mcp.config import Settings +from secure_sql_mcp.opa_policy import OpaPolicyEngine, PolicyDecision +from secure_sql_mcp.query_validator import QueryValidator +from tests.conftest import write_policy + + +class _FakeResponse: + def __init__(self, payload: dict[str, object]) -> None: + self._payload = payload + self.status = 200 + + def read(self) -> bytes: + return json.dumps(self._payload).encode("utf-8") + + def __enter__(self) -> _FakeResponse: + return self + + def __exit__(self, *_: object) -> None: + return None + + +class _CaptureEngine: + def __init__(self, decision: PolicyDecision) -> None: + self.decision = decision + self.last_payload: dict[str, object] | None = None + + def evaluate_sync(self, payload: dict[str, object]) -> PolicyDecision: + self.last_payload = payload + return self.decision + + +def _settings(tmp_path: Path) -> Settings: + policy_path = tmp_path / "allowed_policy.txt" + write_policy(policy_path, "customers:id,email\norders:*\n") + return Settings.model_validate( + { + "DATABASE_URL": "sqlite+aiosqlite:///./tmp.db", + "ALLOWED_POLICY_FILE": str(policy_path), + "OPA_URL": "http://127.0.0.1:8181", + "OPA_FAIL_CLOSED": True, + } + ) + + +def test_opa_engine_fail_closed_on_transport_error(tmp_path: Path, monkeypatch) -> None: + settings = _settings(tmp_path) + engine = OpaPolicyEngine(settings) + + def _raise_url_error(*_: object, **__: object): + raise error.URLError("connection refused") + + monkeypatch.setattr("secure_sql_mcp.opa_policy.request.urlopen", _raise_url_error) + decision = engine.evaluate_sync({"tool": {"name": "query"}, "query": {"statement_count": 1}}) + + assert decision.allow is False + assert "opa_unavailable" in decision.deny_reasons + + +def test_opa_engine_parses_decision_payload(tmp_path: Path, monkeypatch) -> None: + settings = _settings(tmp_path) + engine = OpaPolicyEngine(settings) + + def _ok(*_: object, **__: object): + return _FakeResponse({"result": {"allow": False, "deny_reasons": ["table_restricted"]}}) + + monkeypatch.setattr("secure_sql_mcp.opa_policy.request.urlopen", _ok) + decision = engine.evaluate_sync({"tool": {"name": "query"}, "query": {"statement_count": 1}}) + + assert decision.allow is False + assert decision.deny_reasons == ["table_restricted"] + assert "restricted" in (decision.message or "") + + +def test_validator_builds_policy_input_for_opa(tmp_path: Path) -> None: + settings = _settings(tmp_path) + capture_engine = _CaptureEngine(PolicyDecision(allow=True)) + validator = QueryValidator(settings, policy_engine=cast(Any, capture_engine)) + + result = validator.validate_query( + "SELECT c.id, o.total FROM customers c JOIN orders o ON c.id = o.id" + ) + + assert result.ok + assert capture_engine.last_payload is not None + payload = capture_engine.last_payload + assert payload["tool"] == {"name": "query"} + assert payload["query"]["statement_count"] == 1 + assert sorted(payload["query"]["referenced_tables"]) == ["customers", "orders"] From 2c285d32b925967054411e27d90bf9a45fb8acdc Mon Sep 17 00:00:00 2001 From: Joaquin Rodriguez Date: Fri, 13 Mar 2026 16:53:30 -0400 Subject: [PATCH 4/4] feat: add gated write mode with OPA constraints Introduce policy-governed INSERT/UPDATE/DELETE support behind explicit runtime gates and Rego write constraints to preserve deny-by-default behavior. Add coverage across validator, MCP interface, and Docker+OPA integration paths to harden authorization and fail-closed behavior. --- .dockerignore | 1 - README.md | 51 +- docker-compose.test.yml | 59 ++ docker-compose.yml | 7 + docs/POLICY_AUTHORING.md | 233 ++++++++ docs/WRITE_MODE_DESIGN.md | 163 +++++ policy/rego/acl.rego | 93 ++- policy/rego/authz.rego | 4 + policy/rego/default_constraints.rego | 45 +- policy/rego/write_constraints.rego | 58 ++ pyproject.toml | 4 + scripts/run-docker-opa-smoke.sh | 8 + src/secure_sql_mcp/config.py | 7 + src/secure_sql_mcp/database.py | 48 ++ src/secure_sql_mcp/opa_policy.py | 54 +- src/secure_sql_mcp/query_validator.py | 557 ++++++++++++++++-- src/secure_sql_mcp/server.py | 50 +- tests/integration/__init__.py | 1 + tests/integration/docker/__init__.py | 1 + .../docker/acl/restricted_acl.json | 14 + tests/integration/docker/conftest.py | 212 +++++++ tests/integration/docker/db-init/mysql.sql | 24 + tests/integration/docker/db-init/postgres.sql | 27 + .../docker/policies/read_only_strict.txt | 2 + .../docker/policies/wildcard_tables.txt | 2 + .../policies/write_delete_restricted.txt | 2 + .../docker/policies/write_insert_only.txt | 2 + .../policies/write_update_restricted.txt | 2 + .../docker/test_mcp_docker_opa_matrix.py | 281 +++++++++ tests/test_config.py | 43 ++ tests/test_mcp_interface.py | 227 +++++++ tests/test_mcp_stdio_security.py | 58 +- tests/test_opa_policy.py | 97 ++- tests/test_query_validator_security.py | 162 +++++ tests/test_write_facts.py | 136 +++++ 35 files changed, 2683 insertions(+), 52 deletions(-) create mode 100644 docker-compose.test.yml create mode 100644 docs/POLICY_AUTHORING.md create mode 100644 docs/WRITE_MODE_DESIGN.md create mode 100644 policy/rego/write_constraints.rego create mode 100755 scripts/run-docker-opa-smoke.sh create mode 100644 tests/integration/__init__.py create mode 100644 tests/integration/docker/__init__.py create mode 100644 tests/integration/docker/acl/restricted_acl.json create mode 100644 tests/integration/docker/conftest.py create mode 100644 tests/integration/docker/db-init/mysql.sql create mode 100644 tests/integration/docker/db-init/postgres.sql create mode 100644 tests/integration/docker/policies/read_only_strict.txt create mode 100644 tests/integration/docker/policies/wildcard_tables.txt create mode 100644 tests/integration/docker/policies/write_delete_restricted.txt create mode 100644 tests/integration/docker/policies/write_insert_only.txt create mode 100644 tests/integration/docker/policies/write_update_restricted.txt create mode 100644 tests/integration/docker/test_mcp_docker_opa_matrix.py create mode 100644 tests/test_write_facts.py diff --git a/.dockerignore b/.dockerignore index f3e0b46..834db77 100644 --- a/.dockerignore +++ b/.dockerignore @@ -13,6 +13,5 @@ SECURITY.md CODE_OF_CONDUCT.md CONTRIBUTING.md docker-compose.yml -policy .pre-commit-config.yaml .gitignore diff --git a/README.md b/README.md index 1ac7f3c..cef8c6c 100644 --- a/README.md +++ b/README.md @@ -77,10 +77,22 @@ The volume mounts the policy directory read-only. Pull the image first: | `OPA_TIMEOUT_MS` | No | `50` | OPA decision timeout in milliseconds. | | `OPA_FAIL_CLOSED` | No | `true` | If `true`, OPA errors/timeouts block access. | | `OPA_ACL_DATA_FILE` | No | unset | Optional JSON ACL file (`secure_sql.acl.tables`) preferred over transformed `ALLOWED_POLICY_FILE`. | +| `WRITE_MODE_ENABLED` | No | `false` | Enables write execution path (`INSERT`/`UPDATE`/`DELETE`) when `true`. | +| `ALLOW_INSERT` | No | `false` | Allows `INSERT` statements when write mode is enabled. | +| `ALLOW_UPDATE` | No | `false` | Allows `UPDATE` statements when write mode is enabled. | +| `ALLOW_DELETE` | No | `false` | Allows `DELETE` statements when write mode is enabled. | +| `REQUIRE_WHERE_FOR_UPDATE` | No | `true` | When `true`, `UPDATE` requires a `WHERE` clause. | +| `REQUIRE_WHERE_FOR_DELETE` | No | `true` | When `true`, `DELETE` requires a `WHERE` clause. | +| `ALLOW_RETURNING` | No | `false` | Allows `RETURNING` on write statements when `true`. | | `MAX_ROWS` | No | 100 | Maximum rows returned per query (1–10000) | | `QUERY_TIMEOUT` | No | 30 | Query timeout in seconds (1–300) | | `LOG_LEVEL` | No | INFO | Logging level (DEBUG, INFO, WARNING, ERROR) | +Write mode guardrails: +- All write-related flags default to `false` (deny-by-default). +- OPA remains the policy decision point, but config gates are enforced first as coarse runtime brakes. +- The server logs a `WARNING` when config gates block a write that policy would otherwise allow. + ## Policy File Format `allowed_policy.txt`: @@ -103,6 +115,8 @@ Rules: - `acl.rego` - `authz.rego` - Example ACL data file: `policy/data/acl.example.json` +- Policy authoring guide: [`docs/POLICY_AUTHORING.md`](docs/POLICY_AUTHORING.md) +- Controlled write mode design: [`docs/WRITE_MODE_DESIGN.md`](docs/WRITE_MODE_DESIGN.md) ACL source precedence at runtime: 1. If `OPA_ACL_DATA_FILE` is set, ACL input is loaded from that JSON file. @@ -120,7 +134,8 @@ The MCP server exposes: - allowed columns for that table from policy - schema metadata from DB when available - `query(sql)`: - - executes only if query is read-only and within table/column policy + - executes read queries by default under table/column policy + - executes write queries only when write mode/action toggles allow them and policy permits ## Quick Start (uv) @@ -261,13 +276,45 @@ python -m pytest -q \ ``` What these suites validate: -- read-only enforcement for mutation/privileged SQL operations +- default read-only enforcement for mutation/privileged SQL operations - single-statement validation and parser hardening - strict deny-by-default table/column ACL checks, including join/union/subquery paths +- write-mode guardrails (`WRITE_MODE_ENABLED` and per-action toggles), including WHERE safety checks - protocol-level behavior over MCP stdio transport - timeout, row cap truncation, and non-leaky actionable DB error responses - OPA fail-closed behavior and ACL source precedence +## Real Docker + OPA Matrix Tests + +Run comprehensive real-server scenarios against Dockerized MCP+OPA across +SQLite, PostgreSQL, and MySQL: + +```bash +python -m pytest -q -m docker_integration tests/integration/docker/test_mcp_docker_opa_matrix.py +``` + +Run a faster smoke subset: + +```bash +bash scripts/run-docker-opa-smoke.sh +``` + +Prerequisites: +- Docker Engine with Compose plugin (`docker compose`) +- ability to pull base images (`postgres:16-alpine`, `mysql:8.4`) + +Troubleshooting: +- if MySQL/PostgreSQL startup is slow, rerun with `-m docker_integration -vv` to inspect per-scenario logs +- if Docker is unavailable, these tests auto-skip and unit/security suites still run normally +- if port/resource contention occurs, remove stale test stacks: `docker compose -f docker-compose.test.yml down -v --remove-orphans` + +What the Docker matrix validates: +- read/write allow/deny behavior with real container runtime and OPA process +- policy-profile variants mounted as read-only files +- write gate toggles (`WRITE_MODE_ENABLED`, `ALLOW_*`) and WHERE/RETURNING controls +- bypass-focused checks (`INSERT ... SELECT`, source `SELECT *`, tautological WHERE) +- OPA fail-closed behavior when decision service is unavailable + ## CI Security Gate Expectations For protected branches, treat these checks as merge blockers: diff --git a/docker-compose.test.yml b/docker-compose.test.yml new file mode 100644 index 0000000..c4687f1 --- /dev/null +++ b/docker-compose.test.yml @@ -0,0 +1,59 @@ +services: + secure-sql-mcp: + build: + context: . + dockerfile: Dockerfile + image: secure-sql-mcp:test + environment: + ALLOWED_POLICY_FILE: /run/policy/allowed_policy.txt + OPA_URL: http://127.0.0.1:8181 + OPA_DECISION_PATH: /v1/data/secure_sql/authz/decision + OPA_TIMEOUT_MS: "50" + OPA_FAIL_CLOSED: "true" + WRITE_MODE_ENABLED: "false" + ALLOW_INSERT: "false" + ALLOW_UPDATE: "false" + ALLOW_DELETE: "false" + REQUIRE_WHERE_FOR_UPDATE: "true" + REQUIRE_WHERE_FOR_DELETE: "true" + ALLOW_RETURNING: "false" + MAX_ROWS: "100" + QUERY_TIMEOUT: "30" + LOG_LEVEL: "INFO" + stdin_open: true + tty: false + depends_on: + postgres: + condition: service_healthy + mysql: + condition: service_healthy + + postgres: + image: postgres:16-alpine + environment: + POSTGRES_USER: secure + POSTGRES_PASSWORD: secure + POSTGRES_DB: secure_sql_test + volumes: + - ./tests/integration/docker/db-init/postgres.sql:/docker-entrypoint-initdb.d/01-init.sql:ro + healthcheck: + test: ["CMD-SHELL", "pg_isready -U secure -d secure_sql_test"] + interval: 2s + timeout: 2s + retries: 30 + + mysql: + image: mysql:8.4 + command: ["--default-authentication-plugin=mysql_native_password"] + environment: + MYSQL_ROOT_PASSWORD: root + MYSQL_USER: secure + MYSQL_PASSWORD: secure + MYSQL_DATABASE: secure_sql_test + volumes: + - ./tests/integration/docker/db-init/mysql.sql:/docker-entrypoint-initdb.d/01-init.sql:ro + healthcheck: + test: ["CMD-SHELL", "mysqladmin ping -h 127.0.0.1 -usecure -psecure --silent"] + interval: 2s + timeout: 2s + retries: 60 diff --git a/docker-compose.yml b/docker-compose.yml index 4823db2..1b1edd4 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -7,6 +7,13 @@ services: environment: DATABASE_URL: ${DATABASE_URL:-sqlite+aiosqlite:///./example.db} ALLOWED_POLICY_FILE: ${ALLOWED_POLICY_FILE:-/run/policy/allowed_policy.txt} + OPA_DECISION_PATH: ${OPA_DECISION_PATH:-/v1/data/secure_sql/authz/decision} + OPA_TIMEOUT_MS: ${OPA_TIMEOUT_MS:-50} + OPA_FAIL_CLOSED: ${OPA_FAIL_CLOSED:-true} + WRITE_MODE_ENABLED: ${WRITE_MODE_ENABLED:-false} + ALLOW_INSERT: ${ALLOW_INSERT:-false} + ALLOW_UPDATE: ${ALLOW_UPDATE:-false} + ALLOW_DELETE: ${ALLOW_DELETE:-false} MAX_ROWS: ${MAX_ROWS:-100} QUERY_TIMEOUT: ${QUERY_TIMEOUT:-30} LOG_LEVEL: ${LOG_LEVEL:-INFO} diff --git a/docs/POLICY_AUTHORING.md b/docs/POLICY_AUTHORING.md new file mode 100644 index 0000000..a02b2aa --- /dev/null +++ b/docs/POLICY_AUTHORING.md @@ -0,0 +1,233 @@ +# Policy Authoring Guide (OPA/Rego) + +This guide explains how to write and customize policies for `secure-sql-mcp`. + +It is designed to be: +- practical for humans +- structured enough for agents to generate policy variants + +## Policy model at a glance + +The default bundle composes two policy modules: + +- `default_constraints`: baseline guardrails (statement count, operation class, query shape) +- `acl`: table/column access rules + +Final decision is an AND: + +- allow only when `default_constraints.allow` and `acl.allow` are both true + +See: +- `policy/rego/default_constraints.rego` +- `policy/rego/acl.rego` +- `policy/rego/authz.rego` + +## Runtime architecture + +The server now has split execution paths: + +- read statements execute through `execute_read_query(...)` +- write statements execute through `execute_write_query(...)` + +Write execution is still deny-by-default and requires *both* runtime gates and policy: + +- runtime gates (`WRITE_MODE_ENABLED`, `ALLOW_INSERT/UPDATE/DELETE`) +- policy allow in OPA (`default_constraints`, `acl`, `write_constraints`) + +## Input facts available to policy + +OPA receives `{"input": ...}` payloads from the server. + +### For `query` tool + +```json +{ + "tool": { "name": "query" }, + "query": { + "raw_sql": "SELECT id FROM customers", + "statement_count": 1, + "statement_type": "select", + "is_write_statement": false, + "has_disallowed_operation": false, + "is_read_statement": true, + "referenced_tables": ["customers"], + "referenced_columns": { "customers": ["id"] }, + "star_tables": [], + "has_unqualified_multi_table_columns": false, + "target_table": "", + "insert_columns": [], + "updated_columns": [], + "where_present": false, + "where_tautological": false, + "returning_present": false, + "returning_columns": [], + "has_select_source": false, + "source_tables": [] + }, + "config": { + "write_mode_enabled": false, + "allow_insert": false, + "allow_update": false, + "allow_delete": false, + "require_where_for_update": true, + "require_where_for_delete": true, + "allow_returning": false + }, + "acl": { + "tables": { + "customers": { "columns": ["id", "email"] }, + "orders": { "columns": ["*"] } + } + } +} +``` + +### For `list_tables` tool + +```json +{ + "tool": { "name": "list_tables" }, + "acl": { "tables": { "...": { "columns": ["..."] } } } +} +``` + +### For `describe_table` tool + +```json +{ + "tool": { "name": "describe_table" }, + "table": "customers", + "acl": { "tables": { "...": { "columns": ["..."] } } } +} +``` + +## ACL data sources + +ACL data can come from either: + +1. `OPA_ACL_DATA_FILE` (preferred when set), JSON structure at `secure_sql.acl.tables` +2. transformed legacy `ALLOWED_POLICY_FILE` + +Use `OPA_ACL_DATA_FILE` when you want native OPA-oriented config. + +## Rego patterns you can reuse + +Use these as templates when asking an agent to generate a policy. + +### 1) Keep strict read-only baseline (default behavior) + +```rego +package secure_sql.default_constraints + +default allow := false + +deny_reasons["multiple_statements"] if input.query.statement_count != 1 +deny_reasons["disallowed_operation"] if input.query.has_disallowed_operation +deny_reasons["not_read_query"] if not input.query.is_read_statement + +allow if count(deny_reasons) == 0 +``` + +### 2) Relax baseline to allow inserts only (policy example) + +```rego +package secure_sql.default_constraints + +default allow := false + +is_insert if input.query.statement_type == "insert" + +deny_reasons["multiple_statements"] if input.query.statement_count != 1 +deny_reasons["not_allowed_statement_type"] if { + not input.query.is_read_statement + not is_insert +} + +allow if count(deny_reasons) == 0 +``` + +### 3) Allow updates only to specific tables/columns (policy example) + +```rego +package secure_sql.default_constraints + +default allow := false + +allowed_update_columns := { + "customers": {"email"}, + "profiles": {"display_name", "timezone"}, +} + +is_update if input.query.statement_type == "update" + +deny_reasons["multiple_statements"] if input.query.statement_count != 1 + +deny_reasons["update_column_not_allowed"] if { + is_update + table := object.keys(input.query.referenced_columns)[_] + col := input.query.referenced_columns[table][_] + not allowed_update_columns[table][col] +} + +deny_reasons["not_allowed_statement_type"] if { + not input.query.is_read_statement + not is_update +} + +allow if count(deny_reasons) == 0 +``` + +### 4) Time-window or environment gating + +If you add `context` facts (for example `input.context.environment`), you can gate +by deployment environment or maintenance window. + +```rego +deny_reasons["writes_only_allowed_in_maintenance"] if { + input.query.statement_type == "update" + input.context.environment != "maintenance" +} +``` + +## Agent-friendly prompt template + +Use this prompt with an agent to generate policy: + +```text +Generate Rego policies for secure-sql-mcp. + +Constraints: +- Keep authz composition as default_constraints AND acl. +- Tool names are query, list_tables, describe_table. +- Use deny_reasons for explainability. +- Maintain deny-by-default. + +Desired behavior: +- [Describe exactly which statements are allowed] +- [Describe table/column restrictions] +- [Describe any time/env/principal restrictions] + +Output: +1) default_constraints.rego +2) acl.rego (if ACL behavior changes) +3) authz.rego (if composition changes) +4) short test matrix with allowed/blocked examples +``` + +## Testing checklist for policy changes + +When you relax policy behavior, test all of: + +- single-statement enforcement +- disallowed table/column access +- wildcard behavior (`SELECT *`) +- joins/subqueries/unions +- error hygiene (no sensitive leaks) +- OPA fail-closed behavior + +Run: + +```bash +python -m pytest -q +``` + diff --git a/docs/WRITE_MODE_DESIGN.md b/docs/WRITE_MODE_DESIGN.md new file mode 100644 index 0000000..76202bf --- /dev/null +++ b/docs/WRITE_MODE_DESIGN.md @@ -0,0 +1,163 @@ +# Write Mode Design (Controlled Mutations) + +This document describes how to extend `secure-sql-mcp` from read-only execution to +policy-governed writes while preserving security guarantees. + +It is intentionally conservative: mutation capability is powerful and should be +introduced behind explicit controls, with deny-by-default behavior at each layer. + +## Current state + +Today, policy can theoretically return `allow` for non-read operations, but runtime +execution is still read-only: + +- query execution uses `execute_read_query(...)` +- DB session is configured read-only for PostgreSQL/MySQL/SQLite +- query wrapper enforces select-style row capping logic + +As a result, policy-only changes are not sufficient for write support. + +## Goals + +- Allow tightly scoped mutation scenarios (for example `INSERT` only). +- Keep deny-by-default and fail-closed behavior. +- Preserve clean, actionable error responses for agents. +- Avoid broad privilege escalation in database credentials. + +## Non-goals + +- Full unrestricted SQL write access. +- Multi-statement transaction scripting from agents. +- Bypassing policy checks in application code. + +## Security invariants to preserve + +- Single statement per request unless explicitly designed otherwise. +- Explicit allowlist semantics (tables/columns/actions). +- No sensitive internal error leakage. +- OPA unavailable/timeout behavior remains fail-closed. +- Tool responses remain deterministic and auditable. + +## Proposed architecture changes + +### 1) Split execution paths by statement class + +Introduce separate DB execution methods: + +- `execute_read_query(sql)` (existing) +- `execute_write_query(sql)` (new) + +`execute_write_query` should: + +- run with strict timeout +- return affected row count and optional returning payload +- avoid row-cap wrapper intended for SELECT +- avoid enabling arbitrary transaction control from user SQL + +### 2) Expand policy facts for write authorization + +Current input facts are SELECT-centric. Add mutation-focused facts in +`QueryValidator._build_query_policy_input(...)`, for example: + +- `statement_type` normalized (`insert`, `update`, `delete`, etc.) +- `target_tables` +- `updated_columns` +- `insert_columns` +- `where_present` (for updates/deletes) +- `returning_present` + +These facts should be parser-derived, not regex-derived. + +### 3) Add explicit write mode config gates + +Use coarse-grained runtime toggles in config: + +- `WRITE_MODE_ENABLED=false` by default +- optional action toggles: + - `ALLOW_INSERT=false` + - `ALLOW_UPDATE=false` + - `ALLOW_DELETE=false` +- allow these to be configured with flags from the cli also. + +OPA remains the final decision engine; these toggles are safety brakes. + +### 4) Keep OPA as policy source of truth for permissions + +Model policy in Rego with explicit action constraints: + +- allow read paths as before +- allow writes only when: + - statement type is explicitly permitted + - table is allowed + - affected columns are allowed + - optional contextual constraints pass (tenant/env/user role) + +### 5) Server-level routing + +In `query(...)`, route execution by validated statement class: + +- if read -> `execute_read_query` +- if write and allowed -> `execute_write_query` +- else block with actionable message + +## Example policy patterns for controlled writes + +### Insert-only mode + +- allow `insert` on specific tables +- deny `update`, `delete`, DDL + +### Update-only specific columns + +- allow `update` on table `customers` +- allow only `email` and `phone` +- require `WHERE` clause (no full-table updates) + +### Delete with strict guard + +- allow `delete` only on maintenance tables +- require `WHERE` and additional context flag (e.g. maintenance window) + +## DB credential model + +Policy alone is not enough. Use least-privilege DB credentials: + +- read-only role for read-only deployments +- separate write-capable role for write mode +- grants limited to intended schemas/tables/actions + +Do not rely solely on app-layer checks for write containment. + +## Response contract proposal for writes + +For write operations, return structured JSON: + +```json +{ + "status": "ok", + "operation": "update", + "affected_rows": 3, + "returning": [] +} +``` + +For blocked writes, keep consistent actionable messages and avoid leaking internals. + +## Rollout plan + +1. Add parser-derived write facts and tests (no write execution yet). +2. Add OPA write policy rules in shadow mode (log only). +3. Add `execute_write_query` path behind `WRITE_MODE_ENABLED`. +4. Enable insert-only in non-production. +5. Expand to update/delete only with dedicated tests and DB grants. + +## Test matrix (minimum) + +- parser extraction for write facts by dialect +- blocked/allowed decisions for insert/update/delete +- column-restricted updates +- missing-WHERE safeguards +- fail-closed OPA behavior for writes +- sanitized DB error responses +- stdio MCP contract for write and blocked-write outcomes + diff --git a/policy/rego/acl.rego b/policy/rego/acl.rego index 07217a7..3069f8e 100644 --- a/policy/rego/acl.rego +++ b/policy/rego/acl.rego @@ -1,16 +1,39 @@ package secure_sql.acl -default allow := false +default allow = false acl_tables := object.get(object.get(input, "acl", {}), "tables", {}) is_query_tool if input.tool.name == "query" is_list_tables_tool if input.tool.name == "list_tables" is_describe_table_tool if input.tool.name == "describe_table" +is_write_query if { + is_query_tool + object.get(input.query, "is_write_statement", false) +} -table_allowed(table) if object.get(acl_tables, table, null) != null +normalized_table(table) := lower(table) +short_table_name(table) := name if { + parts := split(normalized_table(table), ".") + idx := count(parts) - 1 + name := parts[idx] +} -allowed_columns(table) := object.get(object.get(acl_tables, table, {}), "columns", []) +table_allowed(table) if object.get(acl_tables, normalized_table(table), null) != null +table_allowed(table) if object.get(acl_tables, short_table_name(table), null) != null + +allowed_columns(table) := cols if { + full := object.get(acl_tables, normalized_table(table), null) + full != null + cols := object.get(full, "columns", []) +} + +allowed_columns(table) := cols if { + full := object.get(acl_tables, normalized_table(table), null) + full == null + short := object.get(acl_tables, short_table_name(table), {}) + cols := object.get(short, "columns", []) +} column_allowed(table, col) if { allowed_columns(table)[_] == "*" @@ -22,12 +45,14 @@ column_allowed(table, col) if { deny_reasons["table_restricted"] if { is_query_tool + not is_write_query table := input.query.referenced_tables[_] not table_allowed(table) } deny_reasons["column_restricted"] if { is_query_tool + not is_write_query table := object.keys(input.query.referenced_columns)[_] col := input.query.referenced_columns[table][_] not column_allowed(table, col) @@ -35,10 +60,72 @@ deny_reasons["column_restricted"] if { deny_reasons["star_not_allowed"] if { is_query_tool + not is_write_query table := input.query.star_tables[_] not column_allowed(table, "*") } +deny_reasons["star_not_allowed"] if { + is_write_query + table := input.query.star_tables[_] + not column_allowed(table, "*") +} + +deny_reasons["table_restricted"] if { + is_write_query + target := object.get(input.query, "target_table", "") + target != "" + not table_allowed(target) +} + +write_columns[col] if { + is_write_query + col := object.get(input.query, "insert_columns", [])[_] +} + +write_columns[col] if { + is_write_query + col := object.get(input.query, "updated_columns", [])[_] +} + +deny_reasons["write_column_restricted"] if { + is_write_query + target := object.get(input.query, "target_table", "") + target != "" + col := write_columns[_] + not column_allowed(target, col) +} + +deny_reasons["write_source_table_restricted"] if { + is_write_query + src := object.get(input.query, "source_tables", [])[_] + not table_allowed(src) +} + +deny_reasons["write_column_restricted"] if { + is_write_query + table := object.keys(input.query.referenced_columns)[_] + col := input.query.referenced_columns[table][_] + not column_allowed(table, col) +} + +deny_reasons["write_column_restricted"] if { + is_write_query + target := object.get(input.query, "target_table", "") + target != "" + col := object.get(input.query, "returning_columns", [])[_] + col != "*" + not column_allowed(target, col) +} + +deny_reasons["star_not_allowed"] if { + is_write_query + target := object.get(input.query, "target_table", "") + target != "" + object.get(input.query, "returning_columns", [])[_] == "*" + not column_allowed(target, "*") +} + deny_reasons["table_restricted"] if { is_describe_table_tool not table_allowed(input.table) diff --git a/policy/rego/authz.rego b/policy/rego/authz.rego index 069e4a6..9182453 100644 --- a/policy/rego/authz.rego +++ b/policy/rego/authz.rego @@ -2,9 +2,13 @@ package secure_sql.authz import data.secure_sql.acl import data.secure_sql.default_constraints +import data.secure_sql.write_constraints + +default allow = false deny_reasons[reason] if default_constraints.deny_reasons[reason] deny_reasons[reason] if acl.deny_reasons[reason] +deny_reasons[reason] if write_constraints.deny_reasons[reason] allow if { default_constraints.allow diff --git a/policy/rego/default_constraints.rego b/policy/rego/default_constraints.rego index 2285601..3b7462e 100644 --- a/policy/rego/default_constraints.rego +++ b/policy/rego/default_constraints.rego @@ -1,15 +1,26 @@ package secure_sql.default_constraints -default allow := false +default allow = false is_query_tool if input.tool.name == "query" +is_write_query if { + is_query_tool + object.get(input.query, "is_write_statement", false) +} + +statement_type := lower(object.get(input.query, "statement_type", "")) + +write_mode_enabled := object.get(object.get(input, "config", {}), "write_mode_enabled", false) +allow_insert := object.get(object.get(input, "config", {}), "allow_insert", false) +allow_update := object.get(object.get(input, "config", {}), "allow_update", false) +allow_delete := object.get(object.get(input, "config", {}), "allow_delete", false) deny_reasons["multiple_statements"] if { is_query_tool input.query.statement_count != 1 } -deny_reasons["disallowed_operation"] if { +deny_reasons["ddl_or_privilege_operation"] if { is_query_tool input.query.has_disallowed_operation } @@ -17,6 +28,36 @@ deny_reasons["disallowed_operation"] if { deny_reasons["not_read_query"] if { is_query_tool not input.query.is_read_statement + not is_write_query +} + +deny_reasons["write_not_enabled"] if { + is_write_query + not write_mode_enabled +} + +deny_reasons["insert_not_allowed"] if { + is_write_query + statement_type == "insert" + not allow_insert +} + +deny_reasons["update_not_allowed"] if { + is_write_query + statement_type == "update" + not allow_update +} + +deny_reasons["delete_not_allowed"] if { + is_write_query + statement_type == "delete" + not allow_delete +} + +deny_reasons["insert_columns_missing"] if { + is_write_query + statement_type == "insert" + count(object.get(input.query, "insert_columns", [])) == 0 } deny_reasons["unqualified_multi_table_column"] if { diff --git a/policy/rego/write_constraints.rego b/policy/rego/write_constraints.rego new file mode 100644 index 0000000..b2d42c2 --- /dev/null +++ b/policy/rego/write_constraints.rego @@ -0,0 +1,58 @@ +package secure_sql.write_constraints + +is_query_tool if input.tool.name == "query" +is_write_query if { + is_query_tool + object.get(input.query, "is_write_statement", false) +} + +statement_type := lower(object.get(input.query, "statement_type", "")) +where_present := object.get(input.query, "where_present", false) +where_tautological := object.get(input.query, "where_tautological", false) +returning_present := object.get(input.query, "returning_present", false) + +require_where_for_update := object.get( + object.get(input, "config", {}), + "require_where_for_update", + true, +) +require_where_for_delete := object.get( + object.get(input, "config", {}), + "require_where_for_delete", + true, +) +allow_returning := object.get(object.get(input, "config", {}), "allow_returning", false) + +deny_reasons["missing_where_on_update"] if { + is_write_query + statement_type == "update" + require_where_for_update + not where_present +} + +deny_reasons["missing_where_on_delete"] if { + is_write_query + statement_type == "delete" + require_where_for_delete + not where_present +} + +deny_reasons["tautological_where_clause"] if { + is_write_query + statement_type == "update" + where_present + where_tautological +} + +deny_reasons["tautological_where_clause"] if { + is_write_query + statement_type == "delete" + where_present + where_tautological +} + +deny_reasons["returning_not_allowed"] if { + is_write_query + returning_present + not allow_returning +} diff --git a/pyproject.toml b/pyproject.toml index 8881041..b63a008 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,6 +47,10 @@ dev = [ [tool.pytest.ini_options] testpaths = ["tests"] pythonpath = ["src"] +markers = [ + "docker_integration: real MCP server integration tests in Docker with OPA", + "smoke: fast representative scenario subset", +] [tool.ruff] line-length = 100 diff --git a/scripts/run-docker-opa-smoke.sh b/scripts/run-docker-opa-smoke.sh new file mode 100755 index 0000000..3a4f5ac --- /dev/null +++ b/scripts/run-docker-opa-smoke.sh @@ -0,0 +1,8 @@ +#!/usr/bin/env bash +set -euo pipefail + +ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" +cd "$ROOT_DIR" + +echo "Running Docker + OPA smoke scenarios (all backends)..." +python -m pytest -q -m "docker_integration and smoke" tests/integration/docker/test_mcp_docker_opa_matrix.py diff --git a/src/secure_sql_mcp/config.py b/src/secure_sql_mcp/config.py index 52677b3..36c2970 100644 --- a/src/secure_sql_mcp/config.py +++ b/src/secure_sql_mcp/config.py @@ -26,6 +26,13 @@ class Settings(BaseSettings): opa_timeout_ms: int = Field(default=50, alias="OPA_TIMEOUT_MS", ge=1, le=5000) opa_fail_closed: bool = Field(default=True, alias="OPA_FAIL_CLOSED") opa_acl_data_file: str | None = Field(default=None, alias="OPA_ACL_DATA_FILE") + write_mode_enabled: bool = Field(default=False, alias="WRITE_MODE_ENABLED") + allow_insert: bool = Field(default=False, alias="ALLOW_INSERT") + allow_update: bool = Field(default=False, alias="ALLOW_UPDATE") + allow_delete: bool = Field(default=False, alias="ALLOW_DELETE") + require_where_for_update: bool = Field(default=True, alias="REQUIRE_WHERE_FOR_UPDATE") + require_where_for_delete: bool = Field(default=True, alias="REQUIRE_WHERE_FOR_DELETE") + allow_returning: bool = Field(default=False, alias="ALLOW_RETURNING") max_rows: int = Field(default=100, alias="MAX_ROWS", ge=1, le=10000) query_timeout: int = Field(default=30, alias="QUERY_TIMEOUT", ge=1, le=300) log_level: str = Field(default="INFO", alias="LOG_LEVEL") diff --git a/src/secure_sql_mcp/database.py b/src/secure_sql_mcp/database.py index 5b9125e..7eeea71 100644 --- a/src/secure_sql_mcp/database.py +++ b/src/secure_sql_mcp/database.py @@ -21,6 +21,15 @@ class QueryExecutionResult: truncated: bool +@dataclass(slots=True) +class WriteExecutionResult: + """Structured result for write statements.""" + + affected_rows: int + returning_columns: list[str] + returning_rows: list[dict[str, Any]] + + class AsyncDatabase: """Async SQLAlchemy wrapper with read-only execution safeguards.""" @@ -61,6 +70,36 @@ async def _run() -> QueryExecutionResult: return await asyncio.wait_for(_run(), timeout=self._settings.query_timeout) + async def execute_write_query(self, sql: str) -> WriteExecutionResult: + """Execute a single write statement with timeout and optional RETURNING payload.""" + if self._engine is None: + raise RuntimeError("Database engine is not initialized.") + + statement = text(sql.strip().rstrip(";")) + + async def _run() -> WriteExecutionResult: + if self._engine is None: + raise RuntimeError("Database engine is not initialized.") + async with self._engine.begin() as conn: + await self._prepare_write_session(conn) + result = await conn.execute(statement) + affected_rows = int(result.rowcount) if result.rowcount is not None else 0 + returning_rows: list[dict[str, Any]] = [] + returning_columns: list[str] = [] + if result.returns_rows: + fetched = result.fetchmany(self._settings.max_rows + 1) + returning_rows = [ + dict(row._mapping) for row in fetched[: self._settings.max_rows] + ] + returning_columns = list(result.keys()) + return WriteExecutionResult( + affected_rows=affected_rows, + returning_columns=returning_columns, + returning_rows=returning_rows, + ) + + return await asyncio.wait_for(_run(), timeout=self._settings.query_timeout) + async def list_tables(self) -> list[str]: """List all visible base tables from the connected database.""" if self._engine is None: @@ -111,6 +150,15 @@ async def _prepare_read_only_session(self, conn: AsyncConnection) -> None: elif self._settings.database_url.startswith("sqlite"): await conn.execute(text("PRAGMA query_only = ON")) + async def _prepare_write_session(self, conn: AsyncConnection) -> None: + """Apply DB-specific timeout settings for write operations.""" + if self._settings.database_url.startswith("postgresql"): + timeout_ms = int(self._settings.query_timeout) * 1000 + await conn.execute(text(f"SET LOCAL statement_timeout = {timeout_ms}")) + elif self._settings.database_url.startswith("mysql"): + timeout_ms = int(self._settings.query_timeout) * 1000 + await conn.execute(text(f"SET SESSION MAX_EXECUTION_TIME = {timeout_ms}")) + @staticmethod def _wrap_with_limit(sql: str, limit: int) -> str: query = sql.strip().rstrip(";") diff --git a/src/secure_sql_mcp/opa_policy.py b/src/secure_sql_mcp/opa_policy.py index 876e1d4..06392d4 100644 --- a/src/secure_sql_mcp/opa_policy.py +++ b/src/secure_sql_mcp/opa_policy.py @@ -39,8 +39,7 @@ def _evaluate_sync(self, payload: dict[str, Any]) -> PolicyDecision: allow=False, deny_reasons=["opa_unconfigured"], message=( - "Authorization service is not configured. " - "Please escalate to a human operator." + "Authorization service is not configured. Please escalate to a human operator." ), ) @@ -124,18 +123,69 @@ def _message_for_reasons(deny_reasons: list[str]) -> str: "Only a single SQL statement is allowed. " "Please remove additional statements and try again." ) + if "ddl_or_privilege_operation" in deny_reasons: + return ( + "DDL and privilege operations are not permitted. " + "Please escalate to a human operator." + ) if "disallowed_operation" in deny_reasons: return ( "This server is configured for read-only access. " "If you need to modify data, please escalate to a human operator." ) + if "write_not_enabled" in deny_reasons: + return ( + "Write operations are disabled by server configuration. " + "Please escalate to a human operator." + ) + if "insert_not_allowed" in deny_reasons: + return ( + "INSERT operations are not permitted by server configuration. " + "Please escalate to a human operator." + ) + if "update_not_allowed" in deny_reasons: + return ( + "UPDATE operations are not permitted by server configuration. " + "Please escalate to a human operator." + ) + if "delete_not_allowed" in deny_reasons: + return ( + "DELETE operations are not permitted by server configuration. " + "Please escalate to a human operator." + ) + if "insert_columns_missing" in deny_reasons: + return ( + "INSERT statements must include an explicit column list under strict mode. " + "Please specify target columns explicitly." + ) if "not_read_query" in deny_reasons: return "Only read-only SELECT queries are allowed." + if "missing_where_on_update" in deny_reasons: + return "UPDATE without a WHERE clause is not allowed." + if "missing_where_on_delete" in deny_reasons: + return "DELETE without a WHERE clause is not allowed." + if "tautological_where_clause" in deny_reasons: + return ( + "The WHERE clause appears tautological and may update/delete too broadly. " + "Please provide a restrictive predicate." + ) + if "returning_not_allowed" in deny_reasons: + return "RETURNING is not allowed for this write policy." if "table_restricted" in deny_reasons: return ( "Access to one or more tables is restricted by the server access policy. " "Please use list_tables/describe_table to view allowed targets." ) + if "write_source_table_restricted" in deny_reasons: + return ( + "INSERT ... SELECT references one or more source tables restricted by policy. " + "Please use list_tables/describe_table to view allowed targets." + ) + if "write_column_restricted" in deny_reasons: + return ( + "Write access to one or more target columns is restricted by policy. " + "Use describe_table to inspect allowed columns." + ) if "column_restricted" in deny_reasons: return ( "Access to one or more selected columns is restricted by policy. " diff --git a/src/secure_sql_mcp/query_validator.py b/src/secure_sql_mcp/query_validator.py index 79bece3..d3f9840 100644 --- a/src/secure_sql_mcp/query_validator.py +++ b/src/secure_sql_mcp/query_validator.py @@ -2,6 +2,7 @@ from __future__ import annotations +import logging from collections import defaultdict from dataclasses import dataclass from typing import Any @@ -12,6 +13,8 @@ from secure_sql_mcp.config import Settings from secure_sql_mcp.opa_policy import OpaPolicyEngine +LOGGER = logging.getLogger(__name__) + @dataclass(slots=True) class ValidationResult: @@ -21,16 +24,30 @@ class ValidationResult: normalized_sql: str | None = None referenced_tables: list[str] | None = None referenced_columns: dict[str, list[str]] | None = None + statement_type: str | None = None error: str | None = None +@dataclass(slots=True) +class WriteFacts: + """Parser-derived facts for write statements.""" + + statement_type: str + target_table: str + insert_columns: list[str] + updated_columns: list[str] + where_present: bool + where_tautological: bool + returning_present: bool + returning_columns: list[str] + has_select_source: bool + source_tables: list[str] + + class QueryValidator: """Validates SQL query safety constraints.""" - _DISALLOWED_EXPRESSIONS = ( - exp.Insert, - exp.Update, - exp.Delete, + _ALWAYS_DISALLOWED = ( exp.Drop, exp.Alter, exp.Create, @@ -40,6 +57,7 @@ class QueryValidator: exp.Merge, exp.Command, ) + _WRITE_EXPRESSIONS = (exp.Insert, exp.Update, exp.Delete) def __init__(self, settings: Settings, policy_engine: OpaPolicyEngine | None = None) -> None: self.settings = settings @@ -68,58 +86,192 @@ def validate_query(self, sql: str) -> ValidationResult: ) statement = statements[0] - statement_type = statement.key.upper() if statement.key else "UNKNOWN" + statement_type = self._statement_type(statement) + statement_type_upper = statement_type.upper() statement_count = len(statements) has_disallowed_operation = any( - stmt is not None and self._contains_disallowed_operation(stmt) for stmt in statements + stmt is not None and self._contains_always_disallowed_operation(stmt) + for stmt in statements ) is_read_statement = statement_count == 1 and self._is_read_statement(statement) + is_write_statement = statement_count == 1 and self._is_write_statement(statement) referenced_tables: list[str] = [] referenced_columns: dict[str, set[str]] = {} star_tables: set[str] = set() has_unqualified_multi_table_columns = False + write_facts: WriteFacts | None = None if statement_count == 1: referenced_tables = self.extract_referenced_tables(statement) - if self.policy_engine is None: - if has_disallowed_operation: + if is_write_statement: + write_facts = self._extract_write_facts(statement) + if write_facts is None: return ValidationResult( ok=False, error=( - "This server is configured for read-only access. " - f"The operation '{statement_type}' is not permitted. " - "If you need to modify data, please escalate to a human operator." + "Could not determine write operation details from SQL. " + "Please use an explicit INSERT/UPDATE/DELETE statement." ), ) - if not is_read_statement: + + if not self.settings.write_mode_enabled: + self._warn_if_policy_would_allow_blocked_write( + sql=query, + statement_count=statement_count, + statement_type=statement_type, + has_disallowed_operation=has_disallowed_operation, + is_read_statement=is_read_statement, + referenced_tables=referenced_tables, + referenced_columns=referenced_columns, + star_tables=star_tables, + has_unqualified_multi_table_columns=has_unqualified_multi_table_columns, + write_facts=write_facts, + blocked_reason="WRITE_MODE_ENABLED", + ) return ValidationResult( ok=False, + statement_type=statement_type, error=( - "Only read-only SELECT queries are allowed. " - f"Received '{statement_type}'." + "This server is configured for read-only access. " + f"The operation '{statement_type_upper}' is not permitted. " + "Please escalate to a human operator." ), ) - table_policy = self._resolve_table_policy(referenced_tables) - if isinstance(table_policy, str): - return ValidationResult(ok=False, error=table_policy) + if not self._is_write_action_enabled(statement_type): + gate_name = f"ALLOW_{statement_type_upper}" + self._warn_if_policy_would_allow_blocked_write( + sql=query, + statement_count=statement_count, + statement_type=statement_type, + has_disallowed_operation=has_disallowed_operation, + is_read_statement=is_read_statement, + referenced_tables=referenced_tables, + referenced_columns=referenced_columns, + star_tables=star_tables, + has_unqualified_multi_table_columns=has_unqualified_multi_table_columns, + write_facts=write_facts, + blocked_reason=gate_name, + ) + return ValidationResult( + ok=False, + statement_type=statement_type, + error=( + f"{statement_type_upper} operations are disabled " + "by server configuration. " + "Please escalate to a human operator." + ), + ) - columns_result = self.extract_referenced_columns(statement, referenced_tables) - if isinstance(columns_result, str): - return ValidationResult(ok=False, error=columns_result) + if write_facts.statement_type == "insert" and not write_facts.insert_columns: + return ValidationResult( + ok=False, + statement_type=statement_type, + error=( + "INSERT statements must include an explicit " + "column list under strict mode. " + "Please specify allowed target columns explicitly." + ), + ) + if ( + write_facts.statement_type in {"update", "delete"} + and self.settings.require_where_for_update + and write_facts.statement_type == "update" + and not write_facts.where_present + ): + return ValidationResult( + ok=False, + statement_type=statement_type, + error=f"{statement_type_upper} without a WHERE clause is not allowed.", + ) + if ( + write_facts.statement_type in {"update", "delete"} + and self.settings.require_where_for_delete + and write_facts.statement_type == "delete" + and not write_facts.where_present + ): + return ValidationResult( + ok=False, + statement_type=statement_type, + error=f"{statement_type_upper} without a WHERE clause is not allowed.", + ) + if ( + write_facts.statement_type in {"update", "delete"} + and write_facts.where_tautological + ): + return ValidationResult( + ok=False, + statement_type=statement_type, + error=( + "The WHERE clause appears tautological and may " + "update/delete too broadly. " + "Please provide a restrictive predicate." + ), + ) - referenced_columns, star_tables = columns_result - columns_error = self._validate_column_access( - table_policy, referenced_columns, star_tables - ) - if columns_error: - return ValidationResult(ok=False, error=columns_error) + if write_facts.returning_present and not self.settings.allow_returning: + return ValidationResult( + ok=False, + statement_type=statement_type, + error="RETURNING is not allowed for this write policy.", + ) - for table in referenced_tables: - access_error = self.table_access_error(table, table_policy=table_policy) - if access_error: - return ValidationResult(ok=False, error=access_error) + if self.policy_engine is None: + write_acl_error = self._validate_write_acl(write_facts) + if write_acl_error: + return ValidationResult( + ok=False, statement_type=statement_type, error=write_acl_error + ) + + columns_result = self.extract_referenced_columns(statement, referenced_tables) + if isinstance(columns_result, str): + return ValidationResult( + ok=False, statement_type=statement_type, error=columns_result + ) + referenced_columns, star_tables = columns_result + + table_policy = self._resolve_table_policy(referenced_tables) + if isinstance(table_policy, str): + return ValidationResult( + ok=False, statement_type=statement_type, error=table_policy + ) + columns_error = self._validate_column_access( + table_policy, referenced_columns, star_tables + ) + if columns_error: + return ValidationResult( + ok=False, statement_type=statement_type, error=columns_error + ) + else: + referenced_columns, star_tables, has_unqualified_multi_table_columns = ( + self._extract_referenced_columns_relaxed(statement, referenced_tables) + ) + elif is_read_statement: + if self.policy_engine is None: + table_policy = self._resolve_table_policy(referenced_tables) + if isinstance(table_policy, str): + return ValidationResult(ok=False, error=table_policy) + + columns_result = self.extract_referenced_columns(statement, referenced_tables) + if isinstance(columns_result, str): + return ValidationResult(ok=False, error=columns_result) + + referenced_columns, star_tables = columns_result + columns_error = self._validate_column_access( + table_policy, referenced_columns, star_tables + ) + if columns_error: + return ValidationResult(ok=False, error=columns_error) + + for table in referenced_tables: + access_error = self.table_access_error(table, table_policy=table_policy) + if access_error: + return ValidationResult(ok=False, error=access_error) + else: + referenced_columns, star_tables, has_unqualified_multi_table_columns = ( + self._extract_referenced_columns_relaxed(statement, referenced_tables) + ) else: referenced_columns, star_tables, has_unqualified_multi_table_columns = ( self._extract_referenced_columns_relaxed(statement, referenced_tables) @@ -129,6 +281,7 @@ def validate_query(self, sql: str) -> ValidationResult: if statement_count != 1: return ValidationResult( ok=False, + statement_type=statement_type, error=( "Only a single SQL statement is allowed. " "Please remove additional statements and try again." @@ -137,18 +290,21 @@ def validate_query(self, sql: str) -> ValidationResult: if has_disallowed_operation: return ValidationResult( ok=False, + statement_type=statement_type, error=( "This server is configured for read-only access. " - f"The operation '{statement_type}' is not permitted. " + f"The operation '{statement_type_upper}' is not permitted. " "If you need to modify data, please escalate to a human operator." ), ) - if not is_read_statement: + if not (is_read_statement or is_write_statement): return ValidationResult( ok=False, + statement_type=statement_type, error=( - "Only read-only SELECT queries are allowed. " - f"Received '{statement_type}'." + "Only read-only SELECT queries or explicitly enabled " + "write operations are allowed. " + f"Received '{statement_type_upper}'." ), ) else: @@ -156,18 +312,21 @@ def validate_query(self, sql: str) -> ValidationResult: self._build_query_policy_input( sql=query, statement_count=statement_count, - statement_type=statement_type.lower(), + statement_type=statement_type, has_disallowed_operation=has_disallowed_operation, is_read_statement=is_read_statement, referenced_tables=referenced_tables, referenced_columns=referenced_columns, star_tables=star_tables, has_unqualified_multi_table_columns=has_unqualified_multi_table_columns, + write_facts=write_facts, ) ) if not decision.allow: return ValidationResult( - ok=False, error=decision.message or "Query blocked by policy." + ok=False, + statement_type=statement_type, + error=decision.message or "Query blocked by policy.", ) return ValidationResult( @@ -179,6 +338,7 @@ def validate_query(self, sql: str) -> ValidationResult: referenced_columns={ table: sorted(columns) for table, columns in referenced_columns.items() }, + statement_type=statement_type, ) def table_access_error( @@ -307,8 +467,25 @@ def _dialect(self) -> str | None: return "sqlite" return None - def _contains_disallowed_operation(self, statement: exp.Expression) -> bool: - return any(statement.find(kind) is not None for kind in self._DISALLOWED_EXPRESSIONS) + def _contains_always_disallowed_operation(self, statement: exp.Expression) -> bool: + return any(statement.find(kind) is not None for kind in self._ALWAYS_DISALLOWED) + + def _is_write_statement(self, statement: exp.Expression) -> bool: + return isinstance(statement, self._WRITE_EXPRESSIONS) + + @staticmethod + def _statement_type(statement: exp.Expression) -> str: + key = (statement.key or "unknown").lower() + return key + + def _is_write_action_enabled(self, statement_type: str) -> bool: + if statement_type == "insert": + return self.settings.allow_insert + if statement_type == "update": + return self.settings.allow_update + if statement_type == "delete": + return self.settings.allow_delete + return False @staticmethod def _is_read_statement(statement: exp.Expression) -> bool: @@ -316,7 +493,7 @@ def _is_read_statement(statement: exp.Expression) -> bool: return True if isinstance(statement, (exp.Union, exp.Intersect, exp.Except)): return True - return statement.find(exp.Select) is not None + return False @staticmethod def _table_to_name(table: exp.Table) -> str: @@ -394,6 +571,252 @@ def _extract_referenced_columns_relaxed( return dict(columns_by_table), star_tables, has_unqualified_multi_table_columns + def _extract_write_facts(self, statement: exp.Expression) -> WriteFacts | None: + statement_type = self._statement_type(statement) + if statement_type not in {"insert", "update", "delete"}: + return None + + target_table = self._extract_target_table(statement) + if not target_table: + return None + + insert_columns: list[str] = [] + updated_columns: list[str] = [] + where_present = False + where_tautological = False + returning_present = ( + bool(statement.args.get("returning")) or statement.find(exp.Returning) is not None + ) + returning_columns = self._extract_returning_columns(statement) + source_tables: list[str] = [] + + if isinstance(statement, exp.Insert): + insert_columns = self._extract_insert_columns(statement) + source_expr = statement.args.get("expression") + if isinstance(source_expr, exp.Expression): + source_tables = self.extract_referenced_tables(source_expr) + elif isinstance(statement, exp.Update): + updated_columns = self._extract_update_columns(statement) + where_expr = statement.args.get("where") + where_present = where_expr is not None + if isinstance(where_expr, exp.Expression): + where_tautological = self._is_tautological_where(where_expr) + elif isinstance(statement, exp.Delete): + where_expr = statement.args.get("where") + where_present = where_expr is not None + if isinstance(where_expr, exp.Expression): + where_tautological = self._is_tautological_where(where_expr) + + if not source_tables: + all_tables = self.extract_referenced_tables(statement) + source_tables = sorted(table for table in all_tables if table != target_table) + + return WriteFacts( + statement_type=statement_type, + target_table=target_table, + insert_columns=sorted(set(insert_columns)), + updated_columns=sorted(set(updated_columns)), + where_present=where_present, + where_tautological=where_tautological, + returning_present=returning_present, + returning_columns=returning_columns, + has_select_source=bool(source_tables), + source_tables=sorted(set(source_tables)), + ) + + def _extract_target_table(self, statement: exp.Expression) -> str | None: + target_expr = statement.args.get("this") + if isinstance(target_expr, exp.Schema): + target_expr = target_expr.this + if isinstance(target_expr, exp.Table): + return self._table_to_name(target_expr) + if isinstance(target_expr, exp.Expression): + table = target_expr.find(exp.Table) + if isinstance(table, exp.Table): + return self._table_to_name(table) + return None + + def _extract_insert_columns(self, statement: exp.Insert) -> list[str]: + target_expr = statement.args.get("this") + if isinstance(target_expr, exp.Schema): + columns: list[str] = [] + for column in target_expr.expressions: + if isinstance(column, exp.Column) and column.name: + columns.append(column.name.lower()) + elif isinstance(column, exp.Identifier) and column.this: + columns.append(str(column.this).lower()) + return columns + return [] + + @staticmethod + def _extract_update_columns(statement: exp.Update) -> list[str]: + columns: set[str] = set() + for assignment in statement.expressions: + lhs = assignment.args.get("this") if isinstance(assignment, exp.Expression) else None + if isinstance(lhs, exp.Column) and lhs.name: + columns.add(lhs.name.lower()) + return sorted(columns) + + @staticmethod + def _extract_returning_columns(statement: exp.Expression) -> list[str]: + returning_expr = statement.args.get("returning") + if not isinstance(returning_expr, exp.Returning): + return [] + + columns: set[str] = set() + for expression in returning_expr.expressions: + if isinstance(expression, exp.Star): + columns.add("*") + continue + if isinstance(expression, exp.Column): + if isinstance(expression.this, exp.Star): + columns.add("*") + continue + if expression.name: + columns.add(expression.name.lower()) + continue + for nested_column in expression.find_all(exp.Column): + if isinstance(nested_column.this, exp.Star): + columns.add("*") + continue + if nested_column.name: + columns.add(nested_column.name.lower()) + return sorted(columns) + + def _is_tautological_where(self, where_expr: exp.Expression) -> bool: + expr = where_expr.this if isinstance(where_expr, exp.Where) else where_expr + if isinstance(expr, exp.Paren): + return self._is_tautological_where(expr.this) + if isinstance(expr, exp.Boolean): + return bool(expr.this) + if isinstance(expr, exp.Not): + child = expr.this + return isinstance(child, exp.Boolean) and not bool(child.this) + if isinstance(expr, exp.Literal): + if expr.is_string: + return expr.this.strip().lower() in {"true", "t", "yes", "on", "1"} + return str(expr.this).strip() in {"1"} + if isinstance(expr, exp.Or): + return self._is_tautological_where(expr.left) or self._is_tautological_where(expr.right) + if isinstance(expr, exp.And): + return self._is_tautological_where(expr.left) and self._is_tautological_where( + expr.right + ) + if isinstance(expr, exp.EQ): + left = expr.left + right = expr.right + if isinstance(left, exp.Literal) and isinstance(right, exp.Literal): + return str(left.this) == str(right.this) and left.is_string == right.is_string + if isinstance(left, exp.Column) and isinstance(right, exp.Column): + return ( + left.name.lower() == right.name.lower() + and (left.table or "").lower() == (right.table or "").lower() + ) + return False + + def _validate_write_acl(self, write_facts: WriteFacts) -> str | None: + target_policy = self.lookup_table_policy(write_facts.target_table) + if target_policy is None: + available = ", ".join(sorted(self.settings.effective_acl_policy)) + return ( + f"Access to table '{write_facts.target_table}' " + "is restricted by the server access policy. " + f"Allowed tables are: {available}. " + "Please use list_tables/describe_table or escalate to a human operator." + ) + + changed_columns = set(write_facts.insert_columns or write_facts.updated_columns) + if changed_columns and "*" not in target_policy: + disallowed = sorted(column for column in changed_columns if column not in target_policy) + if disallowed: + allowed_text = ", ".join(sorted(target_policy)) + return ( + f"Write access to column(s) {', '.join(disallowed)} " + f"on table '{write_facts.target_table}' " + "is restricted. " + f"Allowed columns: {allowed_text}. " + "Use describe_table to inspect policy or escalate to a human operator." + ) + + if write_facts.returning_present: + if "*" in write_facts.returning_columns and "*" not in target_policy: + allowed_text = ", ".join(sorted(target_policy)) + return ( + f"RETURNING * is not allowed for table '{write_facts.target_table}' " + "under strict policy. " + f"Allowed columns: {allowed_text}. " + "Please list explicit allowed RETURNING columns." + ) + if "*" not in target_policy: + disallowed_returning = sorted( + column + for column in write_facts.returning_columns + if column != "*" and column not in target_policy + ) + if disallowed_returning: + allowed_text = ", ".join(sorted(target_policy)) + return ( + f"RETURNING column(s) {', '.join(disallowed_returning)} on table " + f"'{write_facts.target_table}' are restricted. " + f"Allowed columns: {allowed_text}. " + "Use describe_table to inspect policy or escalate to a human operator." + ) + + return None + + def _warn_if_policy_would_allow_blocked_write( + self, + *, + sql: str, + statement_count: int, + statement_type: str, + has_disallowed_operation: bool, + is_read_statement: bool, + referenced_tables: list[str], + referenced_columns: dict[str, set[str]], + star_tables: set[str], + has_unqualified_multi_table_columns: bool, + write_facts: WriteFacts | None, + blocked_reason: str, + ) -> None: + if self.policy_engine is None or write_facts is None: + return + + shadow_payload = self._build_query_policy_input( + sql=sql, + statement_count=statement_count, + statement_type=statement_type, + has_disallowed_operation=has_disallowed_operation, + is_read_statement=is_read_statement, + referenced_tables=referenced_tables, + referenced_columns=referenced_columns, + star_tables=star_tables, + has_unqualified_multi_table_columns=has_unqualified_multi_table_columns, + write_facts=write_facts, + config_overrides={ + "write_mode_enabled": True, + "allow_insert": True, + "allow_update": True, + "allow_delete": True, + "allow_returning": True, + }, + ) + decision = self.policy_engine.evaluate_sync(shadow_payload) + if decision.allow: + operation = statement_type.upper() + gate_name = blocked_reason.upper() + enabled_flag = "true" if self.settings.write_mode_enabled else "false" + action_flag = "true" if self._is_write_action_enabled(statement_type) else "false" + LOGGER.warning( + "Write operation '%s' blocked by config gate " + "(%s, WRITE_MODE_ENABLED=%s, ALLOW_%s=%s).", + operation, + gate_name, + enabled_flag, + operation, + action_flag, + ) + def _build_query_policy_input( self, *, @@ -406,17 +829,55 @@ def _build_query_policy_input( referenced_columns: dict[str, set[str]], star_tables: set[str], has_unqualified_multi_table_columns: bool, + write_facts: WriteFacts | None, + config_overrides: dict[str, bool] | None = None, ) -> dict[str, Any]: acl_tables = { table: {"columns": sorted(columns)} for table, columns in sorted(self.settings.effective_acl_policy.items()) } + write_mode_enabled = ( + config_overrides["write_mode_enabled"] + if config_overrides and "write_mode_enabled" in config_overrides + else self.settings.write_mode_enabled + ) + allow_insert = ( + config_overrides["allow_insert"] + if config_overrides and "allow_insert" in config_overrides + else self.settings.allow_insert + ) + allow_update = ( + config_overrides["allow_update"] + if config_overrides and "allow_update" in config_overrides + else self.settings.allow_update + ) + allow_delete = ( + config_overrides["allow_delete"] + if config_overrides and "allow_delete" in config_overrides + else self.settings.allow_delete + ) + require_where_for_update = ( + config_overrides["require_where_for_update"] + if config_overrides and "require_where_for_update" in config_overrides + else self.settings.require_where_for_update + ) + require_where_for_delete = ( + config_overrides["require_where_for_delete"] + if config_overrides and "require_where_for_delete" in config_overrides + else self.settings.require_where_for_delete + ) + allow_returning = ( + config_overrides["allow_returning"] + if config_overrides and "allow_returning" in config_overrides + else self.settings.allow_returning + ) return { "tool": {"name": "query"}, "query": { "raw_sql": sql, "statement_count": statement_count, "statement_type": statement_type, + "is_write_statement": write_facts is not None, "has_disallowed_operation": has_disallowed_operation, "is_read_statement": is_read_statement, "referenced_tables": referenced_tables, @@ -425,6 +886,24 @@ def _build_query_policy_input( }, "star_tables": sorted(star_tables), "has_unqualified_multi_table_columns": has_unqualified_multi_table_columns, + "target_table": write_facts.target_table if write_facts else "", + "insert_columns": write_facts.insert_columns if write_facts else [], + "updated_columns": write_facts.updated_columns if write_facts else [], + "where_present": write_facts.where_present if write_facts else False, + "where_tautological": write_facts.where_tautological if write_facts else False, + "returning_present": write_facts.returning_present if write_facts else False, + "returning_columns": write_facts.returning_columns if write_facts else [], + "has_select_source": write_facts.has_select_source if write_facts else False, + "source_tables": write_facts.source_tables if write_facts else [], + }, + "config": { + "write_mode_enabled": write_mode_enabled, + "allow_insert": allow_insert, + "allow_update": allow_update, + "allow_delete": allow_delete, + "require_where_for_update": require_where_for_update, + "require_where_for_delete": require_where_for_delete, + "allow_returning": allow_returning, }, "acl": {"tables": acl_tables}, } diff --git a/src/secure_sql_mcp/server.py b/src/secure_sql_mcp/server.py index 91bb47d..35630f3 100644 --- a/src/secure_sql_mcp/server.py +++ b/src/secure_sql_mcp/server.py @@ -2,8 +2,10 @@ from __future__ import annotations +import argparse import json import logging +import os from collections.abc import AsyncIterator from contextlib import asynccontextmanager from dataclasses import dataclass @@ -61,13 +63,27 @@ def _state() -> AppState: @mcp.tool() async def query(sql: str) -> str: - """Run a read-only SQL query and return structured results.""" + """Run a SQL query (read-only by default; writes only when explicitly enabled).""" app = _state() validation = app.validator.validate_query(sql) if not validation.ok: return validation.error or "Query blocked by policy." + statement_type = (validation.statement_type or "").lower() try: + if statement_type in {"insert", "update", "delete"}: + write_result = await app.db.execute_write_query(validation.normalized_sql or sql) + payload = { + "status": "ok", + "operation": statement_type, + "affected_rows": write_result.affected_rows, + "returning_columns": write_result.returning_columns, + "returning": write_result.returning_rows, + "referenced_tables": validation.referenced_tables or [], + "referenced_columns": validation.referenced_columns or {}, + } + return json.dumps(payload, default=str, indent=2) + result = await app.db.execute_read_query(validation.normalized_sql or sql) except TimeoutError: return ( @@ -193,6 +209,38 @@ async def describe_table(table: str) -> str: def main() -> None: """Run the MCP server with stdio transport.""" + parser = argparse.ArgumentParser(add_help=True) + parser.add_argument( + "--write-mode", + action="store_true", + help="Enable write-mode execution path (disabled by default).", + ) + parser.add_argument( + "--allow-insert", + action="store_true", + help="Allow INSERT statements when write mode is enabled.", + ) + parser.add_argument( + "--allow-update", + action="store_true", + help="Allow UPDATE statements when write mode is enabled.", + ) + parser.add_argument( + "--allow-delete", + action="store_true", + help="Allow DELETE statements when write mode is enabled.", + ) + args, _ = parser.parse_known_args() + + if args.write_mode: + os.environ["WRITE_MODE_ENABLED"] = "true" + if args.allow_insert: + os.environ["ALLOW_INSERT"] = "true" + if args.allow_update: + os.environ["ALLOW_UPDATE"] = "true" + if args.allow_delete: + os.environ["ALLOW_DELETE"] = "true" + mcp.run(transport="stdio") diff --git a/tests/integration/__init__.py b/tests/integration/__init__.py new file mode 100644 index 0000000..70da13d --- /dev/null +++ b/tests/integration/__init__.py @@ -0,0 +1 @@ +"""Integration test package.""" diff --git a/tests/integration/docker/__init__.py b/tests/integration/docker/__init__.py new file mode 100644 index 0000000..ccfa55f --- /dev/null +++ b/tests/integration/docker/__init__.py @@ -0,0 +1 @@ +"""Docker-backed integration tests.""" diff --git a/tests/integration/docker/acl/restricted_acl.json b/tests/integration/docker/acl/restricted_acl.json new file mode 100644 index 0000000..8d53acf --- /dev/null +++ b/tests/integration/docker/acl/restricted_acl.json @@ -0,0 +1,14 @@ +{ + "secure_sql": { + "acl": { + "tables": { + "customers": { + "columns": ["id", "email"] + }, + "orders": { + "columns": ["id", "total"] + } + } + } + } +} diff --git a/tests/integration/docker/conftest.py b/tests/integration/docker/conftest.py new file mode 100644 index 0000000..5a1050b --- /dev/null +++ b/tests/integration/docker/conftest.py @@ -0,0 +1,212 @@ +from __future__ import annotations + +import os +import sqlite3 +import subprocess +import uuid +from collections.abc import Callable, Iterator +from dataclasses import dataclass +from pathlib import Path + +import pytest +from mcp.client.stdio import StdioServerParameters + + +@dataclass(frozen=True, slots=True) +class BackendConfig: + name: str + database_url: str + needs_deps: bool + + +ROOT = Path(__file__).resolve().parents[3] +COMPOSE_FILE = ROOT / "docker-compose.test.yml" +POLICY_DIR = ROOT / "tests" / "integration" / "docker" / "policies" +ACL_DIR = ROOT / "tests" / "integration" / "docker" / "acl" + + +def _run(command: list[str], *, check: bool = True) -> subprocess.CompletedProcess[str]: + return subprocess.run( # noqa: S603 + command, + cwd=ROOT, + check=check, + text=True, + capture_output=True, + ) + + +@pytest.fixture(scope="session") +def docker_available() -> None: + try: + _run(["docker", "version"]) + _run(["docker", "compose", "version"]) + except (OSError, subprocess.CalledProcessError): + pytest.skip("Docker or docker compose is unavailable on this host.") + + +@pytest.fixture(scope="session") +def compose_project_name() -> str: + return f"secure_sql_it_{uuid.uuid4().hex[:10]}" + + +@pytest.fixture(scope="session") +def docker_stack(docker_available: None, compose_project_name: str) -> Iterator[None]: + compose = ["docker", "compose", "-p", compose_project_name, "-f", str(COMPOSE_FILE)] + _run([*compose, "build", "secure-sql-mcp"]) + _run([*compose, "up", "-d", "postgres", "mysql"]) + try: + yield + finally: + _run([*compose, "down", "-v", "--remove-orphans"], check=False) + + +@pytest.fixture(params=["sqlite", "postgresql", "mysql"]) +def backend(request: pytest.FixtureRequest) -> BackendConfig: + backend_name = str(request.param) + if backend_name == "sqlite": + return BackendConfig( + name="sqlite", + database_url="sqlite+aiosqlite:///run/sqlite/test.db", + needs_deps=False, + ) + if backend_name == "postgresql": + return BackendConfig( + name="postgresql", + database_url="postgresql+asyncpg://secure:secure@postgres:5432/secure_sql_test", + needs_deps=True, + ) + return BackendConfig( + name="mysql", + database_url="mysql+aiomysql://secure:secure@mysql:3306/secure_sql_test", + needs_deps=True, + ) + + +@pytest.fixture +def policy_path() -> Callable[[str], Path]: + def _resolve(policy_name: str) -> Path: + path = POLICY_DIR / f"{policy_name}.txt" + if not path.exists(): + raise FileNotFoundError(f"Policy file not found: {path}") + return path + + return _resolve + + +@pytest.fixture +def acl_path() -> Callable[[str], Path]: + def _resolve(acl_name: str) -> Path: + path = ACL_DIR / f"{acl_name}.json" + if not path.exists(): + raise FileNotFoundError(f"ACL file not found: {path}") + return path + + return _resolve + + +@pytest.fixture +def sqlite_db_dir(tmp_path: Path) -> Path: + db_dir = tmp_path / "sqlite" + db_dir.mkdir(parents=True, exist_ok=True) + db_path = db_dir / "test.db" + conn = sqlite3.connect(db_path) + try: + conn.executescript( + """ + CREATE TABLE customers ( + id INTEGER PRIMARY KEY, + email TEXT NOT NULL, + ssn TEXT + ); + CREATE TABLE orders ( + id INTEGER PRIMARY KEY, + total NUMERIC + ); + CREATE TABLE secrets ( + id INTEGER PRIMARY KEY, + token TEXT + ); + INSERT INTO customers (id, email, ssn) VALUES (1, 'a@example.com', '111-22-3333'); + INSERT INTO orders (id, total) VALUES (10, 19.99); + INSERT INTO secrets (id, token) VALUES (99, 'top-secret-token'); + """ + ) + conn.commit() + finally: + conn.close() + return db_dir + + +@pytest.fixture +def docker_server_params_factory( + compose_project_name: str, + sqlite_db_dir: Path, +) -> Callable[..., StdioServerParameters]: + def _factory( + *, + backend: BackendConfig, + policy_file: Path, + write_mode_enabled: bool = False, + allow_insert: bool = False, + allow_update: bool = False, + allow_delete: bool = False, + require_where_for_update: bool = True, + require_where_for_delete: bool = True, + allow_returning: bool = False, + opa_fail_closed: bool = True, + opa_url: str = "http://127.0.0.1:8181", + opa_decision_path: str = "/v1/data/secure_sql/authz/decision", + opa_acl_data_file: Path | None = None, + ) -> StdioServerParameters: + args = ["run", "--rm", "-i"] + if backend.needs_deps: + args.extend(["--network", f"{compose_project_name}_default"]) + + args.extend( + [ + "-e", + f"DATABASE_URL={backend.database_url}", + "-e", + "ALLOWED_POLICY_FILE=/run/policy/allowed_policy.txt", + "-e", + f"OPA_URL={opa_url}", + "-e", + f"OPA_DECISION_PATH={opa_decision_path}", + "-e", + f"OPA_FAIL_CLOSED={'true' if opa_fail_closed else 'false'}", + "-e", + f"WRITE_MODE_ENABLED={'true' if write_mode_enabled else 'false'}", + "-e", + f"ALLOW_INSERT={'true' if allow_insert else 'false'}", + "-e", + f"ALLOW_UPDATE={'true' if allow_update else 'false'}", + "-e", + f"ALLOW_DELETE={'true' if allow_delete else 'false'}", + "-e", + (f"REQUIRE_WHERE_FOR_UPDATE={'true' if require_where_for_update else 'false'}"), + "-e", + (f"REQUIRE_WHERE_FOR_DELETE={'true' if require_where_for_delete else 'false'}"), + "-e", + f"ALLOW_RETURNING={'true' if allow_returning else 'false'}", + "-v", + f"{policy_file}:/run/policy/allowed_policy.txt:ro", + ] + ) + + if backend.name == "sqlite": + args.extend(["-v", f"{sqlite_db_dir}:/run/sqlite:rw"]) + + if opa_acl_data_file is not None: + args.extend( + [ + "-e", + "OPA_ACL_DATA_FILE=/run/policy/acl.json", + "-v", + f"{opa_acl_data_file}:/run/policy/acl.json:ro", + ] + ) + + args.append("secure-sql-mcp:test") + return StdioServerParameters(command="docker", args=args, env=os.environ.copy()) + + return _factory diff --git a/tests/integration/docker/db-init/mysql.sql b/tests/integration/docker/db-init/mysql.sql new file mode 100644 index 0000000..358499c --- /dev/null +++ b/tests/integration/docker/db-init/mysql.sql @@ -0,0 +1,24 @@ +CREATE TABLE IF NOT EXISTS customers ( + id INT PRIMARY KEY, + email VARCHAR(255) NOT NULL, + ssn VARCHAR(64) +); + +CREATE TABLE IF NOT EXISTS orders ( + id INT PRIMARY KEY, + total DECIMAL(10, 2) +); + +CREATE TABLE IF NOT EXISTS secrets ( + id INT PRIMARY KEY, + token VARCHAR(255) +); + +INSERT IGNORE INTO customers (id, email, ssn) +VALUES (1, 'a@example.com', '111-22-3333'); + +INSERT IGNORE INTO orders (id, total) +VALUES (10, 19.99); + +INSERT IGNORE INTO secrets (id, token) +VALUES (99, 'top-secret-token'); diff --git a/tests/integration/docker/db-init/postgres.sql b/tests/integration/docker/db-init/postgres.sql new file mode 100644 index 0000000..f096b1f --- /dev/null +++ b/tests/integration/docker/db-init/postgres.sql @@ -0,0 +1,27 @@ +CREATE TABLE IF NOT EXISTS customers ( + id INTEGER PRIMARY KEY, + email TEXT NOT NULL, + ssn TEXT +); + +CREATE TABLE IF NOT EXISTS orders ( + id INTEGER PRIMARY KEY, + total NUMERIC +); + +CREATE TABLE IF NOT EXISTS secrets ( + id INTEGER PRIMARY KEY, + token TEXT +); + +INSERT INTO customers (id, email, ssn) +VALUES (1, 'a@example.com', '111-22-3333') +ON CONFLICT (id) DO NOTHING; + +INSERT INTO orders (id, total) +VALUES (10, 19.99) +ON CONFLICT (id) DO NOTHING; + +INSERT INTO secrets (id, token) +VALUES (99, 'top-secret-token') +ON CONFLICT (id) DO NOTHING; diff --git a/tests/integration/docker/policies/read_only_strict.txt b/tests/integration/docker/policies/read_only_strict.txt new file mode 100644 index 0000000..4cbc445 --- /dev/null +++ b/tests/integration/docker/policies/read_only_strict.txt @@ -0,0 +1,2 @@ +customers:id,email +orders:id,total diff --git a/tests/integration/docker/policies/wildcard_tables.txt b/tests/integration/docker/policies/wildcard_tables.txt new file mode 100644 index 0000000..97b8414 --- /dev/null +++ b/tests/integration/docker/policies/wildcard_tables.txt @@ -0,0 +1,2 @@ +customers:* +orders:* diff --git a/tests/integration/docker/policies/write_delete_restricted.txt b/tests/integration/docker/policies/write_delete_restricted.txt new file mode 100644 index 0000000..4cbc445 --- /dev/null +++ b/tests/integration/docker/policies/write_delete_restricted.txt @@ -0,0 +1,2 @@ +customers:id,email +orders:id,total diff --git a/tests/integration/docker/policies/write_insert_only.txt b/tests/integration/docker/policies/write_insert_only.txt new file mode 100644 index 0000000..4cbc445 --- /dev/null +++ b/tests/integration/docker/policies/write_insert_only.txt @@ -0,0 +1,2 @@ +customers:id,email +orders:id,total diff --git a/tests/integration/docker/policies/write_update_restricted.txt b/tests/integration/docker/policies/write_update_restricted.txt new file mode 100644 index 0000000..4cbc445 --- /dev/null +++ b/tests/integration/docker/policies/write_update_restricted.txt @@ -0,0 +1,2 @@ +customers:id,email +orders:id,total diff --git a/tests/integration/docker/test_mcp_docker_opa_matrix.py b/tests/integration/docker/test_mcp_docker_opa_matrix.py new file mode 100644 index 0000000..349d52e --- /dev/null +++ b/tests/integration/docker/test_mcp_docker_opa_matrix.py @@ -0,0 +1,281 @@ +from __future__ import annotations + +import asyncio +import json +import time + +import pytest +from mcp.client.session import ClientSession +from mcp.client.stdio import StdioServerParameters, stdio_client + +from .conftest import BackendConfig + + +def _first_text(call_result: object) -> str: + for item in getattr(call_result, "content", []): + text = getattr(item, "text", None) + if text is not None: + return text + return "" + + +async def _call_tool( + server_params: StdioServerParameters, tool: str, payload: dict[str, object] +) -> str: + async with stdio_client(server_params) as (read, write): + async with ClientSession(read, write) as session: + await session.initialize() + result = await session.call_tool(tool, payload) + return _first_text(result) + + +def _call_tool_with_retries( + server_params: StdioServerParameters, + tool: str, + payload: dict[str, object], + *, + retries: int = 4, + wait_seconds: float = 2.0, +) -> str: + last_response = "" + for attempt in range(retries): + last_response = asyncio.run(_call_tool(server_params, tool, payload)) + if "database error" not in last_response: + return last_response + if attempt < retries - 1: + time.sleep(wait_seconds) + return last_response + + +pytestmark = pytest.mark.docker_integration + + +def test_read_baseline_policy_enforced( + docker_stack: None, + backend: BackendConfig, + policy_path, + docker_server_params_factory, +) -> None: + params = docker_server_params_factory( + backend=backend, + policy_file=policy_path("read_only_strict"), + ) + + if backend.name != "mysql": + allowed = _call_tool_with_retries( + params, "query", {"sql": "SELECT id, email FROM customers"} + ) + payload = json.loads(allowed) + assert payload["status"] == "ok" + else: + list_tables = asyncio.run(_call_tool(params, "list_tables", {})) + list_payload = json.loads(list_tables) + assert list_payload["status"] == "ok" + + blocked_col = asyncio.run(_call_tool(params, "query", {"sql": "SELECT ssn FROM customers"})) + assert "restricted" in blocked_col + + blocked_table = asyncio.run(_call_tool(params, "query", {"sql": "SELECT id FROM secrets"})) + assert "restricted" in blocked_table + + multi = asyncio.run( + _call_tool(params, "query", {"sql": "SELECT id FROM customers; DROP TABLE customers"}) + ) + assert "Only a single SQL statement is allowed" in multi + + +@pytest.mark.smoke +def test_write_disabled_blocks_insert_even_with_policy_allow( + docker_stack: None, + backend: BackendConfig, + policy_path, + docker_server_params_factory, +) -> None: + params = docker_server_params_factory( + backend=backend, + policy_file=policy_path("write_insert_only"), + write_mode_enabled=False, + allow_insert=True, + ) + blocked = asyncio.run( + _call_tool( + params, + "query", + {"sql": "INSERT INTO customers (id, email) VALUES (2, 'b@example.com')"}, + ) + ) + assert "read-only access" in blocked + assert "INSERT" in blocked + + +@pytest.mark.smoke +def test_insert_allowed_with_write_mode_and_gate( + docker_stack: None, + backend: BackendConfig, + policy_path, + docker_server_params_factory, +) -> None: + if backend.name != "postgresql": + pytest.skip("Write success-path assertions are validated on PostgreSQL in this matrix.") + + params = docker_server_params_factory( + backend=backend, + policy_file=policy_path("wildcard_tables"), + write_mode_enabled=True, + allow_insert=True, + ) + allowed = _call_tool_with_retries( + params, + "query", + {"sql": "INSERT INTO customers (id, email) VALUES (2, 'b@example.com')"}, + ) + payload = json.loads(allowed) + assert payload["status"] == "ok" + assert payload["operation"] == "insert" + assert payload["affected_rows"] == 1 + + +def test_insert_select_source_table_and_star_protections( + docker_stack: None, + backend: BackendConfig, + policy_path, + docker_server_params_factory, +) -> None: + params = docker_server_params_factory( + backend=backend, + policy_file=policy_path("write_insert_only"), + write_mode_enabled=True, + allow_insert=True, + ) + + disallowed_source = asyncio.run( + _call_tool( + params, + "query", + {"sql": "INSERT INTO orders (id, total) SELECT s.id, s.id FROM secrets AS s"}, + ) + ) + assert "restricted" in disallowed_source + + star_source = asyncio.run( + _call_tool( + params, "query", {"sql": "INSERT INTO orders (id, total) SELECT * FROM customers"} + ) + ) + assert "restricted" in star_source + + +def test_update_delete_where_guards_and_tautology( + docker_stack: None, + backend: BackendConfig, + policy_path, + docker_server_params_factory, +) -> None: + if backend.name != "postgresql": + pytest.skip("Write success-path assertions are validated on PostgreSQL in this matrix.") + + params = docker_server_params_factory( + backend=backend, + policy_file=policy_path("wildcard_tables"), + write_mode_enabled=True, + allow_update=True, + allow_delete=True, + ) + + missing_where_update = asyncio.run( + _call_tool(params, "query", {"sql": "UPDATE customers SET email = 'x@example.com'"}) + ) + assert "without a WHERE clause is not allowed" in missing_where_update + + tautological_delete = asyncio.run( + _call_tool(params, "query", {"sql": "DELETE FROM orders WHERE 1 = 1"}) + ) + assert "WHERE clause appears tautological" in tautological_delete + + valid_update = _call_tool_with_retries( + params, "query", {"sql": "UPDATE customers SET email = 'x@example.com' WHERE id = 1"} + ) + payload = json.loads(valid_update) + assert payload["status"] == "ok" + assert payload["operation"] == "update" + + +def test_returning_controls_and_column_acl( + docker_stack: None, + backend: BackendConfig, + policy_path, + docker_server_params_factory, +) -> None: + blocked_params = docker_server_params_factory( + backend=backend, + policy_file=policy_path("wildcard_tables"), + write_mode_enabled=True, + allow_update=True, + allow_returning=False, + ) + returning_blocked = asyncio.run( + _call_tool( + blocked_params, + "query", + {"sql": "UPDATE customers SET email = 'x@example.com' WHERE id = 1 RETURNING email"}, + ) + ) + assert "RETURNING is not allowed" in returning_blocked + + if backend.name != "postgresql": + return + + allowed_params = docker_server_params_factory( + backend=backend, + policy_file=policy_path("write_update_restricted"), + write_mode_enabled=True, + allow_update=True, + allow_returning=True, + ) + restricted_column = asyncio.run( + _call_tool( + allowed_params, + "query", + {"sql": "UPDATE customers SET email = 'x@example.com' WHERE id = 1 RETURNING ssn"}, + ) + ) + assert "restricted" in restricted_column + + +def test_opa_fail_closed_when_unavailable( + docker_stack: None, + backend: BackendConfig, + policy_path, + docker_server_params_factory, +) -> None: + params = docker_server_params_factory( + backend=backend, + policy_file=policy_path("read_only_strict"), + opa_decision_path="/v1/data/secure_sql/authz/missing", + opa_fail_closed=True, + ) + query_msg = asyncio.run(_call_tool(params, "query", {"sql": "SELECT id FROM customers"})) + assert "Authorization decision is unavailable" in query_msg + + list_msg = asyncio.run(_call_tool(params, "list_tables", {})) + assert "Authorization decision is unavailable" in list_msg + + describe_msg = asyncio.run(_call_tool(params, "describe_table", {"table": "customers"})) + assert "Authorization decision is unavailable" in describe_msg + + +def test_opa_acl_data_file_profile_works( + docker_stack: None, + backend: BackendConfig, + policy_path, + acl_path, + docker_server_params_factory, +) -> None: + params = docker_server_params_factory( + backend=backend, + policy_file=policy_path("wildcard_tables"), + opa_acl_data_file=acl_path("restricted_acl"), + ) + + blocked = asyncio.run(_call_tool(params, "query", {"sql": "SELECT ssn FROM customers"})) + assert "restricted" in blocked diff --git a/tests/test_config.py b/tests/test_config.py index d66b8f6..f94d50d 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -219,3 +219,46 @@ def test_invalid_opa_acl_json_raises(tmp_path: Path) -> None: "OPA_ACL_DATA_FILE": str(opa_acl_path), } ) + + +def test_write_mode_flags_default_to_false(tmp_path: Path) -> None: + policy_path = tmp_path / "policy.txt" + write_policy(policy_path, "customers:id\n") + settings = Settings.model_validate( + { + "DATABASE_URL": "sqlite+aiosqlite:///./tmp.db", + "ALLOWED_POLICY_FILE": str(policy_path), + } + ) + assert settings.write_mode_enabled is False + assert settings.allow_insert is False + assert settings.allow_update is False + assert settings.allow_delete is False + assert settings.require_where_for_update is True + assert settings.require_where_for_delete is True + assert settings.allow_returning is False + + +def test_write_mode_flags_can_be_enabled(tmp_path: Path) -> None: + policy_path = tmp_path / "policy.txt" + write_policy(policy_path, "customers:id\n") + settings = Settings.model_validate( + { + "DATABASE_URL": "sqlite+aiosqlite:///./tmp.db", + "ALLOWED_POLICY_FILE": str(policy_path), + "WRITE_MODE_ENABLED": True, + "ALLOW_INSERT": True, + "ALLOW_UPDATE": True, + "ALLOW_DELETE": True, + "REQUIRE_WHERE_FOR_UPDATE": False, + "REQUIRE_WHERE_FOR_DELETE": False, + "ALLOW_RETURNING": True, + } + ) + assert settings.write_mode_enabled is True + assert settings.allow_insert is True + assert settings.allow_update is True + assert settings.allow_delete is True + assert settings.require_where_for_update is False + assert settings.require_where_for_delete is False + assert settings.allow_returning is True diff --git a/tests/test_mcp_interface.py b/tests/test_mcp_interface.py index 7161d47..b7a8a3c 100644 --- a/tests/test_mcp_interface.py +++ b/tests/test_mcp_interface.py @@ -2,7 +2,9 @@ import asyncio import json +import os import sqlite3 +import sys from pathlib import Path from unittest.mock import AsyncMock @@ -105,6 +107,89 @@ def limited_app_state(tmp_path: Path): mcp_server.STATE = None +@pytest.fixture() +def write_enabled_app_state(tmp_path: Path): + db_path = tmp_path / "write_enabled.db" + init_sqlite_db(db_path) + + policy_path = tmp_path / "allowed_policy.txt" + write_policy( + policy_path, + """ + customers:id,email + orders:* + """, + ) + + settings = Settings.model_validate( + { + "DATABASE_URL": f"sqlite+aiosqlite:///{db_path}", + "ALLOWED_POLICY_FILE": str(policy_path), + "WRITE_MODE_ENABLED": True, + "ALLOW_INSERT": True, + "ALLOW_UPDATE": True, + "ALLOW_DELETE": True, + "MAX_ROWS": 100, + "QUERY_TIMEOUT": 30, + "LOG_LEVEL": "INFO", + } + ) + db = AsyncDatabase(settings) + asyncio.run(db.connect()) + state = AppState( + settings=settings, db=db, validator=QueryValidator(settings), policy_engine=None + ) + mcp_server.STATE = state + + try: + yield state + finally: + asyncio.run(db.dispose()) + mcp_server.STATE = None + + +@pytest.fixture() +def write_enabled_returning_app_state(tmp_path: Path): + db_path = tmp_path / "write_enabled_returning.db" + init_sqlite_db(db_path) + + policy_path = tmp_path / "allowed_policy.txt" + write_policy( + policy_path, + """ + customers:id,email + orders:* + """, + ) + + settings = Settings.model_validate( + { + "DATABASE_URL": f"sqlite+aiosqlite:///{db_path}", + "ALLOWED_POLICY_FILE": str(policy_path), + "WRITE_MODE_ENABLED": True, + "ALLOW_INSERT": True, + "ALLOW_UPDATE": True, + "ALLOW_DELETE": True, + "ALLOW_RETURNING": True, + "MAX_ROWS": 100, + "QUERY_TIMEOUT": 30, + "LOG_LEVEL": "INFO", + } + ) + db = AsyncDatabase(settings) + asyncio.run(db.connect()) + state = AppState( + settings=settings, db=db, validator=QueryValidator(settings), policy_engine=None + ) + mcp_server.STATE = state + + try: + yield state + finally: + asyncio.run(db.dispose()) + mcp_server.STATE = None + + def test_policy_parsing_valid(tmp_path: Path) -> None: policy_path = tmp_path / "policy.txt" write_policy( @@ -171,6 +256,14 @@ def test_query_blocks_mutation_operations(app_state: AppState, sql: str, operati assert operation in response +def test_query_blocks_insert_select_when_write_mode_disabled(app_state: AppState) -> None: + response = asyncio.run( + mcp_server.query("INSERT INTO orders (id, total) SELECT id, id FROM orders") + ) + assert "read-only access" in response + assert "INSERT" in response + + def test_query_blocks_multi_statement_payload(app_state: AppState) -> None: response = asyncio.run(mcp_server.query("SELECT id FROM customers; DROP TABLE customers")) assert "Only a single SQL statement is allowed" in response @@ -345,3 +438,137 @@ def test_prepare_read_only_session_mysql_sets_timeout_and_read_only(tmp_path: Pa "SET SESSION MAX_EXECUTION_TIME = 12000", "START TRANSACTION READ ONLY", ] + + +def test_query_allows_insert_when_write_mode_enabled(write_enabled_app_state: AppState) -> None: + response = asyncio.run( + mcp_server.query("INSERT INTO customers (id, email) VALUES (2, 'b@example.com')") + ) + payload = json.loads(response) + assert payload["status"] == "ok" + assert payload["operation"] == "insert" + assert payload["affected_rows"] == 1 + + +def test_query_allows_update_with_where_when_write_mode_enabled( + write_enabled_app_state: AppState, +) -> None: + response = asyncio.run( + mcp_server.query("UPDATE customers SET email = 'c@example.com' WHERE id = 1") + ) + payload = json.loads(response) + assert payload["status"] == "ok" + assert payload["operation"] == "update" + assert payload["affected_rows"] == 1 + + verify_response = asyncio.run(mcp_server.query("SELECT email FROM customers WHERE id = 1")) + verify_payload = json.loads(verify_response) + assert verify_payload["rows"][0]["email"] == "c@example.com" + + +def test_query_blocks_update_without_where_even_when_write_enabled( + write_enabled_app_state: AppState, +) -> None: + response = asyncio.run(mcp_server.query("UPDATE customers SET email = 'x@example.com'")) + assert "UPDATE without a WHERE clause is not allowed" in response + + +def test_query_blocks_tautological_where_even_when_write_enabled( + write_enabled_app_state: AppState, +) -> None: + response = asyncio.run(mcp_server.query("DELETE FROM customers WHERE 1 = 1")) + assert "WHERE clause appears tautological" in response + + +def test_query_blocks_insert_from_disallowed_source_table_even_when_write_enabled( + write_enabled_app_state: AppState, +) -> None: + response = asyncio.run( + mcp_server.query("INSERT INTO orders (id, total) SELECT s.id, s.id FROM secrets AS s") + ) + assert "Access to table 'secrets' is restricted" in response + + +def test_query_blocks_returning_by_default_even_when_write_enabled( + write_enabled_app_state: AppState, +) -> None: + response = asyncio.run( + mcp_server.query( + "UPDATE customers SET email = 'x@example.com' WHERE id = 1 RETURNING email" + ) + ) + assert "RETURNING is not allowed" in response + + +def test_query_blocks_restricted_returning_column_when_allowed( + write_enabled_returning_app_state: AppState, +) -> None: + response = asyncio.run( + mcp_server.query("UPDATE customers SET email = 'x@example.com' WHERE id = 1 RETURNING ssn") + ) + assert "RETURNING column(s) ssn" in response + + +def test_write_query_timeout_returns_actionable_message( + write_enabled_app_state: AppState, monkeypatch: pytest.MonkeyPatch +) -> None: + async def _raise_timeout(_: str) -> object: + raise TimeoutError() + + monkeypatch.setattr(write_enabled_app_state.db, "execute_write_query", _raise_timeout) + response = asyncio.run( + mcp_server.query("UPDATE customers SET email = 'x@example.com' WHERE id = 1") + ) + assert ( + f"Query exceeded the {write_enabled_app_state.settings.query_timeout}-second timeout" + in response + ) + + +def test_write_query_db_error_message_does_not_leak_sensitive_details( + write_enabled_app_state: AppState, monkeypatch: pytest.MonkeyPatch +) -> None: + async def _raise_db_error(_: str) -> object: + raise RuntimeError("password=supersecret host=internal-db") + + monkeypatch.setattr(write_enabled_app_state.db, "execute_write_query", _raise_db_error) + response = asyncio.run( + mcp_server.query("UPDATE customers SET email = 'x@example.com' WHERE id = 1") + ) + assert "Query execution failed with a database error" in response + assert "supersecret" not in response + assert "internal-db" not in response + + +def test_main_cli_flags_set_write_mode_env(monkeypatch: pytest.MonkeyPatch) -> None: + captured: dict[str, str] = {} + keys = ("WRITE_MODE_ENABLED", "ALLOW_INSERT", "ALLOW_UPDATE", "ALLOW_DELETE") + previous = {key: os.environ.get(key) for key in keys} + + def _fake_run(*_: object, **__: object) -> None: + captured["WRITE_MODE_ENABLED"] = os.environ.get("WRITE_MODE_ENABLED", "") + captured["ALLOW_INSERT"] = os.environ.get("ALLOW_INSERT", "") + captured["ALLOW_UPDATE"] = os.environ.get("ALLOW_UPDATE", "") + captured["ALLOW_DELETE"] = os.environ.get("ALLOW_DELETE", "") + + monkeypatch.setattr(mcp_server.mcp, "run", _fake_run) + monkeypatch.setattr( + sys, + "argv", + ["secure-sql-mcp", "--write-mode", "--allow-insert", "--allow-update", "--allow-delete"], + ) + + try: + mcp_server.main() + assert captured == { + "WRITE_MODE_ENABLED": "true", + "ALLOW_INSERT": "true", + "ALLOW_UPDATE": "true", + "ALLOW_DELETE": "true", + } + finally: + for key, value in previous.items(): + if value is None: + os.environ.pop(key, None) + else: + os.environ[key] = value diff --git a/tests/test_mcp_stdio_security.py b/tests/test_mcp_stdio_security.py index 22a3865..f38d54b 100644 --- a/tests/test_mcp_stdio_security.py +++ b/tests/test_mcp_stdio_security.py @@ -20,7 +20,7 @@ def _first_text(call_result: object) -> str: return "" -def _server_params(tmp_path: Path) -> StdioServerParameters: +def _server_params(tmp_path: Path, *, write_mode: bool = False) -> StdioServerParameters: db_path = tmp_path / "test.db" policy_path = tmp_path / "allowed_policy.txt" init_sqlite_db(db_path) @@ -36,6 +36,15 @@ def _server_params(tmp_path: Path) -> StdioServerParameters: "LOG_LEVEL": "INFO", } ) + if write_mode: + env.update( + { + "WRITE_MODE_ENABLED": "true", + "ALLOW_INSERT": "true", + "ALLOW_UPDATE": "true", + "ALLOW_DELETE": "true", + } + ) return StdioServerParameters( command=sys.executable, @@ -98,3 +107,50 @@ async def _run() -> None: assert "Only a single SQL statement is allowed" in message asyncio.run(_run()) + + +def test_mcp_stdio_blocks_insert_select_when_write_disabled(tmp_path: Path) -> None: + async def _run() -> None: + async with stdio_client(_server_params(tmp_path)) as (read, write): + async with ClientSession(read, write) as session: + await session.initialize() + result = await session.call_tool( + "query", + {"sql": "INSERT INTO orders (id, total) SELECT id, id FROM orders"}, + ) + message = _first_text(result) + assert "read-only access" in message + assert "INSERT" in message + + asyncio.run(_run()) + + +def test_mcp_stdio_write_mode_allows_insert(tmp_path: Path) -> None: + async def _run() -> None: + async with stdio_client(_server_params(tmp_path, write_mode=True)) as (read, write): + async with ClientSession(read, write) as session: + await session.initialize() + result = await session.call_tool( + "query", + {"sql": "INSERT INTO customers (id, email) VALUES (2, 'b@example.com')"}, + ) + payload = json.loads(_first_text(result)) + assert payload["status"] == "ok" + assert payload["operation"] == "insert" + assert payload["affected_rows"] == 1 + + asyncio.run(_run()) + + +def test_mcp_stdio_write_mode_blocks_tautological_delete(tmp_path: Path) -> None: + async def _run() -> None: + async with stdio_client(_server_params(tmp_path, write_mode=True)) as (read, write): + async with ClientSession(read, write) as session: + await session.initialize() + result = await session.call_tool( + "query", {"sql": "DELETE FROM customers WHERE 1 = 1"} + ) + message = _first_text(result) + assert "WHERE clause appears tautological" in message + + asyncio.run(_run()) diff --git a/tests/test_opa_policy.py b/tests/test_opa_policy.py index 0280316..22537c7 100644 --- a/tests/test_opa_policy.py +++ b/tests/test_opa_policy.py @@ -89,7 +89,102 @@ def test_validator_builds_policy_input_for_opa(tmp_path: Path) -> None: assert result.ok assert capture_engine.last_payload is not None - payload = capture_engine.last_payload + payload = cast(dict[str, Any], capture_engine.last_payload) assert payload["tool"] == {"name": "query"} assert payload["query"]["statement_count"] == 1 assert sorted(payload["query"]["referenced_tables"]) == ["customers", "orders"] + assert payload["config"]["write_mode_enabled"] is False + assert payload["config"]["allow_returning"] is False + assert payload["config"]["require_where_for_update"] is True + assert payload["config"]["require_where_for_delete"] is True + assert payload["query"]["is_write_statement"] is False + + +def test_validator_builds_write_policy_input_for_opa(tmp_path: Path) -> None: + policy_path = tmp_path / "allowed_policy.txt" + write_policy(policy_path, "customers:id,email\n") + settings = Settings.model_validate( + { + "DATABASE_URL": "sqlite+aiosqlite:///./tmp.db", + "ALLOWED_POLICY_FILE": str(policy_path), + "OPA_URL": "http://127.0.0.1:8181", + "OPA_FAIL_CLOSED": True, + "WRITE_MODE_ENABLED": True, + "ALLOW_UPDATE": True, + } + ) + capture_engine = _CaptureEngine(PolicyDecision(allow=True)) + validator = QueryValidator(settings, policy_engine=cast(Any, capture_engine)) + + result = validator.validate_query("UPDATE customers SET email = 'x@example.com' WHERE id = 1") + assert result.ok + assert capture_engine.last_payload is not None + payload = cast(dict[str, Any], capture_engine.last_payload) + assert payload["query"]["is_write_statement"] is True + assert payload["query"]["statement_type"] == "update" + assert payload["query"]["target_table"] == "customers" + assert payload["query"]["updated_columns"] == ["email"] + assert payload["query"]["where_present"] is True + assert payload["config"]["write_mode_enabled"] is True + assert payload["config"]["allow_update"] is True + + +def test_validator_marks_insert_select_as_write_for_opa(tmp_path: Path) -> None: + policy_path = tmp_path / "allowed_policy.txt" + write_policy(policy_path, "customers:id,email\norders:*\n") + settings = Settings.model_validate( + { + "DATABASE_URL": "sqlite+aiosqlite:///./tmp.db", + "ALLOWED_POLICY_FILE": str(policy_path), + "OPA_URL": "http://127.0.0.1:8181", + "WRITE_MODE_ENABLED": True, + "ALLOW_INSERT": True, + } + ) + capture_engine = _CaptureEngine(PolicyDecision(allow=True)) + validator = QueryValidator(settings, policy_engine=cast(Any, capture_engine)) + result = validator.validate_query("INSERT INTO orders (id, total) SELECT id, id FROM customers") + assert result.ok + assert capture_engine.last_payload is not None + payload = cast(dict[str, Any], capture_engine.last_payload) + assert payload["query"]["is_write_statement"] is True + assert payload["query"]["statement_type"] == "insert" + + +def test_validator_includes_star_tables_for_insert_select_star_opa(tmp_path: Path) -> None: + policy_path = tmp_path / "allowed_policy.txt" + write_policy(policy_path, "customers:id,email\norders:*\n") + settings = Settings.model_validate( + { + "DATABASE_URL": "sqlite+aiosqlite:///./tmp.db", + "ALLOWED_POLICY_FILE": str(policy_path), + "OPA_URL": "http://127.0.0.1:8181", + "WRITE_MODE_ENABLED": True, + "ALLOW_INSERT": True, + } + ) + capture_engine = _CaptureEngine(PolicyDecision(allow=True)) + validator = QueryValidator(settings, policy_engine=cast(Any, capture_engine)) + result = validator.validate_query("INSERT INTO orders (id, total) SELECT * FROM customers") + assert result.ok + assert capture_engine.last_payload is not None + payload = cast(dict[str, Any], capture_engine.last_payload) + assert payload["query"]["is_write_statement"] is True + assert "customers" in payload["query"]["star_tables"] + + +def test_opa_engine_maps_write_deny_reason_to_message(tmp_path: Path, monkeypatch) -> None: + settings = _settings(tmp_path) + engine = OpaPolicyEngine(settings) + + def _ok(*_: object, **__: object): + return _FakeResponse( + {"result": {"allow": False, "deny_reasons": ["missing_where_on_update"]}} + ) + + monkeypatch.setattr("secure_sql_mcp.opa_policy.request.urlopen", _ok) + decision = engine.evaluate_sync({"tool": {"name": "query"}, "query": {"statement_count": 1}}) + + assert decision.allow is False + assert decision.deny_reasons == ["missing_where_on_update"] + assert decision.message == "UPDATE without a WHERE clause is not allowed." diff --git a/tests/test_query_validator_security.py b/tests/test_query_validator_security.py index db734af..144790c 100644 --- a/tests/test_query_validator_security.py +++ b/tests/test_query_validator_security.py @@ -170,3 +170,165 @@ def test_validator_uses_mysql_dialect_for_mysql_url(tmp_path: Path) -> None: validator = QueryValidator(settings) assert validator._dialect == "mysql" + + +def test_validator_allows_insert_when_write_mode_enabled(tmp_path: Path) -> None: + policy_path = tmp_path / "allowed_policy.txt" + write_policy(policy_path, "customers:id,email\n") + settings = Settings.model_validate( + { + "DATABASE_URL": "sqlite+aiosqlite:///./validator.db", + "ALLOWED_POLICY_FILE": str(policy_path), + "WRITE_MODE_ENABLED": True, + "ALLOW_INSERT": True, + } + ) + validator = QueryValidator(settings) + result = validator.validate_query( + "INSERT INTO customers (id, email) VALUES (2, 'b@example.com')" + ) + assert result.ok + assert result.statement_type == "insert" + + +def test_validator_blocks_update_without_where_in_write_mode(tmp_path: Path) -> None: + policy_path = tmp_path / "allowed_policy.txt" + write_policy(policy_path, "customers:id,email\n") + settings = Settings.model_validate( + { + "DATABASE_URL": "sqlite+aiosqlite:///./validator.db", + "ALLOWED_POLICY_FILE": str(policy_path), + "WRITE_MODE_ENABLED": True, + "ALLOW_UPDATE": True, + } + ) + validator = QueryValidator(settings) + result = validator.validate_query("UPDATE customers SET email = 'x@example.com'") + assert not result.ok + assert "UPDATE without a WHERE clause is not allowed" in (result.error or "") + + +def test_validator_blocks_tautological_where_in_write_mode(tmp_path: Path) -> None: + policy_path = tmp_path / "allowed_policy.txt" + write_policy(policy_path, "customers:id,email\n") + settings = Settings.model_validate( + { + "DATABASE_URL": "sqlite+aiosqlite:///./validator.db", + "ALLOWED_POLICY_FILE": str(policy_path), + "WRITE_MODE_ENABLED": True, + "ALLOW_DELETE": True, + } + ) + validator = QueryValidator(settings) + result = validator.validate_query("DELETE FROM customers WHERE 1 = 1") + assert not result.ok + assert "WHERE clause appears tautological" in (result.error or "") + + +def test_validator_blocks_insert_select_when_write_mode_disabled(tmp_path: Path) -> None: + policy_path = tmp_path / "allowed_policy.txt" + write_policy(policy_path, "customers:id,email\norders:*\n") + settings = Settings.model_validate( + { + "DATABASE_URL": "sqlite+aiosqlite:///./validator.db", + "ALLOWED_POLICY_FILE": str(policy_path), + "WRITE_MODE_ENABLED": False, + "ALLOW_INSERT": True, + } + ) + validator = QueryValidator(settings) + result = validator.validate_query("INSERT INTO orders (id, total) SELECT id, id FROM customers") + assert not result.ok + assert "configured for read-only access" in (result.error or "") + + +def test_validator_blocks_update_with_subquery_when_update_disabled(tmp_path: Path) -> None: + policy_path = tmp_path / "allowed_policy.txt" + write_policy(policy_path, "customers:id,email\norders:*\n") + settings = Settings.model_validate( + { + "DATABASE_URL": "sqlite+aiosqlite:///./validator.db", + "ALLOWED_POLICY_FILE": str(policy_path), + "WRITE_MODE_ENABLED": True, + "ALLOW_UPDATE": False, + } + ) + validator = QueryValidator(settings) + result = validator.validate_query( + "UPDATE customers SET email = (SELECT 'x@example.com') WHERE id = 1" + ) + assert not result.ok + assert "UPDATE operations are disabled" in (result.error or "") + + +def test_validator_blocks_delete_with_subquery_when_delete_disabled(tmp_path: Path) -> None: + policy_path = tmp_path / "allowed_policy.txt" + write_policy(policy_path, "customers:id,email\norders:*\n") + settings = Settings.model_validate( + { + "DATABASE_URL": "sqlite+aiosqlite:///./validator.db", + "ALLOWED_POLICY_FILE": str(policy_path), + "WRITE_MODE_ENABLED": True, + "ALLOW_DELETE": False, + } + ) + validator = QueryValidator(settings) + result = validator.validate_query( + "DELETE FROM customers WHERE id IN (SELECT id FROM customers)" + ) + assert not result.ok + assert "DELETE operations are disabled" in (result.error or "") + + +def test_validator_blocks_returning_by_default(tmp_path: Path) -> None: + policy_path = tmp_path / "allowed_policy.txt" + write_policy(policy_path, "customers:id,email\n") + settings = Settings.model_validate( + { + "DATABASE_URL": "sqlite+aiosqlite:///./validator.db", + "ALLOWED_POLICY_FILE": str(policy_path), + "WRITE_MODE_ENABLED": True, + "ALLOW_UPDATE": True, + } + ) + validator = QueryValidator(settings) + result = validator.validate_query( + "UPDATE customers SET email = 'x@example.com' WHERE id = 1 RETURNING email" + ) + assert not result.ok + assert "RETURNING is not allowed" in (result.error or "") + + +def test_validator_blocks_insert_select_star_from_non_wildcard_source(tmp_path: Path) -> None: + policy_path = tmp_path / "allowed_policy.txt" + write_policy(policy_path, "customers:id,email\norders:*\n") + settings = Settings.model_validate( + { + "DATABASE_URL": "sqlite+aiosqlite:///./validator.db", + "ALLOWED_POLICY_FILE": str(policy_path), + "WRITE_MODE_ENABLED": True, + "ALLOW_INSERT": True, + } + ) + validator = QueryValidator(settings) + result = validator.validate_query("INSERT INTO orders (id, total) SELECT * FROM customers") + assert not result.ok + assert "SELECT * is not allowed for table 'customers'" in (result.error or "") + + +def test_validator_accepts_qualified_target_table_with_short_policy_name(tmp_path: Path) -> None: + policy_path = tmp_path / "allowed_policy.txt" + write_policy(policy_path, "customers:id,email\n") + settings = Settings.model_validate( + { + "DATABASE_URL": "sqlite+aiosqlite:///./validator.db", + "ALLOWED_POLICY_FILE": str(policy_path), + "WRITE_MODE_ENABLED": True, + "ALLOW_UPDATE": True, + } + ) + validator = QueryValidator(settings) + result = validator.validate_query( + "UPDATE main.customers SET email = 'x@example.com' WHERE id = 1" + ) + assert result.ok diff --git a/tests/test_write_facts.py b/tests/test_write_facts.py new file mode 100644 index 0000000..91a81dd --- /dev/null +++ b/tests/test_write_facts.py @@ -0,0 +1,136 @@ +from __future__ import annotations + +from pathlib import Path + +import sqlglot + +from secure_sql_mcp.config import Settings +from secure_sql_mcp.query_validator import QueryValidator +from tests.conftest import write_policy + + +def _validator(tmp_path: Path, **overrides: object) -> QueryValidator: + policy_path = tmp_path / "allowed_policy.txt" + write_policy( + policy_path, + """ + customers:id,email + orders:* + """, + ) + payload: dict[str, object] = { + "DATABASE_URL": "sqlite+aiosqlite:///./write-facts.db", + "ALLOWED_POLICY_FILE": str(policy_path), + } + payload.update(overrides) + settings = Settings.model_validate(payload) + return QueryValidator(settings) + + +def test_extract_insert_write_facts(tmp_path: Path) -> None: + validator = _validator(tmp_path, WRITE_MODE_ENABLED=True, ALLOW_INSERT=True) + statement = sqlglot.parse_one( + "INSERT INTO customers (id, email) VALUES (2, 'b@example.com')", + read=validator._dialect, + ) + facts = validator._extract_write_facts(statement) + assert facts is not None + assert facts.statement_type == "insert" + assert facts.target_table == "customers" + assert facts.insert_columns == ["email", "id"] + assert facts.source_tables == [] + + +def test_extract_update_write_facts_with_tautological_where(tmp_path: Path) -> None: + validator = _validator(tmp_path, WRITE_MODE_ENABLED=True, ALLOW_UPDATE=True) + statement = sqlglot.parse_one( + "UPDATE customers SET email = 'x@example.com' WHERE 1 = 1", + read=validator._dialect, + ) + facts = validator._extract_write_facts(statement) + assert facts is not None + assert facts.statement_type == "update" + assert facts.where_present is True + assert facts.where_tautological is True + assert facts.updated_columns == ["email"] + + +def test_extract_delete_write_facts(tmp_path: Path) -> None: + validator = _validator(tmp_path, WRITE_MODE_ENABLED=True, ALLOW_DELETE=True) + statement = sqlglot.parse_one( + "DELETE FROM customers WHERE id = 1", + read=validator._dialect, + ) + facts = validator._extract_write_facts(statement) + assert facts is not None + assert facts.statement_type == "delete" + assert facts.where_present is True + assert facts.where_tautological is False + + +def test_extract_returning_columns(tmp_path: Path) -> None: + validator = _validator( + tmp_path, + WRITE_MODE_ENABLED=True, + ALLOW_UPDATE=True, + ALLOW_RETURNING=True, + ) + statement = sqlglot.parse_one( + "UPDATE customers SET email = 'x@example.com' WHERE id = 1 RETURNING email", + read=validator._dialect, + ) + facts = validator._extract_write_facts(statement) + assert facts is not None + assert facts.returning_present is True + assert facts.returning_columns == ["email"] + + +def test_extract_insert_select_source_tables(tmp_path: Path) -> None: + validator = _validator(tmp_path, WRITE_MODE_ENABLED=True, ALLOW_INSERT=True) + statement = sqlglot.parse_one( + "INSERT INTO orders (id, total) SELECT id, total FROM orders", + read=validator._dialect, + ) + facts = validator._extract_write_facts(statement) + assert facts is not None + assert facts.has_select_source is True + assert facts.source_tables == ["orders"] + + +def test_write_mode_disabled_blocks_writes(tmp_path: Path) -> None: + validator = _validator(tmp_path) + result = validator.validate_query( + "INSERT INTO customers (id, email) VALUES (2, 'b@example.com')" + ) + assert not result.ok + assert "configured for read-only access" in (result.error or "") + + +def test_allow_insert_flag_controls_insert(tmp_path: Path) -> None: + validator = _validator(tmp_path, WRITE_MODE_ENABLED=True, ALLOW_INSERT=False) + result = validator.validate_query( + "INSERT INTO customers (id, email) VALUES (2, 'b@example.com')" + ) + assert not result.ok + assert "INSERT operations are disabled by server configuration" in (result.error or "") + + +def test_enable_insert_does_not_enable_update(tmp_path: Path) -> None: + validator = _validator(tmp_path, WRITE_MODE_ENABLED=True, ALLOW_INSERT=True, ALLOW_UPDATE=False) + result = validator.validate_query("UPDATE customers SET email = 'x@example.com' WHERE id = 1") + assert not result.ok + assert "UPDATE operations are disabled by server configuration" in (result.error or "") + + +def test_ddl_still_blocked_when_write_mode_enabled(tmp_path: Path) -> None: + validator = _validator( + tmp_path, + WRITE_MODE_ENABLED=True, + ALLOW_INSERT=True, + ALLOW_UPDATE=True, + ALLOW_DELETE=True, + ) + result = validator.validate_query("DROP TABLE customers") + assert not result.ok + assert "read-only access" in (result.error or "") + assert "DROP" in (result.error or "")