diff --git a/.gitignore b/.gitignore index 0728338..37bd76f 100644 --- a/.gitignore +++ b/.gitignore @@ -24,3 +24,7 @@ target # Added by cargo /target + +# PGLite runtime temporary files +# TODO: Remove once pglite_oxide uses OS tempdir (https://github.com/pglite-dev/pglite-oxide/issues/XXX) +tmp/ diff --git a/src/cli/mcp.rs b/src/cli/mcp.rs new file mode 100644 index 0000000..6edcbfb --- /dev/null +++ b/src/cli/mcp.rs @@ -0,0 +1,95 @@ +//! MCP server command implementation. +//! +//! This module provides the `tern mcp` command that starts an MCP server +//! for AI-assisted migration authoring. + +use std::path::PathBuf; + +use clap::Parser; +use miette::IntoDiagnostic; + +use crate::cli::{ensure_backend_initialized, load_backend}; +use crate::db::state::StateBackend; +use crate::mcp::McpServer; + +/// Start the MCP server for AI-assisted migration authoring. +/// +/// The MCP (Model Context Protocol) server enables AI assistants to interact +/// with Tern's migration system. It uses an in-memory PostgreSQL database +/// (via PGLite) to validate schema changes before generating migrations. +/// +/// # Communication +/// +/// The server communicates over stdio using newline-delimited JSON-RPC 2.0 +/// messages. This is compatible with AI assistants that support the MCP +/// protocol, such as Claude Desktop. +/// +/// # Example Configuration +/// +/// Add to your MCP configuration: +/// +/// ```json +/// { +/// "mcpServers": { +/// "tern": { +/// "command": "tern", +/// "args": ["mcp"], +/// "cwd": "/path/to/project" +/// } +/// } +/// } +/// ``` +/// +/// # Available Tools +/// +/// - `start_session`: Begin a new migration authoring session +/// - `execute_sql`: Execute SQL in the session's in-memory database +/// - `apply_operation`: Apply a structured schema operation +/// - `get_session_schema`: View the current schema state +/// - `get_session_diff`: View pending changes compared to baseline +/// - `generate_migration`: Create a migration from session changes +/// - `cancel_session`: Discard session without creating a migration +/// - `list_sessions`: List all active sessions +/// +/// # Available Resources +/// +/// - `tern://schema`: Current database schema +/// - `tern://migrations`: List of all migrations +/// - `tern://migration/{id}`: Details of a specific migration +#[derive(Debug, Parser, Clone)] +pub struct Mcp { + /// Path to the state directory + /// + /// If not specified, uses `.tern/` in the current directory. + #[arg(long)] + state_path: Option, +} + +impl Mcp { + /// Runs the MCP server. + pub async fn dispatch(self) -> miette::Result<()> { + let backend = load_backend(self.state_path.as_deref()); + + // Ensure backend is initialized + ensure_backend_initialized(&backend).await?; + + // Load current state + let current_state = backend.get_current_state().await.into_diagnostic()?; + + // Create and run the server + let server = McpServer::new(backend, current_state); + server.run_stdio().await + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn mcp_command_parses() { + // Just verify the command can be parsed + use clap::CommandFactory; + let _ = Mcp::command(); + } +} diff --git a/src/cli/mod.rs b/src/cli/mod.rs index 8626dcf..ca54f0e 100644 --- a/src/cli/mod.rs +++ b/src/cli/mod.rs @@ -13,6 +13,7 @@ pub mod history; pub mod import; pub mod init; pub mod inspect; +pub mod mcp; pub mod migrate; pub mod print_migrations; pub mod record; @@ -213,6 +214,13 @@ pub enum CliCommand { /// migration workflow. Export(export::Export), + /// Start the MCP server for AI-assisted migration authoring + /// + /// The MCP (Model Context Protocol) server enables AI assistants to + /// interact with Tern's migration system. It provides tools for + /// creating and validating migrations through a structured protocol. + Mcp(mcp::Mcp), + /// [DEPRECATED] Schema management commands (use 'tern export' instead) /// /// Commands for working with the schema DDL file, which forms the @@ -320,6 +328,7 @@ impl CliCommand { CliCommand::Verify(args) => args.dispatch().await, CliCommand::VerifyChain(args) => args.dispatch().await, CliCommand::Export(args) => args.dispatch().await, + CliCommand::Mcp(args) => args.dispatch().await, CliCommand::Schema(action) => action.dispatch().await, CliCommand::PrintMigrations(args) => args.dispatch().await, CliCommand::Compile(args) => args.dispatch().await, diff --git a/src/lib.rs b/src/lib.rs index e41703b..9d65603 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,5 +1,6 @@ pub mod cli; pub mod db; +pub mod mcp; pub mod util; // ============================================================================= diff --git a/src/mcp/error.rs b/src/mcp/error.rs new file mode 100644 index 0000000..cad7554 --- /dev/null +++ b/src/mcp/error.rs @@ -0,0 +1,350 @@ +//! Error types for the MCP server. +//! +//! This module defines error types specific to the MCP (Model Context Protocol) +//! server implementation, including protocol errors, session errors, and +//! operation errors. + +// Suppress warnings about enum variant fields that are used by thiserror's Display impl. +#![allow(unused_assignments)] + +use std::fmt; + +/// Error returned by MCP server operations. +#[derive(Debug, thiserror::Error, miette::Diagnostic)] +pub enum McpError { + // ========================================================================= + // Protocol Errors (JSON-RPC standard codes) + // ========================================================================= + /// Parse error: Invalid JSON was received. + #[error("parse error: {0}")] + #[diagnostic(code(mcp::parse_error))] + ParseError(String), + + /// Invalid request: The JSON is not a valid request object. + #[error("invalid request: {0}")] + #[diagnostic(code(mcp::invalid_request))] + InvalidRequest(String), + + /// Method not found: The requested method does not exist. + #[error("method not found: {0}")] + #[diagnostic(code(mcp::method_not_found))] + MethodNotFound(String), + + /// Invalid params: The method parameters are invalid. + #[error("invalid params: {0}")] + #[diagnostic(code(mcp::invalid_params))] + InvalidParams(String), + + /// Internal error: An internal server error occurred. + #[error("internal error: {0}")] + #[diagnostic(code(mcp::internal_error))] + InternalError(String), + + // ========================================================================= + // Session Errors + // ========================================================================= + /// Session not found: The specified session does not exist. + #[error("session not found: {0}")] + #[diagnostic( + code(mcp::session_not_found), + help("Use list_sessions to see active sessions, or start_session to create a new one.") + )] + SessionNotFound(SessionId), + + /// No active session: A session ID is required but none was provided. + #[error("no active session; call start_session first")] + #[diagnostic(code(mcp::no_active_session))] + NoActiveSession, + + /// Session limit reached: Too many concurrent sessions are active. + #[error("session limit reached; cancel an existing session first")] + #[diagnostic( + code(mcp::session_limit_reached), + help("Use cancel_session to close an existing session before starting a new one.") + )] + SessionLimitReached { + /// Maximum number of sessions allowed. + max_sessions: usize, + /// Currently active session count. + current_sessions: usize, + }, + + // ========================================================================= + // SQL Execution Errors + // ========================================================================= + /// SQL execution failed: The SQL statement could not be executed. + #[error("SQL execution failed: {message}")] + #[diagnostic(code(mcp::sql_execution_failed))] + SqlExecutionFailed { + /// Error message from PostgreSQL. + message: String, + /// PostgreSQL error code, if available. + code: Option, + /// Detailed error information, if available. + detail: Option, + /// Hint for fixing the error, if available. + hint: Option, + /// Position in the SQL where the error occurred, if available. + position: Option, + }, + + // ========================================================================= + // Migration Generation Errors + // ========================================================================= + /// No changes: The session has no schema changes to migrate. + #[error("no changes in session")] + #[diagnostic( + code(mcp::no_changes), + help( + "Make schema changes using execute_sql or apply_operation before generating a migration." + ) + )] + NoChanges, + + /// Breaking changes detected: The migration contains breaking changes. + #[error("breaking changes detected; use force=true to proceed")] + #[diagnostic( + code(mcp::breaking_changes), + help("Review the breaking changes and set force=true if you want to proceed.") + )] + BreakingChangesDetected { + /// List of breaking changes detected. + breaking_changes: Vec, + }, + + /// Save failed: The migration could not be saved to disk. + #[error("failed to save migration: {0}")] + #[diagnostic(code(mcp::save_failed))] + SaveFailed(String), + + // ========================================================================= + // Backend Errors + // ========================================================================= + /// Backend not initialized: The state backend is not initialized. + #[error("backend not initialized: {0}")] + #[diagnostic( + code(mcp::backend_not_initialized), + help("Run 'tern init' to initialize the project first.") + )] + BackendNotInitialized(String), + + /// Backend error: A state backend operation failed. + #[error("backend error: {0}")] + #[diagnostic(code(mcp::backend_error))] + BackendError(String), + + /// Migration not found: The specified migration does not exist. + #[error("migration not found: {0}")] + #[diagnostic(code(mcp::migration_not_found))] + MigrationNotFound(String), + + // ========================================================================= + // PGLite Errors + // ========================================================================= + /// PGLite error: An error occurred with the embedded PostgreSQL. + #[error("PGLite error: {0}")] + #[diagnostic(code(mcp::pglite_error))] + PgLiteError(String), + + // ========================================================================= + // Resource Errors + // ========================================================================= + /// Resource not found: The requested resource does not exist. + #[error("resource not found: {0}")] + #[diagnostic(code(mcp::resource_not_found))] + ResourceNotFound(String), + + /// Invalid resource URI: The resource URI is malformed. + #[error("invalid resource URI: {0}")] + #[diagnostic(code(mcp::invalid_resource_uri))] + InvalidResourceUri(String), + + // ========================================================================= + // Transport Errors + // ========================================================================= + /// I/O error: An I/O operation failed. + #[error("I/O error: {0}")] + #[diagnostic(code(mcp::io_error))] + IoError(#[from] std::io::Error), +} + +impl McpError { + /// Returns the JSON-RPC error code for this error. + /// + /// Standard JSON-RPC error codes: + /// - -32700: Parse error + /// - -32600: Invalid request + /// - -32601: Method not found + /// - -32602: Invalid params + /// - -32603: Internal error + /// + /// Application-specific error codes (starting at -32100): + /// - -32100: Session not found + /// - -32101: No active session + /// - -32102: Session limit reached + /// - -32103: SQL execution failed + /// - -32104: No changes + /// - -32105: Breaking changes detected + /// - -32106: Save failed + /// - -32107: Backend not initialized + /// - -32108: PGLite error + /// - -32109: Resource not found + /// - -32110: Invalid resource URI + /// - -32111: Backend error + /// - -32112: Migration not found + #[must_use] + pub fn error_code(&self) -> i32 { + match self { + Self::ParseError(_) => -32700, + Self::InvalidRequest(_) => -32600, + Self::MethodNotFound(_) => -32601, + Self::InvalidParams(_) => -32602, + Self::InternalError(_) => -32603, + Self::SessionNotFound(_) => -32100, + Self::NoActiveSession => -32101, + Self::SessionLimitReached { .. } => -32102, + Self::SqlExecutionFailed { .. } => -32103, + Self::NoChanges => -32104, + Self::BreakingChangesDetected { .. } => -32105, + Self::SaveFailed(_) => -32106, + Self::BackendNotInitialized(_) => -32107, + Self::PgLiteError(_) => -32108, + Self::ResourceNotFound(_) => -32109, + Self::InvalidResourceUri(_) => -32110, + Self::BackendError(_) => -32111, + Self::MigrationNotFound(_) => -32112, + Self::IoError(_) => -32603, // Map to internal error + } + } +} + +/// Unique identifier for a migration authoring session. +/// +/// Session IDs are prefixed with "sess_" followed by 8 alphanumeric characters. +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct SessionId(String); + +impl SessionId { + /// Creates a new session ID from a string. + /// + /// # Panics + /// + /// Panics if the string is empty. + #[must_use] + pub fn new(id: impl Into) -> Self { + let id = id.into(); + assert!(!id.is_empty(), "session ID must not be empty"); + Self(id) + } + + /// Generates a new random session ID. + /// + /// The format is "sess_" followed by 12 alphanumeric characters. + #[must_use] + pub fn generate() -> Self { + use std::sync::atomic::{AtomicU64, Ordering}; + use std::time::{SystemTime, UNIX_EPOCH}; + + // Static counter to ensure uniqueness even when called in quick succession + static COUNTER: AtomicU64 = AtomicU64::new(0); + + // Simple random ID generation without external dependencies + let timestamp = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap_or_default() + .as_nanos(); + + // Combine timestamp with an incrementing counter for uniqueness + // Using SeqCst ordering to ensure strict ordering across threads + let counter = COUNTER.fetch_add(1, Ordering::SeqCst); + + // Include thread ID for additional uniqueness in multi-threaded contexts + let thread_id = std::thread::current().id(); + let thread_hash = format!("{:?}", thread_id).bytes().fold(0u64, |acc, b| { + acc.wrapping_mul(31).wrapping_add(u64::from(b)) + }); + + let combined = (timestamp as u64) + .wrapping_add(counter) + .wrapping_add(thread_hash); + let random_part = format!("{:012x}", combined ^ 0xDEAD_BEEF_CAFE_BABE); + Self(format!("sess_{}", &random_part[..12])) + } + + /// Returns the session ID as a string slice. + #[must_use] + pub fn as_str(&self) -> &str { + &self.0 + } +} + +impl fmt::Display for SessionId { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.0) + } +} + +impl AsRef for SessionId { + fn as_ref(&self) -> &str { + &self.0 + } +} + +impl From for SessionId { + fn from(s: String) -> Self { + Self::new(s) + } +} + +impl From<&str> for SessionId { + fn from(s: &str) -> Self { + Self::new(s.to_string()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn session_id_generate() { + let id1 = SessionId::generate(); + let id2 = SessionId::generate(); + + assert!(id1.as_str().starts_with("sess_")); + assert!(id2.as_str().starts_with("sess_")); + // Generated IDs should be different (in most cases) + // Note: This could theoretically fail if generated in the same nanosecond + } + + #[test] + fn session_id_from_string() { + let id = SessionId::from("sess_abc12345".to_string()); + assert_eq!(id.as_str(), "sess_abc12345"); + } + + #[test] + fn session_id_display() { + let id = SessionId::new("sess_test1234"); + assert_eq!(format!("{}", id), "sess_test1234"); + } + + #[test] + fn error_codes_are_correct() { + assert_eq!(McpError::ParseError("test".into()).error_code(), -32700); + assert_eq!(McpError::InvalidRequest("test".into()).error_code(), -32600); + assert_eq!(McpError::MethodNotFound("test".into()).error_code(), -32601); + assert_eq!(McpError::InvalidParams("test".into()).error_code(), -32602); + assert_eq!(McpError::InternalError("test".into()).error_code(), -32603); + assert_eq!( + McpError::SessionNotFound(SessionId::new("sess_test")).error_code(), + -32100 + ); + } + + #[test] + #[should_panic(expected = "session ID must not be empty")] + fn session_id_empty_panics() { + let _id = SessionId::new(""); + } +} diff --git a/src/mcp/mod.rs b/src/mcp/mod.rs new file mode 100644 index 0000000..77198a0 --- /dev/null +++ b/src/mcp/mod.rs @@ -0,0 +1,50 @@ +//! MCP (Model Context Protocol) server for AI-assisted migration authoring. +//! +//! This module implements an MCP server that enables AI assistants to interact +//! with Tern's migration system programmatically. The server runs locally and +//! uses an in-memory PostgreSQL database (via PGLite) for schema validation. +//! +//! # Architecture +//! +//! The MCP server consists of several components: +//! +//! - **Transport**: Handles JSON-RPC communication over stdio +//! - **Protocol**: Defines MCP message types and request/response handling +//! - **Resources**: Read-only data exposed to clients (schema, migrations) +//! - **Tools**: Actions that modify state (session management, SQL execution) +//! - **Sessions**: Manages migration authoring sessions with isolated databases +//! +//! # Usage +//! +//! The server is started via the `tern mcp` command and communicates over stdio: +//! +//! ```text +//! $ tern mcp +//! ``` +//! +//! # Example Session +//! +//! 1. Client reads `tern://schema` resource to understand current schema +//! 2. Client calls `start_session` tool to begin a migration +//! 3. Client makes changes via `execute_sql` or `apply_operation` +//! 4. Client calls `get_session_diff` to review changes +//! 5. Client calls `generate_migration` to create the migration file +//! +//! # Feature Requirements +//! +//! This module requires the `pglite` feature for the embedded PostgreSQL +//! functionality used in session databases. + +// Allow dead code for work-in-progress items that will be used when PGLite integration is complete. +#![allow(dead_code)] + +mod error; +mod protocol; +mod resources; +mod server; +mod session; +mod tools; +mod transport; + +pub use error::{McpError, SessionId}; +pub use server::McpServer; diff --git a/src/mcp/protocol.rs b/src/mcp/protocol.rs new file mode 100644 index 0000000..441a96c --- /dev/null +++ b/src/mcp/protocol.rs @@ -0,0 +1,647 @@ +//! MCP protocol types and JSON-RPC message handling. +//! +//! This module defines the types used for MCP (Model Context Protocol) +//! communication, following the JSON-RPC 2.0 specification. +//! +//! # Protocol Overview +//! +//! MCP uses JSON-RPC 2.0 over stdio for communication. Messages are: +//! - **Requests**: Client asks server to do something +//! - **Responses**: Server replies to a request +//! - **Notifications**: One-way messages (no response expected) + +use serde::{Deserialize, Serialize}; +use serde_json::Value; + +use crate::mcp::error::McpError; + +/// JSON-RPC version string (always "2.0"). +pub const JSONRPC_VERSION: &str = "2.0"; + +/// MCP protocol version supported by this server. +pub const MCP_PROTOCOL_VERSION: &str = "2024-11-05"; + +/// Server name for MCP initialize response. +pub const SERVER_NAME: &str = "tern"; + +/// Server version for MCP initialize response. +pub const SERVER_VERSION: &str = env!("CARGO_PKG_VERSION"); + +// ============================================================================= +// JSON-RPC Message Types +// ============================================================================= + +/// A JSON-RPC request message. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct JsonRpcRequest { + /// JSON-RPC version (always "2.0"). + pub jsonrpc: String, + /// Request identifier (used to match responses). + pub id: RequestId, + /// Method name to invoke. + pub method: String, + /// Method parameters (optional). + #[serde(default, skip_serializing_if = "Option::is_none")] + pub params: Option, +} + +/// A JSON-RPC response message. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct JsonRpcResponse { + /// JSON-RPC version (always "2.0"). + pub jsonrpc: String, + /// Request identifier (matches the request). + pub id: RequestId, + /// Result on success. + #[serde(skip_serializing_if = "Option::is_none")] + pub result: Option, + /// Error on failure. + #[serde(skip_serializing_if = "Option::is_none")] + pub error: Option, +} + +impl JsonRpcResponse { + /// Creates a successful response. + #[must_use] + pub fn success(id: RequestId, result: Value) -> Self { + Self { + jsonrpc: JSONRPC_VERSION.to_string(), + id, + result: Some(result), + error: None, + } + } + + /// Creates an error response. + #[must_use] + pub fn error(id: RequestId, error: JsonRpcError) -> Self { + Self { + jsonrpc: JSONRPC_VERSION.to_string(), + id, + result: None, + error: Some(error), + } + } + + /// Creates an error response from an `McpError`. + #[must_use] + pub fn from_mcp_error(id: RequestId, error: &McpError) -> Self { + Self::error(id, JsonRpcError::from_mcp_error(error)) + } +} + +/// A JSON-RPC notification message (no response expected). +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct JsonRpcNotification { + /// JSON-RPC version (always "2.0"). + pub jsonrpc: String, + /// Method name. + pub method: String, + /// Method parameters (optional). + #[serde(default, skip_serializing_if = "Option::is_none")] + pub params: Option, +} + +/// A JSON-RPC error object. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct JsonRpcError { + /// Error code. + pub code: i32, + /// Human-readable error message. + pub message: String, + /// Additional error data (optional). + #[serde(skip_serializing_if = "Option::is_none")] + pub data: Option, +} + +impl JsonRpcError { + /// Creates a new JSON-RPC error. + #[must_use] + pub fn new(code: i32, message: impl Into) -> Self { + Self { + code, + message: message.into(), + data: None, + } + } + + /// Creates a JSON-RPC error with additional data. + #[must_use] + pub fn with_data(code: i32, message: impl Into, data: Value) -> Self { + Self { + code, + message: message.into(), + data: Some(data), + } + } + + /// Creates a JSON-RPC error from an `McpError`. + #[must_use] + pub fn from_mcp_error(error: &McpError) -> Self { + let code = error.error_code(); + let message = error.to_string(); + + // Add additional data for certain error types + let data = match error { + McpError::SessionNotFound(session_id) => Some(serde_json::json!({ + "sessionId": session_id.as_str(), + })), + McpError::SessionLimitReached { + max_sessions, + current_sessions, + } => Some(serde_json::json!({ + "maxSessions": max_sessions, + "currentSessions": current_sessions, + })), + McpError::SqlExecutionFailed { + code, + detail, + hint, + position, + .. + } => Some(serde_json::json!({ + "code": code, + "detail": detail, + "hint": hint, + "position": position, + })), + McpError::BreakingChangesDetected { breaking_changes } => Some(serde_json::json!({ + "breakingChanges": breaking_changes, + })), + _ => None, + }; + + Self { + code, + message, + data, + } + } + + /// Creates a parse error. + #[must_use] + pub fn parse_error(message: impl Into) -> Self { + Self::new(-32700, message) + } + + /// Creates an invalid request error. + #[must_use] + pub fn invalid_request(message: impl Into) -> Self { + Self::new(-32600, message) + } + + /// Creates a method not found error. + #[must_use] + pub fn method_not_found(method: &str) -> Self { + Self::new(-32601, format!("method not found: {method}")) + } + + /// Creates an invalid params error. + #[must_use] + pub fn invalid_params(message: impl Into) -> Self { + Self::new(-32602, message) + } + + /// Creates an internal error. + #[must_use] + pub fn internal_error(message: impl Into) -> Self { + Self::new(-32603, message) + } +} + +/// A request identifier (can be string, number, or null). +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[serde(untagged)] +#[derive(Default)] +pub enum RequestId { + /// String identifier. + String(String), + /// Numeric identifier. + Number(i64), + /// Null identifier (for notifications converted to requests). + #[default] + Null, +} + +impl From for RequestId { + fn from(n: i64) -> Self { + Self::Number(n) + } +} + +impl From for RequestId { + fn from(s: String) -> Self { + Self::String(s) + } +} + +impl From<&str> for RequestId { + fn from(s: &str) -> Self { + Self::String(s.to_string()) + } +} + +// ============================================================================= +// MCP-Specific Types +// ============================================================================= + +/// MCP initialize request parameters. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct InitializeParams { + /// Protocol version the client supports. + pub protocol_version: String, + /// Capabilities the client supports. + pub capabilities: ClientCapabilities, + /// Information about the client. + pub client_info: ClientInfo, +} + +/// MCP client capabilities. +#[derive(Debug, Clone, Default, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ClientCapabilities { + /// Whether the client supports roots. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub roots: Option, + /// Whether the client supports sampling. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub sampling: Option, +} + +/// Roots capability configuration. +#[derive(Debug, Clone, Default, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct RootsCapability { + /// Whether the client can list roots. + #[serde(default)] + pub list_changed: bool, +} + +/// Information about the client. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ClientInfo { + /// Client name. + pub name: String, + /// Client version. + pub version: String, +} + +/// MCP initialize response result. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct InitializeResult { + /// Protocol version the server supports. + pub protocol_version: String, + /// Capabilities the server supports. + pub capabilities: ServerCapabilities, + /// Information about the server. + pub server_info: ServerInfo, +} + +impl Default for InitializeResult { + fn default() -> Self { + Self { + protocol_version: MCP_PROTOCOL_VERSION.to_string(), + capabilities: ServerCapabilities::default(), + server_info: ServerInfo { + name: SERVER_NAME.to_string(), + version: SERVER_VERSION.to_string(), + }, + } + } +} + +/// MCP server capabilities. +#[derive(Debug, Clone, Default, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ServerCapabilities { + /// Resources capability. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub resources: Option, + /// Tools capability. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub tools: Option, + /// Prompts capability. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub prompts: Option, +} + +/// Resources capability configuration. +#[derive(Debug, Clone, Default, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ResourcesCapability { + /// Whether resources are supported. + #[serde(default)] + pub subscribe: bool, + /// Whether resource listing is supported. + #[serde(default)] + pub list_changed: bool, +} + +/// Tools capability configuration. +#[derive(Debug, Clone, Default, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ToolsCapability { + /// Whether tool listing changes are supported. + #[serde(default)] + pub list_changed: bool, +} + +/// Prompts capability configuration. +#[derive(Debug, Clone, Default, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PromptsCapability { + /// Whether prompt listing changes are supported. + #[serde(default)] + pub list_changed: bool, +} + +/// Information about the server. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ServerInfo { + /// Server name. + pub name: String, + /// Server version. + pub version: String, +} + +// ============================================================================= +// Resource Types +// ============================================================================= + +/// A resource descriptor. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct Resource { + /// Resource URI. + pub uri: String, + /// Human-readable name. + pub name: String, + /// Description of the resource. + #[serde(skip_serializing_if = "Option::is_none")] + pub description: Option, + /// MIME type of the resource content. + #[serde(skip_serializing_if = "Option::is_none")] + pub mime_type: Option, +} + +/// Resource content. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ResourceContent { + /// Resource URI. + pub uri: String, + /// MIME type of the content. + #[serde(skip_serializing_if = "Option::is_none")] + pub mime_type: Option, + /// Text content (for text resources). + #[serde(skip_serializing_if = "Option::is_none")] + pub text: Option, + /// Binary content as base64 (for binary resources). + #[serde(skip_serializing_if = "Option::is_none")] + pub blob: Option, +} + +/// List resources request parameters. +#[derive(Debug, Clone, Default, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ListResourcesParams { + /// Cursor for pagination. + #[serde(skip_serializing_if = "Option::is_none")] + pub cursor: Option, +} + +/// List resources response. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ListResourcesResult { + /// Available resources. + pub resources: Vec, + /// Cursor for next page, if more results available. + #[serde(skip_serializing_if = "Option::is_none")] + pub next_cursor: Option, +} + +/// Read resource request parameters. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ReadResourceParams { + /// URI of the resource to read. + pub uri: String, +} + +/// Read resource response. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ReadResourceResult { + /// Resource contents. + pub contents: Vec, +} + +// ============================================================================= +// Tool Types +// ============================================================================= + +/// A tool descriptor. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct Tool { + /// Tool name. + pub name: String, + /// Description of what the tool does. + #[serde(skip_serializing_if = "Option::is_none")] + pub description: Option, + /// JSON Schema for the tool's input parameters. + pub input_schema: Value, +} + +/// List tools request parameters. +#[derive(Debug, Clone, Default, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ListToolsParams { + /// Cursor for pagination. + #[serde(skip_serializing_if = "Option::is_none")] + pub cursor: Option, +} + +/// List tools response. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ListToolsResult { + /// Available tools. + pub tools: Vec, + /// Cursor for next page, if more results available. + #[serde(skip_serializing_if = "Option::is_none")] + pub next_cursor: Option, +} + +/// Call tool request parameters. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct CallToolParams { + /// Name of the tool to call. + pub name: String, + /// Arguments to pass to the tool. + #[serde(default)] + pub arguments: Value, +} + +/// Call tool response. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct CallToolResult { + /// Tool output content. + pub content: Vec, + /// Whether the tool execution resulted in an error. + #[serde(default, skip_serializing_if = "std::ops::Not::not")] + pub is_error: bool, +} + +impl CallToolResult { + /// Creates a successful text result. + #[must_use] + pub fn text(text: impl Into) -> Self { + Self { + content: vec![ToolContent::Text { text: text.into() }], + is_error: false, + } + } + + /// Creates a successful JSON result. + #[must_use] + pub fn json(value: Value) -> Self { + Self { + content: vec![ToolContent::Text { + text: serde_json::to_string_pretty(&value).unwrap_or_else(|_| value.to_string()), + }], + is_error: false, + } + } + + /// Creates an error result. + #[must_use] + pub fn error(message: impl Into) -> Self { + Self { + content: vec![ToolContent::Text { + text: message.into(), + }], + is_error: true, + } + } +} + +/// Tool output content. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type", rename_all = "camelCase")] +pub enum ToolContent { + /// Text content. + #[serde(rename = "text")] + Text { + /// The text content. + text: String, + }, + /// Image content (base64 encoded). + #[serde(rename = "image")] + Image { + /// Base64 encoded image data. + data: String, + /// MIME type of the image. + mime_type: String, + }, +} + +// ============================================================================= +// Ping +// ============================================================================= + +/// Ping response (empty object). +#[derive(Debug, Clone, Default, Serialize, Deserialize)] +pub struct PingResult {} + +// ============================================================================= +// Tests +// ============================================================================= + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn request_id_variants() { + let num_id = RequestId::from(42_i64); + let str_id = RequestId::from("test-id"); + let null_id = RequestId::default(); + + assert_eq!(num_id, RequestId::Number(42)); + assert_eq!(str_id, RequestId::String("test-id".to_string())); + assert_eq!(null_id, RequestId::Null); + } + + #[test] + fn json_rpc_response_success() { + let response = + JsonRpcResponse::success(RequestId::from(1_i64), serde_json::json!({"ok": true})); + + assert!(response.result.is_some()); + assert!(response.error.is_none()); + assert_eq!(response.jsonrpc, JSONRPC_VERSION); + } + + #[test] + fn json_rpc_response_error() { + let error = JsonRpcError::new(-32600, "invalid request"); + let response = JsonRpcResponse::error(RequestId::from(1_i64), error); + + assert!(response.result.is_none()); + assert!(response.error.is_some()); + assert_eq!(response.error.unwrap().code, -32600); + } + + #[test] + fn json_rpc_error_helpers() { + assert_eq!(JsonRpcError::parse_error("test").code, -32700); + assert_eq!(JsonRpcError::invalid_request("test").code, -32600); + assert_eq!(JsonRpcError::method_not_found("test").code, -32601); + assert_eq!(JsonRpcError::invalid_params("test").code, -32602); + assert_eq!(JsonRpcError::internal_error("test").code, -32603); + } + + #[test] + fn initialize_result_defaults() { + let result = InitializeResult::default(); + + assert_eq!(result.protocol_version, MCP_PROTOCOL_VERSION); + assert_eq!(result.server_info.name, SERVER_NAME); + assert_eq!(result.server_info.version, SERVER_VERSION); + } + + #[test] + fn call_tool_result_helpers() { + let text_result = CallToolResult::text("hello"); + assert!(!text_result.is_error); + assert_eq!(text_result.content.len(), 1); + + let json_result = CallToolResult::json(serde_json::json!({"key": "value"})); + assert!(!json_result.is_error); + + let error_result = CallToolResult::error("something went wrong"); + assert!(error_result.is_error); + } + + #[test] + fn request_serialization() { + let request = JsonRpcRequest { + jsonrpc: JSONRPC_VERSION.to_string(), + id: RequestId::from(1_i64), + method: "test/method".to_string(), + params: Some(serde_json::json!({"arg": "value"})), + }; + + let json = serde_json::to_string(&request).unwrap(); + let parsed: JsonRpcRequest = serde_json::from_str(&json).unwrap(); + + assert_eq!(parsed.method, "test/method"); + assert_eq!(parsed.id, RequestId::Number(1)); + } +} diff --git a/src/mcp/resources.rs b/src/mcp/resources.rs new file mode 100644 index 0000000..58211b4 --- /dev/null +++ b/src/mcp/resources.rs @@ -0,0 +1,727 @@ +//! MCP resources for schema inspection. +//! +//! This module provides read-only resources that expose schema information +//! to MCP clients. Resources include: +//! +//! - `tern://schema` — Current schema state after all migrations +//! - `tern://migrations` — List of all migrations with metadata +//! - `tern://migration/{id}` — Details of a specific migration + +use serde::Serialize; +use serde_json::Value; + +use crate::db::model::Namespace; +use crate::db::state::{LocalFileBackend, StateBackend}; +use crate::db::state::{Migration, MigrationId, MigrationIndex, StateHash}; +use crate::mcp::error::McpError; +use crate::mcp::protocol::{Resource, ResourceContent}; + +/// URI scheme for tern resources. +pub const TERN_SCHEME: &str = "tern"; + +/// Resource URIs. +pub mod uris { + /// Schema resource URI. + pub const SCHEMA: &str = "tern://schema"; + /// Migrations list resource URI. + pub const MIGRATIONS: &str = "tern://migrations"; + /// Migration detail resource URI prefix. + pub const MIGRATION_PREFIX: &str = "tern://migration/"; +} + +/// Returns the list of available resources. +pub fn list_resources() -> Vec { + vec![ + Resource { + uri: uris::SCHEMA.to_string(), + name: "Current Schema".to_string(), + description: Some( + "The current database schema after all migrations have been applied.".to_string(), + ), + mime_type: Some("application/json".to_string()), + }, + Resource { + uri: uris::MIGRATIONS.to_string(), + name: "Migration History".to_string(), + description: Some("List of all migrations in order of application.".to_string()), + mime_type: Some("application/json".to_string()), + }, + ] +} + +/// Reads a resource by URI. +/// +/// # Errors +/// +/// Returns an error if the resource is not found or cannot be read. +pub async fn read_resource( + uri: &str, + backend: &LocalFileBackend, + current_state: &Namespace, +) -> Result { + if uri == uris::SCHEMA { + read_schema_resource(current_state) + } else if uri == uris::MIGRATIONS { + read_migrations_resource(backend).await + } else if let Some(id_str) = uri.strip_prefix(uris::MIGRATION_PREFIX) { + read_migration_resource(backend, id_str).await + } else { + Err(McpError::ResourceNotFound(uri.to_string())) + } +} + +/// Reads the schema resource. +fn read_schema_resource(namespace: &Namespace) -> Result { + let schema_output = SchemaResource::from(namespace); + let json = serde_json::to_string_pretty(&schema_output) + .map_err(|e| McpError::InternalError(format!("failed to serialize schema: {e}")))?; + + Ok(ResourceContent { + uri: uris::SCHEMA.to_string(), + mime_type: Some("application/json".to_string()), + text: Some(json), + blob: None, + }) +} + +/// Reads the migrations list resource. +async fn read_migrations_resource(backend: &LocalFileBackend) -> Result { + let index = backend + .get_migration_index() + .await + .map_err(|e| McpError::BackendError(e.to_string()))?; + + let migrations = backend + .get_all_migrations() + .await + .map_err(|e| McpError::BackendError(e.to_string()))?; + + let current_state_hash = backend + .get_current_state_hash() + .await + .map_err(|e| McpError::BackendError(e.to_string()))?; + + let output = MigrationsResource::from_migrations(&index, &migrations, current_state_hash); + let json = serde_json::to_string_pretty(&output) + .map_err(|e| McpError::InternalError(format!("failed to serialize migrations: {e}")))?; + + Ok(ResourceContent { + uri: uris::MIGRATIONS.to_string(), + mime_type: Some("application/json".to_string()), + text: Some(json), + blob: None, + }) +} + +/// Reads a specific migration resource. +async fn read_migration_resource( + backend: &LocalFileBackend, + id_str: &str, +) -> Result { + // Parse the migration ID + let migration_id = MigrationId::from_hex(id_str) + .or_else(|| MigrationId::from_hex_prefix(id_str)) + .ok_or_else(|| McpError::InvalidResourceUri(format!("invalid migration ID: {id_str}")))?; + + // Get the migration index to find the sequence number + let index = backend + .get_migration_index() + .await + .map_err(|e| McpError::BackendError(e.to_string()))?; + + let sequence_number = index + .position(&migration_id) + .map(|pos| pos + 1) + .ok_or_else(|| McpError::MigrationNotFound(id_str.to_string()))?; + + // Get the migration + let migration = backend + .get_migration(&migration_id) + .await + .map_err(|e| McpError::MigrationNotFound(e.to_string()))?; + + let output = MigrationDetailResource::from_migration(&migration, sequence_number); + let json = serde_json::to_string_pretty(&output) + .map_err(|e| McpError::InternalError(format!("failed to serialize migration: {e}")))?; + + let uri = format!("{}{}", uris::MIGRATION_PREFIX, id_str); + Ok(ResourceContent { + uri, + mime_type: Some("application/json".to_string()), + text: Some(json), + blob: None, + }) +} + +// ============================================================================= +// Resource Output Types +// ============================================================================= + +/// Output format for the schema resource. +#[derive(Debug, Clone, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct SchemaResource { + /// Schema name. + pub name: String, + /// Tables in the schema. + pub tables: Vec, + /// Views in the schema. + pub views: Vec, + /// Sequences in the schema. + pub sequences: Vec, + /// Enum types in the schema. + pub enums: Vec, +} + +impl From<&Namespace> for SchemaResource { + fn from(ns: &Namespace) -> Self { + Self { + name: ns.name.as_ref().to_string(), + tables: ns.tables.iter().map(TableOutput::from).collect(), + views: ns.views.iter().map(ViewOutput::from).collect(), + sequences: ns.sequences.iter().map(SequenceOutput::from).collect(), + enums: ns.enums.iter().map(EnumOutput::from).collect(), + } + } +} + +/// Output format for a table. +#[derive(Debug, Clone, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct TableOutput { + /// Table name. + pub name: String, + /// Columns in the table. + pub columns: Vec, + /// Primary key constraint, if any. + #[serde(skip_serializing_if = "Option::is_none")] + pub primary_key: Option, + /// Foreign key constraints. + pub foreign_keys: Vec, + /// Unique constraints. + pub unique_constraints: Vec, + /// Check constraints. + pub check_constraints: Vec, + /// Indexes on the table. + pub indexes: Vec, +} + +impl From<&crate::db::model::Table> for TableOutput { + fn from(table: &crate::db::model::Table) -> Self { + use crate::db::model::constraint::ConstraintKind; + + let primary_key = table.constraints.iter().find_map(|c| { + if let ConstraintKind::PrimaryKey(pk) = &c.kind { + Some(PrimaryKeyOutput { + name: c.name.as_ref().to_string(), + columns: pk + .columns + .iter() + .map(|col| col.as_ref().to_string()) + .collect(), + }) + } else { + None + } + }); + + let foreign_keys = table + .constraints + .iter() + .filter_map(|c| { + if let ConstraintKind::ForeignKey(fk) = &c.kind { + Some(ForeignKeyOutput { + name: c.name.as_ref().to_string(), + columns: fk + .columns + .iter() + .map(|col| col.as_ref().to_string()) + .collect(), + references_table: fk.referenced_table.name.as_ref().to_string(), + references_columns: fk + .referenced_columns + .iter() + .map(|col| col.as_ref().to_string()) + .collect(), + on_delete: fk.on_delete.as_sql().to_string(), + on_update: fk.on_update.as_sql().to_string(), + }) + } else { + None + } + }) + .collect(); + + let unique_constraints = table + .constraints + .iter() + .filter_map(|c| { + if let ConstraintKind::Unique(uq) = &c.kind { + Some(UniqueConstraintOutput { + name: c.name.as_ref().to_string(), + columns: uq + .columns + .iter() + .map(|col| col.as_ref().to_string()) + .collect(), + }) + } else { + None + } + }) + .collect(); + + let check_constraints = table + .constraints + .iter() + .filter_map(|c| { + if let ConstraintKind::Check(chk) = &c.kind { + Some(CheckConstraintOutput { + name: c.name.as_ref().to_string(), + expression: chk.expression.as_ref().to_string(), + }) + } else { + None + } + }) + .collect(); + + Self { + name: table.name.as_ref().to_string(), + columns: table.columns.iter().map(ColumnOutput::from).collect(), + primary_key, + foreign_keys, + unique_constraints, + check_constraints, + indexes: table.indexes.iter().map(IndexOutput::from).collect(), + } + } +} + +/// Output format for a column. +#[derive(Debug, Clone, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct ColumnOutput { + /// Column name. + pub name: String, + /// PostgreSQL data type. + #[serde(rename = "type")] + pub data_type: String, + /// Whether the column allows NULL values. + pub nullable: bool, + /// Default value expression, if any. + #[serde(skip_serializing_if = "Option::is_none")] + pub default: Option, + /// Identity column information, if any. + #[serde(skip_serializing_if = "Option::is_none")] + pub identity: Option, +} + +impl From<&crate::db::model::Column> for ColumnOutput { + fn from(col: &crate::db::model::Column) -> Self { + Self { + name: col.name.as_ref().to_string(), + data_type: col.type_info.formatted.clone(), + nullable: col.is_nullable, + default: col.default.as_ref().map(|d| d.as_ref().to_string()), + identity: col.identity.map(|id| IdentityOutput { + kind: format!("{:?}", id), + }), + } + } +} + +/// Output format for identity column info. +#[derive(Debug, Clone, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct IdentityOutput { + /// Identity kind (ALWAYS or BY DEFAULT). + #[serde(rename = "type")] + pub kind: String, +} + +/// Output format for a primary key. +#[derive(Debug, Clone, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct PrimaryKeyOutput { + /// Constraint name. + pub name: String, + /// Column names. + pub columns: Vec, +} + +/// Output format for a foreign key. +#[derive(Debug, Clone, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct ForeignKeyOutput { + /// Constraint name. + pub name: String, + /// Columns in the referencing table. + pub columns: Vec, + /// Referenced table name. + pub references_table: String, + /// Referenced column names. + pub references_columns: Vec, + /// ON DELETE action. + pub on_delete: String, + /// ON UPDATE action. + pub on_update: String, +} + +/// Output format for a unique constraint. +#[derive(Debug, Clone, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct UniqueConstraintOutput { + /// Constraint name. + pub name: String, + /// Column names. + pub columns: Vec, +} + +/// Output format for a check constraint. +#[derive(Debug, Clone, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct CheckConstraintOutput { + /// Constraint name. + pub name: String, + /// Check expression. + pub expression: String, +} + +/// Output format for an index. +#[derive(Debug, Clone, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct IndexOutput { + /// Index name. + pub name: String, + /// Column names (with sort order if specified). + pub columns: Vec, + /// Whether the index is unique. + pub unique: bool, + /// Index method (btree, hash, gin, etc.). + pub method: String, + /// Partial index predicate, if any. + #[serde(skip_serializing_if = "Option::is_none")] + pub predicate: Option, +} + +impl From<&crate::db::model::Index> for IndexOutput { + fn from(idx: &crate::db::model::Index) -> Self { + use crate::db::model::index::SortOrder; + + let columns = idx + .columns + .iter() + .map(|col| { + // Use column name if available, otherwise use expression + let base_name = col + .column + .as_ref() + .map(|c| c.as_ref().to_string()) + .or_else(|| col.expression.as_ref().map(|e| e.as_ref().to_string())) + .unwrap_or_default(); + + // Append sort order if not ascending (the default) + match col.order { + SortOrder::Descending => format!("{} DESC", base_name), + SortOrder::Ascending => base_name, + } + }) + .collect(); + + Self { + name: idx.name.as_ref().to_string(), + columns, + unique: idx.is_unique, + method: idx.method.as_str().to_string(), + predicate: idx.predicate.as_ref().map(|p| p.as_ref().to_string()), + } + } +} + +/// Output format for a view. +#[derive(Debug, Clone, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct ViewOutput { + /// View name. + pub name: String, + /// View definition (SELECT statement). + pub definition: String, + /// Whether this is a materialized view. + pub is_materialized: bool, +} + +impl From<&crate::db::model::View> for ViewOutput { + fn from(view: &crate::db::model::View) -> Self { + Self { + name: view.name.as_ref().to_string(), + definition: view.definition.as_ref().to_string(), + is_materialized: view.is_materialized, + } + } +} + +/// Output format for a sequence. +#[derive(Debug, Clone, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct SequenceOutput { + /// Sequence name. + pub name: String, + /// Data type. + pub data_type: String, + /// Start value. + pub start: i64, + /// Increment. + pub increment: i64, + /// Minimum value. + pub min_value: i64, + /// Maximum value. + pub max_value: i64, + /// Cache size. + pub cache: i64, + /// Whether the sequence cycles. + pub cycle: bool, +} + +impl From<&crate::db::model::Sequence> for SequenceOutput { + fn from(seq: &crate::db::model::Sequence) -> Self { + Self { + name: seq.name.as_ref().to_string(), + data_type: seq.data_type.formatted.clone(), + start: seq.start_value, + increment: seq.increment, + min_value: seq.min_value, + max_value: seq.max_value, + cache: seq.cache_size, + cycle: seq.is_cyclic, + } + } +} + +/// Output format for an enum type. +#[derive(Debug, Clone, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct EnumOutput { + /// Enum type name. + pub name: String, + /// Enum values in order. + pub values: Vec, +} + +impl From<&crate::db::model::EnumType> for EnumOutput { + fn from(enum_type: &crate::db::model::EnumType) -> Self { + Self { + name: enum_type.name.as_ref().to_string(), + values: enum_type.values.clone(), + } + } +} + +// ============================================================================= +// Migrations Resource Output +// ============================================================================= + +/// Output format for the migrations list resource. +#[derive(Debug, Clone, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct MigrationsResource { + /// List of migrations. + pub migrations: Vec, + /// Total number of migrations. + pub total_count: usize, + /// Current state hash. + pub current_state_hash: String, +} + +impl MigrationsResource { + /// Creates a migrations resource from the index and migrations. + pub fn from_migrations( + index: &MigrationIndex, + migrations: &[Migration], + current_state_hash: StateHash, + ) -> Self { + let summaries = migrations + .iter() + .enumerate() + .map(|(i, m)| MigrationSummary { + id: m.id.to_hex(), + sequence_number: i + 1, + description: m.description.clone(), + created_at: m.created_at.to_string(), + operation_count: m.up_operations.len(), + has_breaking_changes: m.has_breaking_changes(), + parent_state_hash: m.parent_state_hash.to_hex(), + resulting_state_hash: m.resulting_state_hash.to_hex(), + }) + .collect(); + + Self { + migrations: summaries, + total_count: index.len(), + current_state_hash: current_state_hash.to_hex(), + } + } +} + +/// Summary of a single migration for listing. +#[derive(Debug, Clone, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct MigrationSummary { + /// Migration ID. + pub id: String, + /// Sequence number (1-indexed). + pub sequence_number: usize, + /// Description. + pub description: String, + /// Creation timestamp. + pub created_at: String, + /// Number of up operations. + pub operation_count: usize, + /// Whether this migration has breaking changes. + pub has_breaking_changes: bool, + /// Parent state hash. + pub parent_state_hash: String, + /// Resulting state hash. + pub resulting_state_hash: String, +} + +// ============================================================================= +// Migration Detail Resource Output +// ============================================================================= + +/// Output format for a specific migration resource. +#[derive(Debug, Clone, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct MigrationDetailResource { + /// Migration ID. + pub id: String, + /// Sequence number (1-indexed). + pub sequence_number: usize, + /// Description. + pub description: String, + /// Creation timestamp. + pub created_at: String, + /// Parent state hash. + pub parent_state_hash: String, + /// Resulting state hash. + pub resulting_state_hash: String, + /// Up operations. + pub up_operations: Vec, + /// Down operations. + pub down_operations: Vec, + /// Breaking changes. + pub breaking_changes: Vec, + /// Whether this migration is reversible. + pub is_reversible: bool, +} + +impl MigrationDetailResource { + /// Creates a migration detail from a migration. + pub fn from_migration(migration: &Migration, sequence_number: usize) -> Self { + Self { + id: migration.id.to_hex(), + sequence_number, + description: migration.description.clone(), + created_at: migration.created_at.to_string(), + parent_state_hash: migration.parent_state_hash.to_hex(), + resulting_state_hash: migration.resulting_state_hash.to_hex(), + up_operations: migration + .up_operations + .iter() + .map(OperationOutput::from) + .collect(), + down_operations: migration + .down_operations + .iter() + .map(OperationOutput::from) + .collect(), + breaking_changes: migration + .breaking_changes + .iter() + .map(|bc| bc.description.clone()) + .collect(), + is_reversible: migration.is_reversible(), + } + } +} + +/// Output format for an operation. +#[derive(Debug, Clone, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct OperationOutput { + /// Operation type. + #[serde(rename = "type")] + pub op_type: String, + /// Operation description. + pub description: String, + /// Full operation data (JSON). + pub data: Value, +} + +impl From<&crate::db::migrate::Operation> for OperationOutput { + fn from(op: &crate::db::migrate::Operation) -> Self { + let op_type = match op { + crate::db::migrate::Operation::CreateTable { .. } => "CreateTable", + crate::db::migrate::Operation::DropTable { .. } => "DropTable", + crate::db::migrate::Operation::RenameTable { .. } => "RenameTable", + crate::db::migrate::Operation::AddColumn { .. } => "AddColumn", + crate::db::migrate::Operation::DropColumn { .. } => "DropColumn", + crate::db::migrate::Operation::RenameColumn { .. } => "RenameColumn", + crate::db::migrate::Operation::AlterColumn { .. } => "AlterColumn", + crate::db::migrate::Operation::AddConstraint { .. } => "AddConstraint", + crate::db::migrate::Operation::DropConstraint { .. } => "DropConstraint", + crate::db::migrate::Operation::RenameConstraint { .. } => "RenameConstraint", + crate::db::migrate::Operation::CreateIndex { .. } => "CreateIndex", + crate::db::migrate::Operation::DropIndex { .. } => "DropIndex", + crate::db::migrate::Operation::RenameIndex { .. } => "RenameIndex", + crate::db::migrate::Operation::CreateEnum { .. } => "CreateEnum", + crate::db::migrate::Operation::DropEnum { .. } => "DropEnum", + crate::db::migrate::Operation::RenameEnum { .. } => "RenameEnum", + crate::db::migrate::Operation::AddEnumValue { .. } => "AddEnumValue", + crate::db::migrate::Operation::CreateSequence { .. } => "CreateSequence", + crate::db::migrate::Operation::DropSequence { .. } => "DropSequence", + crate::db::migrate::Operation::RenameSequence { .. } => "RenameSequence", + crate::db::migrate::Operation::AlterSequence { .. } => "AlterSequence", + crate::db::migrate::Operation::CreateView { .. } => "CreateView", + crate::db::migrate::Operation::DropView { .. } => "DropView", + crate::db::migrate::Operation::RenameView { .. } => "RenameView", + crate::db::migrate::Operation::ReplaceView { .. } => "ReplaceView", + crate::db::migrate::Operation::RefreshMaterializedView { .. } => { + "RefreshMaterializedView" + } + crate::db::migrate::Operation::SetComment { .. } => "SetComment", + }; + + Self { + op_type: op_type.to_string(), + description: op.description(), + data: serde_json::to_value(op).unwrap_or(Value::Null), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn list_resources_returns_expected() { + let resources = list_resources(); + + assert_eq!(resources.len(), 2); + assert!(resources.iter().any(|r| r.uri == uris::SCHEMA)); + assert!(resources.iter().any(|r| r.uri == uris::MIGRATIONS)); + } + + #[test] + fn schema_resource_from_namespace() { + let namespace = Namespace::empty("public"); + let output = SchemaResource::from(&namespace); + + assert_eq!(output.name, "public"); + assert!(output.tables.is_empty()); + assert!(output.views.is_empty()); + assert!(output.sequences.is_empty()); + assert!(output.enums.is_empty()); + } +} diff --git a/src/mcp/server.rs b/src/mcp/server.rs new file mode 100644 index 0000000..a8cac9a --- /dev/null +++ b/src/mcp/server.rs @@ -0,0 +1,421 @@ +//! MCP server implementation. +//! +//! This module provides the main `McpServer` struct that handles the MCP +//! protocol, routing requests to the appropriate handlers. + +use std::sync::Arc; + +use serde_json::Value; +use tokio::sync::RwLock; + +use crate::db::model::Namespace; +use crate::db::state::{LocalFileBackend, StateBackend}; +use crate::mcp::error::McpError; +use crate::mcp::protocol::{ + CallToolParams, CallToolResult, InitializeParams, InitializeResult, JsonRpcError, + JsonRpcRequest, JsonRpcResponse, ListResourcesResult, ListToolsResult, PingResult, + ReadResourceParams, ReadResourceResult, RequestId, ResourcesCapability, ServerCapabilities, + ToolsCapability, +}; +use crate::mcp::resources; +use crate::mcp::session::SessionManager; +use crate::mcp::tools; +use crate::mcp::transport::StdioTransport; + +/// MCP server for AI-assisted migration authoring. +pub struct McpServer { + /// State backend for reading/writing migrations. + backend: Arc, + /// Current schema state (base state for sessions). + current_state: Arc, + /// Session manager for migration authoring sessions. + session_manager: Arc>, + /// Whether the server has been initialized. + initialized: bool, +} + +impl McpServer { + /// Creates a new MCP server with the given backend and state. + pub fn new(backend: LocalFileBackend, current_state: Namespace) -> Self { + let backend = Arc::new(backend); + let session_manager = Arc::new(RwLock::new(SessionManager::new( + current_state.clone(), + (*backend).clone(), + ))); + + Self { + backend, + current_state: Arc::new(current_state), + session_manager, + initialized: false, + } + } + + /// Initializes the MCP server with the given backend. + /// + /// This method: + /// 1. Verifies the backend is initialized + /// 2. Loads the current schema state + /// 3. Sets up the session manager + /// + /// # Errors + /// + /// Returns an error if the backend is not initialized or the state + /// cannot be loaded. + pub async fn initialize(backend: LocalFileBackend) -> Result { + // Verify backend is initialized + if !backend + .is_initialized() + .await + .map_err(|e| McpError::BackendNotInitialized(e.to_string()))? + { + return Err(McpError::BackendNotInitialized( + "backend not initialized; run 'tern init' first".into(), + )); + } + + // Load current state + let current_state = backend + .get_current_state() + .await + .map_err(|e| McpError::BackendError(e.to_string()))?; + + Ok(Self::new(backend, current_state)) + } + + /// Runs the MCP server over stdio. + /// + /// This method enters a loop reading requests from stdin and writing + /// responses to stdout until EOF is received. + pub async fn run_stdio(mut self) -> miette::Result<()> { + let mut transport = StdioTransport::new(); + + tracing::info!("MCP server started, waiting for requests"); + + loop { + match transport.read_request() { + Ok(Some(request)) => { + let response = self.handle_request(request).await; + if let Err(e) = transport.write_response(&response) { + tracing::error!("Failed to write response: {}", e); + break; + } + } + Ok(None) => { + tracing::info!("Received EOF, shutting down"); + break; + } + Err(e) => { + tracing::error!("Failed to read request: {}", e); + // Try to send an error response + let response = JsonRpcResponse::error( + RequestId::Null, + JsonRpcError::parse_error(e.to_string()), + ); + let _ = transport.write_response(&response); + } + } + } + + Ok(()) + } + + /// Handles a single JSON-RPC request and returns a response. + async fn handle_request(&mut self, request: JsonRpcRequest) -> JsonRpcResponse { + tracing::debug!("Handling request: method={}", request.method); + + let result = match request.method.as_str() { + "initialize" => self.handle_initialize(request.params).await, + "initialized" => { + // Notification, no response needed but we return empty for consistency + Ok(Value::Null) + } + "ping" => self.handle_ping().await, + "resources/list" => self.handle_list_resources(request.params).await, + "resources/read" => self.handle_read_resource(request.params).await, + "tools/list" => self.handle_list_tools(request.params).await, + "tools/call" => self.handle_call_tool(request.params).await, + _ => Err(McpError::MethodNotFound(request.method.clone())), + }; + + match result { + Ok(value) => JsonRpcResponse::success(request.id, value), + Err(error) => JsonRpcResponse::from_mcp_error(request.id, &error), + } + } + + /// Handles the initialize request. + async fn handle_initialize(&mut self, params: Option) -> Result { + let _params: InitializeParams = params + .map(|p| serde_json::from_value(p).map_err(|e| McpError::InvalidParams(e.to_string()))) + .transpose()? + .unwrap_or_else(|| InitializeParams { + protocol_version: "2024-11-05".to_string(), + capabilities: Default::default(), + client_info: crate::mcp::protocol::ClientInfo { + name: "unknown".to_string(), + version: "0.0.0".to_string(), + }, + }); + + self.initialized = true; + + let result = InitializeResult { + capabilities: ServerCapabilities { + resources: Some(ResourcesCapability { + subscribe: false, + list_changed: false, + }), + tools: Some(ToolsCapability { + list_changed: false, + }), + prompts: None, + }, + ..Default::default() + }; + + serde_json::to_value(result).map_err(|e| McpError::InternalError(e.to_string())) + } + + /// Handles the ping request. + async fn handle_ping(&self) -> Result { + serde_json::to_value(PingResult {}).map_err(|e| McpError::InternalError(e.to_string())) + } + + /// Handles the resources/list request. + async fn handle_list_resources(&self, _params: Option) -> Result { + let result = ListResourcesResult { + resources: resources::list_resources(), + next_cursor: None, + }; + + serde_json::to_value(result).map_err(|e| McpError::InternalError(e.to_string())) + } + + /// Handles the resources/read request. + async fn handle_read_resource(&self, params: Option) -> Result { + let params: ReadResourceParams = params + .map(|p| serde_json::from_value(p).map_err(|e| McpError::InvalidParams(e.to_string()))) + .transpose()? + .ok_or_else(|| McpError::InvalidParams("missing params".into()))?; + + let content = + resources::read_resource(¶ms.uri, &self.backend, &self.current_state).await?; + + let result = ReadResourceResult { + contents: vec![content], + }; + + serde_json::to_value(result).map_err(|e| McpError::InternalError(e.to_string())) + } + + /// Handles the tools/list request. + async fn handle_list_tools(&self, _params: Option) -> Result { + let result = ListToolsResult { + tools: tools::list_tools(), + next_cursor: None, + }; + + serde_json::to_value(result).map_err(|e| McpError::InternalError(e.to_string())) + } + + /// Handles the tools/call request. + async fn handle_call_tool(&mut self, params: Option) -> Result { + let params: CallToolParams = params + .map(|p| serde_json::from_value(p).map_err(|e| McpError::InvalidParams(e.to_string()))) + .transpose()? + .ok_or_else(|| McpError::InvalidParams("missing params".into()))?; + + let result = self.dispatch_tool(¶ms.name, params.arguments).await?; + + let tool_result = CallToolResult::json(result); + + serde_json::to_value(tool_result).map_err(|e| McpError::InternalError(e.to_string())) + } + + /// Dispatches a tool call to the appropriate handler. + async fn dispatch_tool(&mut self, name: &str, arguments: Value) -> Result { + match name { + "start_session" => { + let input: tools::StartSessionInput = serde_json::from_value(arguments) + .map_err(|e| McpError::InvalidParams(e.to_string()))?; + let mut manager = self.session_manager.write().await; + tools::handle_start_session(input, &mut manager).await + } + "execute_sql" => { + let input: tools::ExecuteSqlInput = serde_json::from_value(arguments) + .map_err(|e| McpError::InvalidParams(e.to_string()))?; + let mut manager = self.session_manager.write().await; + tools::handle_execute_sql(input, &mut manager).await + } + "apply_operation" => { + let input: tools::ApplyOperationInput = serde_json::from_value(arguments) + .map_err(|e| McpError::InvalidParams(e.to_string()))?; + let mut manager = self.session_manager.write().await; + tools::handle_apply_operation(input, &mut manager).await + } + "get_session_schema" => { + let input: tools::GetSessionSchemaInput = serde_json::from_value(arguments) + .map_err(|e| McpError::InvalidParams(e.to_string()))?; + let mut manager = self.session_manager.write().await; + tools::handle_get_session_schema(input, &mut manager).await + } + "get_session_diff" => { + let input: tools::GetSessionDiffInput = serde_json::from_value(arguments) + .map_err(|e| McpError::InvalidParams(e.to_string()))?; + let mut manager = self.session_manager.write().await; + tools::handle_get_session_diff(input, &mut manager).await + } + "generate_migration" => { + let input: tools::GenerateMigrationInput = serde_json::from_value(arguments) + .map_err(|e| McpError::InvalidParams(e.to_string()))?; + let mut manager = self.session_manager.write().await; + tools::handle_generate_migration(input, &mut manager).await + } + "cancel_session" => { + let input: tools::CancelSessionInput = serde_json::from_value(arguments) + .map_err(|e| McpError::InvalidParams(e.to_string()))?; + let mut manager = self.session_manager.write().await; + tools::handle_cancel_session(input, &mut manager).await + } + "list_sessions" => { + let manager = self.session_manager.read().await; + tools::handle_list_sessions(&manager).await + } + _ => Err(McpError::MethodNotFound(format!("tool not found: {name}"))), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use tempfile::TempDir; + + async fn create_test_server() -> (TempDir, McpServer) { + let temp_dir = TempDir::new().unwrap(); + let backend = LocalFileBackend::new(temp_dir.path().join(".tern")); + + // Initialize the backend + backend.initialize().await.unwrap(); + + // Create an empty initial state + let namespace = Namespace::empty("public"); + backend.save_current_state(&namespace).await.unwrap(); + + let server = McpServer::new(backend, namespace); + (temp_dir, server) + } + + #[tokio::test] + async fn server_handles_initialize() { + let (_temp_dir, mut server) = create_test_server().await; + + let request = JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: RequestId::from(1_i64), + method: "initialize".to_string(), + params: Some(serde_json::json!({ + "protocolVersion": "2024-11-05", + "capabilities": {}, + "clientInfo": { + "name": "test", + "version": "1.0.0" + } + })), + }; + + let response = server.handle_request(request).await; + assert!(response.result.is_some()); + assert!(response.error.is_none()); + } + + #[tokio::test] + async fn server_handles_ping() { + let (_temp_dir, mut server) = create_test_server().await; + + let request = JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: RequestId::from(1_i64), + method: "ping".to_string(), + params: None, + }; + + let response = server.handle_request(request).await; + assert!(response.result.is_some()); + assert!(response.error.is_none()); + } + + #[tokio::test] + async fn server_handles_list_resources() { + let (_temp_dir, mut server) = create_test_server().await; + + let request = JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: RequestId::from(1_i64), + method: "resources/list".to_string(), + params: None, + }; + + let response = server.handle_request(request).await; + assert!(response.result.is_some()); + + let result: ListResourcesResult = serde_json::from_value(response.result.unwrap()).unwrap(); + assert!(!result.resources.is_empty()); + } + + #[tokio::test] + async fn server_handles_list_tools() { + let (_temp_dir, mut server) = create_test_server().await; + + let request = JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: RequestId::from(1_i64), + method: "tools/list".to_string(), + params: None, + }; + + let response = server.handle_request(request).await; + assert!(response.result.is_some()); + + let result: ListToolsResult = serde_json::from_value(response.result.unwrap()).unwrap(); + assert!(!result.tools.is_empty()); + } + + #[tokio::test] + async fn server_handles_unknown_method() { + let (_temp_dir, mut server) = create_test_server().await; + + let request = JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: RequestId::from(1_i64), + method: "unknown/method".to_string(), + params: None, + }; + + let response = server.handle_request(request).await; + assert!(response.error.is_some()); + assert_eq!(response.error.unwrap().code, -32601); + } + + #[tokio::test] + async fn server_handles_start_session() { + let (_temp_dir, mut server) = create_test_server().await; + + let request = JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: RequestId::from(1_i64), + method: "tools/call".to_string(), + params: Some(serde_json::json!({ + "name": "start_session", + "arguments": { + "description": "Test migration" + } + })), + }; + + let response = server.handle_request(request).await; + assert!(response.result.is_some()); + assert!(response.error.is_none()); + } +} diff --git a/src/mcp/session.rs b/src/mcp/session.rs new file mode 100644 index 0000000..2dde6b0 --- /dev/null +++ b/src/mcp/session.rs @@ -0,0 +1,420 @@ +//! Session management for migration authoring. +//! +//! This module provides the session management functionality for the MCP server, +//! allowing users to create isolated editing sessions where they can make schema +//! changes and eventually generate migrations. + +use std::collections::HashMap; +use std::sync::Arc; + +use jiff::Timestamp; +use tokio::sync::RwLock; + +use crate::db::migrate::Operation; +use crate::db::model::Namespace; +#[cfg(feature = "pglite")] +use crate::db::pglite::PgLiteRuntime; +use crate::db::state::LocalFileBackend; +#[cfg(feature = "pglite")] +use crate::db::state::SchemaExporter; +use crate::mcp::error::{McpError, SessionId}; + +/// Default maximum number of concurrent sessions. +pub const DEFAULT_MAX_SESSIONS: usize = 5; + +/// Record of an operation applied in a session. +#[derive(Debug, Clone)] +pub struct OperationRecord { + /// The structured operation, if this was applied via `apply_operation`. + pub operation: Option, + /// The SQL that was executed. + pub sql: String, + /// When the operation was applied. + pub applied_at: Timestamp, +} + +/// An active migration authoring session. +pub struct Session { + /// Unique identifier for this session. + pub id: SessionId, + /// User-provided description for the migration being created. + pub description: String, + /// When the session was started. + pub started_at: Timestamp, + /// Schema state at session start (immutable, used for diffing). + pub base_state: Namespace, + /// History of SQL statements executed in this session. + pub sql_history: Vec, + /// History of operations applied in this session. + pub operation_history: Vec, + /// The PGLite runtime for this session's database. + #[cfg(feature = "pglite")] + pub runtime: Option, +} + +impl Session { + /// Creates a new session. + #[cfg(feature = "pglite")] + pub fn new(id: SessionId, description: String, base_state: Namespace) -> Self { + Self { + id, + description, + started_at: Timestamp::now(), + base_state, + sql_history: Vec::new(), + operation_history: Vec::new(), + runtime: None, + } + } + + /// Creates a new session (non-pglite version). + #[cfg(not(feature = "pglite"))] + pub fn new(id: SessionId, description: String, base_state: Namespace) -> Self { + Self { + id, + description, + started_at: Timestamp::now(), + base_state, + sql_history: Vec::new(), + operation_history: Vec::new(), + } + } + + /// Gets a PGLite client connection for this session. + /// + /// # Errors + /// + /// Returns an error if the runtime is not initialized or the connection fails. + #[cfg(feature = "pglite")] + pub async fn get_client(&self) -> Result { + let runtime = self + .runtime + .as_ref() + .ok_or_else(|| McpError::PgLiteError("session database not initialized".into()))?; + runtime + .client() + .await + .map_err(|e| McpError::PgLiteError(e.to_string())) + } + + /// Returns the number of operations in this session. + #[must_use] + pub fn operation_count(&self) -> usize { + self.operation_history.len() + } + + /// Returns true if this session has any changes. + #[must_use] + pub fn has_changes(&self) -> bool { + !self.operation_history.is_empty() + } + + /// Records an SQL execution in the session history. + pub fn record_sql(&mut self, sql: String) { + self.sql_history.push(sql.clone()); + self.operation_history.push(OperationRecord { + operation: None, + sql, + applied_at: Timestamp::now(), + }); + } + + /// Records a structured operation in the session history. + pub fn record_operation(&mut self, operation: Operation, sql: String) { + self.sql_history.push(sql.clone()); + self.operation_history.push(OperationRecord { + operation: Some(operation), + sql, + applied_at: Timestamp::now(), + }); + } +} + +/// Summary of a session for listing. +#[derive(Debug, Clone, serde::Serialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionSummary { + /// Session ID. + pub session_id: String, + /// Migration description. + pub description: String, + /// When the session was started. + pub started_at: String, + /// Number of operations applied. + pub operation_count: usize, + /// Whether the session has any changes. + pub has_changes: bool, +} + +impl From<&Session> for SessionSummary { + fn from(session: &Session) -> Self { + Self { + session_id: session.id.to_string(), + description: session.description.clone(), + started_at: session.started_at.to_string(), + operation_count: session.operation_count(), + has_changes: session.has_changes(), + } + } +} + +/// Manages all active migration authoring sessions. +pub struct SessionManager { + /// Active sessions indexed by ID. + sessions: HashMap, + /// Base schema state (shared across all sessions). + base_state: Arc, + /// State backend for saving migrations. + backend: Arc, + /// Maximum number of concurrent sessions. + max_sessions: usize, + /// PGLite runtime for session databases. + #[cfg(feature = "pglite")] + _pglite_runtime: Option, +} + +impl SessionManager { + /// Creates a new session manager. + pub fn new(base_state: Namespace, backend: LocalFileBackend) -> Self { + Self { + sessions: HashMap::new(), + base_state: Arc::new(base_state), + backend: Arc::new(backend), + max_sessions: DEFAULT_MAX_SESSIONS, + #[cfg(feature = "pglite")] + _pglite_runtime: None, + } + } + + /// Sets the maximum number of concurrent sessions. + pub fn with_max_sessions(mut self, max: usize) -> Self { + self.max_sessions = max; + self + } + + /// Returns the base schema state. + #[must_use] + pub fn base_state(&self) -> &Namespace { + &self.base_state + } + + /// Returns a reference to the state backend. + #[must_use] + pub fn backend(&self) -> &LocalFileBackend { + &self.backend + } + + /// Starts a new migration authoring session. + /// + /// Note: PGLite initialization is deferred until the first SQL operation + /// to keep session creation fast for tests and quick operations. + /// + /// # Errors + /// + /// Returns an error if the session limit has been reached. + pub async fn start_session(&mut self, description: String) -> Result { + // Check session limit + if self.sessions.len() >= self.max_sessions { + return Err(McpError::SessionLimitReached { + max_sessions: self.max_sessions, + current_sessions: self.sessions.len(), + }); + } + + // Generate session ID + let id = SessionId::generate(); + + // Create session with base state (PGLite is initialized lazily) + let session = Session::new(id.clone(), description, (*self.base_state).clone()); + + // Store session + self.sessions.insert(id.clone(), session); + + Ok(id) + } + + /// Ensures a session's PGLite database is initialized. + /// + /// This is called lazily when the first SQL operation is performed. + /// + /// # Errors + /// + /// Returns an error if PGLite initialization fails. + #[cfg(feature = "pglite")] + pub async fn ensure_session_initialized(&mut self, id: &SessionId) -> Result<(), McpError> { + let session = self + .sessions + .get_mut(id) + .ok_or_else(|| McpError::SessionNotFound(id.clone()))?; + + // Already initialized? + if session.runtime.is_some() { + return Ok(()); + } + + // Initialize PGLite database for this session + let mut runtime = PgLiteRuntime::new().map_err(|e| McpError::PgLiteError(e.to_string()))?; + runtime + .start() + .await + .map_err(|e| McpError::PgLiteError(e.to_string()))?; + + // Generate DDL from base state and execute it + let ddl = SchemaExporter::export(&session.base_state); + let client = runtime + .client() + .await + .map_err(|e| McpError::PgLiteError(e.to_string()))?; + client.batch_execute(&ddl).await.map_err(|e| { + McpError::PgLiteError(format!("failed to initialize session database: {e}")) + })?; + + session.runtime = Some(runtime); + + Ok(()) + } + + /// Gets a session by ID. + #[must_use] + pub fn get_session(&self, id: &SessionId) -> Option<&Session> { + self.sessions.get(id) + } + + /// Gets a mutable session by ID. + #[must_use] + pub fn get_session_mut(&mut self, id: &SessionId) -> Option<&mut Session> { + self.sessions.get_mut(id) + } + + /// Cancels and removes a session. + /// + /// # Errors + /// + /// Returns an error if the session is not found. + pub async fn cancel_session(&mut self, id: &SessionId) -> Result { + let session = self + .sessions + .remove(id) + .ok_or_else(|| McpError::SessionNotFound(id.clone()))?; + + let discarded = session.operation_count(); + Ok(discarded) + } + + /// Lists all active sessions. + #[must_use] + pub fn list_sessions(&self) -> Vec { + self.sessions.values().map(SessionSummary::from).collect() + } + + /// Returns the number of active sessions. + #[must_use] + pub fn session_count(&self) -> usize { + self.sessions.len() + } + + /// Removes a session from the manager (used after generating a migration). + pub fn remove_session(&mut self, id: &SessionId) -> Option { + self.sessions.remove(id) + } +} + +/// Thread-safe wrapper around `SessionManager`. +#[allow(dead_code)] +pub type SharedSessionManager = Arc>; + +/// Creates a new shared session manager. +#[must_use] +#[allow(dead_code)] +pub fn shared_session_manager( + base_state: Namespace, + backend: LocalFileBackend, +) -> SharedSessionManager { + Arc::new(RwLock::new(SessionManager::new(base_state, backend))) +} + +#[cfg(test)] +mod tests { + use super::*; + use tempfile::TempDir; + + fn create_test_backend() -> (TempDir, LocalFileBackend) { + let temp_dir = TempDir::new().unwrap(); + let backend = LocalFileBackend::new(temp_dir.path().join(".tern")); + (temp_dir, backend) + } + + #[tokio::test] + async fn start_session_generates_id() { + let (_temp_dir, backend) = create_test_backend(); + let base_state = Namespace::empty("public"); + let mut manager = SessionManager::new(base_state, backend); + + let id = manager + .start_session("Test migration".into()) + .await + .unwrap(); + + assert!(id.as_str().starts_with("sess_")); + assert_eq!(manager.session_count(), 1); + } + + #[tokio::test] + async fn get_session_returns_session() { + let (_temp_dir, backend) = create_test_backend(); + let base_state = Namespace::empty("public"); + let mut manager = SessionManager::new(base_state, backend); + + let id = manager + .start_session("Test migration".into()) + .await + .unwrap(); + let session = manager.get_session(&id); + + assert!(session.is_some()); + assert_eq!(session.unwrap().description, "Test migration"); + } + + #[tokio::test] + async fn cancel_session_removes_session() { + let (_temp_dir, backend) = create_test_backend(); + let base_state = Namespace::empty("public"); + let mut manager = SessionManager::new(base_state, backend); + + let id = manager + .start_session("Test migration".into()) + .await + .unwrap(); + assert_eq!(manager.session_count(), 1); + + let discarded = manager.cancel_session(&id).await.unwrap(); + assert_eq!(discarded, 0); + assert_eq!(manager.session_count(), 0); + } + + #[tokio::test] + async fn session_limit_enforced() { + let (_temp_dir, backend) = create_test_backend(); + let base_state = Namespace::empty("public"); + let mut manager = SessionManager::new(base_state, backend).with_max_sessions(2); + + manager.start_session("Session 1".into()).await.unwrap(); + manager.start_session("Session 2".into()).await.unwrap(); + + let result = manager.start_session("Session 3".into()).await; + assert!(matches!(result, Err(McpError::SessionLimitReached { .. }))); + } + + #[tokio::test] + async fn list_sessions_returns_summaries() { + let (_temp_dir, backend) = create_test_backend(); + let base_state = Namespace::empty("public"); + let mut manager = SessionManager::new(base_state, backend); + + manager.start_session("First".into()).await.unwrap(); + manager.start_session("Second".into()).await.unwrap(); + + let sessions = manager.list_sessions(); + assert_eq!(sessions.len(), 2); + } +} diff --git a/src/mcp/tools.rs b/src/mcp/tools.rs new file mode 100644 index 0000000..4845769 --- /dev/null +++ b/src/mcp/tools.rs @@ -0,0 +1,1083 @@ +//! MCP tools for migration authoring. +//! +//! This module provides the tools that allow MCP clients to create and manage +//! migration sessions, execute SQL, and generate migrations. + +use serde::{Deserialize, Serialize}; +use serde_json::{Value, json}; + +#[cfg(feature = "pglite")] +use crate::db::diff::diff_namespaces; +#[cfg(feature = "pglite")] +use crate::db::migrate::{MigrationPlan, Operation, PostgresRenderer, RenderConfig, Renderer}; +#[cfg(feature = "pglite")] +use crate::db::query::{PostgresCatalog, load_namespace}; +use crate::db::state::StateHash; +#[cfg(feature = "pglite")] +use crate::db::state::{Migration, StateBackend}; +use crate::mcp::error::{McpError, SessionId}; +use crate::mcp::protocol::Tool; +use crate::mcp::resources::SchemaResource; +use crate::mcp::session::{SessionManager, SessionSummary}; + +/// Returns the list of available tools. +pub fn list_tools() -> Vec { + vec![ + Tool { + name: "start_session".to_string(), + description: Some( + "Initialize a new session for creating a migration. Returns a session ID that must be used for all subsequent operations.".to_string(), + ), + input_schema: json!({ + "type": "object", + "properties": { + "description": { + "type": "string", + "description": "Description for the migration being created (e.g., 'Add user preferences table')" + } + }, + "required": ["description"] + }), + }, + Tool { + name: "execute_sql".to_string(), + description: Some( + "Execute raw SQL statement(s) against the session's in-memory database. Use this for schema changes that aren't covered by structured operations, or when you prefer writing SQL directly.".to_string(), + ), + input_schema: json!({ + "type": "object", + "properties": { + "sessionId": { + "type": "string", + "description": "Session ID from start_session" + }, + "sql": { + "type": "string", + "description": "SQL statement(s) to execute. Multiple statements can be separated by semicolons." + } + }, + "required": ["sessionId", "sql"] + }), + }, + Tool { + name: "apply_operation".to_string(), + description: Some( + "Apply a structured schema modification operation. This is an alternative to raw SQL that provides better validation and error messages.".to_string(), + ), + input_schema: json!({ + "type": "object", + "properties": { + "sessionId": { + "type": "string", + "description": "Session ID from start_session" + }, + "operation": { + "type": "object", + "description": "The operation to apply (create_table, drop_table, add_column, etc.)" + } + }, + "required": ["sessionId", "operation"] + }), + }, + Tool { + name: "get_session_schema".to_string(), + description: Some( + "Returns the current schema state in the session's database, reflecting all changes made since the session started.".to_string(), + ), + input_schema: json!({ + "type": "object", + "properties": { + "sessionId": { + "type": "string", + "description": "Session ID from start_session" + } + }, + "required": ["sessionId"] + }), + }, + Tool { + name: "get_session_diff".to_string(), + description: Some( + "Returns a summary of all changes made in the session compared to the base schema.".to_string(), + ), + input_schema: json!({ + "type": "object", + "properties": { + "sessionId": { + "type": "string", + "description": "Session ID from start_session" + } + }, + "required": ["sessionId"] + }), + }, + Tool { + name: "generate_migration".to_string(), + description: Some( + "Generates a migration from all changes made in the session and saves it to the .tern/migrations directory. This ends the session.".to_string(), + ), + input_schema: json!({ + "type": "object", + "properties": { + "sessionId": { + "type": "string", + "description": "Session ID from start_session" + }, + "description": { + "type": "string", + "description": "Override the migration description (uses session description if not provided)" + }, + "force": { + "type": "boolean", + "default": false, + "description": "Generate migration even if it contains breaking changes" + } + }, + "required": ["sessionId"] + }), + }, + Tool { + name: "cancel_session".to_string(), + description: Some( + "Cancels an active session and discards all changes. No migration is created.".to_string(), + ), + input_schema: json!({ + "type": "object", + "properties": { + "sessionId": { + "type": "string", + "description": "Session ID from start_session" + } + }, + "required": ["sessionId"] + }), + }, + Tool { + name: "list_sessions".to_string(), + description: Some( + "Returns a list of all currently active sessions. Useful for recovery if a session ID is lost.".to_string(), + ), + input_schema: json!({ + "type": "object", + "properties": {} + }), + }, + ] +} + +// ============================================================================= +// Tool Input Types +// ============================================================================= + +/// Input for start_session tool. +#[derive(Debug, Clone, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct StartSessionInput { + /// Description for the migration being created. + pub description: String, +} + +/// Input for execute_sql tool. +#[derive(Debug, Clone, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ExecuteSqlInput { + /// Session ID. + pub session_id: String, + /// SQL to execute. + pub sql: String, +} + +/// Input for apply_operation tool. +#[derive(Debug, Clone, Deserialize)] +#[serde(rename_all = "camelCase")] +#[allow(dead_code)] +pub struct ApplyOperationInput { + /// Session ID. + pub session_id: String, + /// Operation to apply. + pub operation: Value, +} + +/// Input for get_session_schema tool. +#[derive(Debug, Clone, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct GetSessionSchemaInput { + /// Session ID. + pub session_id: String, +} + +/// Input for get_session_diff tool. +#[derive(Debug, Clone, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct GetSessionDiffInput { + /// Session ID. + pub session_id: String, +} + +/// Input for generate_migration tool. +#[derive(Debug, Clone, Deserialize)] +#[serde(rename_all = "camelCase")] +#[allow(dead_code)] +pub struct GenerateMigrationInput { + /// Session ID. + pub session_id: String, + /// Optional description override. + pub description: Option, + /// Force generation even with breaking changes. + #[serde(default)] + pub force: bool, +} + +/// Input for cancel_session tool. +#[derive(Debug, Clone, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct CancelSessionInput { + /// Session ID. + pub session_id: String, +} + +// ============================================================================= +// Tool Output Types +// ============================================================================= + +/// Output for start_session tool. +#[derive(Debug, Clone, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct StartSessionOutput { + /// Session ID. + pub session_id: String, + /// Migration description. + pub description: String, + /// Base state hash. + pub base_state_hash: String, + /// When the session was started. + pub started_at: String, + /// Human-readable message. + pub message: String, +} + +/// Output for execute_sql tool. +#[derive(Debug, Clone, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct ExecuteSqlOutput { + /// Whether execution succeeded. + pub success: bool, + /// Number of rows affected. + #[serde(skip_serializing_if = "Option::is_none")] + pub rows_affected: Option, + /// Human-readable message. + pub message: String, + /// Warnings, if any. + #[serde(skip_serializing_if = "Vec::is_empty")] + pub warnings: Vec, + /// Error details, if failed. + #[serde(skip_serializing_if = "Option::is_none")] + pub error: Option, +} + +/// SQL error details. +#[derive(Debug, Clone, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct SqlError { + /// Error message. + pub message: String, + /// PostgreSQL error code. + #[serde(skip_serializing_if = "Option::is_none")] + pub code: Option, + /// Detailed error information. + #[serde(skip_serializing_if = "Option::is_none")] + pub detail: Option, + /// Hint for fixing the error. + #[serde(skip_serializing_if = "Option::is_none")] + pub hint: Option, + /// Position in the SQL where the error occurred. + #[serde(skip_serializing_if = "Option::is_none")] + pub position: Option, +} + +/// Output for apply_operation tool. +#[derive(Debug, Clone, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct ApplyOperationOutput { + /// Whether the operation succeeded. + pub success: bool, + /// SQL that was executed. + #[serde(skip_serializing_if = "Option::is_none")] + pub sql_executed: Option, + /// Human-readable message. + pub message: String, +} + +/// Output for get_session_diff tool. +#[derive(Debug, Clone, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionDiffOutput { + /// Whether there are any changes. + pub has_changes: bool, + /// Summary counts. + pub summary: DiffSummary, + /// Detailed changes. + pub details: DiffDetails, + /// Breaking changes detected. + pub breaking_changes: Vec, + /// Whether there are breaking changes. + pub has_breaking_changes: bool, +} + +/// Summary of diff counts. +#[derive(Debug, Clone, Default, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct DiffSummary { + /// Number of tables added. + pub tables_added: usize, + /// Number of tables removed. + pub tables_removed: usize, + /// Number of tables modified. + pub tables_modified: usize, + /// Number of columns added. + pub columns_added: usize, + /// Number of columns removed. + pub columns_removed: usize, + /// Number of columns modified. + pub columns_modified: usize, + /// Number of indexes added. + pub indexes_added: usize, + /// Number of indexes removed. + pub indexes_removed: usize, + /// Number of constraints added. + pub constraints_added: usize, + /// Number of constraints removed. + pub constraints_removed: usize, + /// Number of enums added. + pub enums_added: usize, + /// Number of enums removed. + pub enums_removed: usize, +} + +/// Detailed diff information. +#[derive(Debug, Clone, Default, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct DiffDetails { + /// Names of tables added. + pub tables_added: Vec, + /// Names of tables removed. + pub tables_removed: Vec, + /// Details of modified tables. + pub tables_modified: Vec, + /// Indexes added. + pub indexes_added: Vec, + /// Foreign keys added. + pub foreign_keys_added: Vec, +} + +/// Detail of a modified table. +#[derive(Debug, Clone, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct ModifiedTableDetail { + /// Table name. + pub table: String, + /// Columns added. + pub columns_added: Vec, + /// Columns removed. + pub columns_removed: Vec, + /// Columns modified. + pub columns_modified: Vec, +} + +/// Detail of an index. +#[derive(Debug, Clone, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct IndexDetail { + /// Index name. + pub name: String, + /// Table name. + pub table: String, + /// Column names. + pub columns: Vec, +} + +/// Detail of a foreign key. +#[derive(Debug, Clone, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct ForeignKeyDetail { + /// Constraint name. + pub name: String, + /// Referencing table. + pub table: String, + /// Referencing columns. + pub columns: Vec, + /// Referenced table. + pub references_table: String, + /// Referenced columns. + pub references_columns: Vec, +} + +/// Output for generate_migration tool. +#[derive(Debug, Clone, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct GenerateMigrationOutput { + /// Whether generation succeeded. + pub success: bool, + /// Migration ID. + pub migration_id: String, + /// Sequence number. + pub sequence_number: usize, + /// Migration description. + pub description: String, + /// File path where the migration was saved. + pub file_path: String, + /// Number of operations in the migration. + pub operation_count: usize, + /// Whether the migration has breaking changes. + pub has_breaking_changes: bool, + /// Breaking changes detected. + pub breaking_changes: Vec, + /// Whether the migration is reversible. + pub is_reversible: bool, + /// Whether the session has ended. + pub session_ended: bool, + /// Human-readable message. + pub message: String, +} + +/// Output for cancel_session tool. +#[derive(Debug, Clone, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct CancelSessionOutput { + /// Whether cancellation succeeded. + pub success: bool, + /// Session ID that was cancelled. + pub session_id: String, + /// Number of operations discarded. + pub changes_discarded: usize, + /// Human-readable message. + pub message: String, +} + +/// Output for list_sessions tool. +#[derive(Debug, Clone, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct ListSessionsOutput { + /// List of active sessions. + pub sessions: Vec, + /// Total count. + pub total_count: usize, +} + +// ============================================================================= +// Tool Handlers +// ============================================================================= + +/// Handles the start_session tool. +pub async fn handle_start_session( + input: StartSessionInput, + session_manager: &mut SessionManager, +) -> Result { + let session_id = session_manager + .start_session(input.description.clone()) + .await?; + + let session = session_manager + .get_session(&session_id) + .ok_or_else(|| McpError::InternalError("session not found after creation".into()))?; + + let base_state = session_manager.base_state(); + let table_count = base_state.tables.len(); + let view_count = base_state.views.len(); + let sequence_count = base_state.sequences.len(); + let enum_count = base_state.enums.len(); + + let output = StartSessionOutput { + session_id: session_id.to_string(), + description: input.description, + base_state_hash: StateHash::from_namespace(base_state).to_hex(), + started_at: session.started_at.to_string(), + message: format!( + "Session started. Current schema has {} tables, {} views, {} sequences, {} enums.", + table_count, view_count, sequence_count, enum_count + ), + }; + + serde_json::to_value(output).map_err(|e| McpError::InternalError(e.to_string())) +} + +/// Handles the execute_sql tool. +#[cfg(feature = "pglite")] +pub async fn handle_execute_sql( + input: ExecuteSqlInput, + session_manager: &mut SessionManager, +) -> Result { + let session_id = SessionId::new(input.session_id); + + // Ensure the session's PGLite database is initialized + session_manager + .ensure_session_initialized(&session_id) + .await?; + + let session = session_manager + .get_session_mut(&session_id) + .ok_or_else(|| McpError::SessionNotFound(session_id.clone()))?; + + // Get the PGLite client for this session + let client = session.get_client().await?; + + // Execute the SQL + match client.batch_execute(&input.sql).await { + Ok(()) => { + // Record the SQL in history + session.record_sql(input.sql.clone()); + + let output = ExecuteSqlOutput { + success: true, + rows_affected: Some(0), + message: "SQL executed successfully".to_string(), + warnings: vec![], + error: None, + }; + + serde_json::to_value(output).map_err(|e| McpError::InternalError(e.to_string())) + } + Err(e) => { + let db_error = e.as_db_error(); + let output = ExecuteSqlOutput { + success: false, + rows_affected: None, + message: e.to_string(), + warnings: vec![], + error: Some(SqlError { + message: e.to_string(), + code: db_error.map(|e| e.code().code().to_string()), + detail: db_error.and_then(|e| e.detail().map(|s| s.to_string())), + hint: db_error.and_then(|e| e.hint().map(|s| s.to_string())), + position: db_error.and_then(|e| { + e.position().map(|p| match p { + tokio_postgres::error::ErrorPosition::Original(pos) + | tokio_postgres::error::ErrorPosition::Internal { + position: pos, + .. + } => *pos as i32, + }) + }), + }), + }; + + serde_json::to_value(output).map_err(|e| McpError::InternalError(e.to_string())) + } + } +} + +/// Handles the execute_sql tool (non-pglite version). +#[cfg(not(feature = "pglite"))] +pub async fn handle_execute_sql( + input: ExecuteSqlInput, + session_manager: &mut SessionManager, +) -> Result { + let session_id = SessionId::new(input.session_id); + let session = session_manager + .get_session_mut(&session_id) + .ok_or_else(|| McpError::SessionNotFound(session_id.clone()))?; + + // Record the SQL in history + session.record_sql(input.sql.clone()); + + let output = ExecuteSqlOutput { + success: true, + rows_affected: Some(0), + message: "SQL recorded (PGLite not enabled)".to_string(), + warnings: vec![], + error: None, + }; + + serde_json::to_value(output).map_err(|e| McpError::InternalError(e.to_string())) +} + +/// Handles the apply_operation tool. +#[cfg(feature = "pglite")] +pub async fn handle_apply_operation( + input: ApplyOperationInput, + session_manager: &mut SessionManager, +) -> Result { + let session_id = SessionId::new(input.session_id); + + // Parse the operation from JSON + let operation: Operation = serde_json::from_value(input.operation.clone()) + .map_err(|e| McpError::InvalidParams(format!("invalid operation: {e}")))?; + + // Ensure the session's PGLite database is initialized + session_manager + .ensure_session_initialized(&session_id) + .await?; + + let session = session_manager + .get_session_mut(&session_id) + .ok_or_else(|| McpError::SessionNotFound(session_id.clone()))?; + + // Render the operation to SQL + let renderer = PostgresRenderer::new(RenderConfig::default()); + let rendered = renderer.render(&operation); + let sql = rendered.forward.join(";\n"); + + // Get the PGLite client for this session + let client = session.get_client().await?; + + // Execute the SQL + match client.batch_execute(&sql).await { + Ok(()) => { + // Record the operation in history + session.record_operation(operation, sql.clone()); + + let output = ApplyOperationOutput { + success: true, + sql_executed: Some(sql), + message: "Operation applied successfully".to_string(), + }; + + serde_json::to_value(output).map_err(|e| McpError::InternalError(e.to_string())) + } + Err(e) => { + let output = ApplyOperationOutput { + success: false, + sql_executed: Some(sql), + message: format!("Operation failed: {e}"), + }; + + serde_json::to_value(output).map_err(|e| McpError::InternalError(e.to_string())) + } + } +} + +/// Handles the apply_operation tool (non-pglite version). +#[cfg(not(feature = "pglite"))] +pub async fn handle_apply_operation( + input: ApplyOperationInput, + session_manager: &mut SessionManager, +) -> Result { + let session_id = SessionId::new(input.session_id); + let _session = session_manager + .get_session_mut(&session_id) + .ok_or(McpError::SessionNotFound(session_id))?; + + let output = ApplyOperationOutput { + success: false, + sql_executed: None, + message: "apply_operation requires PGLite feature to be enabled".to_string(), + }; + + serde_json::to_value(output).map_err(|e| McpError::InternalError(e.to_string())) +} + +/// Handles the get_session_schema tool. +#[cfg(feature = "pglite")] +pub async fn handle_get_session_schema( + input: GetSessionSchemaInput, + session_manager: &mut SessionManager, +) -> Result { + let session_id = SessionId::new(input.session_id); + + // Ensure the session's PGLite database is initialized + session_manager + .ensure_session_initialized(&session_id) + .await?; + + let session = session_manager + .get_session(&session_id) + .ok_or(McpError::SessionNotFound(session_id))?; + + // Get the current schema from the session's PGLite database + let client = session.get_client().await?; + let catalog = PostgresCatalog::new(&client); + let namespace = load_namespace(&catalog, "public") + .await + .map_err(|e| McpError::InternalError(format!("failed to load schema: {e}")))?; + + let output = SchemaResource::from(&namespace); + + serde_json::to_value(output).map_err(|e| McpError::InternalError(e.to_string())) +} + +/// Handles the get_session_schema tool (non-pglite version). +#[cfg(not(feature = "pglite"))] +pub async fn handle_get_session_schema( + input: GetSessionSchemaInput, + session_manager: &SessionManager, +) -> Result { + let session_id = SessionId::new(input.session_id); + let session = session_manager + .get_session(&session_id) + .ok_or(McpError::SessionNotFound(session_id))?; + + // Return the base state when PGLite is not available + let output = SchemaResource::from(&session.base_state); + + serde_json::to_value(output).map_err(|e| McpError::InternalError(e.to_string())) +} + +/// Handles the get_session_diff tool. +#[cfg(feature = "pglite")] +pub async fn handle_get_session_diff( + input: GetSessionDiffInput, + session_manager: &mut SessionManager, +) -> Result { + let session_id = SessionId::new(input.session_id); + + // Ensure the session's PGLite database is initialized + session_manager + .ensure_session_initialized(&session_id) + .await?; + + let session = session_manager + .get_session(&session_id) + .ok_or(McpError::SessionNotFound(session_id))?; + + // Get the current schema from the session's PGLite database + let client = session.get_client().await?; + let catalog = PostgresCatalog::new(&client); + let current_state = load_namespace(&catalog, "public") + .await + .map_err(|e| McpError::InternalError(format!("failed to load schema: {e}")))?; + + // Diff against base state + let diff = diff_namespaces(&session.base_state, ¤t_state); + + // Build the output + let mut summary = DiffSummary::default(); + let mut details = DiffDetails::default(); + let mut breaking_changes = Vec::new(); + + // Process added tables + for table in &diff.tables.added { + summary.tables_added += 1; + details.tables_added.push(table.name.as_ref().to_string()); + } + + // Process removed tables (we only have keys) + for table_name in &diff.tables.removed { + summary.tables_removed += 1; + details.tables_removed.push(table_name.as_ref().to_string()); + breaking_changes.push(format!("Table '{}' removed", table_name.as_ref())); + } + + // Process modified tables + for modified in &diff.tables.modified { + summary.tables_modified += 1; + + let mut modified_detail = ModifiedTableDetail { + table: modified.name.as_ref().to_string(), + columns_added: vec![], + columns_removed: vec![], + columns_modified: vec![], + }; + + // Check for column changes + for col in &modified.columns.added { + summary.columns_added += 1; + modified_detail + .columns_added + .push(col.name.as_ref().to_string()); + } + + for col_name in &modified.columns.removed { + summary.columns_removed += 1; + modified_detail + .columns_removed + .push(col_name.as_ref().to_string()); + breaking_changes.push(format!( + "Column '{}.{}' removed", + modified.name.as_ref(), + col_name.as_ref() + )); + } + + for col_mod in &modified.columns.modified { + summary.columns_modified += 1; + modified_detail + .columns_modified + .push(col_mod.name.as_ref().to_string()); + } + + // Check for constraint changes + summary.constraints_added += modified.constraints.added.len(); + summary.constraints_removed += modified.constraints.removed.len(); + + // Check for index changes + summary.indexes_added += modified.indexes.added.len(); + summary.indexes_removed += modified.indexes.removed.len(); + + details.tables_modified.push(modified_detail); + } + + // Process added enums + summary.enums_added = diff.enums.added.len(); + + // Process removed enums + for enum_name in &diff.enums.removed { + summary.enums_removed += 1; + breaking_changes.push(format!("Enum '{}' removed", enum_name.as_ref())); + } + + let has_changes = summary.tables_added > 0 + || summary.tables_removed > 0 + || summary.tables_modified > 0 + || summary.enums_added > 0 + || summary.enums_removed > 0; + + let output = SessionDiffOutput { + has_changes, + summary, + details, + breaking_changes: breaking_changes.clone(), + has_breaking_changes: !breaking_changes.is_empty(), + }; + + serde_json::to_value(output).map_err(|e| McpError::InternalError(e.to_string())) +} + +/// Handles the get_session_diff tool (non-pglite version). +#[cfg(not(feature = "pglite"))] +pub async fn handle_get_session_diff( + input: GetSessionDiffInput, + session_manager: &SessionManager, +) -> Result { + let session_id = SessionId::new(input.session_id); + let session = session_manager + .get_session(&session_id) + .ok_or(McpError::SessionNotFound(session_id))?; + + // Return an empty diff when PGLite is not available + let output = SessionDiffOutput { + has_changes: session.has_changes(), + summary: DiffSummary::default(), + details: DiffDetails::default(), + breaking_changes: vec![], + has_breaking_changes: false, + }; + + serde_json::to_value(output).map_err(|e| McpError::InternalError(e.to_string())) +} + +/// Handles the generate_migration tool. +#[cfg(feature = "pglite")] +pub async fn handle_generate_migration( + input: GenerateMigrationInput, + session_manager: &mut SessionManager, +) -> Result { + let session_id = SessionId::new(input.session_id); + + // Ensure the session's PGLite database is initialized + session_manager + .ensure_session_initialized(&session_id) + .await?; + + // Get the session's base state and description + let (base_state, description, parent_hash) = { + let session = session_manager + .get_session(&session_id) + .ok_or_else(|| McpError::SessionNotFound(session_id.clone()))?; + + if !session.has_changes() && !input.force { + return Err(McpError::NoChanges); + } + + let desc = input + .description + .clone() + .unwrap_or_else(|| session.description.clone()); + let parent = StateHash::from_namespace(&session.base_state); + (session.base_state.clone(), desc, parent) + }; + + // Get the current schema from the session's PGLite database + let current_state = { + let session = session_manager + .get_session(&session_id) + .ok_or_else(|| McpError::SessionNotFound(session_id.clone()))?; + let client = session.get_client().await?; + let catalog = PostgresCatalog::new(&client); + load_namespace(&catalog, "public") + .await + .map_err(|e| McpError::InternalError(format!("failed to load schema: {e}")))? + }; + + // Diff against base state + let diff = diff_namespaces(&base_state, ¤t_state); + + // Check for breaking changes + let mut breaking_changes = Vec::new(); + for table_name in &diff.tables.removed { + breaking_changes.push(format!("Table '{}' removed", table_name.as_ref())); + } + for modified in &diff.tables.modified { + for col_name in &modified.columns.removed { + breaking_changes.push(format!( + "Column '{}.{}' removed", + modified.name.as_ref(), + col_name.as_ref() + )); + } + } + for enum_name in &diff.enums.removed { + breaking_changes.push(format!("Enum '{}' removed", enum_name.as_ref())); + } + + if !breaking_changes.is_empty() && !input.force { + return Err(McpError::BreakingChangesDetected { breaking_changes }); + } + + // Convert diff to migration plan + let plan = MigrationPlan::from_diff(&diff); + + if plan.is_empty() && !input.force { + return Err(McpError::NoChanges); + } + + // Calculate resulting state hash + let resulting_hash = StateHash::from_namespace(¤t_state); + + // Create the migration + let migration = Migration::new( + &description, + plan.operations.clone(), + vec![], // down_operations would require inverse operation generation + parent_hash, + resulting_hash, + vec![], // breaking_changes as BreakingChange structs (simplified for now) + ); + + // Save the migration + let backend = session_manager.backend(); + let migration_index = backend + .get_migration_index() + .await + .map_err(|e| McpError::InternalError(format!("failed to get migration index: {e}")))?; + let sequence_number = migration_index.len() + 1; + + backend + .save_migration(&migration) + .await + .map_err(|e| McpError::SaveFailed(e.to_string()))?; + + // Save the current state + backend + .save_current_state(¤t_state) + .await + .map_err(|e| McpError::SaveFailed(format!("failed to save current state: {e}")))?; + + // Get the migration file path + let migrations_dir = backend.root().join("migrations"); + let file_path = migrations_dir + .join(format!("{:05}.json", sequence_number)) + .to_string_lossy() + .to_string(); + + // Remove the session + let _ = session_manager.cancel_session(&session_id).await; + + let output = GenerateMigrationOutput { + success: true, + migration_id: migration.id.to_hex(), + sequence_number, + description, + file_path, + operation_count: plan.operations.len(), + has_breaking_changes: !breaking_changes.is_empty(), + breaking_changes, + is_reversible: false, // Would need inverse operation generation + session_ended: true, + message: format!( + "Migration {} created with {} operations", + migration.id.to_hex(), + plan.operations.len() + ), + }; + + serde_json::to_value(output).map_err(|e| McpError::InternalError(e.to_string())) +} + +/// Handles the generate_migration tool (non-pglite version). +#[cfg(not(feature = "pglite"))] +pub async fn handle_generate_migration( + input: GenerateMigrationInput, + session_manager: &mut SessionManager, +) -> Result { + let session_id = SessionId::new(input.session_id); + + // Check session exists + { + let session = session_manager + .get_session(&session_id) + .ok_or_else(|| McpError::SessionNotFound(session_id.clone()))?; + + if !session.has_changes() { + return Err(McpError::NoChanges); + } + } + + let output = GenerateMigrationOutput { + success: false, + migration_id: "".to_string(), + sequence_number: 0, + description: input.description.unwrap_or_default(), + file_path: "".to_string(), + operation_count: 0, + has_breaking_changes: false, + breaking_changes: vec![], + is_reversible: false, + session_ended: false, + message: "generate_migration requires PGLite feature to be enabled".to_string(), + }; + + serde_json::to_value(output).map_err(|e| McpError::InternalError(e.to_string())) +} + +/// Handles the cancel_session tool. +pub async fn handle_cancel_session( + input: CancelSessionInput, + session_manager: &mut SessionManager, +) -> Result { + let session_id = SessionId::new(input.session_id.clone()); + let discarded = session_manager.cancel_session(&session_id).await?; + + let output = CancelSessionOutput { + success: true, + session_id: input.session_id, + changes_discarded: discarded, + message: format!( + "Session cancelled. {} operations discarded, no migration created.", + discarded + ), + }; + + serde_json::to_value(output).map_err(|e| McpError::InternalError(e.to_string())) +} + +/// Handles the list_sessions tool. +pub async fn handle_list_sessions(session_manager: &SessionManager) -> Result { + let sessions = session_manager.list_sessions(); + let total_count = sessions.len(); + + let output = ListSessionsOutput { + sessions, + total_count, + }; + + serde_json::to_value(output).map_err(|e| McpError::InternalError(e.to_string())) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn list_tools_returns_expected() { + let tools = list_tools(); + + assert!(tools.iter().any(|t| t.name == "start_session")); + assert!(tools.iter().any(|t| t.name == "execute_sql")); + assert!(tools.iter().any(|t| t.name == "apply_operation")); + assert!(tools.iter().any(|t| t.name == "get_session_schema")); + assert!(tools.iter().any(|t| t.name == "get_session_diff")); + assert!(tools.iter().any(|t| t.name == "generate_migration")); + assert!(tools.iter().any(|t| t.name == "cancel_session")); + assert!(tools.iter().any(|t| t.name == "list_sessions")); + } +} diff --git a/src/mcp/transport.rs b/src/mcp/transport.rs new file mode 100644 index 0000000..96d5b86 --- /dev/null +++ b/src/mcp/transport.rs @@ -0,0 +1,119 @@ +//! Stdio transport for MCP server. +//! +//! This module handles reading JSON-RPC messages from stdin and writing +//! responses to stdout. Messages are newline-delimited JSON. + +use std::io::{self, BufRead, Write}; + +use crate::mcp::error::McpError; +use crate::mcp::protocol::{JsonRpcRequest, JsonRpcResponse}; + +/// Transport layer for stdio-based MCP communication. +/// +/// Messages are newline-delimited JSON. Each message is a single line. +pub struct StdioTransport { + /// Buffered stdin reader. + reader: io::BufReader, + /// Stdout writer. + writer: io::Stdout, +} + +impl StdioTransport { + /// Creates a new stdio transport. + #[must_use] + pub fn new() -> Self { + Self { + reader: io::BufReader::new(io::stdin()), + writer: io::stdout(), + } + } + + /// Reads the next JSON-RPC request from stdin. + /// + /// Returns `None` if stdin is closed (EOF). + /// + /// # Errors + /// + /// Returns an error if reading fails or the JSON is invalid. + pub fn read_request(&mut self) -> Result, McpError> { + let mut line = String::new(); + + match self.reader.read_line(&mut line) { + Ok(0) => Ok(None), // EOF + Ok(_) => { + let line = line.trim(); + if line.is_empty() { + // Skip empty lines and try again + return self.read_request(); + } + + tracing::debug!("Received: {}", line); + + serde_json::from_str(line).map(Some).map_err(|e| { + McpError::ParseError(format!("failed to parse JSON-RPC request: {e}")) + }) + } + Err(e) => Err(McpError::IoError(e)), + } + } + + /// Writes a JSON-RPC response to stdout. + /// + /// # Errors + /// + /// Returns an error if writing fails. + pub fn write_response(&mut self, response: &JsonRpcResponse) -> Result<(), McpError> { + let json = serde_json::to_string(response) + .map_err(|e| McpError::InternalError(format!("failed to serialize response: {e}")))?; + + tracing::debug!("Sending: {}", json); + + writeln!(self.writer, "{json}").map_err(McpError::IoError)?; + self.writer.flush().map_err(McpError::IoError)?; + + Ok(()) + } +} + +impl Default for StdioTransport { + fn default() -> Self { + Self::new() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::mcp::protocol::{JSONRPC_VERSION, JsonRpcError, RequestId}; + + #[test] + fn transport_creation() { + // Just verify we can create a transport + let _transport = StdioTransport::new(); + } + + #[test] + fn response_serialization() { + // Test that responses serialize correctly + let response = + JsonRpcResponse::success(RequestId::from(1_i64), serde_json::json!({"ok": true})); + let json = serde_json::to_string(&response).unwrap(); + + assert!(json.contains("\"jsonrpc\"")); + assert!(json.contains(&format!("\"{}\"", JSONRPC_VERSION))); + assert!(json.contains("\"result\"")); + } + + #[test] + fn error_response_serialization() { + let response = JsonRpcResponse::error( + RequestId::from(1_i64), + JsonRpcError::new(-32600, "test error"), + ); + let json = serde_json::to_string(&response).unwrap(); + + assert!(json.contains("\"error\"")); + assert!(json.contains("-32600")); + assert!(json.contains("test error")); + } +}