diff --git a/Cargo.lock b/Cargo.lock index 796a8cb..a825545 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2863,7 +2863,7 @@ dependencies = [ [[package]] name = "rowdy" -version = "0.16.3" +version = "0.17.0" dependencies = [ "anyhow", "arboard", diff --git a/Cargo.toml b/Cargo.toml index 7ce398b..58a7bb8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "rowdy" -version = "0.16.3" +version = "0.17.0" edition = "2024" rust-version = "1.86" license = "MIT" diff --git a/src/action/mod.rs b/src/action/mod.rs index 0e0e2fb..4a51834 100644 --- a/src/action/mod.rs +++ b/src/action/mod.rs @@ -12,10 +12,13 @@ mod llm_settings; mod params_prompt; mod query; mod results; +mod saved_queries; mod schema; mod session; mod update; +pub use saved_queries::SavedQueryAction; + pub(crate) use session::{flush_session, schedule_session_save}; pub use update::try_promote_pending_update; @@ -158,6 +161,11 @@ pub enum Action { /// the action layer replies to the LLM with `{"error": "user /// denied access"}` so the turn keeps moving. ToolApproveDeny, + /// Saved-query overlay / picker interaction. The translation layer + /// keeps `:save` / `:load` / `:run-saved` outside this variant + /// (they're dispatched directly through `dispatch_command`) so this + /// only handles the overlay key flow. + SavedQuery(SavedQueryAction), } /// What a click or scroll-wheel was aimed at. Translated from @@ -495,6 +503,7 @@ pub fn apply(app: &mut App, action: Action) { Action::Session(s) => session::dispatch_session(app, s), Action::ToolApproveAccept => chat::on_tool_approve_accept(app), Action::ToolApproveDeny => chat::on_tool_approve_deny(app), + Action::SavedQuery(a) => saved_queries::apply(app, a), } } @@ -785,6 +794,10 @@ fn dispatch_command(app: &mut App, cmd: command::Command) { Action::Session(session::session_subcommand_to_action(sub)), ), C::Update => apply(app, Action::CheckForUpdate), + C::Save(name) => saved_queries::apply_save(app, name), + C::Load(name) => saved_queries::apply_load(app, name), + C::RunSaved(Some(name)) => saved_queries::apply_run_saved(app, name), + C::RunSaved(None) => saved_queries::open_run_picker(app), } } diff --git a/src/action/saved_queries.rs b/src/action/saved_queries.rs new file mode 100644 index 0000000..7456609 --- /dev/null +++ b/src/action/saved_queries.rs @@ -0,0 +1,220 @@ +//! `:save`, `:load`, `:run-saved` — per-connection named query store. + +use crate::app::App; +use crate::saved_queries; +use crate::state::overlay::Overlay; +use crate::state::saved_query_picker::SavedQueryPickerState; +use crate::state::status::QueryStatus; + +#[derive(Debug, Clone)] +pub enum SavedQueryAction { + /// Picker cursor — up/down step. + PickerMove(i32), + PickerTop, + PickerBottom, + /// Picker Enter — load / run the selected entry. + PickerConfirm, + /// Picker Esc. + PickerCancel, + /// `:save` overwrite prompt — Enter. + ConfirmOverwrite, + /// `:save` overwrite prompt — Esc / n. + CancelOverwrite, +} + +pub fn apply_save(app: &mut App, name: String) { + let Some(conn) = app.active_connection.clone() else { + app.status = QueryStatus::Failed { + error: "no active connection".into(), + }; + return; + }; + if let Err(err) = saved_queries::validate_name(&name) { + app.status = QueryStatus::Failed { error: err }; + return; + } + let Some(sql) = resolve_sql_to_save(app) else { + app.status = QueryStatus::Failed { + error: "no selection or statement under cursor to save".into(), + }; + return; + }; + if saved_queries::exists(&app.data_dir, &conn, &name) { + app.overlay = Some(Overlay::ConfirmSaveOverwrite { name, sql }); + return; + } + write_and_notice(app, &conn, &name, &sql); +} + +pub fn apply_load(app: &mut App, name: String) { + let Some(conn) = app.active_connection.clone() else { + app.status = QueryStatus::Failed { + error: "no active connection".into(), + }; + return; + }; + if let Err(err) = saved_queries::validate_name(&name) { + app.status = QueryStatus::Failed { error: err }; + return; + } + let sql = match saved_queries::load(&app.data_dir, &conn, &name) { + Ok(s) => s, + Err(err) if err.kind() == std::io::ErrorKind::NotFound => { + app.status = QueryStatus::Failed { + error: format!("no saved query named {name:?}"), + }; + return; + } + Err(err) => { + app.status = QueryStatus::Failed { + error: format!("load {name:?} failed: {err}"), + }; + return; + } + }; + crate::state::editor::insert_text_at_cursor(&mut app.editor.state, &sql); + app.editor_dirty = true; + super::schedule_session_save(app); + app.status = QueryStatus::Notice { + msg: format!("loaded saved query {name:?}"), + }; +} + +pub fn apply_run_saved(app: &mut App, name: String) { + let Some(conn) = app.active_connection.clone() else { + app.status = QueryStatus::Failed { + error: "no active connection".into(), + }; + return; + }; + if let Err(err) = saved_queries::validate_name(&name) { + app.status = QueryStatus::Failed { error: err }; + return; + } + let sql = match saved_queries::load(&app.data_dir, &conn, &name) { + Ok(s) => s, + Err(err) if err.kind() == std::io::ErrorKind::NotFound => { + app.status = QueryStatus::Failed { + error: format!("no saved query named {name:?}"), + }; + return; + } + Err(err) => { + app.status = QueryStatus::Failed { + error: format!("load {name:?} failed: {err}"), + }; + return; + } + }; + super::query::dispatch_query(app, sql); +} + +pub fn open_run_picker(app: &mut App) { + let Some(conn) = app.active_connection.clone() else { + app.status = QueryStatus::Failed { + error: "no active connection".into(), + }; + return; + }; + let entries = match saved_queries::list(&app.data_dir, &conn) { + Ok(v) => v, + Err(err) => { + app.status = QueryStatus::Failed { + error: format!("list saved queries failed: {err}"), + }; + return; + } + }; + app.overlay = Some(Overlay::SavedQueryPicker(SavedQueryPickerState::new( + entries, + ))); +} + +pub fn apply(app: &mut App, action: SavedQueryAction) { + match action { + SavedQueryAction::PickerMove(delta) => { + if let Some(Overlay::SavedQueryPicker(state)) = app.overlay.as_mut() { + state.move_selection(delta); + } + } + SavedQueryAction::PickerTop => { + if let Some(Overlay::SavedQueryPicker(state)) = app.overlay.as_mut() { + state.jump_top(); + } + } + SavedQueryAction::PickerBottom => { + if let Some(Overlay::SavedQueryPicker(state)) = app.overlay.as_mut() { + state.jump_bottom(); + } + } + SavedQueryAction::PickerCancel => { + if matches!(app.overlay, Some(Overlay::SavedQueryPicker(_))) { + app.overlay = None; + } + } + SavedQueryAction::PickerConfirm => picker_confirm(app), + SavedQueryAction::ConfirmOverwrite => confirm_overwrite(app), + SavedQueryAction::CancelOverwrite => { + if matches!(app.overlay, Some(Overlay::ConfirmSaveOverwrite { .. })) { + app.overlay = None; + app.status = QueryStatus::Notice { + msg: "save cancelled".into(), + }; + } + } + } +} + +fn picker_confirm(app: &mut App) { + let Some(Overlay::SavedQueryPicker(state)) = app.overlay.as_ref() else { + return; + }; + let Some(name) = state.selected_name().map(str::to_string) else { + // Empty list — just close the overlay. + app.overlay = None; + return; + }; + app.overlay = None; + apply_run_saved(app, name); +} + +fn confirm_overwrite(app: &mut App) { + let Some(Overlay::ConfirmSaveOverwrite { name, sql }) = app.overlay.take() else { + return; + }; + let Some(conn) = app.active_connection.clone() else { + app.status = QueryStatus::Failed { + error: "no active connection".into(), + }; + return; + }; + write_and_notice(app, &conn, &name, &sql); +} + +fn resolve_sql_to_save(app: &App) -> Option { + if let Some(text) = crate::state::editor::selection_text(&app.editor.state) { + return Some(text); + } + let range = crate::state::editor::statement_under_cursor(&app.editor.state)?; + let trimmed = range.text.trim(); + if trimmed.is_empty() { + None + } else { + Some(trimmed.to_string()) + } +} + +fn write_and_notice(app: &mut App, conn: &str, name: &str, sql: &str) { + match saved_queries::save(&app.data_dir, conn, name, sql) { + Ok(()) => { + app.status = QueryStatus::Notice { + msg: format!("saved query {name:?}"), + }; + } + Err(err) => { + app.status = QueryStatus::Failed { + error: format!("save {name:?} failed: {err}"), + }; + } + } +} diff --git a/src/command.rs b/src/command.rs index ad9a3ca..63a8de5 100644 --- a/src/command.rs +++ b/src/command.rs @@ -59,6 +59,15 @@ pub enum Command { /// version → "v0.7.x is the latest" notice; network failure → /// error in the bottom bar. Update, + /// `:save ` — persist the selection (or statement under cursor) + /// as a named query scoped to the active connection. + Save(String), + /// `:load ` — insert a previously saved query at the cursor. + Load(String), + /// `:run-saved [name]` — execute a saved query. Bare form opens a + /// picker overlay; with a name, dispatches straight through the + /// existing query pipeline (placeholders prompt, etc.). + RunSaved(Option), } /// `:chat` subcommands. Bare `:chat` toggles the right panel between @@ -334,6 +343,21 @@ pub static COMMAND_TREE: &[CommandSpec] = &[ aliases: &[], children: &[], }, + CommandSpec { + name: "save", + aliases: &[], + children: &[], + }, + CommandSpec { + name: "load", + aliases: &[], + children: &[], + }, + CommandSpec { + name: "run-saved", + aliases: &[], + children: &[], + }, ]; /// Parse a single `:` line. `Ok(None)` is the empty-line case (treat @@ -365,11 +389,42 @@ pub fn parse(line: &str) -> Result, String> { "chat" => Command::Chat(parse_chat(&args)?), "session" | "sess" => Command::Session(parse_session(&args)?), "update" => Command::Update, + "save" => parse_save(&args)?, + "load" => parse_load(&args)?, + "run-saved" => parse_run_saved(&args), _ => return Err(format!("unknown command: {cmd}")), }; Ok(Some(parsed)) } +fn parse_save(args: &[&str]) -> Result { + let name = args.join(" "); + let name = name.trim(); + if name.is_empty() { + return Err("usage: :save ".to_string()); + } + Ok(Command::Save(name.to_string())) +} + +fn parse_load(args: &[&str]) -> Result { + let name = args.join(" "); + let name = name.trim(); + if name.is_empty() { + return Err("usage: :load ".to_string()); + } + Ok(Command::Load(name.to_string())) +} + +fn parse_run_saved(args: &[&str]) -> Command { + let joined = args.join(" "); + let trimmed = joined.trim(); + if trimmed.is_empty() { + Command::RunSaved(None) + } else { + Command::RunSaved(Some(trimmed.to_string())) + } +} + fn parse_format(args: &[&str]) -> Result { let scope = match args.first().copied() { None => FormatScope::Cursor, @@ -946,6 +1001,32 @@ mod tests { assert!(matches!(parse("session yikes"), Err(msg) if msg.contains("unknown"))); } + #[test] + fn save_requires_name() { + assert_eq!(parse("save daily"), Ok(Some(Command::Save("daily".into())))); + // Names with spaces survive (sanitizer maps them on save). + assert_eq!( + parse("save weekly cohort"), + Ok(Some(Command::Save("weekly cohort".into()))) + ); + assert!(matches!(parse("save"), Err(msg) if msg.contains("usage:"))); + } + + #[test] + fn load_requires_name() { + assert_eq!(parse("load daily"), Ok(Some(Command::Load("daily".into())))); + assert!(matches!(parse("load"), Err(msg) if msg.contains("usage:"))); + } + + #[test] + fn run_saved_optional_name() { + assert_eq!(parse("run-saved"), Ok(Some(Command::RunSaved(None)))); + assert_eq!( + parse("run-saved daily"), + Ok(Some(Command::RunSaved(Some("daily".into())))) + ); + } + #[test] fn chat_unknown_subcommand() { assert!(matches!( diff --git a/src/datasource/sql/mod.rs b/src/datasource/sql/mod.rs index 76614dd..67da37d 100644 --- a/src/datasource/sql/mod.rs +++ b/src/datasource/sql/mod.rs @@ -308,6 +308,60 @@ fn contains_where_keyword(sql: &str) -> bool { false } +/// Effect of executing `sql` on the connection's transaction state. +/// +/// Used by the per-driver `execute()` paths to decide whether to keep +/// the pinned session connection (we're inside an open transaction so +/// subsequent statements must hit the same backend) or release it back +/// to the pool (no transaction is open — the next query can grab any +/// pool conn). Only the first parsed statement is inspected: the worker +/// always splits multi-statement input via [`split_statements`] before +/// calling `execute`. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub(crate) enum TxEffect { + /// `BEGIN` / `START TRANSACTION` — opens a transaction. + Begin, + /// `COMMIT` / `ROLLBACK` — closes any open transaction. + End, + /// Anything else (DDL, DML, SAVEPOINT, …) — leaves tx state where it was. + None, +} + +/// Inspect `sql` and classify its effect on transaction state. Mirrors +/// [`is_row_returning`]'s AST-first / keyword-fallback strategy so +/// dialect quirks don't slip past us. The keyword fallback is permissive +/// on the "close tx" side so an unparseable `ROLLBACK foo` still +/// releases the session — staying pinned across a half-parsed close is +/// the worse error (would burn a backend). +pub(crate) fn tx_effect(sql: &str, dialect: &dyn sqlparser::dialect::Dialect) -> TxEffect { + use sqlparser::ast::Statement; + use sqlparser::parser::Parser; + + if let Ok(stmts) = Parser::parse_sql(dialect, sql) + && let Some(first) = stmts.first() + { + return match first { + Statement::StartTransaction { .. } => TxEffect::Begin, + Statement::Commit { .. } | Statement::Rollback { .. } => TxEffect::End, + _ => TxEffect::None, + }; + } + classify_tx_keyword(sql) +} + +fn classify_tx_keyword(sql: &str) -> TxEffect { + let stripped = strip_leading_comments_and_ws(sql); + let head: String = stripped + .chars() + .take_while(|c| c.is_alphanumeric() || *c == '_') + .collect(); + match head.to_ascii_uppercase().as_str() { + "BEGIN" | "START" => TxEffect::Begin, + "COMMIT" | "END" | "ROLLBACK" => TxEffect::End, + _ => TxEffect::None, + } +} + /// Hides the password between `://user:` and `@host` so it doesn't end up in /// the log file. Other URL forms are returned untouched. pub(crate) fn redact_url(url: &str) -> String { @@ -519,6 +573,35 @@ mod tests { ); } + #[test] + fn tx_effect_recognises_begin_commit_rollback() { + let d = PostgreSqlDialect {}; + assert_eq!(tx_effect("BEGIN", &d), TxEffect::Begin); + assert_eq!(tx_effect("START TRANSACTION", &d), TxEffect::Begin); + assert_eq!(tx_effect("COMMIT", &d), TxEffect::End); + assert_eq!(tx_effect("ROLLBACK", &d), TxEffect::End); + assert_eq!(tx_effect("SELECT 1", &d), TxEffect::None); + assert_eq!(tx_effect("INSERT INTO t VALUES (1)", &d), TxEffect::None); + } + + #[test] + fn tx_effect_keyword_fallback() { + // Unparseable input — fallback still classifies the leading kw. + let d = PostgreSqlDialect {}; + assert_eq!(tx_effect("BEGIN garble!!", &d), TxEffect::Begin); + assert_eq!(tx_effect("ROLLBACK garble!!", &d), TxEffect::End); + assert_eq!(tx_effect("garble!!", &d), TxEffect::None); + } + + #[test] + fn tx_effect_handles_sqlite_savepoint_as_none() { + // SAVEPOINT inside a tx doesn't change the in-tx flag — it's + // a nested marker, not a fresh BEGIN. + let d = SQLiteDialect {}; + assert_eq!(tx_effect("SAVEPOINT sp1", &d), TxEffect::None); + assert_eq!(tx_effect("RELEASE SAVEPOINT sp1", &d), TxEffect::None); + } + #[test] fn connection_lost_classifies_transport_errors() { // Database errors (server-side, e.g. syntax) keep the conn — diff --git a/src/datasource/sql/mysql.rs b/src/datasource/sql/mysql.rs index e22bcc1..c4a6005 100644 --- a/src/datasource/sql/mysql.rs +++ b/src/datasource/sql/mysql.rs @@ -1,5 +1,6 @@ +use std::sync::Mutex as StdMutex; use std::sync::atomic::{AtomicU64, Ordering}; -use std::time::Instant; +use std::time::{Duration, Instant}; use async_trait::async_trait; use chrono::{NaiveDate, NaiveDateTime, NaiveTime}; @@ -14,27 +15,43 @@ use crate::datasource::error::{DatasourceError, DatasourceResult}; use crate::datasource::schema::{ CatalogInfo, ColumnInfo, DefaultSchema, IndexInfo, SchemaInfo, TableInfo, TableKind, }; -use crate::datasource::sql::decode_to; +use crate::datasource::sql::{TxEffect, decode_to, tx_effect}; use crate::datasource::{Column, Datasource, QueryResult, Row as CellRow}; use crate::log::Logger; const DEFAULT_POOL_SIZE: u32 = 3; +const IDLE_TIMEOUT_SECS: u64 = 60; const MARIADB_SCHEME: &str = "mariadb:"; const MYSQL_SCHEME: &str = "mysql:"; const TARGET: &str = "mysql"; +/// Pinned session connection — see `postgres::Session` for the shape; +/// here the cancel handle is the `CONNECTION_ID()` we already mirror to +/// `session_conn_id`. +struct Session { + conn: PoolConnection, +} + pub struct MysqlDatasource { - pool: MySqlPool, + pool: StdMutex, + url: String, log: Logger, // CONNECTION_ID() of the pinned session connection, or 0 when no // session is currently held. Recorded once on acquire and kept // across executes so `cancel()` can `KILL QUERY ` even when // the spawn_query task is mid-await. session_conn_id: AtomicU64, - // Pinned connection across `execute()` calls so BEGIN / COMMIT / - // ROLLBACK survive between statements. Introspection and cancel - // talk to the pool directly. - session: Mutex>>, + /// Pinned connection held only while a transaction is open. + /// Introspection and cancel always talk to the pool. + session: Mutex>, +} + +fn build_pool(url: &str) -> Result { + MySqlPoolOptions::new() + .max_connections(DEFAULT_POOL_SIZE) + .min_connections(0) + .idle_timeout(Some(Duration::from_secs(IDLE_TIMEOUT_SECS))) + .connect_lazy(url) } impl MysqlDatasource { @@ -50,22 +67,29 @@ impl MysqlDatasource { TARGET, format!("connecting to {}", super::redact_url(&normalized)), ); - let pool = MySqlPoolOptions::new() - .max_connections(DEFAULT_POOL_SIZE) - .connect(&normalized) - .await - .map_err(|e| { - log.error(TARGET, format!("connect failed: {e}")); - DatasourceError::Connect(e.to_string()) - })?; + let pool = build_pool(&normalized).map_err(|e| { + log.error(TARGET, format!("pool build failed: {e}")); + DatasourceError::Connect(e.to_string()) + })?; + // Upfront ping so bad URLs surface here, not on the first query. + let verify_conn = pool.acquire().await.map_err(|e| { + log.error(TARGET, format!("connect verify failed: {e}")); + DatasourceError::Connect(e.to_string()) + })?; + drop(verify_conn); log.info(TARGET, "connected"); Ok(Self { - pool, + pool: StdMutex::new(pool), + url: normalized, log, session_conn_id: AtomicU64::new(0), session: Mutex::new(None), }) } + + fn pool(&self) -> MySqlPool { + self.pool.lock().expect("pool mutex poisoned").clone() + } } #[async_trait] @@ -75,6 +99,7 @@ impl Datasource for MysqlDatasource { // database") comes from the connection URL via `DATABASE()`. If the // user connected without selecting a database, we return an empty // string so the caller can decide what to do (skip prime, etc.). + let pool = self.pool(); let row = sqlx::query( "SELECT \ COALESCE(\ @@ -83,7 +108,7 @@ impl Datasource for MysqlDatasource { ) AS catalog, \ COALESCE(DATABASE(), '') AS schema", ) - .fetch_one(&self.pool) + .fetch_one(&pool) .await .map_err(introspect_err)?; let catalog = try_string(&row, "catalog").unwrap_or_else(|| "def".to_string()); @@ -94,9 +119,10 @@ impl Datasource for MysqlDatasource { async fn introspect_catalogs(&self) -> DatasourceResult> { // MySQL exposes a single static catalog (`def`); we read it from // information_schema rather than hard-coding it. + let pool = self.pool(); let rows = sqlx::query("SELECT DISTINCT catalog_name AS name FROM information_schema.schemata") - .fetch_all(&self.pool) + .fetch_all(&pool) .await .map_err(introspect_err)?; Ok(rows @@ -107,6 +133,7 @@ impl Datasource for MysqlDatasource { } async fn introspect_schemas(&self, catalog: &str) -> DatasourceResult> { + let pool = self.pool(); let rows = sqlx::query( "SELECT schema_name AS name FROM information_schema.schemata \ WHERE catalog_name = ? \ @@ -114,7 +141,7 @@ impl Datasource for MysqlDatasource { ORDER BY schema_name", ) .bind(catalog) - .fetch_all(&self.pool) + .fetch_all(&pool) .await .map_err(introspect_err)?; Ok(rows @@ -129,6 +156,7 @@ impl Datasource for MysqlDatasource { catalog: &str, schema: &str, ) -> DatasourceResult> { + let pool = self.pool(); let rows = sqlx::query( "SELECT table_name AS name, table_type AS kind \ FROM information_schema.tables \ @@ -137,7 +165,7 @@ impl Datasource for MysqlDatasource { ) .bind(catalog) .bind(schema) - .fetch_all(&self.pool) + .fetch_all(&pool) .await .map_err(introspect_err)?; Ok(rows @@ -162,6 +190,7 @@ impl Datasource for MysqlDatasource { ) -> DatasourceResult> { // `column_type` carries the full declared type (e.g. `int(11) unsigned`), // which is more useful for display than the normalised `data_type`. + let pool = self.pool(); let rows = sqlx::query( "SELECT column_name AS name, column_type AS type_name, is_nullable \ FROM information_schema.columns \ @@ -171,7 +200,7 @@ impl Datasource for MysqlDatasource { .bind(catalog) .bind(schema) .bind(table) - .fetch_all(&self.pool) + .fetch_all(&pool) .await .map_err(introspect_err)?; Ok(rows @@ -203,6 +232,7 @@ impl Datasource for MysqlDatasource { // information_schema.statistics has one row per index column; collapse // by index_name and take the lowest non_unique value (0 wins, meaning // unique). + let pool = self.pool(); let rows = sqlx::query( "SELECT index_name AS name, MIN(non_unique) AS non_unique \ FROM information_schema.statistics \ @@ -212,7 +242,7 @@ impl Datasource for MysqlDatasource { ) .bind(schema) .bind(table) - .fetch_all(&self.pool) + .fetch_all(&pool) .await .map_err(introspect_err)?; Ok(rows @@ -236,8 +266,10 @@ impl Datasource for MysqlDatasource { let started = Instant::now(); let mut guard = self.session.lock().await; - if guard.is_none() { - let mut conn = self.pool.acquire().await.map_err(|e| { + let was_in_tx = guard.is_some(); + if !was_in_tx { + let pool = self.pool(); + let mut conn = pool.acquire().await.map_err(|e| { self.log.error(TARGET, format!("acquire failed: {e}")); execute_err(e) })?; @@ -250,13 +282,13 @@ impl Datasource for MysqlDatasource { execute_err(e) })?; self.session_conn_id.store(conn_id, Ordering::SeqCst); - *guard = Some(conn); + *guard = Some(Session { conn }); } - let conn = guard.as_mut().expect("session conn populated above"); + let session = guard.as_mut().expect("session populated above"); let outcome: Result = if super::is_row_returning(statement, &sqlparser::dialect::MySqlDialect {}) { - match sqlx::query(statement).fetch_all(&mut **conn).await { + match sqlx::query(statement).fetch_all(&mut *session.conn).await { Ok(rows) => { let elapsed = started.elapsed(); let columns = build_columns(&rows); @@ -275,7 +307,7 @@ impl Datasource for MysqlDatasource { Err(e) => Err(e), } } else { - match sqlx::query(statement).execute(&mut **conn).await { + match sqlx::query(statement).execute(&mut *session.conn).await { Ok(outcome) => { let elapsed = started.elapsed(); Ok(QueryResult { @@ -290,6 +322,14 @@ impl Datasource for MysqlDatasource { } }; + let effect = tx_effect(statement, &sqlparser::dialect::MySqlDialect {}); + let new_in_tx = match (was_in_tx, effect) { + (true, TxEffect::End) => false, + (true, _) => true, + (false, TxEffect::Begin) => true, + (false, _) => false, + }; + match outcome { Ok(r) => { self.log.info( @@ -299,6 +339,10 @@ impl Datasource for MysqlDatasource { None => format!("execute ok: {} rows in {:?}", r.rows.len(), r.elapsed), }, ); + if !new_in_tx { + *guard = None; + self.session_conn_id.store(0, Ordering::SeqCst); + } Ok(r) } Err(e) => { @@ -308,6 +352,9 @@ impl Datasource for MysqlDatasource { self.session_conn_id.store(0, Ordering::SeqCst); self.log .warn(TARGET, "session conn dropped after connection loss"); + } else if !was_in_tx { + *guard = None; + self.session_conn_id.store(0, Ordering::SeqCst); } Err(execute_err(e)) } @@ -326,7 +373,8 @@ impl Datasource for MysqlDatasource { // busy session. let sql = format!("KILL QUERY {conn_id}"); self.log.info(TARGET, format!("cancel: {sql}")); - sqlx::query(&sql).execute(&self.pool).await.map_err(|e| { + let pool = self.pool(); + sqlx::query(&sql).execute(&pool).await.map_err(|e| { self.log.warn(TARGET, format!("cancel failed: {e}")); execute_err(e) })?; @@ -334,18 +382,35 @@ impl Datasource for MysqlDatasource { } async fn reset_session(&self) -> DatasourceResult<()> { + // Drop the pinned session conn first (with an explicit ROLLBACK + // so the backend isn't left holding an open tx). let mut guard = self.session.lock().await; - if let Some(mut conn) = guard.take() { - // Drop any open transaction before returning the conn to - // the pool — a stale BEGIN would otherwise be inherited by - // the next checkout. Logged-and-swallowed on failure. - if let Err(e) = sqlx::query("ROLLBACK").execute(&mut *conn).await { + if let Some(mut session) = guard.take() { + if let Err(e) = sqlx::query("ROLLBACK").execute(&mut *session.conn).await { self.log.warn(TARGET, format!("reset rollback: {e}")); } - drop(conn); + drop(session); self.log.info(TARGET, "session reset"); } self.session_conn_id.store(0, Ordering::SeqCst); + drop(guard); + + // Swap in a fresh lazy pool and close the old one — that forces + // every existing backend connection this datasource was holding + // (idle or otherwise) to disconnect. + let new_pool = match build_pool(&self.url) { + Ok(p) => p, + Err(e) => { + self.log.error(TARGET, format!("pool rebuild failed: {e}")); + return Err(DatasourceError::Execute(e.to_string())); + } + }; + let old_pool = { + let mut pool_guard = self.pool.lock().expect("pool mutex poisoned"); + std::mem::replace(&mut *pool_guard, new_pool) + }; + old_pool.close().await; + self.log.info(TARGET, "pool drained"); Ok(()) } } diff --git a/src/datasource/sql/postgres.rs b/src/datasource/sql/postgres.rs index 763c342..3ace783 100644 --- a/src/datasource/sql/postgres.rs +++ b/src/datasource/sql/postgres.rs @@ -1,5 +1,6 @@ +use std::sync::Mutex as StdMutex; use std::sync::atomic::{AtomicI32, Ordering}; -use std::time::Instant; +use std::time::{Duration, Instant}; use async_trait::async_trait; use chrono::{DateTime, NaiveDate, NaiveDateTime, NaiveTime, Utc}; @@ -15,15 +16,34 @@ use crate::datasource::error::{DatasourceError, DatasourceResult}; use crate::datasource::schema::{ CatalogInfo, ColumnInfo, DefaultSchema, IndexInfo, SchemaInfo, TableInfo, TableKind, }; -use crate::datasource::sql::decode_to; +use crate::datasource::sql::{TxEffect, decode_to, tx_effect}; use crate::datasource::{Column, Datasource, QueryResult, Row as CellRow}; use crate::log::Logger; const DEFAULT_POOL_SIZE: u32 = 3; +const IDLE_TIMEOUT_SECS: u64 = 60; const TARGET: &str = "postgres"; +/// Pinned session connection held only while the user is inside an +/// open transaction. Released back to the pool the moment the +/// transaction closes (and reacquired on the next statement). Outside +/// a transaction the datasource owns no connection — the pool can +/// shrink to zero and idle conns get reaped after `IDLE_TIMEOUT_SECS`. +/// The backend PID is mirrored to `PostgresDatasource::session_pid` so +/// `cancel()` can read it without locking the session. +struct Session { + conn: PoolConnection, +} + pub struct PostgresDatasource { - pool: PgPool, + /// The pool itself behind a sync mutex so `:reset` can swap it + /// (close-and-rebuild) without making every `execute` / introspect + /// caller take an async lock. `PgPool` is internally `Arc<...>` so + /// cloning out is O(1); we never hold the guard across `.await`. + pool: StdMutex, + /// Connection URL kept around so `:reset` can rebuild the pool from + /// scratch. + url: String, log: Logger, // Backend PID of the pinned session connection, or 0 when no session // is currently held. Recorded once when the session is acquired and @@ -32,33 +52,54 @@ pub struct PostgresDatasource { // in a transaction (callers can hit `pg_cancel_backend` to break out // of a stuck wait without first finding a running statement). session_pid: AtomicI32, - // Pinned connection held across `execute()` calls so BEGIN / - // COMMIT / ROLLBACK work the way the user expects. Lazily acquired - // and dropped by `reset_session()`. Introspection and `cancel()` - // still talk to the pool, never to this connection — they need to - // make progress while the session is busy. - session: Mutex>>, + /// Pinned connection held only while a transaction is open. Outside + /// a tx the slot is `None` and every `execute()` lands on a fresh + /// pool checkout that's released as soon as the statement finishes. + /// Introspection and `cancel()` always talk to the pool, never to + /// this connection — they need to make progress while the session + /// is busy. + session: Mutex>, +} + +fn build_pool(url: &str) -> Result { + PgPoolOptions::new() + .max_connections(DEFAULT_POOL_SIZE) + .min_connections(0) + .idle_timeout(Some(Duration::from_secs(IDLE_TIMEOUT_SECS))) + .connect_lazy(url) } impl PostgresDatasource { pub async fn connect(url: &str, log: Logger) -> DatasourceResult { log.info(TARGET, format!("connecting to {}", super::redact_url(url))); - let pool = PgPoolOptions::new() - .max_connections(DEFAULT_POOL_SIZE) - .connect(url) - .await - .map_err(|e| { - log.error(TARGET, format!("connect failed: {e}")); - DatasourceError::Connect(e.to_string()) - })?; + let pool = build_pool(url).map_err(|e| { + log.error(TARGET, format!("pool build failed: {e}")); + DatasourceError::Connect(e.to_string()) + })?; + // Verify connectivity up-front by checking out one conn and + // releasing it back to the pool. With `min_connections(0)` + + // `idle_timeout`, the conn drops shortly after; bad URLs still + // fail here instead of waiting until the first user query. + let verify_conn = pool.acquire().await.map_err(|e| { + log.error(TARGET, format!("connect verify failed: {e}")); + DatasourceError::Connect(e.to_string()) + })?; + drop(verify_conn); log.info(TARGET, "connected"); Ok(Self { - pool, + pool: StdMutex::new(pool), + url: url.to_string(), log, session_pid: AtomicI32::new(0), session: Mutex::new(None), }) } + + /// Cheap clone of the current pool handle (`PgPool` is `Arc`-backed). + /// Never holds the std mutex across an `.await`. + fn pool(&self) -> PgPool { + self.pool.lock().expect("pool mutex poisoned").clone() + } } #[async_trait] @@ -69,11 +110,12 @@ impl Datasource for PostgresDatasource { // unqualified objects land" — what users mean by "default schema". // `public` is the fallback if the search_path is empty (rare but // possible after `SET search_path = ''`). + let pool = self.pool(); let row = sqlx::query( "SELECT current_database() AS catalog, \ COALESCE((current_schemas(false))[1], 'public') AS schema", ) - .fetch_one(&self.pool) + .fetch_one(&pool) .await .map_err(introspect_err)?; let catalog: String = row.try_get("catalog").map_err(introspect_err)?; @@ -84,8 +126,9 @@ impl Datasource for PostgresDatasource { async fn introspect_catalogs(&self) -> DatasourceResult> { // A Postgres connection is bound to a single database; expose it as the // sole catalog so the tree mirrors the rest of the drivers. + let pool = self.pool(); let row = sqlx::query("SELECT current_database() AS name") - .fetch_one(&self.pool) + .fetch_one(&pool) .await .map_err(introspect_err)?; let name: String = row.try_get("name").map_err(introspect_err)?; @@ -93,6 +136,7 @@ impl Datasource for PostgresDatasource { } async fn introspect_schemas(&self, _catalog: &str) -> DatasourceResult> { + let pool = self.pool(); let rows = sqlx::query( "SELECT nspname AS name FROM pg_namespace \ WHERE nspname NOT IN ('pg_catalog', 'information_schema', 'pg_toast') \ @@ -100,7 +144,7 @@ impl Datasource for PostgresDatasource { AND nspname NOT LIKE 'pg_toast_temp_%' \ ORDER BY nspname", ) - .fetch_all(&self.pool) + .fetch_all(&pool) .await .map_err(introspect_err)?; Ok(rows @@ -115,6 +159,7 @@ impl Datasource for PostgresDatasource { catalog: &str, schema: &str, ) -> DatasourceResult> { + let pool = self.pool(); let rows = sqlx::query( "SELECT table_name AS name, table_type AS kind \ FROM information_schema.tables \ @@ -123,7 +168,7 @@ impl Datasource for PostgresDatasource { ) .bind(catalog) .bind(schema) - .fetch_all(&self.pool) + .fetch_all(&pool) .await .map_err(introspect_err)?; Ok(rows @@ -146,6 +191,7 @@ impl Datasource for PostgresDatasource { schema: &str, table: &str, ) -> DatasourceResult> { + let pool = self.pool(); let rows = sqlx::query( "SELECT column_name AS name, data_type AS type_name, is_nullable \ FROM information_schema.columns \ @@ -155,7 +201,7 @@ impl Datasource for PostgresDatasource { .bind(catalog) .bind(schema) .bind(table) - .fetch_all(&self.pool) + .fetch_all(&pool) .await .map_err(introspect_err)?; Ok(rows @@ -186,6 +232,7 @@ impl Datasource for PostgresDatasource { ) -> DatasourceResult> { // pg_indexes doesn't expose `indisunique`, so we walk pg_class/pg_index // directly to get the uniqueness flag in a single round-trip. + let pool = self.pool(); let rows = sqlx::query( "SELECT i.relname AS name, ix.indisunique AS is_unique \ FROM pg_class i \ @@ -197,7 +244,7 @@ impl Datasource for PostgresDatasource { ) .bind(schema) .bind(table) - .fetch_all(&self.pool) + .fetch_all(&pool) .await .map_err(introspect_err)?; Ok(rows @@ -218,8 +265,10 @@ impl Datasource for PostgresDatasource { let started = Instant::now(); let mut guard = self.session.lock().await; - if guard.is_none() { - let mut conn = self.pool.acquire().await.map_err(|e| { + let was_in_tx = guard.is_some(); + if !was_in_tx { + let pool = self.pool(); + let mut conn = pool.acquire().await.map_err(|e| { self.log.error(TARGET, format!("acquire failed: {e}")); execute_err(e) })?; @@ -232,13 +281,13 @@ impl Datasource for PostgresDatasource { execute_err(e) })?; self.session_pid.store(pid, Ordering::SeqCst); - *guard = Some(conn); + *guard = Some(Session { conn }); } - let conn = guard.as_mut().expect("session conn populated above"); + let session = guard.as_mut().expect("session populated above"); let outcome: Result = if super::is_row_returning(statement, &sqlparser::dialect::PostgreSqlDialect {}) { - match sqlx::query(statement).fetch_all(&mut **conn).await { + match sqlx::query(statement).fetch_all(&mut *session.conn).await { Ok(rows) => { let elapsed = started.elapsed(); let columns = build_columns(&rows); @@ -257,7 +306,7 @@ impl Datasource for PostgresDatasource { Err(e) => Err(e), } } else { - match sqlx::query(statement).execute(&mut **conn).await { + match sqlx::query(statement).execute(&mut *session.conn).await { Ok(outcome) => { let elapsed = started.elapsed(); Ok(QueryResult { @@ -272,6 +321,18 @@ impl Datasource for PostgresDatasource { } }; + // Compute the resulting in-tx flag and decide whether to keep + // the pinned conn. Keeping it across an open tx is mandatory; + // releasing it when no tx is open is what lets the pool shrink + // back to zero and reap idle conns after `IDLE_TIMEOUT_SECS`. + let effect = tx_effect(statement, &sqlparser::dialect::PostgreSqlDialect {}); + let new_in_tx = match (was_in_tx, effect) { + (true, TxEffect::End) => false, + (true, _) => true, + (false, TxEffect::Begin) => true, + (false, _) => false, + }; + match outcome { Ok(r) => { self.log.info( @@ -281,6 +342,10 @@ impl Datasource for PostgresDatasource { None => format!("execute ok: {} rows in {:?}", r.rows.len(), r.elapsed), }, ); + if !new_in_tx { + *guard = None; + self.session_pid.store(0, Ordering::SeqCst); + } Ok(r) } Err(e) => { @@ -290,6 +355,12 @@ impl Datasource for PostgresDatasource { self.session_pid.store(0, Ordering::SeqCst); self.log .warn(TARGET, "session conn dropped after connection loss"); + } else if !was_in_tx { + // No tx to preserve — failure means the session we + // just acquired serves no purpose. Release it so + // the pool can reap the conn. + *guard = None; + self.session_pid.store(0, Ordering::SeqCst); } Err(execute_err(e)) } @@ -307,9 +378,10 @@ impl Datasource for PostgresDatasource { // A separate pool connection is used so the cancel doesn't wait on // the busy backend. `pg_cancel_backend` returns false if the target // PID is no longer running anything — best-effort by design. + let pool = self.pool(); let signaled: bool = sqlx::query_scalar("SELECT pg_cancel_backend($1)") .bind(pid) - .fetch_one(&self.pool) + .fetch_one(&pool) .await .map_err(|e| { self.log.warn(TARGET, format!("cancel failed: {e}")); @@ -323,22 +395,36 @@ impl Datasource for PostgresDatasource { } async fn reset_session(&self) -> DatasourceResult<()> { + // Drop the pinned session conn first (with an explicit ROLLBACK + // so the backend isn't left holding an aborted tx). let mut guard = self.session.lock().await; - if let Some(mut conn) = guard.take() { - // Roll back any in-progress (or aborted) transaction - // before the conn returns to the pool. Postgres leaves a - // tx aborted after a cancel, and a pooled hand-off without - // ROLLBACK would surface "current transaction is aborted" - // on the next checkout from a different caller. Failure - // is logged but swallowed — the session is being discarded - // and a fresh acquire will paper over a poisoned conn. - if let Err(e) = sqlx::query("ROLLBACK").execute(&mut *conn).await { + if let Some(mut session) = guard.take() { + if let Err(e) = sqlx::query("ROLLBACK").execute(&mut *session.conn).await { self.log.warn(TARGET, format!("reset rollback: {e}")); } - drop(conn); + drop(session); self.log.info(TARGET, "session reset"); } self.session_pid.store(0, Ordering::SeqCst); + drop(guard); + + // Force-close all pool connections by swapping in a fresh lazy + // pool. Closing the old pool waits for in-flight queries to + // return their conns, then disconnects them — equivalent to a + // hard reset of every backend this datasource was talking to. + let new_pool = match build_pool(&self.url) { + Ok(p) => p, + Err(e) => { + self.log.error(TARGET, format!("pool rebuild failed: {e}")); + return Err(DatasourceError::Execute(e.to_string())); + } + }; + let old_pool = { + let mut pool_guard = self.pool.lock().expect("pool mutex poisoned"); + std::mem::replace(&mut *pool_guard, new_pool) + }; + old_pool.close().await; + self.log.info(TARGET, "pool drained"); Ok(()) } } diff --git a/src/datasource/sql/sqlite.rs b/src/datasource/sql/sqlite.rs index 753f101..b46ae54 100644 --- a/src/datasource/sql/sqlite.rs +++ b/src/datasource/sql/sqlite.rs @@ -11,7 +11,7 @@ use crate::datasource::error::{DatasourceError, DatasourceResult}; use crate::datasource::schema::{ CatalogInfo, ColumnInfo, DefaultSchema, IndexInfo, SchemaInfo, TableInfo, TableKind, }; -use crate::datasource::sql::decode_to; +use crate::datasource::sql::{TxEffect, decode_to, tx_effect}; use crate::datasource::{Column, Datasource, QueryResult, Row as CellRow}; use crate::log::Logger; @@ -163,6 +163,14 @@ impl Datasource for SqliteDatasource { .collect()) } + // SQLite intentionally keeps the pool as-is (no idle-timeout) — for + // `sqlite::memory:` the database lives in the connection, so reaping + // idle conns would drop the data. The on-network drivers (postgres, + // mysql) configure `min_connections(0)+idle_timeout`; sqlite holds + // its conns indefinitely. Within that constraint we still release + // the pinned session conn after any non-transactional statement so + // file-backed sqlite users at least don't sit on a single pinned + // checkout forever. async fn execute(&self, statement: &str) -> DatasourceResult { self.log.info( TARGET, @@ -171,7 +179,8 @@ impl Datasource for SqliteDatasource { let started = Instant::now(); let mut guard = self.session.lock().await; - if guard.is_none() { + let was_in_tx = guard.is_some(); + if !was_in_tx { let conn = self.pool.acquire().await.map_err(|e| { self.log.error(TARGET, format!("acquire failed: {e}")); execute_err(e) @@ -216,6 +225,14 @@ impl Datasource for SqliteDatasource { } }; + let effect = tx_effect(statement, &sqlparser::dialect::SQLiteDialect {}); + let new_in_tx = match (was_in_tx, effect) { + (true, TxEffect::End) => false, + (true, _) => true, + (false, TxEffect::Begin) => true, + (false, _) => false, + }; + match outcome { Ok(r) => { self.log.info( @@ -225,6 +242,9 @@ impl Datasource for SqliteDatasource { None => format!("execute ok: {} rows in {:?}", r.rows.len(), r.elapsed), }, ); + if !new_in_tx { + *guard = None; + } Ok(r) } Err(e) => { @@ -237,6 +257,10 @@ impl Datasource for SqliteDatasource { *guard = None; self.log .warn(TARGET, "session conn dropped after connection loss"); + } else if !was_in_tx { + // Nothing to preserve — release so the pool gets the + // conn back instead of pinning it across a failure. + *guard = None; } Err(execute_err(e)) } diff --git a/src/event.rs b/src/event.rs index 5aa6633..25058ae 100644 --- a/src/event.rs +++ b/src/event.rs @@ -8,7 +8,7 @@ use ratatui_textarea::Input; use crate::action::{ Action, AuthAction, ChatAction, CommandAction, CompletionAction, ConnFormAction, ConnListAction, HelpAxis, HelpScrollDelta, LlmSettingsAction, MouseTarget, ParamsPromptAction, - ResultColumnAction, ResultNavAction, SchemaAction, + ResultColumnAction, ResultNavAction, SavedQueryAction, SchemaAction, }; use crate::app::App; use crate::export::ExportFormat; @@ -92,6 +92,8 @@ fn translate_key(app: &App, key: KeyEvent, raw: CtEvent) -> Option { Overlay::UpdateAvailable { .. } => translate_update_key(key), Overlay::ConfirmToolUse { .. } => translate_tool_confirm_key(key), Overlay::ParamsPrompt(_) => translate_params_prompt_key(key), + Overlay::ConfirmSaveOverwrite { .. } => translate_save_overwrite_key(key), + Overlay::SavedQueryPicker(_) => translate_saved_query_picker_key(key), }; } match &app.screen { @@ -318,6 +320,40 @@ fn translate_tool_confirm_key(key: KeyEvent) -> Option { } } +/// `:save` overwrite prompt: Enter / `y` accept, Esc / `n` cancel. +fn translate_save_overwrite_key(key: KeyEvent) -> Option { + match key.code { + KeyCode::Char('y') | KeyCode::Char('Y') | KeyCode::Enter => { + Some(Action::SavedQuery(SavedQueryAction::ConfirmOverwrite)) + } + KeyCode::Char('n') | KeyCode::Char('N') | KeyCode::Esc => { + Some(Action::SavedQuery(SavedQueryAction::CancelOverwrite)) + } + _ => None, + } +} + +/// Saved-query picker: vim-style nav, Enter to confirm, Esc/q to dismiss. +fn translate_saved_query_picker_key(key: KeyEvent) -> Option { + use SavedQueryAction as A; + let action = match (key.code, key.modifiers) { + (KeyCode::Char('j') | KeyCode::Down, _) => A::PickerMove(1), + (KeyCode::Char('k') | KeyCode::Up, _) => A::PickerMove(-1), + (KeyCode::Char('n'), m) if m.contains(KeyModifiers::CONTROL) => A::PickerMove(1), + (KeyCode::Char('p'), m) if m.contains(KeyModifiers::CONTROL) => A::PickerMove(-1), + (KeyCode::Char('g'), _) => A::PickerTop, + (KeyCode::Char('G'), _) => A::PickerBottom, + (KeyCode::PageDown, _) => A::PickerMove(10), + (KeyCode::PageUp, _) => A::PickerMove(-10), + (KeyCode::Home, _) => A::PickerTop, + (KeyCode::End, _) => A::PickerBottom, + (KeyCode::Enter, _) => A::PickerConfirm, + (KeyCode::Esc | KeyCode::Char('q'), _) => A::PickerCancel, + _ => return None, + }; + Some(Action::SavedQuery(action)) +} + fn panic_quit(key: KeyEvent) -> Option { // Plain Ctrl+C only — `Ctrl+Shift+C` and `Cmd+C` are clipboard shortcuts // and must never accidentally quit. diff --git a/src/main.rs b/src/main.rs index 9626c84..983230a 100644 --- a/src/main.rs +++ b/src/main.rs @@ -15,6 +15,7 @@ mod keybindings; mod llm; mod log; mod param_history; +mod saved_queries; mod session; mod sql_infer; mod sql_quote; diff --git a/src/saved_queries.rs b/src/saved_queries.rs new file mode 100644 index 0000000..58fac8e --- /dev/null +++ b/src/saved_queries.rs @@ -0,0 +1,208 @@ +//! Per-connection named query storage. +//! +//! Each saved query is one `.sql` file at +//! `/queries//.sql`. The layout +//! mirrors `src/session.rs` so the file tree stays predictable and a +//! human can grep / edit queries externally if they want. + +use std::ffi::OsStr; +use std::fs; +use std::io; +use std::path::{Path, PathBuf}; + +const QUERIES_DIR: &str = "queries"; + +/// Directory holding one connection's saved queries. Caller is +/// responsible for creating it before writing. +pub fn dir_for(data_dir: &Path, connection_name: &str) -> PathBuf { + data_dir.join(QUERIES_DIR).join(sanitize(connection_name)) +} + +/// Path to a single saved query file. Creates no directories. +pub fn path_for(data_dir: &Path, connection_name: &str, name: &str) -> PathBuf { + dir_for(data_dir, connection_name).join(format!("{}.sql", sanitize(name))) +} + +/// Validate a user-supplied query name before persistence. Catches the +/// obviously-bad shapes (empty, `..`, path separators) so the on-disk +/// sanitizer never has to silently collapse them into something that +/// could collide with an unrelated entry. +pub fn validate_name(name: &str) -> Result<(), String> { + let trimmed = name.trim(); + if trimmed.is_empty() { + return Err("query name is empty".to_string()); + } + if trimmed.contains('/') || trimmed.contains('\\') { + return Err("query name may not contain path separators".to_string()); + } + if trimmed == "." || trimmed == ".." { + return Err("query name may not be '.' or '..'".to_string()); + } + Ok(()) +} + +/// Write `sql` to the file for `(connection, name)`, creating parents. +pub fn save(data_dir: &Path, connection_name: &str, name: &str, sql: &str) -> io::Result<()> { + let path = path_for(data_dir, connection_name, name); + if let Some(parent) = path.parent() { + fs::create_dir_all(parent)?; + } + fs::write(&path, sql) +} + +/// Read the saved query body. `NotFound` propagates so callers can +/// surface a clear "no saved query named X" message. +pub fn load(data_dir: &Path, connection_name: &str, name: &str) -> io::Result { + let path = path_for(data_dir, connection_name, name); + fs::read_to_string(path) +} + +/// True iff the file exists on disk. +pub fn exists(data_dir: &Path, connection_name: &str, name: &str) -> bool { + path_for(data_dir, connection_name, name).is_file() +} + +/// Alphabetically sorted list of saved query names for `connection`. +/// Missing directory → empty vec (a fresh connection just has nothing +/// saved yet). +pub fn list(data_dir: &Path, connection_name: &str) -> io::Result> { + let dir = dir_for(data_dir, connection_name); + let entries = match fs::read_dir(&dir) { + Ok(it) => it, + Err(err) if err.kind() == io::ErrorKind::NotFound => return Ok(Vec::new()), + Err(err) => return Err(err), + }; + let mut names: Vec = entries + .flatten() + .filter_map(|e| name_from_filename(&e.file_name())) + .collect(); + names.sort_unstable(); + names.dedup(); + Ok(names) +} + +/// Best-effort delete — idempotent on a missing file. Currently +/// unused; kept for future `:delete-saved` symmetry with `session::delete`. +#[allow(dead_code)] +pub fn delete(data_dir: &Path, connection_name: &str, name: &str) -> io::Result<()> { + let path = path_for(data_dir, connection_name, name); + match fs::remove_file(&path) { + Ok(()) => Ok(()), + Err(err) if err.kind() == io::ErrorKind::NotFound => Ok(()), + Err(err) => Err(err), + } +} + +fn name_from_filename(name: &OsStr) -> Option { + let s = name.to_str()?; + let stem = s.strip_suffix(".sql")?; + if stem.is_empty() { + return None; + } + Some(stem.to_string()) +} + +/// Same rule as `session::sanitize` — keep alnum / `_-.`, replace the +/// rest with `_`, prefix `_` if the result would be `.` / `..` / empty. +fn sanitize(name: &str) -> String { + let mut out = String::with_capacity(name.len()); + for ch in name.chars() { + if ch.is_ascii_alphanumeric() || ch == '_' || ch == '-' || ch == '.' { + out.push(ch); + } else { + out.push('_'); + } + } + if out.is_empty() || out == "." || out == ".." { + out.insert(0, '_'); + } + out +} + +#[cfg(test)] +mod tests { + use super::*; + + fn tempdir() -> PathBuf { + let base = std::env::temp_dir().join(format!( + "rowdy-saved-queries-{}-{}", + std::process::id(), + uuid::Uuid::new_v4() + )); + fs::create_dir_all(&base).unwrap(); + base + } + + #[test] + fn round_trip_save_load() { + let dir = tempdir(); + save(&dir, "conn", "daily", "SELECT 1;\n").unwrap(); + assert_eq!(load(&dir, "conn", "daily").unwrap(), "SELECT 1;\n"); + } + + #[test] + fn exists_reports_disk_state() { + let dir = tempdir(); + assert!(!exists(&dir, "conn", "x")); + save(&dir, "conn", "x", "SELECT 1;").unwrap(); + assert!(exists(&dir, "conn", "x")); + } + + #[test] + fn list_returns_sorted_names() { + let dir = tempdir(); + save(&dir, "c", "beta", "1").unwrap(); + save(&dir, "c", "alpha", "1").unwrap(); + save(&dir, "c", "gamma", "1").unwrap(); + // Stray non-sql file is ignored. + let stray = dir_for(&dir, "c").join("notes.txt"); + fs::write(&stray, "hi").unwrap(); + assert_eq!(list(&dir, "c").unwrap(), vec!["alpha", "beta", "gamma"]); + } + + #[test] + fn list_returns_empty_when_dir_missing() { + let dir = tempdir(); + assert!(list(&dir, "fresh").unwrap().is_empty()); + } + + #[test] + fn per_connection_isolation() { + let dir = tempdir(); + save(&dir, "prod", "same", "SELECT 1;").unwrap(); + save(&dir, "stage", "same", "SELECT 2;").unwrap(); + assert_eq!(load(&dir, "prod", "same").unwrap(), "SELECT 1;"); + assert_eq!(load(&dir, "stage", "same").unwrap(), "SELECT 2;"); + } + + #[test] + fn delete_is_idempotent() { + let dir = tempdir(); + delete(&dir, "c", "missing").unwrap(); + save(&dir, "c", "x", "1").unwrap(); + delete(&dir, "c", "x").unwrap(); + delete(&dir, "c", "x").unwrap(); + assert!(!exists(&dir, "c", "x")); + } + + #[test] + fn sanitize_replaces_unsafe_chars() { + // Names with spaces / slashes are caught up-front by validate_name, + // but the on-disk sanitizer must still produce a safe path so any + // bypass (e.g. connection names supplied externally) can't escape. + assert_eq!(sanitize("a b/c"), "a_b_c"); + assert_eq!(sanitize(".."), "_.."); + assert_eq!(sanitize(""), "_"); + } + + #[test] + fn validate_name_rejects_bad_shapes() { + assert!(validate_name("").is_err()); + assert!(validate_name(" ").is_err()); + assert!(validate_name("a/b").is_err()); + assert!(validate_name("a\\b").is_err()); + assert!(validate_name(".").is_err()); + assert!(validate_name("..").is_err()); + assert!(validate_name("ok-name_1.2").is_ok()); + } +} diff --git a/src/state/editor.rs b/src/state/editor.rs index e1a4b3f..2b78333 100644 --- a/src/state/editor.rs +++ b/src/state/editor.rs @@ -221,6 +221,33 @@ pub fn replace_selection_text(state: &mut EditorState, replacement: &str) -> boo true } +/// Insert `text` at the current cursor position. Splits the buffer at +/// the cursor, splices in `text`, and reparses through `Lines::from` +/// so embedded newlines become real rows. The cursor lands at the +/// end of the inserted text. Selection is dropped; mode is left at +/// Normal so the user can immediately keep navigating. +/// +/// edtui's undo stack only tracks edits driven through its own action +/// handler; this primitive mutates `lines` directly, so a single `u` +/// won't unwind a `:load`. Same caveat as `replace_buffer_text`. +pub fn insert_text_at_cursor(state: &mut EditorState, text: &str) { + let chars: Vec = state.lines.flatten(&Some('\n')); + let cursor_off = cursor_to_offset(state).min(chars.len()); + + let mut next = String::with_capacity(chars.len() + text.len()); + next.extend(chars[..cursor_off].iter()); + next.push_str(text); + next.extend(chars[cursor_off..].iter()); + + state.lines = Lines::from(next.as_str()); + state.selection = None; + state.mode = EditorMode::Normal; + + let new_chars: Vec = state.lines.flatten(&Some('\n')); + let after_off = (cursor_off + text.chars().count()).min(new_chars.len()); + state.cursor = clamp_index(&state.lines, offset_to_index(&new_chars, after_off)); +} + pub fn cursor_to_offset(state: &EditorState) -> usize { let mut offset = 0; for row in 0..state.cursor.row { @@ -544,6 +571,27 @@ mod tests { assert_eq!(cur.statement, "SELECT * FROM t WHERE x = ';' AND y = 1"); } + #[test] + fn insert_text_at_cursor_splices_inline() { + let mut state = EditorState::new(Lines::from("AB")); + state.cursor = Index2::new(0, 1); + insert_text_at_cursor(&mut state, "x"); + assert_eq!(flatten(&state), "AxB"); + // Cursor sits at the end of the inserted text. + assert_eq!(state.cursor, Index2::new(0, 2)); + } + + #[test] + fn insert_text_at_cursor_handles_multiline_payload() { + let mut state = EditorState::new(Lines::from("ab")); + state.cursor = Index2::new(0, 1); + insert_text_at_cursor(&mut state, "X\nY"); + // a + "X\nY" + b → "aX\nYb" + assert_eq!(flatten(&state), "aX\nYb"); + // Cursor lands right after the inserted text (row 1, col 1 = after Y). + assert_eq!(state.cursor, Index2::new(1, 1)); + } + #[test] fn clamp_index_keeps_position_inside_buffer() { let lines = Lines::from("ab\ncde"); diff --git a/src/state/mod.rs b/src/state/mod.rs index 1d4cccf..4304f80 100644 --- a/src/state/mod.rs +++ b/src/state/mod.rs @@ -12,6 +12,7 @@ pub mod overlay; pub mod params_prompt; pub mod results; pub mod right_panel; +pub mod saved_query_picker; pub mod schema; pub mod screen; pub mod status; diff --git a/src/state/overlay.rs b/src/state/overlay.rs index d63d55b..8a5ae8d 100644 --- a/src/state/overlay.rs +++ b/src/state/overlay.rs @@ -10,6 +10,7 @@ use crate::state::command::CommandBuffer; use crate::state::llm_settings::LlmSettingsState; use crate::state::params_prompt::ParamsPromptState; +use crate::state::saved_query_picker::SavedQueryPickerState; /// The layer that floats over the current [`crate::state::screen::Screen`]. // @@ -67,6 +68,14 @@ pub enum Overlay { name: String, args_json: String, }, + /// `:save ` collided with an existing entry. Enter overwrites + /// the file with `sql`; Esc cancels and leaves the previous entry + /// untouched. + ConfirmSaveOverwrite { name: String, sql: String }, + /// Modal list of saved query names — opens via bare `:load` / + /// `:run-saved`. `purpose` decides whether Enter inserts the body + /// at the cursor or dispatches it through the query pipeline. + SavedQueryPicker(SavedQueryPickerState), } /// Why the confirm-run overlay opened. Drives the headline at the top diff --git a/src/state/saved_query_picker.rs b/src/state/saved_query_picker.rs new file mode 100644 index 0000000..f185755 --- /dev/null +++ b/src/state/saved_query_picker.rs @@ -0,0 +1,45 @@ +//! Modal list of saved queries — opens via bare `:run-saved`. +//! +//! Same shape as `state::conn_list` but lives as an overlay (not a +//! screen) so the editor stays visible underneath; Esc returns the user +//! to whatever they were typing without disturbing it. Enter dispatches +//! the chosen query through the usual `dispatch_query` path so the +//! params prompt still fires for placeholder-bearing SQL. + +#[derive(Debug)] +pub struct SavedQueryPickerState { + pub entries: Vec, + pub selected: usize, +} + +impl SavedQueryPickerState { + pub fn new(entries: Vec) -> Self { + Self { + entries, + selected: 0, + } + } + + pub fn selected_name(&self) -> Option<&str> { + self.entries.get(self.selected).map(String::as_str) + } + + pub fn move_selection(&mut self, delta: i32) { + if self.entries.is_empty() { + return; + } + let max = self.entries.len() as i32 - 1; + let next = (self.selected as i32 + delta).clamp(0, max); + self.selected = next as usize; + } + + pub fn jump_top(&mut self) { + self.selected = 0; + } + + pub fn jump_bottom(&mut self) { + if !self.entries.is_empty() { + self.selected = self.entries.len() - 1; + } + } +} diff --git a/src/ui/bottom_bar.rs b/src/ui/bottom_bar.rs index 730927a..74db6ce 100644 --- a/src/ui/bottom_bar.rs +++ b/src/ui/bottom_bar.rs @@ -66,6 +66,14 @@ impl Widget for BottomBar<'_> { render_tool_confirm(name, args_json, area, buf, &self.app.theme); return; } + Some(Overlay::ConfirmSaveOverwrite { name, .. }) => { + render_save_overwrite(name, area, buf, &self.app.theme); + return; + } + Some(Overlay::SavedQueryPicker(_)) => { + // Picker owns its own footer line. + return; + } None => {} } match &self.app.screen { @@ -234,6 +242,24 @@ fn summarise_tool_args(name: &str, args_json: &str) -> String { } } +fn render_save_overwrite(name: &str, area: Rect, buf: &mut Buffer, theme: &Theme) { + let line = Line::from(vec![ + Span::styled("⚠ ", Style::default().fg(theme.status_error).bg(theme.bg)), + Span::styled( + format!("overwrite saved query {name:?}?"), + Style::default() + .fg(theme.fg) + .bg(theme.bg) + .add_modifier(Modifier::BOLD), + ), + Span::styled( + " y/Enter to overwrite · n/Esc to cancel", + Style::default().fg(theme.fg_dim).bg(theme.bg), + ), + ]); + line.render(area, buf); +} + fn paint_background(area: Rect, buf: &mut Buffer, theme: &Theme) { for x in area.x..area.x + area.width { if let Some(cell) = buf.cell_mut((x, area.y)) { diff --git a/src/ui/help_view.rs b/src/ui/help_view.rs index 8bbbde8..142a552 100644 --- a/src/ui/help_view.rs +++ b/src/ui/help_view.rs @@ -755,6 +755,18 @@ const HELP_SECTIONS: &[HelpSection] = &[ keys: ":update", desc: "Check GitHub for a new release (manual; bypasses 24h throttle)", }, + HelpEntry { + keys: ":save ", + desc: "Save selection (or statement under cursor) as a named query", + }, + HelpEntry { + keys: ":load ", + desc: "Insert a saved query at the cursor", + }, + HelpEntry { + keys: ":run-saved [name]", + desc: "Run a saved query (bare form opens a picker)", + }, ], }, ]; diff --git a/src/ui/mod.rs b/src/ui/mod.rs index 31ae9aa..82ffed1 100644 --- a/src/ui/mod.rs +++ b/src/ui/mod.rs @@ -10,6 +10,7 @@ pub mod help_view; pub mod llm_settings_view; pub mod params_prompt_view; pub mod results_view; +pub mod saved_query_picker_view; pub mod schema_view; pub mod theme; pub mod theme_picker_view; @@ -36,6 +37,7 @@ use help_view::HelpPopover; use llm_settings_view::LlmSettingsForm; use params_prompt_view::ParamsPrompt; use results_view::{ExpandedResult, InlineResult}; +use saved_query_picker_view::SavedQueryPicker; use schema_view::SchemaPane; const INLINE_RESULT_HEIGHT: u16 = 10; @@ -77,6 +79,14 @@ pub fn render(app: &mut App, frame: &mut Frame) { }; frame.render_widget(widget, area); } + if let Some(Overlay::SavedQueryPicker(state)) = &app.overlay { + let widget = SavedQueryPicker { + state, + connection: app.active_connection.as_deref(), + theme: &app.theme, + }; + frame.render_widget(widget, area); + } } fn render_theme_picker(app: &mut App, frame: &mut Frame, full: Rect, bottom_area: Rect) { diff --git a/src/ui/saved_query_picker_view.rs b/src/ui/saved_query_picker_view.rs new file mode 100644 index 0000000..c0845dd --- /dev/null +++ b/src/ui/saved_query_picker_view.rs @@ -0,0 +1,109 @@ +use ratatui::buffer::Buffer; +use ratatui::layout::{Constraint, Layout, Rect}; +use ratatui::style::{Modifier, Style}; +use ratatui::text::{Line, Span}; +use ratatui::widgets::{Block, Borders, Paragraph, Widget, Wrap}; + +use crate::state::saved_query_picker::SavedQueryPickerState; +use crate::ui::theme::Theme; + +pub struct SavedQueryPicker<'a> { + pub state: &'a SavedQueryPickerState, + pub connection: Option<&'a str>, + pub theme: &'a Theme, +} + +impl Widget for SavedQueryPicker<'_> { + fn render(self, area: Rect, buf: &mut Buffer) { + let Some(box_area) = inner_box(area, self.state.entries.len()) else { + return; + }; + let title = match self.connection { + Some(c) => format!(" run saved query — {c} "), + None => " run saved query ".to_string(), + }; + let block = Block::default() + .borders(Borders::ALL) + .border_style( + Style::default() + .fg(self.theme.border_focus) + .bg(self.theme.bg), + ) + .title(title) + .title_style( + Style::default() + .fg(self.theme.fg) + .bg(self.theme.bg) + .add_modifier(Modifier::BOLD), + ) + .style(Style::default().bg(self.theme.bg)); + let inner = block.inner(box_area); + block.render(box_area, buf); + + let entries_h = (self.state.entries.len() as u16).max(1); + let chunks = Layout::vertical([ + Constraint::Length(entries_h), + Constraint::Length(1), + Constraint::Length(2), + ]) + .split(inner); + + if self.state.entries.is_empty() { + Paragraph::new("(no saved queries)") + .style(Style::default().fg(self.theme.fg_dim).bg(self.theme.bg)) + .render(chunks[0], buf); + } else { + let lines: Vec = self + .state + .entries + .iter() + .enumerate() + .map(|(i, name)| entry_line(name, i == self.state.selected, self.theme)) + .collect(); + Paragraph::new(lines) + .style(Style::default().fg(self.theme.fg).bg(self.theme.bg)) + .render(chunks[0], buf); + } + + let footer = "j/k move · Enter pick · Esc close"; + Paragraph::new(footer) + .style(Style::default().fg(self.theme.fg_dim).bg(self.theme.bg)) + .wrap(Wrap { trim: true }) + .render(chunks[2], buf); + } +} + +fn entry_line<'a>(name: &str, selected: bool, theme: &Theme) -> Line<'a> { + let bg = if selected { + theme.selection_bg + } else { + theme.bg + }; + let fg = if selected { + theme.selection_fg + } else { + theme.fg + }; + Line::from(vec![ + Span::styled(" ".to_string(), Style::default().fg(fg).bg(bg)), + Span::styled(name.to_string(), Style::default().fg(fg).bg(bg)), + ]) +} + +pub fn inner_box(area: Rect, entry_count: usize) -> Option { + let width = area.width.min(72); + let needed_inner = (entry_count.max(1) as u16).saturating_add(3); + let needed = needed_inner.saturating_add(2); + let height = needed.clamp(8, 24).min(area.height); + if width < 40 || height < 8 { + return None; + } + let x = area.x + (area.width.saturating_sub(width)) / 2; + let y = area.y + (area.height.saturating_sub(height)) / 2; + Some(Rect { + x, + y, + width, + height, + }) +}