diff --git a/.github/workflows/build-multiplatform.yml b/.github/workflows/build-multiplatform.yml index 79c4d0c8..fc21034d 100644 --- a/.github/workflows/build-multiplatform.yml +++ b/.github/workflows/build-multiplatform.yml @@ -30,7 +30,7 @@ jobs: - name: Setup Bun uses: oven-sh/setup-bun@v2 with: - bun-version: latest + bun-version: '1.3.10' - name: Cache Bun dependencies uses: actions/cache@v4 diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 29a31554..b612f621 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -22,7 +22,7 @@ jobs: - name: Setup Bun uses: oven-sh/setup-bun@v2 with: - bun-version: latest + bun-version: '1.3.10' - name: Cache Bun dependencies uses: actions/cache@v4 @@ -76,7 +76,7 @@ jobs: - name: Setup Bun uses: oven-sh/setup-bun@v2 with: - bun-version: latest + bun-version: '1.3.10' - name: Cache Bun dependencies uses: actions/cache@v4 diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 167de31a..999b49df 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -120,7 +120,7 @@ jobs: - name: Setup Bun uses: oven-sh/setup-bun@v2 with: - bun-version: latest + bun-version: '1.3.10' - name: Cache Bun dependencies uses: actions/cache@v4 @@ -392,32 +392,3 @@ jobs: merged/latest.json env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - - upgradelink-upload: - needs: release - permissions: - contents: read - runs-on: ubuntu-22.04 - steps: - - name: Resolve release tag - id: release_meta - shell: bash - env: - TAG_NAME: ${{ inputs.tag || github.ref_name }} - run: | - set -euo pipefail - echo "release_tag=${TAG_NAME}" >> "$GITHUB_OUTPUT" - - - name: Upload latest.json to UpgradeLink - uses: toolsetlink/upgradelink-action@3.0.2 - with: - access_key: ${{ secrets.UPGRADE_LINK_ACCESS_KEY }} - access_secret: ${{ secrets.UPGRADE_LINK_ACCESS_SECRET }} - config: | - { - "app_type": "tauri", - "request": { - "app_key": "${{ secrets.UPGRADE_LINK_TAURI_KEY }}", - "latest_json_url": "https://github.com/${{ github.repository }}/releases/download/${{ steps.release_meta.outputs.release_tag }}/latest.json" - } - } diff --git a/.github/workflows/upgradelink-upload.yml b/.github/workflows/upgradelink-upload.yml new file mode 100644 index 00000000..1e856da5 --- /dev/null +++ b/.github/workflows/upgradelink-upload.yml @@ -0,0 +1,29 @@ +name: UpgradeLink Upload + +on: + workflow_dispatch: + inputs: + tag: + description: "Release tag to upload (e.g. v1.2.3)" + required: true + type: string + +jobs: + upgradelink-upload: + permissions: + contents: read + runs-on: ubuntu-22.04 + steps: + - name: Upload latest.json to UpgradeLink + uses: toolsetlink/upgradelink-action@3.0.2 + with: + access_key: ${{ secrets.UPGRADE_LINK_ACCESS_KEY }} + access_secret: ${{ secrets.UPGRADE_LINK_ACCESS_SECRET }} + config: | + { + "app_type": "tauri", + "request": { + "app_key": "${{ secrets.UPGRADE_LINK_TAURI_KEY }}", + "latest_json_url": "https://github.com/${{ github.repository }}/releases/download/${{ inputs.tag }}/latest.json" + } + } diff --git a/package.json b/package.json index 6331011d..1f9a343b 100644 --- a/package.json +++ b/package.json @@ -5,7 +5,7 @@ "url": "git+https://github.com/codeErrorSleep/dbpaw.git" }, "private": true, - "version": "0.3.2", + "version": "0.3.3", "type": "module", "scripts": { "dev": "vite", diff --git a/scripts/test-integration.sh b/scripts/test-integration.sh index b6079307..6a4ccdad 100755 --- a/scripts/test-integration.sh +++ b/scripts/test-integration.sh @@ -39,6 +39,10 @@ case "${it_db}" in run_integration_test "mysql_command_integration" run_integration_test "mysql_stateful_command_integration" ;; + starrocks) + run_integration_test "starrocks_integration" + run_integration_test "starrocks_command_integration" + ;; mariadb) run_integration_test "mariadb_integration" run_integration_test "mariadb_command_integration" @@ -89,7 +93,7 @@ case "${it_db}" in run_integration_test "oracle_command_integration" ;; *) - echo "[error] Invalid IT_DB='${it_db}'. Expected one of: mysql|mariadb|postgres|clickhouse|mssql|duckdb|sqlite|oracle|all" + echo "[error] Invalid IT_DB='${it_db}'. Expected one of: mysql|starrocks|mariadb|postgres|clickhouse|mssql|duckdb|sqlite|oracle|all" exit 1 ;; esac diff --git a/src-tauri/Cargo.toml b/src-tauri/Cargo.toml index 76ff1e4c..0fe96189 100644 --- a/src-tauri/Cargo.toml +++ b/src-tauri/Cargo.toml @@ -18,7 +18,7 @@ crate-type = ["staticlib", "cdylib", "rlib"] tauri-build = { version = "2", features = [] } [dependencies] -tauri = { version = "2", features = [] } +tauri = { version = "2", features = ["devtools"] } tauri-plugin-opener = "2" tauri-plugin-dialog = "2" tauri-plugin-fs = "2" diff --git a/src-tauri/src/commands/connection.rs b/src-tauri/src/commands/connection.rs index 452f34e0..e65aa03d 100644 --- a/src-tauri/src/commands/connection.rs +++ b/src-tauri/src/commands/connection.rs @@ -267,7 +267,7 @@ pub async fn create_database_by_id( } let exec_res = match driver.as_str() { - "mysql" | "mariadb" | "tidb" => { + driver if crate::db::drivers::is_mysql_family_driver(driver) => { let sql = build_mysql_create_database_sql(&payload, &db_name)?; super::execute_with_retry(&state, id, None, |driver| { let sql_clone = sql.clone(); @@ -348,7 +348,7 @@ pub async fn create_database_by_id_direct( } let exec_res = match driver.as_str() { - "mysql" | "mariadb" | "tidb" => { + driver if crate::db::drivers::is_mysql_family_driver(driver) => { let sql = build_mysql_create_database_sql(&payload, &db_name)?; super::execute_with_retry_from_app_state(state, id, None, |driver| { let sql_clone = sql.clone(); @@ -419,6 +419,122 @@ pub async fn test_connection_ephemeral( }) } +#[tauri::command] +pub async fn get_mysql_charsets_by_id( + state: State<'_, AppState>, + id: i64, +) -> Result, String> { + super::execute_with_retry(&state, id, None, |driver| async move { + let result = driver + .execute_query("SHOW CHARACTER SET".to_string()) + .await?; + let mut charsets: Vec = result + .data + .iter() + .filter_map(|row| { + row.get("Charset") + .and_then(|v| v.as_str()) + .map(|s| s.to_string()) + }) + .collect(); + charsets.sort(); + Ok(charsets) + }) + .await +} + +#[tauri::command] +pub async fn get_mysql_collations_by_id( + state: State<'_, AppState>, + id: i64, + charset: Option, +) -> Result, String> { + let sql = match &charset { + Some(cs) if is_safe_option_token(cs) => { + format!("SHOW COLLATION WHERE Charset = '{}'", cs) + } + Some(cs) => { + return Err(format!("[VALIDATION_ERROR] Invalid charset: {}", cs)); + } + None => "SHOW COLLATION".to_string(), + }; + super::execute_with_retry(&state, id, None, |driver| { + let sql = sql.clone(); + async move { + let result = driver.execute_query(sql).await?; + let mut collations: Vec = result + .data + .iter() + .filter_map(|row| { + row.get("Collation") + .and_then(|v| v.as_str()) + .map(|s| s.to_string()) + }) + .collect(); + collations.sort(); + Ok(collations) + } + }) + .await +} + +pub async fn get_mysql_charsets_by_id_direct( + state: &AppState, + id: i64, +) -> Result, String> { + super::execute_with_retry_from_app_state(state, id, None, |driver| async move { + let result = driver + .execute_query("SHOW CHARACTER SET".to_string()) + .await?; + let mut charsets: Vec = result + .data + .iter() + .filter_map(|row| { + row.get("Charset") + .and_then(|v| v.as_str()) + .map(|s| s.to_string()) + }) + .collect(); + charsets.sort(); + Ok(charsets) + }) + .await +} + +pub async fn get_mysql_collations_by_id_direct( + state: &AppState, + id: i64, + charset: Option, +) -> Result, String> { + let sql = match &charset { + Some(cs) if is_safe_option_token(cs) => { + format!("SHOW COLLATION WHERE Charset = '{}'", cs) + } + Some(cs) => { + return Err(format!("[VALIDATION_ERROR] Invalid charset: {}", cs)); + } + None => "SHOW COLLATION".to_string(), + }; + super::execute_with_retry_from_app_state(state, id, None, |driver| { + let sql = sql.clone(); + async move { + let result = driver.execute_query(sql).await?; + let mut collations: Vec = result + .data + .iter() + .filter_map(|row| { + row.get("Collation") + .and_then(|v| v.as_str()) + .map(|s| s.to_string()) + }) + .collect(); + collations.sort(); + Ok(collations) + } + }) + .await +} + #[tauri::command] pub async fn get_connections(state: State<'_, AppState>) -> Result, String> { let local_db = { @@ -553,8 +669,8 @@ mod tests { validate_database_name, CreateDatabasePayload, }; use super::{ - normalize_create_database_error, normalize_option_token, quote_clickhouse_ident, - quote_mssql_ident, quote_mysql_ident, quote_pg_ident, + is_safe_option_token, normalize_create_database_error, normalize_option_token, + quote_clickhouse_ident, quote_mssql_ident, quote_mysql_ident, quote_pg_ident, }; use crate::connection_input::normalize_connection_form; use crate::models::ConnectionForm; @@ -734,4 +850,52 @@ mod tests { .unwrap_err(); assert!(err.contains("does not support charset option")); } + + #[test] + fn get_mysql_collations_charset_validation_rejects_unsafe_tokens() { + // Verify the validation logic used by get_mysql_collations_by_id/_direct. + // A charset with spaces or semicolons must be rejected. + assert!(!is_safe_option_token("utf8 mb4")); + assert!(!is_safe_option_token("utf8;drop")); + assert!(!is_safe_option_token("")); + } + + #[test] + fn get_mysql_collations_charset_validation_accepts_valid_charsets() { + // All standard MySQL charset names must pass the token check. + let valid = [ + "utf8mb4", + "utf8", + "latin1", + "gbk", + "gb18030", + "ascii", + "binary", + "utf8mb4_0900_ai_ci", + ]; + for cs in valid { + assert!(is_safe_option_token(cs), "expected '{}' to be accepted", cs); + } + } + + #[test] + fn mysql_create_database_sql_is_reusable_for_starrocks_connections() { + assert!(crate::db::drivers::is_mysql_family_driver("starrocks")); + + let sql = build_mysql_create_database_sql( + &CreateDatabasePayload { + name: "analytics".to_string(), + if_not_exists: Some(true), + charset: None, + collation: None, + encoding: None, + lc_collate: None, + lc_ctype: None, + }, + "analytics", + ) + .unwrap(); + + assert_eq!(sql, "CREATE DATABASE IF NOT EXISTS `analytics`"); + } } diff --git a/src-tauri/src/commands/query.rs b/src-tauri/src/commands/query.rs index 196a61f6..35aa835a 100644 --- a/src-tauri/src/commands/query.rs +++ b/src-tauri/src/commands/query.rs @@ -76,6 +76,39 @@ fn skip_backtick_quote(bytes: &[u8], mut i: usize) -> usize { i } +fn parse_dollar_quote_tag(bytes: &[u8], start: usize) -> Option { + if bytes.get(start) != Some(&b'$') { + return None; + } + let mut i = start + 1; + while i < bytes.len() && (bytes[i].is_ascii_alphanumeric() || bytes[i] == b'_') { + i += 1; + } + if bytes.get(i) == Some(&b'$') { + Some(i) + } else { + None + } +} + +fn skip_dollar_quote(bytes: &[u8], start: usize) -> usize { + let Some(tag_end) = parse_dollar_quote_tag(bytes, start) else { + return start + 1; + }; + let tag = &bytes[start..=tag_end]; + let tag_len = tag.len(); + let mut i = tag_end + 1; + + while i + tag_len <= bytes.len() { + if &bytes[i..i + tag_len] == tag { + return i + tag_len; + } + i += 1; + } + + bytes.len() +} + fn skip_line_comment(bytes: &[u8], mut i: usize) -> usize { i += 2; while i < bytes.len() && bytes[i] != b'\n' { @@ -148,6 +181,13 @@ fn is_single_statement(sql: &str) -> bool { i = skip_backtick_quote(bytes, i); continue; } + if b == b'$' { + let next = skip_dollar_quote(bytes, i); + if next != i + 1 { + i = next; + continue; + } + } if b == b'(' { depth += 1; i += 1; @@ -216,6 +256,13 @@ fn collect_top_level_keywords(sql: &str) -> Vec { i = skip_backtick_quote(bytes, i); continue; } + if b == b'$' { + let next = skip_dollar_quote(bytes, i); + if next != i + 1 { + i = next; + continue; + } + } if b == b'(' { depth += 1; i += 1; @@ -338,6 +385,13 @@ fn has_top_level_limit(sql: &str) -> bool { i = skip_backtick_quote(bytes, i); continue; } + if b == b'$' { + let next = skip_dollar_quote(bytes, i); + if next != i + 1 { + i = next; + continue; + } + } if b == b'(' { depth += 1; i += 1; @@ -1215,6 +1269,12 @@ mod tests { assert!(is_single_statement("SELECT 'a; b'")); assert!(is_single_statement("SELECT \"a; b\"")); assert!(is_single_statement("SELECT `a; b`")); + assert!(is_single_statement( + "CREATE FUNCTION f() RETURNS void AS $$ BEGIN PERFORM 1; END; $$ LANGUAGE plpgsql;" + )); + assert!(is_single_statement( + "CREATE FUNCTION f() RETURNS text AS $tag$ BEGIN RETURN ';'; END; $tag$ LANGUAGE plpgsql;" + )); assert!(!is_single_statement("SELECT 1; SELECT 2")); } @@ -1234,6 +1294,18 @@ mod tests { assert!(tokens.contains(&"from".to_string())); } + #[test] + fn collect_top_level_keywords_skips_dollar_quoted_bodies() { + let tokens = collect_top_level_keywords( + "CREATE FUNCTION f() RETURNS void AS $$ BEGIN SELECT 1; END; $$ LANGUAGE plpgsql", + ); + assert_eq!(tokens.first().map(String::as_str), Some("create")); + assert!(tokens.contains(&"function".to_string())); + assert!(!tokens + .iter() + .any(|token| token == "begin" || token == "end")); + } + #[test] fn statement_kind_for_limit_guard_classifies_with_queries() { assert_eq!( diff --git a/src-tauri/src/commands/transfer.rs b/src-tauri/src/commands/transfer.rs index fbbe0f5f..8428a5ac 100644 --- a/src-tauri/src/commands/transfer.rs +++ b/src-tauri/src/commands/transfer.rs @@ -1,9 +1,11 @@ +use crate::db::drivers::DatabaseDriver; use crate::state::AppState; use serde::{Deserialize, Serialize}; use serde_json::{Map, Value}; use std::fs::{self, File}; use std::io::{BufWriter, Write}; use std::path::{Path, PathBuf}; +use std::sync::Arc; use tauri::State; const DEFAULT_CHUNK_SIZE: i64 = 2000; @@ -15,7 +17,9 @@ const MAX_IMPORT_STATEMENTS: usize = 50_000; pub enum ExportFormat { Csv, Json, - Sql, + SqlDml, + SqlDdl, + SqlFull, } #[derive(Debug, Clone, Deserialize)] @@ -60,6 +64,102 @@ struct PreparedImportPlan { script_managed_transaction: bool, } +async fn do_table_export( + db_driver: Arc, + output_path: PathBuf, + schema: String, + table: String, + driver: String, + format: ExportFormat, + scope: ExportScope, + filter: Option, + order_by: Option, + sort_column: Option, + sort_direction: Option, + page: Option, + limit: Option, + chunk: i64, +) -> Result { + let mut writer = ExportWriter::new(output_path.clone(), format.clone())?; + let mut exported = 0i64; + + if matches!(format, ExportFormat::SqlDdl | ExportFormat::SqlFull) { + let ddl = db_driver + .get_table_ddl(schema.clone(), table.clone()) + .await?; + writer.write_ddl(&ddl)?; + } + + if !matches!(format, ExportFormat::SqlDdl) { + let columns: Vec = db_driver + .get_table_metadata(schema.clone(), table.clone()) + .await? + .columns + .into_iter() + .map(|c| c.name) + .collect(); + + writer.write_csv_header(&columns)?; + + match scope { + ExportScope::CurrentPage => { + let resp = db_driver + .get_table_data_chunk( + schema.clone(), + table.clone(), + page.unwrap_or(1).max(1), + limit.unwrap_or(50).max(1), + sort_column, + sort_direction, + filter, + order_by, + ) + .await?; + exported += + writer.write_rows(&resp.data, &columns, Some(&schema), &table, &driver)?; + } + ExportScope::Filtered | ExportScope::FullTable => { + let (eff_filter, eff_order, eff_sort_col, eff_sort_dir) = + if matches!(scope, ExportScope::Filtered) { + (filter, order_by, sort_column, sort_direction) + } else { + (None, None, None, None) + }; + let mut current_page = 1; + loop { + let resp = db_driver + .get_table_data_chunk( + schema.clone(), + table.clone(), + current_page, + chunk, + eff_sort_col.clone(), + eff_sort_dir.clone(), + eff_filter.clone(), + eff_order.clone(), + ) + .await?; + if resp.data.is_empty() { + break; + } + exported += + writer.write_rows(&resp.data, &columns, Some(&schema), &table, &driver)?; + if exported >= resp.total { + break; + } + current_page += 1; + } + } + } + } + + writer.finish()?; + Ok(ExportResult { + file_path: output_path.to_string_lossy().to_string(), + row_count: exported, + }) +} + #[tauri::command] pub async fn export_table_data( state: State<'_, AppState>, @@ -80,9 +180,7 @@ pub async fn export_table_data( chunk_size: Option, ) -> Result { let output_path = resolve_output_path(file_path, &table, extension_for_format(&format))?; - let chunk = chunk_size.unwrap_or(DEFAULT_CHUNK_SIZE).max(1); - super::execute_with_retry(&state, id, database, |db_driver| { let output_path = output_path.clone(); let schema = schema.clone(); @@ -95,97 +193,23 @@ pub async fn export_table_data( let scope = scope.clone(); let format = format.clone(); async move { - let columns = db_driver - .get_table_metadata(schema.clone(), table.clone()) - .await? - .columns - .into_iter() - .map(|c| c.name) - .collect::>(); - - let mut writer = - ExportWriter::new(output_path.clone(), format.clone(), columns.clone())?; - let mut exported = 0i64; - - match scope { - ExportScope::CurrentPage => { - let use_page = page.unwrap_or(1).max(1); - let use_limit = limit.unwrap_or(50).max(1); - let resp = db_driver - .get_table_data_chunk( - schema.clone(), - table.clone(), - use_page, - use_limit, - sort_column.clone(), - sort_direction.clone(), - filter.clone(), - order_by.clone(), - ) - .await?; - exported += - writer.write_rows(&resp.data, &columns, Some(&schema), &table, &driver)?; - } - ExportScope::Filtered | ExportScope::FullTable => { - let filter_for_scope = if matches!(scope, ExportScope::Filtered) { - filter.clone() - } else { - None - }; - let order_for_scope = if matches!(scope, ExportScope::Filtered) { - order_by.clone() - } else { - None - }; - let sort_col_for_scope = if matches!(scope, ExportScope::Filtered) { - sort_column.clone() - } else { - None - }; - let sort_dir_for_scope = if matches!(scope, ExportScope::Filtered) { - sort_direction.clone() - } else { - None - }; - - let mut current_page = 1; - loop { - let resp = db_driver - .get_table_data_chunk( - schema.clone(), - table.clone(), - current_page, - chunk, - sort_col_for_scope.clone(), - sort_dir_for_scope.clone(), - filter_for_scope.clone(), - order_for_scope.clone(), - ) - .await?; - if resp.data.is_empty() { - break; - } - - exported += writer.write_rows( - &resp.data, - &columns, - Some(&schema), - &table, - &driver, - )?; - if exported >= resp.total { - break; - } - current_page += 1; - } - } - } - - writer.finish()?; - Ok(ExportResult { - file_path: output_path.to_string_lossy().to_string(), - row_count: exported, - }) + do_table_export( + db_driver, + output_path, + schema, + table, + driver, + format, + scope, + filter, + order_by, + sort_column, + sort_direction, + page, + limit, + chunk, + ) + .await } }) .await @@ -211,7 +235,6 @@ pub async fn export_table_data_direct( ) -> Result { let output_path = resolve_output_path(file_path, &table, extension_for_format(&format))?; let chunk = chunk_size.unwrap_or(DEFAULT_CHUNK_SIZE).max(1); - super::execute_with_retry_from_app_state(state, id, database, |db_driver| { let output_path = output_path.clone(); let schema = schema.clone(); @@ -224,97 +247,23 @@ pub async fn export_table_data_direct( let scope = scope.clone(); let format = format.clone(); async move { - let columns = db_driver - .get_table_metadata(schema.clone(), table.clone()) - .await? - .columns - .into_iter() - .map(|c| c.name) - .collect::>(); - - let mut writer = - ExportWriter::new(output_path.clone(), format.clone(), columns.clone())?; - let mut exported = 0i64; - - match scope { - ExportScope::CurrentPage => { - let use_page = page.unwrap_or(1).max(1); - let use_limit = limit.unwrap_or(50).max(1); - let resp = db_driver - .get_table_data_chunk( - schema.clone(), - table.clone(), - use_page, - use_limit, - sort_column.clone(), - sort_direction.clone(), - filter.clone(), - order_by.clone(), - ) - .await?; - exported += - writer.write_rows(&resp.data, &columns, Some(&schema), &table, &driver)?; - } - ExportScope::Filtered | ExportScope::FullTable => { - let filter_for_scope = if matches!(scope, ExportScope::Filtered) { - filter.clone() - } else { - None - }; - let order_for_scope = if matches!(scope, ExportScope::Filtered) { - order_by.clone() - } else { - None - }; - let sort_col_for_scope = if matches!(scope, ExportScope::Filtered) { - sort_column.clone() - } else { - None - }; - let sort_dir_for_scope = if matches!(scope, ExportScope::Filtered) { - sort_direction.clone() - } else { - None - }; - - let mut current_page = 1; - loop { - let resp = db_driver - .get_table_data_chunk( - schema.clone(), - table.clone(), - current_page, - chunk, - sort_col_for_scope.clone(), - sort_dir_for_scope.clone(), - filter_for_scope.clone(), - order_for_scope.clone(), - ) - .await?; - if resp.data.is_empty() { - break; - } - - exported += writer.write_rows( - &resp.data, - &columns, - Some(&schema), - &table, - &driver, - )?; - if exported >= resp.total { - break; - } - current_page += 1; - } - } - } - - writer.finish()?; - Ok(ExportResult { - file_path: output_path.to_string_lossy().to_string(), - row_count: exported, - }) + do_table_export( + db_driver, + output_path, + schema, + table, + driver, + format, + scope, + filter, + order_by, + sort_column, + sort_direction, + page, + limit, + chunk, + ) + .await } }) .await @@ -345,7 +294,8 @@ pub async fn export_query_result( .into_iter() .map(|c| c.name) .collect::>(); - let mut writer = ExportWriter::new(output_path.clone(), format, columns.clone())?; + let mut writer = ExportWriter::new(output_path.clone(), format)?; + writer.write_csv_header(&columns)?; let exported = writer.write_rows(&result.data, &columns, None, "query_result", &driver)?; writer.finish()?; @@ -382,7 +332,8 @@ pub async fn export_query_result_direct( .into_iter() .map(|c| c.name) .collect::>(); - let mut writer = ExportWriter::new(output_path.clone(), format, columns.clone())?; + let mut writer = ExportWriter::new(output_path.clone(), format)?; + writer.write_csv_header(&columns)?; let exported = writer.write_rows(&result.data, &columns, None, "query_result", &driver)?; writer.finish()?; @@ -614,6 +565,10 @@ fn import_transaction_sql<'a>( ) -> Result<(&'a str, &'a str, &'a str), String> { match normalized_driver { "mysql" | "mariadb" | "tidb" => Ok(("START TRANSACTION", "COMMIT", "ROLLBACK")), + "starrocks" => Err( + "[UNSUPPORTED] Driver starrocks does not support transactional SQL import in this flow" + .to_string(), + ), "postgres" | "sqlite" | "duckdb" => Ok(("BEGIN", "COMMIT", "ROLLBACK")), "mssql" => Ok(( "BEGIN TRANSACTION", @@ -906,7 +861,7 @@ fn extension_for_format(format: &ExportFormat) -> &'static str { match format { ExportFormat::Csv => "csv", ExportFormat::Json => "json", - ExportFormat::Sql => "sql", + ExportFormat::SqlDml | ExportFormat::SqlDdl | ExportFormat::SqlFull => "sql", } } @@ -976,19 +931,474 @@ enum SqlScanState { BlockComment, } +fn starts_with_chars(chars: &[char], idx: usize, needle: &[char]) -> bool { + if idx + needle.len() > chars.len() { + return false; + } + for (offset, ch) in needle.iter().enumerate() { + if chars[idx + offset] != *ch { + return false; + } + } + true +} + +fn line_start_index(chars: &[char], idx: usize) -> usize { + let mut start = idx; + while start > 0 && chars[start - 1] != '\n' { + start -= 1; + } + start +} + +fn parse_mysql_delimiter_command(chars: &[char], idx: usize) -> Option<(String, usize)> { + let line_start = line_start_index(chars, idx); + let mut cursor = line_start; + while cursor < chars.len() && matches!(chars[cursor], ' ' | '\t' | '\r') { + cursor += 1; + } + if cursor != idx { + return None; + } + + let keyword: Vec = "DELIMITER".chars().collect(); + if !starts_with_chars(chars, cursor, &keyword) { + return None; + } + + let mut after_keyword = cursor + keyword.len(); + if after_keyword < chars.len() && chars[after_keyword] != ' ' && chars[after_keyword] != '\t' { + return None; + } + while after_keyword < chars.len() && matches!(chars[after_keyword], ' ' | '\t') { + after_keyword += 1; + } + if after_keyword >= chars.len() || matches!(chars[after_keyword], '\n' | '\r') { + return None; + } + + let mut line_end = after_keyword; + while line_end < chars.len() && !matches!(chars[line_end], '\n' | '\r') { + line_end += 1; + } + + let delimiter: String = chars[after_keyword..line_end] + .iter() + .collect::() + .trim() + .to_string(); + if delimiter.is_empty() { + return None; + } + + let mut next_idx = line_end; + if next_idx < chars.len() && chars[next_idx] == '\r' { + next_idx += 1; + } + if next_idx < chars.len() && chars[next_idx] == '\n' { + next_idx += 1; + } + + Some((delimiter, next_idx)) +} + +fn sqlite_trigger_state(sql: &str) -> (bool, bool) { + let chars: Vec = sql.chars().collect(); + let mut state = SqlScanState::Normal; + let mut i = 0usize; + let mut tokens = Vec::new(); + let mut trigger_begin_seen = false; + let mut trigger_block_depth = 0i32; + let mut case_depth = 0i32; + let mut last_word: Option = None; + + while i < chars.len() { + match &state { + SqlScanState::Normal => { + let ch = chars[i]; + let next = chars.get(i + 1).copied(); + if ch == '-' && next == Some('-') { + state = SqlScanState::LineComment; + i += 2; + continue; + } + if ch == '/' && next == Some('*') { + state = SqlScanState::BlockComment; + i += 2; + continue; + } + if ch == '\'' { + state = SqlScanState::SingleQuoted; + i += 1; + continue; + } + if ch == '"' { + state = SqlScanState::DoubleQuoted; + i += 1; + continue; + } + if ch == '`' { + state = SqlScanState::BacktickQuoted; + i += 1; + continue; + } + if ch.is_ascii_alphabetic() || ch == '_' { + let start = i; + i += 1; + while i < chars.len() && (chars[i].is_ascii_alphanumeric() || chars[i] == '_') { + i += 1; + } + let token = chars[start..i] + .iter() + .collect::() + .to_ascii_lowercase(); + tokens.push(token.clone()); + + if trigger_begin_seen { + match token.as_str() { + "case" => case_depth += 1, + "begin" => trigger_block_depth += 1, + "end" => { + if case_depth > 0 { + case_depth -= 1; + } else if trigger_block_depth > 0 { + trigger_block_depth -= 1; + } + } + _ => {} + } + } else if token == "begin" { + let is_create_trigger = matches!( + tokens.as_slice(), + [first, second, ..] if first == "create" + && (second == "trigger" + || ((second == "temp" || second == "temporary") + && tokens.get(2).map(String::as_str) == Some("trigger"))) + ); + if is_create_trigger { + trigger_begin_seen = true; + trigger_block_depth = 1; + } + } + + last_word = Some(token); + continue; + } + i += 1; + } + SqlScanState::SingleQuoted => { + if chars[i] == '\\' && chars.get(i + 1).is_some() { + i += 2; + continue; + } + if chars[i] == '\'' { + if chars.get(i + 1) == Some(&'\'') { + i += 2; + continue; + } + state = SqlScanState::Normal; + } + i += 1; + } + SqlScanState::DoubleQuoted => { + if chars[i] == '"' { + if chars.get(i + 1) == Some(&'"') { + i += 2; + continue; + } + state = SqlScanState::Normal; + } + i += 1; + } + SqlScanState::BacktickQuoted => { + if chars[i] == '`' { + if chars.get(i + 1) == Some(&'`') { + i += 2; + continue; + } + state = SqlScanState::Normal; + } + i += 1; + } + SqlScanState::LineComment => { + if chars[i] == '\n' { + state = SqlScanState::Normal; + } + i += 1; + } + SqlScanState::BlockComment => { + if chars[i] == '*' && chars.get(i + 1) == Some(&'/') { + state = SqlScanState::Normal; + i += 2; + } else { + i += 1; + } + } + SqlScanState::DollarQuoted(_) => { + state = SqlScanState::Normal; + } + } + } + + let is_trigger = trigger_begin_seen; + let ready_to_terminate = is_trigger + && trigger_block_depth == 0 + && case_depth == 0 + && last_word.as_deref() == Some("end"); + (is_trigger, ready_to_terminate) +} + +fn oracle_plsql_state(sql: &str) -> (bool, bool) { + let chars: Vec = sql.chars().collect(); + let mut state = SqlScanState::Normal; + let mut i = 0usize; + let mut tokens = Vec::new(); + let mut block_depth = 0i32; + let mut case_depth = 0i32; + let mut last_word: Option = None; + let mut is_oracle_block = false; + + while i < chars.len() { + match &state { + SqlScanState::Normal => { + let ch = chars[i]; + let next = chars.get(i + 1).copied(); + if ch == '-' && next == Some('-') { + state = SqlScanState::LineComment; + i += 2; + continue; + } + if ch == '/' && next == Some('*') { + state = SqlScanState::BlockComment; + i += 2; + continue; + } + if ch == '\'' { + state = SqlScanState::SingleQuoted; + i += 1; + continue; + } + if ch == '"' { + state = SqlScanState::DoubleQuoted; + i += 1; + continue; + } + if ch == '`' { + state = SqlScanState::BacktickQuoted; + i += 1; + continue; + } + if ch.is_ascii_alphabetic() || ch == '_' { + let start = i; + i += 1; + while i < chars.len() && (chars[i].is_ascii_alphanumeric() || chars[i] == '_') { + i += 1; + } + let token = chars[start..i] + .iter() + .collect::() + .to_ascii_lowercase(); + tokens.push(token.clone()); + + if !is_oracle_block { + let second = tokens.get(1).map(String::as_str); + let third = tokens.get(2).map(String::as_str); + let fourth = tokens.get(3).map(String::as_str); + is_oracle_block = matches!( + tokens.first().map(String::as_str), + Some("declare") | Some("begin") + ) || (tokens.first().map(String::as_str) + == Some("create") + && second == Some("or") + && third == Some("replace") + && matches!( + fourth, + Some("function") + | Some("procedure") + | Some("trigger") + | Some("package") + | Some("type") + )); + } + + if is_oracle_block { + match token.as_str() { + "case" => case_depth += 1, + "begin" => block_depth += 1, + "end" => { + if case_depth > 0 { + case_depth -= 1; + } else if block_depth > 0 { + block_depth -= 1; + } + } + _ => {} + } + } + + last_word = Some(token); + continue; + } + i += 1; + } + SqlScanState::SingleQuoted => { + if chars[i] == '\\' && chars.get(i + 1).is_some() { + i += 2; + continue; + } + if chars[i] == '\'' { + if chars.get(i + 1) == Some(&'\'') { + i += 2; + continue; + } + state = SqlScanState::Normal; + } + i += 1; + } + SqlScanState::DoubleQuoted => { + if chars[i] == '"' { + if chars.get(i + 1) == Some(&'"') { + i += 2; + continue; + } + state = SqlScanState::Normal; + } + i += 1; + } + SqlScanState::BacktickQuoted => { + if chars[i] == '`' { + if chars.get(i + 1) == Some(&'`') { + i += 2; + continue; + } + state = SqlScanState::Normal; + } + i += 1; + } + SqlScanState::LineComment => { + if chars[i] == '\n' { + state = SqlScanState::Normal; + } + i += 1; + } + SqlScanState::BlockComment => { + if chars[i] == '*' && chars.get(i + 1) == Some(&'/') { + state = SqlScanState::Normal; + i += 2; + } else { + i += 1; + } + } + SqlScanState::DollarQuoted(_) => { + state = SqlScanState::Normal; + } + } + } + + let ready_to_terminate = is_oracle_block + && block_depth == 0 + && case_depth == 0 + && last_word.as_deref() == Some("end"); + (is_oracle_block, ready_to_terminate) +} + +fn parse_oracle_slash_terminator(chars: &[char], idx: usize) -> Option { + let line_start = line_start_index(chars, idx); + let mut cursor = line_start; + while cursor < chars.len() && matches!(chars[cursor], ' ' | '\t' | '\r') { + cursor += 1; + } + if cursor != idx || chars.get(idx) != Some(&'/') { + return None; + } + + let mut line_end = idx + 1; + while line_end < chars.len() && !matches!(chars[line_end], '\n' | '\r') { + if !matches!(chars[line_end], ' ' | '\t') { + return None; + } + line_end += 1; + } + + let mut next_idx = line_end; + if next_idx < chars.len() && chars[next_idx] == '\r' { + next_idx += 1; + } + if next_idx < chars.len() && chars[next_idx] == '\n' { + next_idx += 1; + } + + Some(next_idx) +} + fn parse_sql_statements(sql: &str, driver: &str) -> Result, String> { let mysql_style_hash_comment = matches!(driver, "mysql" | "mariadb" | "tidb"); + let mysql_style_delimiter = mysql_style_hash_comment; + let sqlite_style_trigger = driver == "sqlite"; + let oracle_style_block = driver == "oracle"; let chars: Vec = sql.chars().collect(); let mut out = Vec::new(); let mut current = String::new(); let mut state = SqlScanState::Normal; + let mut delimiter = ";".to_string(); let mut i = 0usize; while i < chars.len() { match &state { SqlScanState::Normal => { + if mysql_style_delimiter { + if let Some((next_delimiter, next_idx)) = + parse_mysql_delimiter_command(&chars, i) + { + delimiter = next_delimiter; + i = next_idx; + continue; + } + } + if oracle_style_block { + if let Some(next_idx) = parse_oracle_slash_terminator(&chars, i) { + let (is_block, ready_to_terminate) = oracle_plsql_state(current.trim()); + if is_block && ready_to_terminate { + let statement = current.trim(); + if !statement.is_empty() { + out.push(statement.to_string()); + } + current.clear(); + i = next_idx; + continue; + } + } + } + let ch = chars[i]; let next = chars.get(i + 1).copied(); + let delimiter_chars: Vec = delimiter.chars().collect(); + + if starts_with_chars(&chars, i, &delimiter_chars) { + if sqlite_style_trigger && delimiter == ";" { + let (is_trigger, ready_to_terminate) = sqlite_trigger_state(current.trim()); + if is_trigger && !ready_to_terminate { + current.push(ch); + i += delimiter_chars.len(); + continue; + } + } + if oracle_style_block && delimiter == ";" { + let (is_block, _) = oracle_plsql_state(current.trim()); + if is_block { + current.push(ch); + i += delimiter_chars.len(); + continue; + } + } + let statement = current.trim(); + if !statement.is_empty() { + out.push(statement.to_string()); + } + current.clear(); + i += delimiter_chars.len(); + continue; + } if ch == '-' && next == Some('-') { state = SqlScanState::LineComment; @@ -1031,15 +1441,6 @@ fn parse_sql_statements(sql: &str, driver: &str) -> Result, String> continue; } } - if ch == ';' { - let statement = current.trim(); - if !statement.is_empty() { - out.push(statement.to_string()); - } - current.clear(); - i += 1; - continue; - } current.push(ch); i += 1; } @@ -1233,28 +1634,15 @@ struct ExportWriter { } impl ExportWriter { - fn new(path: PathBuf, format: ExportFormat, columns: Vec) -> Result { + fn new(path: PathBuf, format: ExportFormat) -> Result { let file = File::create(path).map_err(|e| format!("[EXPORT_ERROR] create file failed: {e}"))?; let mut writer = BufWriter::new(file); - match format { - ExportFormat::Csv => { - let header = columns - .iter() - .map(|c| csv_escape(c)) - .collect::>() - .join(","); - writer - .write_all(format!("{header}\n").as_bytes()) - .map_err(|e| format!("[EXPORT_ERROR] write csv header failed: {e}"))?; - } - ExportFormat::Json => { - writer - .write_all(b"[\n") - .map_err(|e| format!("[EXPORT_ERROR] write json header failed: {e}"))?; - } - ExportFormat::Sql => {} + if matches!(format, ExportFormat::Json) { + writer + .write_all(b"[\n") + .map_err(|e| format!("[EXPORT_ERROR] write json header failed: {e}"))?; } Ok(Self { @@ -1264,6 +1652,20 @@ impl ExportWriter { }) } + fn write_csv_header(&mut self, columns: &[String]) -> Result<(), String> { + if !matches!(self.format, ExportFormat::Csv) { + return Ok(()); + } + let header = columns + .iter() + .map(|c| csv_escape(c)) + .collect::>() + .join(","); + self.writer + .write_all(format!("{header}\n").as_bytes()) + .map_err(|e| format!("[EXPORT_ERROR] write csv header failed: {e}")) + } + fn write_rows( &mut self, rows: &[Value], @@ -1315,7 +1717,7 @@ impl ExportWriter { .write_all(text.as_bytes()) .map_err(|e| format!("[EXPORT_ERROR] write json row failed: {e}"))?; } - ExportFormat::Sql => { + ExportFormat::SqlDml | ExportFormat::SqlFull => { let quoted_cols = columns .iter() .map(|c| quote_ident(c, driver)) @@ -1340,10 +1742,19 @@ impl ExportWriter { .write_all(statement.as_bytes()) .map_err(|e| format!("[EXPORT_ERROR] write sql row failed: {e}"))?; } + ExportFormat::SqlDdl => unreachable!("SqlDdl rows are never written"), } Ok(()) } + fn write_ddl(&mut self, ddl: &str) -> Result<(), String> { + let content = format!("{}\n\n", ddl.trim_end()); + self.writer + .write_all(content.as_bytes()) + .map_err(|e| format!("[EXPORT_ERROR] write ddl failed: {e}"))?; + Ok(()) + } + fn finish(&mut self) -> Result<(), String> { if matches!(self.format, ExportFormat::Json) { self.writer @@ -1548,8 +1959,7 @@ mod tests { .unwrap() .as_nanos(); let path = std::env::temp_dir().join(format!("dbpaw-transfer-test-{unique}.json")); - let mut writer = - ExportWriter::new(path.clone(), ExportFormat::Json, vec!["a".to_string()]).unwrap(); + let mut writer = ExportWriter::new(path.clone(), ExportFormat::Json).unwrap(); let err = writer .write_rows( &[Value::String("not-object".to_string())], @@ -1596,6 +2006,120 @@ mod tests { assert_eq!(statements[1], "SELECT '#not_comment'"); } + #[test] + fn parse_sql_statements_supports_mysql_delimiter_blocks() { + let sql = r#" + DELIMITER $$ + CREATE PROCEDURE p_demo() + BEGIN + SELECT 1; + SELECT 'semi;inside'; + END$$ + DELIMITER ; + SELECT 2; + "#; + + let statements = parse_sql_statements(sql, "mysql").unwrap(); + assert_eq!(statements.len(), 2); + assert!(statements[0].starts_with("CREATE PROCEDURE p_demo()")); + assert!(statements[0].contains("SELECT 1;")); + assert!(statements[0].contains("SELECT 'semi;inside';")); + assert_eq!(statements[1], "SELECT 2"); + } + + #[test] + fn parse_sql_statements_ignores_mysql_delimiter_inside_strings() { + let sql = r#" + DELIMITER // + CREATE TRIGGER trg_demo BEFORE INSERT ON demo + FOR EACH ROW + BEGIN + SET @note = 'DELIMITER // should stay'; + END// + DELIMITER ; + "#; + + let statements = parse_sql_statements(sql, "mysql").unwrap(); + assert_eq!(statements.len(), 1); + assert!(statements[0].contains("DELIMITER // should stay")); + assert!(statements[0].contains("END")); + } + + #[test] + fn parse_sql_statements_supports_sqlite_trigger_blocks() { + let sql = r#" + CREATE TRIGGER trg_demo + AFTER INSERT ON demo + BEGIN + INSERT INTO audit_log(message) VALUES ('first;value'); + UPDATE demo SET touched_at = CURRENT_TIMESTAMP WHERE rowid = NEW.rowid; + END; + SELECT 1; + "#; + + let statements = parse_sql_statements(sql, "sqlite").unwrap(); + assert_eq!(statements.len(), 2); + assert!(statements[0].starts_with("CREATE TRIGGER trg_demo")); + assert!(statements[0].contains("VALUES ('first;value');")); + assert!(statements[0].contains("UPDATE demo SET touched_at = CURRENT_TIMESTAMP")); + assert_eq!(statements[1], "SELECT 1"); + } + + #[test] + fn parse_sql_statements_keeps_sqlite_case_end_inside_trigger_body() { + let sql = r#" + CREATE TRIGGER trg_case + AFTER UPDATE ON demo + BEGIN + UPDATE demo + SET status = CASE WHEN NEW.id > 10 THEN 'big' ELSE 'small' END; + END; + "#; + + let statements = parse_sql_statements(sql, "sqlite").unwrap(); + assert_eq!(statements.len(), 1); + assert!(statements[0].contains("CASE WHEN NEW.id > 10 THEN 'big' ELSE 'small' END;")); + assert!(statements[0].ends_with("END")); + } + + #[test] + fn parse_sql_statements_supports_oracle_create_or_replace_blocks() { + let sql = r#" + CREATE OR REPLACE PROCEDURE p_demo IS + BEGIN + INSERT INTO audit_log(message) VALUES ('first;value'); + UPDATE audit_log SET message = 'done' WHERE message = 'first;value'; + END; + / + SELECT 1 FROM DUAL; + "#; + + let statements = parse_sql_statements(sql, "oracle").unwrap(); + assert_eq!(statements.len(), 2); + assert!(statements[0].starts_with("CREATE OR REPLACE PROCEDURE p_demo IS")); + assert!(statements[0].contains("VALUES ('first;value');")); + assert!(statements[0].contains("END;")); + assert_eq!(statements[1], "SELECT 1 FROM DUAL"); + } + + #[test] + fn parse_sql_statements_supports_oracle_case_end_inside_block() { + let sql = r#" + CREATE OR REPLACE FUNCTION f_demo RETURN VARCHAR2 IS + v_result VARCHAR2(10); + BEGIN + v_result := CASE WHEN 1 = 1 THEN 'yes' ELSE 'no' END; + RETURN v_result; + END; + / + "#; + + let statements = parse_sql_statements(sql, "oracle").unwrap(); + assert_eq!(statements.len(), 1); + assert!(statements[0].contains("CASE WHEN 1 = 1 THEN 'yes' ELSE 'no' END;")); + assert!(statements[0].ends_with("END;")); + } + #[test] fn parse_mssql_batches_splits_on_go_lines_only() { let sql = r#" @@ -1677,6 +2201,7 @@ mod tests { ("SELECT 1 FROM DUAL", "COMMIT", "ROLLBACK") ); assert!(import_transaction_sql("clickhouse", "clickhouse").is_err()); + assert!(import_transaction_sql("starrocks", "starrocks").is_err()); } #[test] @@ -1694,4 +2219,136 @@ mod tests { assert!(truncated.len() <= 503); assert!(truncated.ends_with("...")); } + + fn tmp_path(suffix: &str) -> PathBuf { + let unique = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_nanos(); + std::env::temp_dir().join(format!("dbpaw-transfer-test-{unique}-{suffix}")) + } + + fn make_row(pairs: &[(&str, Value)]) -> Value { + let mut map = serde_json::Map::new(); + for (k, v) in pairs { + map.insert(k.to_string(), v.clone()); + } + Value::Object(map) + } + + #[test] + fn extension_for_format_sql_variants_all_return_sql() { + assert_eq!(extension_for_format(&ExportFormat::SqlDml), "sql"); + assert_eq!(extension_for_format(&ExportFormat::SqlDdl), "sql"); + assert_eq!(extension_for_format(&ExportFormat::SqlFull), "sql"); + assert_eq!(extension_for_format(&ExportFormat::Csv), "csv"); + assert_eq!(extension_for_format(&ExportFormat::Json), "json"); + } + + #[test] + fn export_writer_csv_writes_header_then_rows() { + let path = tmp_path("csv_header.csv"); + let cols = vec!["id".to_string(), "name".to_string()]; + let mut writer = ExportWriter::new(path.clone(), ExportFormat::Csv).unwrap(); + writer.write_csv_header(&cols).unwrap(); + let rows = vec![make_row(&[ + ("id", Value::Number(1.into())), + ("name", Value::String("alice".to_string())), + ])]; + writer + .write_rows(&rows, &cols, None, "t", "postgres") + .unwrap(); + writer.finish().unwrap(); + let content = fs::read_to_string(&path).unwrap(); + assert!(content.starts_with("id,name\n")); + assert!(content.contains("1,alice")); + let _ = fs::remove_file(path); + } + + #[test] + fn write_csv_header_is_noop_for_sql_formats() { + let path = tmp_path("sql_noop_header.sql"); + let mut writer = ExportWriter::new(path.clone(), ExportFormat::SqlDml).unwrap(); + writer.write_csv_header(&["id".to_string()]).unwrap(); + writer.finish().unwrap(); + let content = fs::read_to_string(&path).unwrap(); + assert_eq!(content, ""); + let _ = fs::remove_file(path); + } + + #[test] + fn export_writer_sql_dml_writes_insert_statements() { + let path = tmp_path("sql_dml.sql"); + let cols = vec!["id".to_string(), "name".to_string()]; + let mut writer = ExportWriter::new(path.clone(), ExportFormat::SqlDml).unwrap(); + let rows = vec![ + make_row(&[ + ("id", Value::Number(1.into())), + ("name", Value::String("alice".to_string())), + ]), + make_row(&[("id", Value::Number(2.into())), ("name", Value::Null)]), + ]; + let count = writer + .write_rows(&rows, &cols, Some("public"), "users", "postgres") + .unwrap(); + writer.finish().unwrap(); + assert_eq!(count, 2); + let content = fs::read_to_string(&path).unwrap(); + assert!(content.contains("INSERT INTO \"public\".\"users\"")); + assert!(content.contains("VALUES (1, 'alice')")); + assert!(content.contains("VALUES (2, NULL)")); + assert!(!content.contains("CREATE TABLE")); + let _ = fs::remove_file(path); + } + + #[test] + fn export_writer_sql_ddl_writes_only_ddl() { + let path = tmp_path("sql_ddl.sql"); + let mut writer = ExportWriter::new(path.clone(), ExportFormat::SqlDdl).unwrap(); + writer + .write_ddl("CREATE TABLE users (id INTEGER);") + .unwrap(); + writer.finish().unwrap(); + let content = fs::read_to_string(&path).unwrap(); + assert!(content.contains("CREATE TABLE users (id INTEGER);")); + assert!(!content.contains("INSERT INTO")); + let _ = fs::remove_file(path); + } + + #[test] + fn export_writer_sql_full_writes_ddl_then_inserts() { + let path = tmp_path("sql_full.sql"); + let cols = vec!["id".to_string(), "val".to_string()]; + let mut writer = ExportWriter::new(path.clone(), ExportFormat::SqlFull).unwrap(); + writer + .write_ddl("CREATE TABLE t (id INT, val TEXT);") + .unwrap(); + let rows = vec![make_row(&[ + ("id", Value::Number(1.into())), + ("val", Value::String("x".to_string())), + ])]; + let count = writer + .write_rows(&rows, &cols, None, "t", "postgres") + .unwrap(); + writer.finish().unwrap(); + assert_eq!(count, 1); + let content = fs::read_to_string(&path).unwrap(); + let ddl_pos = content.find("CREATE TABLE").unwrap(); + let dml_pos = content.find("INSERT INTO").unwrap(); + assert!(ddl_pos < dml_pos, "DDL should appear before DML"); + assert!(content.contains("VALUES (1, 'x')")); + let _ = fs::remove_file(path); + } + + #[test] + fn write_ddl_trims_trailing_whitespace_and_adds_blank_line() { + let path = tmp_path("ddl_trim.sql"); + let mut writer = ExportWriter::new(path.clone(), ExportFormat::SqlDdl).unwrap(); + writer.write_ddl("CREATE TABLE t (id INT); \n\n").unwrap(); + writer.finish().unwrap(); + let content = fs::read_to_string(&path).unwrap(); + assert!(content.starts_with("CREATE TABLE t (id INT);")); + assert!(content.ends_with("\n\n")); + let _ = fs::remove_file(path); + } } diff --git a/src-tauri/src/connection_input/mod.rs b/src-tauri/src/connection_input/mod.rs index b4dacbfd..3d91ac05 100644 --- a/src-tauri/src/connection_input/mod.rs +++ b/src-tauri/src/connection_input/mod.rs @@ -54,7 +54,7 @@ pub fn normalize_connection_form(mut form: ConnectionForm) -> Result bool { + matches!(driver, "mysql" | "mariadb" | "tidb" | "starrocks") +} + /// Build a `[CONN_FAILED]` error message with a context-aware hint derived from the /// underlying error text, so users are not misled by a generic credential warning /// when the actual problem is TLS incompatibility, a network issue, etc. @@ -26,7 +30,9 @@ pub(crate) fn conn_failed_error(e: &dyn std::fmt::Display) -> String { let raw = e.to_string(); let lower = raw.to_ascii_lowercase(); - let hint = if lower.contains("dpi-1047") || lower.contains("cannot locate a 64-bit oracle client") { + let hint = if lower.contains("dpi-1047") + || lower.contains("cannot locate a 64-bit oracle client") + { "hint: Oracle Instant Client is not installed — download it from \ https://www.oracle.com/database/technologies/instant-client/downloads.html \ and add the directory containing libclntsh to your library path \ @@ -143,7 +149,7 @@ pub async fn connect(form: &ConnectionForm) -> Result, S let driver = PostgresDriver::connect(form).await?; Ok(Box::new(driver) as Box) } - "mysql" | "tidb" | "mariadb" => { + driver if is_mysql_family_driver(driver) => { let driver = MysqlDriver::connect(form).await?; Ok(Box::new(driver) as Box) } @@ -176,7 +182,7 @@ pub async fn connect(form: &ConnectionForm) -> Result, S #[cfg(test)] mod tests { - use super::{conn_failed_error, strip_trailing_statement_terminator}; + use super::{conn_failed_error, is_mysql_family_driver, strip_trailing_statement_terminator}; #[test] fn conn_failed_error_oracle_client_hint() { @@ -249,4 +255,11 @@ mod tests { fn strip_trailing_statement_terminator_keeps_sql_without_semicolon() { assert_eq!(strip_trailing_statement_terminator("SELECT 1"), "SELECT 1"); } + + #[test] + fn mysql_family_helper_includes_starrocks() { + assert!(is_mysql_family_driver("mysql")); + assert!(is_mysql_family_driver("starrocks")); + assert!(!is_mysql_family_driver("postgres")); + } } diff --git a/src-tauri/src/db/drivers/mssql.rs b/src-tauri/src/db/drivers/mssql.rs index 70f6bef5..2b3a5884 100644 --- a/src-tauri/src/db/drivers/mssql.rs +++ b/src-tauri/src/db/drivers/mssql.rs @@ -734,7 +734,10 @@ impl DatabaseDriver for MssqlDriver { } let ddl = format!( - "CREATE TABLE {}.{} (\n{}\n);", + "-- Note: This DDL is reconstructed from table metadata and may be incomplete.\n\ + -- Constraints such as foreign keys, unique constraints, check constraints,\n\ + -- and indexes are not included.\n\ + CREATE TABLE {}.{} (\n{}\n);", quote_ident(&schema)?, quote_ident(&table)?, lines.join(",\n") diff --git a/src-tauri/src/db/drivers/mysql.rs b/src-tauri/src/db/drivers/mysql.rs index 4fa891a4..96b6ad8e 100644 --- a/src-tauri/src/db/drivers/mysql.rs +++ b/src-tauri/src/db/drivers/mysql.rs @@ -4,11 +4,18 @@ use crate::models::{ SchemaOverview, TableDataResponse, TableInfo, TableMetadata, TableSchema, TableStructure, }; use async_trait::async_trait; -use sqlx::{mysql::MySqlPoolOptions, Column, Executor, Row, TypeInfo}; +use sqlx::{ + mysql::{MySqlConnectOptions, MySqlPoolOptions}, + Column, Executor, Row, TypeInfo, +}; use std::collections::{HashMap, HashSet}; use std::fs; use std::path::Path; use std::path::PathBuf; +use std::str::FromStr; + +#[cfg(test)] +use sqlx::ConnectOptions; use crate::ssh::SshTunnel; @@ -53,7 +60,16 @@ fn build_verify_ca_query_param(ca_path: &Path) -> String { ) } +fn mysql_family_default_port(driver: &str) -> u16 { + if driver.eq_ignore_ascii_case("starrocks") { + 9030 + } else { + 3306 + } +} + fn normalize_mysql_host_and_port( + raw_driver: &str, raw_host: &str, raw_port: Option, ) -> Result<(String, u16), String> { @@ -62,7 +78,7 @@ fn normalize_mysql_host_and_port( return Err("[VALIDATION_ERROR] host cannot be empty".to_string()); } - let mut port = raw_port.unwrap_or(3306); + let mut port = raw_port.unwrap_or(i64::from(mysql_family_default_port(raw_driver))); if !host.starts_with('[') && host.matches(':').count() == 1 { if let Some((host_part, port_part)) = host.rsplit_once(':') { let host_part = host_part.trim(); @@ -91,7 +107,7 @@ fn build_dsn_and_ca_path(form: &ConnectionForm) -> Result<(String, Option Result<(String, Option Result { + let mut options = + MySqlConnectOptions::from_str(dsn).map_err(|e| format!("[CONN_FAILED] {e}"))?; + + if driver.eq_ignore_ascii_case("starrocks") { + // sqlx initializes MySQL connections with: + // SET sql_mode=(SELECT CONCAT(@@sql_mode, ...)) + // plus timezone / SET NAMES session mutations tailored for MySQL. + // StarRocks rejects part of this initialization sequence, so skip the + // post-connect SET mutations entirely for the StarRocks compatibility path. + options = options + .pipes_as_concat(false) + .no_engine_substitution(false) + .timezone(None::) + .set_names(false); + } + + Ok(options) +} + fn cleanup_ca_file(path: &Path) { let _ = fs::remove_file(path); } @@ -172,7 +208,6 @@ fn is_prepared_protocol_unsupported_error(err: &str) -> bool { || lower.contains("preparedoes not support") // PolarDB-X } - impl Drop for MysqlDriver { fn drop(&mut self) { cleanup_ca_file_opt(self.ca_cert_path.as_ref()); @@ -196,10 +231,11 @@ impl MysqlDriver { } let (dsn, ca_cert_path) = build_dsn_with_ca_path(&dsn_form)?; + let connect_options = build_connect_options(&dsn, &dsn_form.driver)?; let pool = MySqlPoolOptions::new() .max_connections(5) .acquire_timeout(std::time::Duration::from_secs(3)) - .connect(&dsn) + .connect_with(connect_options) .await .map_err(|e| super::conn_failed_error(&e))?; @@ -1147,6 +1183,37 @@ mod tests { ); } + #[test] + fn test_conn_string_uses_starrocks_default_port_when_port_missing() { + let form = ConnectionForm { + driver: "starrocks".to_string(), + host: Some("localhost".to_string()), + port: None, + username: Some("root".to_string()), + password: Some("password".to_string()), + database: Some("analytics".to_string()), + ..Default::default() + }; + + let conn_str = build_dsn(&form).unwrap(); + assert_eq!( + conn_str, + "mysql://root:password@localhost:9030/analytics?ssl-mode=DISABLED" + ); + } + + #[test] + fn test_starrocks_connect_options_disable_sql_mode_mutations() { + let options = build_connect_options( + "mysql://root:password@localhost:9030/analytics?ssl-mode=DISABLED", + "starrocks", + ) + .unwrap(); + + let rendered = options.to_url_lossy().to_string(); + assert!(rendered.contains("ssl-mode=DISABLED")); + } + #[test] fn test_conn_string_encodes_credentials() { let form = ConnectionForm { diff --git a/src-tauri/src/db/drivers/oracle.rs b/src-tauri/src/db/drivers/oracle.rs index a11d461e..a7638c76 100644 --- a/src-tauri/src/db/drivers/oracle.rs +++ b/src-tauri/src/db/drivers/oracle.rs @@ -186,9 +186,8 @@ impl OracleDriver { let cfg = self.config.clone(); tokio::task::spawn_blocking(move || { let connect_string = build_connect_string(&cfg); - let conn = - oracle::Connection::connect(&cfg.username, &cfg.password, &connect_string) - .map_err(|e| conn_failed_error(&e))?; + let conn = oracle::Connection::connect(&cfg.username, &cfg.password, &connect_string) + .map_err(|e| conn_failed_error(&e))?; f(conn) }) .await @@ -414,9 +413,10 @@ impl DatabaseDriver for OracleDriver { if let (Some(name), Some(col_name)) = (idx_name, col_name) { let unique = is_unique.unwrap_or(0) == 1; let pos = position.unwrap_or(0); - let entry = idx_map - .entry(name) - .or_insert((unique, idx_type.clone(), Vec::new())); + let entry = + idx_map + .entry(name) + .or_insert((unique, idx_type.clone(), Vec::new())); entry.0 = unique; if entry.1.is_none() { entry.1 = idx_type; @@ -704,7 +704,8 @@ impl DatabaseDriver for OracleDriver { .map_err(|e| format!("[QUERY_ERROR] {e}"))?; let row_count = stmt.row_count().unwrap_or(0) as i64; // Commit so the change is visible after the connection closes. - conn.commit().map_err(|e| format!("[QUERY_ERROR] commit failed: {e}"))?; + conn.commit() + .map_err(|e| format!("[QUERY_ERROR] commit failed: {e}"))?; Ok(QueryResult { row_count, data: vec![], @@ -751,10 +752,10 @@ impl DatabaseDriver for OracleDriver { if let (Some(sn), Some(tn), Some(cn), Some(ct)) = (schema_name, table_name, col_name, col_type) { - table_map - .entry((sn, tn)) - .or_default() - .push(ColumnSchema { name: cn, r#type: ct }); + table_map.entry((sn, tn)).or_default().push(ColumnSchema { + name: cn, + r#type: ct, + }); } } let mut tables: Vec = table_map diff --git a/src-tauri/src/db/drivers/postgres.rs b/src-tauri/src/db/drivers/postgres.rs index ad351af4..7814ff21 100644 --- a/src-tauri/src/db/drivers/postgres.rs +++ b/src-tauri/src/db/drivers/postgres.rs @@ -364,6 +364,39 @@ fn skip_backtick_quote(bytes: &[u8], mut i: usize) -> usize { i } +fn parse_dollar_quote_tag(bytes: &[u8], start: usize) -> Option { + if bytes.get(start) != Some(&b'$') { + return None; + } + let mut i = start + 1; + while i < bytes.len() && (bytes[i].is_ascii_alphanumeric() || bytes[i] == b'_') { + i += 1; + } + if bytes.get(i) == Some(&b'$') { + Some(i) + } else { + None + } +} + +fn skip_dollar_quote(bytes: &[u8], start: usize) -> usize { + let Some(tag_end) = parse_dollar_quote_tag(bytes, start) else { + return start + 1; + }; + let tag = &bytes[start..=tag_end]; + let tag_len = tag.len(); + let mut i = tag_end + 1; + + while i + tag_len <= bytes.len() { + if &bytes[i..i + tag_len] == tag { + return i + tag_len; + } + i += 1; + } + + bytes.len() +} + fn skip_line_comment(bytes: &[u8], mut i: usize) -> usize { i += 2; while i < bytes.len() && bytes[i] != b'\n' { @@ -431,6 +464,13 @@ fn split_sql_statements(sql: &str) -> Vec { i = skip_backtick_quote(bytes, i); continue; } + if b == b'$' { + let next = skip_dollar_quote(bytes, i); + if next != i + 1 { + i = next; + continue; + } + } if b == b'(' { depth += 1; i += 1; @@ -1489,6 +1529,39 @@ CREATE TABLE pg_data_type_test ( assert!(statements[1].starts_with("CREATE TABLE pg_data_type_test")); } + #[test] + fn test_split_sql_statements_keeps_postgres_dollar_quoted_function_intact() { + let sql = r#" +CREATE OR REPLACE FUNCTION update_updated_at_column() +RETURNS TRIGGER LANGUAGE PLPGSQL AS $$ +BEGIN + NEW.updated_at = CURRENT_TIMESTAMP; + RETURN NEW; +END; +$$; +"#; + let statements = split_sql_statements(sql); + assert_eq!(statements.len(), 1); + assert!(statements[0].contains("NEW.updated_at = CURRENT_TIMESTAMP;")); + assert!(statements[0].ends_with("$$")); + } + + #[test] + fn test_split_sql_statements_keeps_tagged_dollar_quoted_function_intact() { + let sql = r#" +CREATE FUNCTION demo() +RETURNS text LANGUAGE plpgsql AS $body$ +BEGIN + RETURN 'ok'; +END; +$body$; +"#; + let statements = split_sql_statements(sql); + assert_eq!(statements.len(), 1); + assert!(statements[0].contains("RETURN 'ok';")); + assert!(statements[0].ends_with("$body$")); + } + #[test] fn test_is_high_precision_pg_type() { assert!(is_high_precision_pg_type("bigint", "int8")); diff --git a/src-tauri/src/lib.rs b/src-tauri/src/lib.rs index d4fb21f2..e23ab94e 100644 --- a/src-tauri/src/lib.rs +++ b/src-tauri/src/lib.rs @@ -22,6 +22,18 @@ pub fn run() { .on_menu_event(|app, event| { if event.id() == "settings" { let _ = app.emit("open-settings", ()); + } else if event.id() == "debug_reload" { + if let Some(window) = app.get_webview_window("main") { + let _ = window.reload(); + } + } else if event.id() == "debug_toggle_devtools" { + if let Some(window) = app.get_webview_window("main") { + if window.is_devtools_open() { + window.close_devtools(); + } else { + window.open_devtools(); + } + } } }) .manage(AppState::new()) @@ -35,6 +47,7 @@ pub fn run() { if let Err(e) = (|| -> tauri::Result<()> { let app_menu = Submenu::new(&handle, "App", true)?; let edit_menu = Submenu::new(&handle, "Edit", true)?; + let developer_menu = Submenu::new(&handle, "Developer", true)?; let about = PredefinedMenuItem::about(&handle, None, None)?; let settings = MenuItem::with_id( @@ -69,6 +82,20 @@ pub fn run() { let copy = PredefinedMenuItem::copy(&handle, None)?; let paste = PredefinedMenuItem::paste(&handle, None)?; let select_all = PredefinedMenuItem::select_all(&handle, None)?; + let reload = MenuItem::with_id( + &handle, + "debug_reload", + "Reload", + true, + Some("CmdOrCtrl+R"), + )?; + let toggle_devtools = MenuItem::with_id( + &handle, + "debug_toggle_devtools", + "Toggle DevTools", + true, + Some("Alt+CmdOrCtrl+I"), + )?; edit_menu.append(&undo)?; edit_menu.append(&redo)?; @@ -78,7 +105,11 @@ pub fn run() { edit_menu.append(&paste)?; edit_menu.append(&select_all)?; - let menu = Menu::with_items(&handle, &[&app_menu, &edit_menu])?; + developer_menu.append(&reload)?; + developer_menu.append(&toggle_devtools)?; + + let menu = + Menu::with_items(&handle, &[&app_menu, &edit_menu, &developer_menu])?; app.set_menu(menu)?; Ok(()) })() { @@ -125,6 +156,8 @@ pub fn run() { commands::connection::list_databases, commands::connection::list_databases_by_id, commands::connection::create_database_by_id, + commands::connection::get_mysql_charsets_by_id, + commands::connection::get_mysql_collations_by_id, commands::storage::save_query, commands::storage::get_saved_queries, commands::storage::update_saved_query, diff --git a/src-tauri/src/ssh.rs b/src-tauri/src/ssh.rs index 75f7bcb8..459f28d2 100644 --- a/src-tauri/src/ssh.rs +++ b/src-tauri/src/ssh.rs @@ -6,6 +6,20 @@ use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; use std::thread; +fn default_target_port(driver: &str) -> i64 { + if crate::db::drivers::is_mysql_family_driver(driver) { + return if driver == "starrocks" { 9030 } else { 3306 }; + } + + match driver { + "mssql" => 1433, + "oracle" => 1521, + "clickhouse" => 9000, + "sqlite" => 0, + _ => 5432, // postgres and unknown drivers + } +} + pub struct SshTunnel { pub local_port: u16, _guard: Arc, @@ -45,14 +59,8 @@ pub fn start_ssh_tunnel(config: &ConnectionForm) -> Result { .and_then(|v| if v.trim().is_empty() { None } else { Some(v) }); let target_host = config.host.clone().unwrap_or("localhost".to_string()); - let default_port: i64 = match config.driver.to_ascii_lowercase().as_str() { - "mysql" => 3306, - "mssql" => 1433, - "oracle" => 1521, - "clickhouse" => 9000, - "sqlite" => 0, - _ => 5432, // postgres and unknown drivers - }; + let normalized_driver = config.driver.to_ascii_lowercase(); + let default_port = default_target_port(&normalized_driver); let target_port = config.port.unwrap_or(default_port); if target_port < 1 || target_port > 65535 { return Err("Target port must be between 1 and 65535".to_string()); @@ -300,6 +308,33 @@ mod tests { "MSSQL default port (1433) should pass validation, got: {e}" ); } + + let config_starrocks = ConnectionForm { + driver: "starrocks".to_string(), + host: Some("127.0.0.1".to_string()), + port: None, // should default to 9030 + ssh_host: Some("127.0.0.1".to_string()), + ssh_port: Some(22), + ssh_username: Some("user".to_string()), + ssh_password: Some("pass".to_string()), + ..Default::default() + }; + let result = start_ssh_tunnel(&config_starrocks); + if let Err(e) = result { + assert!( + !e.contains("Target port must be between 1 and 65535"), + "StarRocks default port (9030) should pass validation, got: {e}" + ); + } + } + + #[test] + fn test_default_target_port_by_driver() { + assert_eq!(default_target_port("mysql"), 3306); + assert_eq!(default_target_port("mariadb"), 3306); + assert_eq!(default_target_port("tidb"), 3306); + assert_eq!(default_target_port("starrocks"), 9030); + assert_eq!(default_target_port("clickhouse"), 9000); } #[test] diff --git a/src-tauri/tauri.conf.json b/src-tauri/tauri.conf.json index 869d8ad0..90fb4142 100644 --- a/src-tauri/tauri.conf.json +++ b/src-tauri/tauri.conf.json @@ -1,7 +1,7 @@ { "$schema": "https://schema.tauri.app/config/2", "productName": "DbPaw", - "version": "0.3.2", + "version": "0.3.3", "identifier": "com.father.dbpaw", "build": { "beforeDevCommand": "bun run dev", diff --git a/src-tauri/tests/common/shared.rs b/src-tauri/tests/common/shared.rs index 6e114d31..12299b36 100644 --- a/src-tauri/tests/common/shared.rs +++ b/src-tauri/tests/common/shared.rs @@ -5,7 +5,9 @@ use std::process::Command; use std::thread::sleep; use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH}; +#[allow(dead_code)] const CONNECT_RETRY_ATTEMPTS: usize = 20; +#[allow(dead_code)] const CONNECT_RETRY_DELAY_MS: u64 = 500; pub fn should_reuse_local_db() -> bool { @@ -88,6 +90,7 @@ pub fn env_i64_any(names: &[&str], default: i64) -> i64 { .unwrap_or(default) } +#[allow(dead_code)] pub async fn connect_with_retry(mut connect: F) -> T where F: FnMut() -> Fut, diff --git a/src-tauri/tests/common/starrocks_context.rs b/src-tauri/tests/common/starrocks_context.rs new file mode 100644 index 00000000..2e66a90e --- /dev/null +++ b/src-tauri/tests/common/starrocks_context.rs @@ -0,0 +1,77 @@ +mod shared; + +use dbpaw_lib::models::ConnectionForm; +use std::env; +use std::time::Duration; +use testcontainers::clients::Cli; +use testcontainers::core::WaitFor; +use testcontainers::{Container, GenericImage, RunnableImage}; + +#[allow(unused_imports)] +pub use shared::{connect_with_retry, should_reuse_local_db}; + +pub fn starrocks_form_from_test_context<'a>( + docker: Option<&'a Cli>, +) -> (Option>, ConnectionForm) { + if should_reuse_local_db() { + return (None, starrocks_form_from_local_env()); + } + shared::ensure_docker_available(); + + let docker = docker.expect("docker client is required when IT_REUSE_LOCAL_DB is not enabled"); + let image = GenericImage::new("starrocks/allin1-ubuntu", "3.2.10") + .with_wait_for(WaitFor::seconds(20)) + .with_exposed_port(9030); + let runnable = + RunnableImage::from(image).with_container_name(shared::unique_container_name("starrocks")); + let container = docker.run(runnable); + let port = container.get_host_port_ipv4(9030); + + shared::wait_for_port("127.0.0.1", port, Duration::from_secs(120)); + + let mut form = ConnectionForm { + driver: "starrocks".to_string(), + host: Some("127.0.0.1".to_string()), + port: Some(i64::from(port)), + username: Some("root".to_string()), + password: Some(String::new()), + ..Default::default() + }; + apply_starrocks_env_overrides(&mut form); + (Some(container), form) +} + +fn starrocks_form_from_local_env() -> ConnectionForm { + let mut form = ConnectionForm { + driver: "starrocks".to_string(), + host: Some(shared::env_or("STARROCKS_HOST", "localhost")), + port: Some(shared::env_i64("STARROCKS_PORT", 9030)), + username: Some(shared::env_or("STARROCKS_USER", "root")), + password: Some(shared::env_or("STARROCKS_PASSWORD", "")), + database: env::var("STARROCKS_DB").ok(), + ..Default::default() + }; + apply_starrocks_env_overrides(&mut form); + form +} + +fn apply_starrocks_env_overrides(form: &mut ConnectionForm) { + if let Ok(host) = env::var("STARROCKS_HOST") { + form.host = Some(host); + } + if let Ok(port) = env::var("STARROCKS_PORT") { + form.port = Some( + port.parse::() + .expect("STARROCKS_PORT should be a valid number"), + ); + } + if let Ok(user) = env::var("STARROCKS_USER") { + form.username = Some(user); + } + if let Ok(password) = env::var("STARROCKS_PASSWORD") { + form.password = Some(password); + } + if let Ok(database) = env::var("STARROCKS_DB") { + form.database = Some(database); + } +} diff --git a/src-tauri/tests/mariadb_command_integration.rs b/src-tauri/tests/mariadb_command_integration.rs index bba658a9..55c959db 100644 --- a/src-tauri/tests/mariadb_command_integration.rs +++ b/src-tauri/tests/mariadb_command_integration.rs @@ -335,3 +335,93 @@ async fn test_mariadb_command_get_table_data_by_conn_invalid_pagination_returns_ cleanup_table(&form, &table).await; } + +#[tokio::test] +#[ignore] +async fn test_mariadb_show_character_set_returns_standard_charsets() { + let docker = (!mariadb_context::should_reuse_local_db()).then(Cli::default); + let (_mariadb_container, form) = + mariadb_context::mariadb_form_from_test_context(docker.as_ref()); + wait_until_mariadb_ready(&form).await; + + let driver = MysqlDriver::connect(&form) + .await + .expect("failed to connect mariadb driver"); + + let result = driver + .execute_query("SHOW CHARACTER SET".to_string()) + .await + .expect("SHOW CHARACTER SET should succeed"); + + let charsets: Vec = result + .data + .iter() + .filter_map(|row| { + row.get("Charset") + .and_then(|v| v.as_str()) + .map(|s| s.to_string()) + }) + .collect(); + + assert!(!charsets.is_empty(), "charset list must not be empty"); + assert!( + charsets.iter().any(|c| c == "utf8mb4"), + "utf8mb4 must be present" + ); + assert!( + charsets.iter().any(|c| c == "latin1"), + "latin1 must be present" + ); + assert!( + charsets.iter().all(|c| !c.trim().is_empty()), + "all charset names must be non-empty" + ); + + driver.close().await; +} + +#[tokio::test] +#[ignore] +async fn test_mariadb_show_collation_for_utf8mb4_returns_matching_collations() { + let docker = (!mariadb_context::should_reuse_local_db()).then(Cli::default); + let (_mariadb_container, form) = + mariadb_context::mariadb_form_from_test_context(docker.as_ref()); + wait_until_mariadb_ready(&form).await; + + let driver = MysqlDriver::connect(&form) + .await + .expect("failed to connect mariadb driver"); + + let result = driver + .execute_query("SHOW COLLATION WHERE Charset = 'utf8mb4'".to_string()) + .await + .expect("SHOW COLLATION should succeed"); + + let collations: Vec = result + .data + .iter() + .filter_map(|row| { + row.get("Collation") + .and_then(|v| v.as_str()) + .map(|s| s.to_string()) + }) + .collect(); + + assert!( + !collations.is_empty(), + "utf8mb4 collation list must not be empty" + ); + assert!( + collations.iter().any(|c| c == "utf8mb4_general_ci"), + "utf8mb4_general_ci must be present" + ); + for col in &collations { + assert!( + col.starts_with("utf8mb4"), + "collation '{}' does not belong to utf8mb4", + col + ); + } + + driver.close().await; +} diff --git a/src-tauri/tests/mysql_stateful_command_integration.rs b/src-tauri/tests/mysql_stateful_command_integration.rs index fcbdb796..19333688 100644 --- a/src-tauri/tests/mysql_stateful_command_integration.rs +++ b/src-tauri/tests/mysql_stateful_command_integration.rs @@ -740,6 +740,114 @@ async fn test_mysql_command_transfer_export_and_import_minimal_flow() { let _ = connection::delete_connection_direct(&state, conn_id).await; } +#[tokio::test] +#[ignore] +async fn test_mysql_command_import_sql_file_supports_delimiter_script() { + let docker = (!mysql_context::should_reuse_local_db()).then(Cli::default); + let (_mysql_container, form) = mysql_context::mysql_form_from_test_context(docker.as_ref()); + wait_until_mysql_ready(&form).await; + let state = init_state_with_local_db().await; + let conn_id = create_mysql_connection_for_state(&state, &form, "import-delimiter").await; + let schema = form + .database + .clone() + .unwrap_or_else(|| "test_db".to_string()); + + let proc_table = unique_name("dbpaw_import_proc_tbl"); + let audit_table = unique_name("dbpaw_import_audit_tbl"); + let proc_name = unique_name("dbpaw_import_proc"); + let trigger_name = unique_name("dbpaw_import_trg"); + let base = std::env::temp_dir().join(unique_name("dbpaw_mysql_import_it")); + fs::create_dir_all(&base).expect("create temp transfer dir should succeed"); + let import_sql_path = base.join("import.sql"); + + let import_sql = format!( + r#" +CREATE TABLE `{proc_table}` (id INT PRIMARY KEY, name VARCHAR(64)); +CREATE TABLE `{audit_table}` (entry_id INT PRIMARY KEY AUTO_INCREMENT, source_id INT, action_name VARCHAR(32)); +DELIMITER $$ +CREATE PROCEDURE `{proc_name}`() +BEGIN + INSERT INTO `{proc_table}` (id, name) VALUES (1, 'from_proc'); +END$$ +CREATE TRIGGER `{trigger_name}` AFTER INSERT ON `{proc_table}` +FOR EACH ROW +BEGIN + INSERT INTO `{audit_table}` (source_id, action_name) VALUES (NEW.id, 'insert'); +END$$ +DELIMITER ; +"# + ); + fs::write(&import_sql_path, import_sql).expect("write import sql file should succeed"); + + let import_result = transfer::import_sql_file_direct( + &state, + conn_id, + Some(schema.clone()), + import_sql_path.to_string_lossy().to_string(), + "mysql".to_string(), + ) + .await + .expect("import_sql_file should succeed"); + assert_eq!( + import_result.success_statements, + import_result.total_statements + ); + assert!(import_result.error.is_none()); + + let driver = MysqlDriver::connect(&form) + .await + .expect("failed to connect mysql driver for verification"); + driver + .execute_query(format!("CALL `{}`()", proc_name)) + .await + .expect("call imported procedure should succeed"); + driver + .execute_query(format!( + "INSERT INTO `{}` (id, name) VALUES (2, 'direct')", + proc_table + )) + .await + .expect("direct insert into imported mysql table should succeed"); + + let verify = driver + .execute_query(format!( + "SELECT COUNT(*) AS c FROM `{}`.`{}`", + schema, audit_table + )) + .await + .expect("verify mysql trigger should succeed"); + let count = verify.data[0]["c"] + .as_str() + .and_then(|v| v.parse::().ok()) + .expect("audit count should parse"); + assert_eq!(count, 2); + + let _ = driver + .execute_query(format!("DROP TRIGGER IF EXISTS `{}`", trigger_name)) + .await; + let _ = driver + .execute_query(format!("DROP PROCEDURE IF EXISTS `{}`", proc_name)) + .await; + let _ = driver + .execute_query(format!( + "DROP TABLE IF EXISTS `{}`.`{}`", + schema, proc_table + )) + .await; + let _ = driver + .execute_query(format!( + "DROP TABLE IF EXISTS `{}`.`{}`", + schema, audit_table + )) + .await; + driver.close().await; + + let _ = fs::remove_file(import_sql_path); + let _ = fs::remove_dir_all(base); + let _ = connection::delete_connection_direct(&state, conn_id).await; +} + #[tokio::test] #[ignore] async fn test_mysql_command_ai_minimal_provider_conversation_and_chat_flow() { @@ -856,3 +964,139 @@ async fn test_mysql_command_ai_minimal_provider_conversation_and_chat_flow() { .await .expect("ai_delete_provider should succeed"); } + +#[tokio::test] +#[ignore] +async fn test_mysql_command_get_charsets_by_id_returns_standard_charsets() { + let docker = (!mysql_context::should_reuse_local_db()).then(Cli::default); + let (_mysql_container, form) = mysql_context::mysql_form_from_test_context(docker.as_ref()); + wait_until_mysql_ready(&form).await; + let state = init_state_with_local_db().await; + let conn_id = create_mysql_connection_for_state(&state, &form, "get-charsets").await; + + let charsets = connection::get_mysql_charsets_by_id_direct(&state, conn_id) + .await + .expect("get_mysql_charsets_by_id should succeed"); + + assert!(!charsets.is_empty(), "charset list must not be empty"); + assert!( + charsets.iter().any(|c| c == "utf8mb4"), + "utf8mb4 must be present" + ); + assert!(charsets.iter().any(|c| c == "utf8"), "utf8 must be present"); + assert!( + charsets.iter().any(|c| c == "latin1"), + "latin1 must be present" + ); + assert!( + charsets.windows(2).all(|w| w[0] <= w[1]), + "charsets must be sorted" + ); + assert!( + charsets.iter().all(|c| !c.trim().is_empty()), + "all charset names must be non-empty" + ); + + let _ = connection::delete_connection_direct(&state, conn_id).await; +} + +#[tokio::test] +#[ignore] +async fn test_mysql_command_get_charsets_by_id_invalid_connection_returns_error() { + let state = init_state_with_local_db().await; + let result = connection::get_mysql_charsets_by_id_direct(&state, -999_999).await; + assert!(result.is_err()); + assert!(!result.err().unwrap_or_default().trim().is_empty()); +} + +#[tokio::test] +#[ignore] +async fn test_mysql_command_get_collations_by_id_without_charset_returns_all() { + let docker = (!mysql_context::should_reuse_local_db()).then(Cli::default); + let (_mysql_container, form) = mysql_context::mysql_form_from_test_context(docker.as_ref()); + wait_until_mysql_ready(&form).await; + let state = init_state_with_local_db().await; + let conn_id = create_mysql_connection_for_state(&state, &form, "get-collations-all").await; + + let collations = connection::get_mysql_collations_by_id_direct(&state, conn_id, None) + .await + .expect("get_mysql_collations_by_id should succeed without charset filter"); + + assert!(!collations.is_empty(), "collation list must not be empty"); + assert!( + collations.iter().any(|c| c == "utf8mb4_general_ci"), + "utf8mb4_general_ci must be present" + ); + assert!( + collations.iter().any(|c| c == "utf8_general_ci"), + "utf8_general_ci must be present" + ); + assert!( + collations.windows(2).all(|w| w[0] <= w[1]), + "collations must be sorted" + ); + + let _ = connection::delete_connection_direct(&state, conn_id).await; +} + +#[tokio::test] +#[ignore] +async fn test_mysql_command_get_collations_by_id_with_charset_returns_only_matching() { + let docker = (!mysql_context::should_reuse_local_db()).then(Cli::default); + let (_mysql_container, form) = mysql_context::mysql_form_from_test_context(docker.as_ref()); + wait_until_mysql_ready(&form).await; + let state = init_state_with_local_db().await; + let conn_id = create_mysql_connection_for_state(&state, &form, "get-collations-filtered").await; + + let collations = + connection::get_mysql_collations_by_id_direct(&state, conn_id, Some("utf8mb4".to_string())) + .await + .expect("get_mysql_collations_by_id should succeed with charset filter"); + + assert!( + !collations.is_empty(), + "utf8mb4 collation list must not be empty" + ); + assert!( + collations.iter().any(|c| c == "utf8mb4_general_ci"), + "utf8mb4_general_ci must be present" + ); + // All returned collations must start with the requested charset prefix + for col in &collations { + assert!( + col.starts_with("utf8mb4"), + "collation '{}' does not belong to utf8mb4", + col + ); + } + + let _ = connection::delete_connection_direct(&state, conn_id).await; +} + +#[tokio::test] +#[ignore] +async fn test_mysql_command_get_collations_by_id_with_invalid_charset_returns_error() { + let docker = (!mysql_context::should_reuse_local_db()).then(Cli::default); + let (_mysql_container, form) = mysql_context::mysql_form_from_test_context(docker.as_ref()); + wait_until_mysql_ready(&form).await; + let state = init_state_with_local_db().await; + let conn_id = + create_mysql_connection_for_state(&state, &form, "get-collations-invalid-cs").await; + + let result = connection::get_mysql_collations_by_id_direct( + &state, + conn_id, + Some("utf8 mb4; DROP TABLE users".to_string()), + ) + .await; + + assert!(result.is_err()); + let err = result.err().unwrap_or_default(); + assert!( + err.contains("[VALIDATION_ERROR]"), + "expected VALIDATION_ERROR, got: {}", + err + ); + + let _ = connection::delete_connection_direct(&state, conn_id).await; +} diff --git a/src-tauri/tests/oracle_command_integration.rs b/src-tauri/tests/oracle_command_integration.rs index 8ac73961..0dc57d06 100644 --- a/src-tauri/tests/oracle_command_integration.rs +++ b/src-tauri/tests/oracle_command_integration.rs @@ -1,9 +1,13 @@ #[path = "common/oracle_context.rs"] mod oracle_context; -use dbpaw_lib::commands::{connection, metadata, query}; +use dbpaw_lib::commands::{connection, metadata, query, transfer}; use dbpaw_lib::db::drivers::oracle::OracleDriver; use dbpaw_lib::db::drivers::DatabaseDriver; +use dbpaw_lib::db::local::LocalDb; +use dbpaw_lib::state::AppState; +use std::fs; +use std::sync::Arc; use std::time::{SystemTime, UNIX_EPOCH}; fn unique_table_name(prefix: &str) -> String { @@ -14,6 +18,31 @@ fn unique_table_name(prefix: &str) -> String { format!("{prefix}_{ms}") } +async fn init_state_with_local_db() -> AppState { + let state = AppState::new(); + let local_db_dir = std::env::temp_dir().join(unique_table_name("dbpaw_oracle_localdb_it")); + let db = LocalDb::init_with_app_dir(&local_db_dir) + .await + .expect("failed to initialize local db"); + let mut lock = state.local_db.lock().await; + *lock = Some(Arc::new(db)); + drop(lock); + state +} + +async fn create_oracle_connection_for_state( + state: &AppState, + base_form: &dbpaw_lib::models::ConnectionForm, + suffix: &str, +) -> i64 { + let mut form = base_form.clone(); + form.name = Some(format!("oracle-stateful-{suffix}")); + let created = connection::create_connection_direct(state, form) + .await + .expect("create_connection should succeed"); + created.id +} + async fn prepare_test_table(schema: &str, table: &str, form: &dbpaw_lib::models::ConnectionForm) { let driver = OracleDriver::connect(form) .await @@ -187,9 +216,7 @@ async fn test_oracle_command_execute_insert_affects_rows() { .expect("CREATE TABLE"); driver.close().await; - let sql = format!( - "INSERT INTO \"{schema}\".\"{table}\" (id, name) VALUES (1, 'alpha')" - ); + let sql = format!("INSERT INTO \"{schema}\".\"{table}\" (id, name) VALUES (1, 'alpha')"); let result = query::execute_by_conn_direct(form.clone(), sql) .await .expect("INSERT should succeed"); @@ -237,14 +264,12 @@ async fn test_oracle_command_get_table_data_pagination_works() { } driver.close().await; - let page1 = - query::get_table_data_by_conn(form.clone(), schema.clone(), table.clone(), 1, 2) - .await - .expect("page 1 should succeed"); - let page2 = - query::get_table_data_by_conn(form.clone(), schema.clone(), table.clone(), 2, 2) - .await - .expect("page 2 should succeed"); + let page1 = query::get_table_data_by_conn(form.clone(), schema.clone(), table.clone(), 1, 2) + .await + .expect("page 1 should succeed"); + let page2 = query::get_table_data_by_conn(form.clone(), schema.clone(), table.clone(), 2, 2) + .await + .expect("page 2 should succeed"); assert_eq!(page1.total, 3); assert_eq!(page1.limit, 2); @@ -254,6 +279,85 @@ async fn test_oracle_command_get_table_data_pagination_works() { cleanup_table(&schema, &table, &form).await; } +#[tokio::test] +#[ignore] +async fn test_oracle_command_import_sql_file_supports_create_or_replace_script() { + let form = oracle_context::oracle_form_from_test_context(); + let schema = form + .schema + .clone() + .expect("ORACLE_SCHEMA must be set") + .to_uppercase(); + let state = init_state_with_local_db().await; + let conn_id = create_oracle_connection_for_state(&state, &form, "import-plsql").await; + + let table = unique_table_name("DBPAW_IMPORT_ORA_TBL").to_uppercase(); + let proc_name = unique_table_name("DBPAW_IMPORT_ORA_PROC").to_uppercase(); + let base = std::env::temp_dir().join(unique_table_name("dbpaw_oracle_import_it")); + fs::create_dir_all(&base).expect("create temp transfer dir should succeed"); + let import_sql_path = base.join("import.sql"); + + let import_sql = format!( + r#" +CREATE TABLE "{schema}"."{table}" (id NUMBER(10) PRIMARY KEY, name VARCHAR2(64)); +CREATE OR REPLACE PROCEDURE "{schema}"."{proc_name}" IS +BEGIN + INSERT INTO "{schema}"."{table}" (id, name) VALUES (1, 'from_proc'); +END; +/ +"# + ); + fs::write(&import_sql_path, import_sql).expect("write import sql file should succeed"); + + let import_result = transfer::import_sql_file_direct( + &state, + conn_id, + Some(schema.clone()), + import_sql_path.to_string_lossy().to_string(), + "oracle".to_string(), + ) + .await + .expect("import_sql_file should succeed"); + assert_eq!( + import_result.success_statements, + import_result.total_statements + ); + assert!(import_result.error.is_none()); + + let driver = OracleDriver::connect(&form) + .await + .expect("connect for oracle import verification"); + driver + .execute_query(format!(r#"BEGIN "{schema}"."{proc_name}"(); END;"#)) + .await + .expect("calling imported oracle procedure should succeed"); + let verify = driver + .execute_query(format!(r#"SELECT COUNT(*) AS C FROM "{schema}"."{table}""#)) + .await + .expect("verify oracle imported procedure should succeed"); + let count = verify.data[0]["C"] + .as_i64() + .or_else(|| { + verify.data[0]["C"] + .as_str() + .and_then(|v| v.parse::().ok()) + }) + .expect("oracle count should be numeric"); + assert_eq!(count, 1); + + let _ = driver + .execute_query(format!(r#"DROP PROCEDURE "{schema}"."{proc_name}""#)) + .await; + let _ = driver + .execute_query(format!(r#"DROP TABLE "{schema}"."{table}""#)) + .await; + driver.close().await; + + let _ = connection::delete_connection_direct(&state, conn_id).await; + let _ = fs::remove_file(import_sql_path); + let _ = fs::remove_dir_all(base); +} + #[tokio::test] #[ignore] async fn test_oracle_command_get_table_data_invalid_pagination_returns_error() { diff --git a/src-tauri/tests/oracle_integration.rs b/src-tauri/tests/oracle_integration.rs index 02c216e4..bb84b256 100644 --- a/src-tauri/tests/oracle_integration.rs +++ b/src-tauri/tests/oracle_integration.rs @@ -96,7 +96,10 @@ async fn test_oracle_integration_flow() { "structure should have columns" ); assert!( - structure.columns.iter().any(|c| c.name == "ID" && c.primary_key), + structure + .columns + .iter() + .any(|c| c.name == "ID" && c.primary_key), "ID column should be marked as primary key" ); assert!( @@ -126,16 +129,7 @@ async fn test_oracle_integration_flow() { // get_table_data let result = driver - .get_table_data( - schema.clone(), - table.clone(), - 1, - 10, - None, - None, - None, - None, - ) + .get_table_data(schema.clone(), table.clone(), 1, 10, None, None, None, None) .await .expect("get_table_data should succeed"); assert_eq!(result.total, 1, "total should be 1"); @@ -148,9 +142,7 @@ async fn test_oracle_integration_flow() { // execute_query SELECT let qr = driver - .execute_query(format!( - "SELECT id, name FROM \"{schema}\".\"{table}\"" - )) + .execute_query(format!("SELECT id, name FROM \"{schema}\".\"{table}\"")) .await .expect("execute_query SELECT should succeed"); assert!(qr.success); @@ -250,5 +242,8 @@ async fn test_oracle_integration_connection_failure() { let result = OracleDriver::connect(&form).await; assert!(result.is_err(), "wrong password should fail"); let err = result.err().expect("should have an error"); - assert!(err.contains("[CONN_FAILED]"), "error should be tagged CONN_FAILED"); + assert!( + err.contains("[CONN_FAILED]"), + "error should be tagged CONN_FAILED" + ); } diff --git a/src-tauri/tests/postgres_integration.rs b/src-tauri/tests/postgres_integration.rs index 025581bf..88561143 100644 --- a/src-tauri/tests/postgres_integration.rs +++ b/src-tauri/tests/postgres_integration.rs @@ -872,8 +872,14 @@ async fn test_postgres_array_types_decoded_as_json_arrays() { let floats8 = r1["floats8"].as_array().expect("floats8 should be array"); assert_eq!(floats8.len(), 2); - assert!(floats8[0].as_f64().map(|v| (v - 3.14).abs() < 0.01).unwrap_or(false), - "floats8[0] should be ~3.14, got {:?}", floats8[0]); + assert!( + floats8[0] + .as_f64() + .map(|v| (v - 3.14).abs() < 0.01) + .unwrap_or(false), + "floats8[0] should be ~3.14, got {:?}", + floats8[0] + ); let texts = r1["texts"].as_array().expect("texts should be array"); assert_eq!(texts.len(), 2); @@ -894,22 +900,42 @@ async fn test_postgres_array_types_decoded_as_json_arrays() { let r2 = &result.data[1]; let ints2_null = r2["ints2"].as_array().expect("ints2 row2 should be array"); - assert_eq!(ints2_null[0], serde_json::Value::Null, "first element should be NULL"); + assert_eq!( + ints2_null[0], + serde_json::Value::Null, + "first element should be NULL" + ); assert_eq!(ints2_null[1].as_i64().unwrap_or(-1), 5); let ints4_null = r2["ints4"].as_array().expect("ints4 row2 should be array"); - assert_eq!(ints4_null[0], serde_json::Value::Null, "first int4 element should be NULL"); + assert_eq!( + ints4_null[0], + serde_json::Value::Null, + "first int4 element should be NULL" + ); let texts_null = r2["texts"].as_array().expect("texts row2 should be array"); assert_eq!(texts_null[0].as_str().unwrap_or(""), "x"); - assert_eq!(texts_null[1], serde_json::Value::Null, "middle text element should be NULL"); + assert_eq!( + texts_null[1], + serde_json::Value::Null, + "middle text element should be NULL" + ); assert_eq!(texts_null[2].as_str().unwrap_or(""), "z"); let bools_null = r2["bools"].as_array().expect("bools row2 should be array"); - assert_eq!(bools_null[0], serde_json::Value::Null, "bool element should be NULL"); + assert_eq!( + bools_null[0], + serde_json::Value::Null, + "bool element should be NULL" + ); // column-level NULL (entire array is NULL) - assert_eq!(r2["ints8"], serde_json::Value::Null, "whole ints8 column should be NULL"); + assert_eq!( + r2["ints8"], + serde_json::Value::Null, + "whole ints8 column should be NULL" + ); // ---- row 3: empty arrays ---- let r3 = &result.data[2]; diff --git a/src-tauri/tests/postgres_stateful_command_integration.rs b/src-tauri/tests/postgres_stateful_command_integration.rs index fe1cc684..51385564 100644 --- a/src-tauri/tests/postgres_stateful_command_integration.rs +++ b/src-tauri/tests/postgres_stateful_command_integration.rs @@ -734,6 +734,110 @@ async fn test_postgres_command_transfer_export_and_import_minimal_flow() { let _ = connection::delete_connection_direct(&state, conn_id).await; } +#[tokio::test] +#[ignore] +async fn test_postgres_command_import_sql_file_supports_function_trigger_script() { + let docker = (!postgres_context::should_reuse_local_db()).then(Cli::default); + let (_container, form) = postgres_context::postgres_form_from_test_context(docker.as_ref()); + wait_until_postgres_ready(&form).await; + let state = init_state_with_local_db().await; + let conn_id = create_postgres_connection_for_state(&state, &form, "import-trigger").await; + let database = form + .database + .clone() + .unwrap_or_else(|| "postgres".to_string()); + let schema = "public".to_string(); + let table = unique_name("dbpaw_import_pg_tbl"); + let func_name = unique_name("dbpaw_import_pg_fn"); + let trigger_name = unique_name("dbpaw_import_pg_trg"); + let qualified = format!("\"{}\".\"{}\"", schema, table); + let base = std::env::temp_dir().join(unique_name("dbpaw_postgres_import_it")); + fs::create_dir_all(&base).expect("create temp transfer dir should succeed"); + let import_sql_path = base.join("import.sql"); + + let import_sql = format!( + r#" +CREATE TABLE {qualified} ( + id INT PRIMARY KEY, + name TEXT, + touch_count INT DEFAULT 0 +); +CREATE OR REPLACE FUNCTION "{schema}"."{func_name}"() +RETURNS TRIGGER LANGUAGE plpgsql AS $$ +BEGIN + NEW.touch_count = COALESCE(NEW.touch_count, 0) + 1; + RETURN NEW; +END; +$$; +CREATE TRIGGER "{trigger_name}" +BEFORE INSERT OR UPDATE ON {qualified} +FOR EACH ROW +EXECUTE FUNCTION "{schema}"."{func_name}"(); +"# + ); + fs::write(&import_sql_path, import_sql).expect("write import sql file should succeed"); + + let import_result = transfer::import_sql_file_direct( + &state, + conn_id, + Some(database.clone()), + import_sql_path.to_string_lossy().to_string(), + "postgres".to_string(), + ) + .await + .expect("import_sql_file should succeed"); + assert_eq!( + import_result.success_statements, + import_result.total_statements + ); + assert!(import_result.error.is_none()); + + let driver = PostgresDriver::connect(&form) + .await + .expect("failed to connect postgres driver for verification"); + driver + .execute_query(format!( + "INSERT INTO {} (id, name) VALUES (1, 'alpha')", + qualified + )) + .await + .expect("insert into imported postgres table should succeed"); + driver + .execute_query(format!( + "UPDATE {} SET name = 'beta' WHERE id = 1", + qualified + )) + .await + .expect("update imported postgres table should succeed"); + let verify = driver + .execute_query(format!( + "SELECT touch_count AS c FROM {} WHERE id = 1", + qualified + )) + .await + .expect("verify postgres trigger should succeed"); + let count = verify.data[0]["c"] + .as_str() + .and_then(|v| v.parse::().ok()) + .expect("touch_count should parse"); + assert_eq!(count, 2); + + let _ = driver + .execute_query(format!("DROP TABLE IF EXISTS {}", qualified)) + .await; + let _ = driver + .execute_query(format!( + "DROP FUNCTION IF EXISTS \"{}\".\"{}\"()", + schema, func_name + )) + .await; + driver.close().await; + + let _ = fs::remove_file(import_sql_path); + let _ = fs::remove_dir_all(base); + let _ = connection::delete_connection_direct(&state, conn_id).await; +} + #[tokio::test] #[ignore] async fn test_postgres_command_ai_minimal_provider_conversation_and_chat_flow() { diff --git a/src-tauri/tests/sqlite_command_integration.rs b/src-tauri/tests/sqlite_command_integration.rs index 8f031a47..9b9d5a29 100644 --- a/src-tauri/tests/sqlite_command_integration.rs +++ b/src-tauri/tests/sqlite_command_integration.rs @@ -1,9 +1,13 @@ -use dbpaw_lib::commands::{connection, metadata, query}; +use dbpaw_lib::commands::{connection, metadata, query, transfer}; use dbpaw_lib::db::drivers::sqlite::SqliteDriver; use dbpaw_lib::db::drivers::DatabaseDriver; +use dbpaw_lib::db::local::LocalDb; use dbpaw_lib::models::ConnectionForm; +use dbpaw_lib::state::AppState; use std::env; +use std::fs; use std::path::PathBuf; +use std::sync::Arc; use std::time::{SystemTime, UNIX_EPOCH}; use uuid::Uuid; @@ -24,6 +28,31 @@ fn unique_table_name(prefix: &str) -> String { format!("{}_{}", prefix, millis) } +async fn init_state_with_local_db() -> AppState { + let state = AppState::new(); + let local_db_dir = std::env::temp_dir().join(unique_table_name("dbpaw_sqlite_localdb_it")); + let db = LocalDb::init_with_app_dir(&local_db_dir) + .await + .expect("failed to initialize local db"); + let mut lock = state.local_db.lock().await; + *lock = Some(Arc::new(db)); + drop(lock); + state +} + +async fn create_sqlite_connection_for_state( + state: &AppState, + base_form: &ConnectionForm, + suffix: &str, +) -> i64 { + let mut form = base_form.clone(); + form.name = Some(format!("sqlite-stateful-{suffix}")); + let created = connection::create_connection_direct(state, form) + .await + .expect("create_connection should succeed"); + created.id +} + async fn prepare_query_test_table(form: &ConnectionForm, table: &str) { let driver = SqliteDriver::connect(form) .await @@ -258,6 +287,92 @@ async fn test_sqlite_command_execute_by_conn_insert_affects_rows() { let _ = std::fs::remove_file(db_path); } +#[tokio::test] +#[ignore] +async fn test_sqlite_command_import_sql_file_supports_trigger_script() { + let db_path = sqlite_test_path(); + let db_path_str = db_path.to_string_lossy().to_string(); + let form = ConnectionForm { + driver: "sqlite".to_string(), + file_path: Some(db_path_str.clone()), + ..Default::default() + }; + + let state = init_state_with_local_db().await; + let conn_id = create_sqlite_connection_for_state(&state, &form, "import-trigger").await; + let base = std::env::temp_dir().join(unique_table_name("dbpaw_sqlite_import_it")); + fs::create_dir_all(&base).expect("create temp transfer dir should succeed"); + let import_sql_path = base.join("import.sql"); + + let source_table = unique_table_name("dbpaw_sqlite_src"); + let audit_table = unique_table_name("dbpaw_sqlite_audit"); + let trigger_name = unique_table_name("dbpaw_sqlite_trg"); + let import_sql = format!( + r#" +CREATE TABLE {source_table} (id INTEGER PRIMARY KEY, name TEXT); +CREATE TABLE {audit_table} (entry_id INTEGER PRIMARY KEY AUTOINCREMENT, source_id INTEGER, action TEXT); +CREATE TRIGGER {trigger_name} +AFTER INSERT ON {source_table} +BEGIN + INSERT INTO {audit_table}(source_id, action) VALUES (NEW.id, 'insert'); +END; +"# + ); + fs::write(&import_sql_path, import_sql).expect("write import sql file should succeed"); + + let import_result = transfer::import_sql_file_direct( + &state, + conn_id, + Some("main".to_string()), + import_sql_path.to_string_lossy().to_string(), + "sqlite".to_string(), + ) + .await + .expect("import_sql_file should succeed"); + assert_eq!( + import_result.success_statements, + import_result.total_statements + ); + assert!(import_result.error.is_none()); + + let driver = SqliteDriver::connect(&form) + .await + .expect("failed to connect sqlite driver for verification"); + driver + .execute_query(format!( + "INSERT INTO {} (id, name) VALUES (1, 'alpha')", + source_table + )) + .await + .expect("insert into imported sqlite table should succeed"); + let verify = driver + .execute_query(format!("SELECT COUNT(*) AS c FROM {}", audit_table)) + .await + .expect("verify sqlite trigger should succeed"); + let count = verify.data[0]["c"] + .as_i64() + .or_else(|| { + verify.data[0]["c"] + .as_str() + .and_then(|v| v.parse::().ok()) + }) + .expect("audit count should be numeric"); + assert_eq!(count, 1); + + let _ = driver + .execute_query(format!("DROP TABLE IF EXISTS {}", source_table)) + .await; + let _ = driver + .execute_query(format!("DROP TABLE IF EXISTS {}", audit_table)) + .await; + driver.close().await; + + let _ = connection::delete_connection_direct(&state, conn_id).await; + let _ = fs::remove_file(import_sql_path); + let _ = fs::remove_dir_all(base); + let _ = fs::remove_file(db_path); +} + #[tokio::test] #[ignore] async fn test_sqlite_command_get_table_data_by_conn_pagination_works() { diff --git a/src-tauri/tests/starrocks_command_integration.rs b/src-tauri/tests/starrocks_command_integration.rs new file mode 100644 index 00000000..2b0a9215 --- /dev/null +++ b/src-tauri/tests/starrocks_command_integration.rs @@ -0,0 +1,103 @@ +#[path = "common/starrocks_context.rs"] +mod starrocks_context; + +use dbpaw_lib::commands::connection::{self, CreateDatabasePayload}; +use dbpaw_lib::db::drivers::mysql::MysqlDriver; +use dbpaw_lib::db::drivers::DatabaseDriver; +use dbpaw_lib::db::local::LocalDb; +use dbpaw_lib::models::ConnectionForm; +use dbpaw_lib::state::AppState; +use std::sync::Arc; +use std::time::{SystemTime, UNIX_EPOCH}; +use testcontainers::clients::Cli; +use tokio::time::{sleep, Duration}; + +fn unique_name(prefix: &str) -> String { + let millis = SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("time should be after unix epoch") + .as_millis(); + format!("{}_{}", prefix, millis) +} + +async fn wait_until_starrocks_ready(form: &ConnectionForm) { + let mut last_error = String::new(); + for _ in 0..90 { + match connection::test_connection_ephemeral(form.clone()).await { + Ok(_) => return, + Err(err) => { + last_error = err; + sleep(Duration::from_secs(1)).await; + } + } + } + panic!("starrocks is not ready for command tests: {last_error}"); +} + +async fn init_state_with_local_db() -> AppState { + let state = AppState::new(); + let local_db_dir = std::env::temp_dir().join(unique_name("dbpaw_starrocks_stateful_it")); + let db = LocalDb::init_with_app_dir(&local_db_dir) + .await + .expect("failed to initialize local db"); + let mut lock = state.local_db.lock().await; + *lock = Some(Arc::new(db)); + drop(lock); + state +} + +async fn create_starrocks_connection_for_state( + state: &AppState, + base_form: &ConnectionForm, +) -> i64 { + let mut form = base_form.clone(); + form.name = Some(unique_name("starrocks-command")); + let created = connection::create_connection_direct(state, form) + .await + .expect("create_connection should succeed"); + created.id +} + +async fn drop_database_if_exists(form: &ConnectionForm, db_name: &str) { + let driver = MysqlDriver::connect(form) + .await + .expect("failed to connect starrocks driver for cleanup"); + let _ = driver + .execute_query(format!("DROP DATABASE IF EXISTS `{}`", db_name)) + .await; + driver.close().await; +} + +#[tokio::test] +#[ignore] +async fn test_starrocks_command_create_database_by_id_success() { + let docker = (!starrocks_context::should_reuse_local_db()).then(Cli::default); + let (_container, form) = starrocks_context::starrocks_form_from_test_context(docker.as_ref()); + wait_until_starrocks_ready(&form).await; + + let state = init_state_with_local_db().await; + let conn_id = create_starrocks_connection_for_state(&state, &form).await; + + let db_name = unique_name("dbpaw_starrocks_cmd_db"); + let payload = CreateDatabasePayload { + name: db_name.clone(), + if_not_exists: Some(true), + charset: None, + collation: None, + encoding: None, + lc_collate: None, + lc_ctype: None, + }; + + connection::create_database_by_id_direct(&state, conn_id, payload) + .await + .expect("create_database_by_id should succeed"); + + let dbs = connection::list_databases_by_id_direct(&state, conn_id) + .await + .expect("list_databases_by_id should succeed"); + assert!(dbs.iter().any(|d| d == &db_name)); + + drop_database_if_exists(&form, &db_name).await; + let _ = connection::delete_connection_direct(&state, conn_id).await; +} diff --git a/src-tauri/tests/starrocks_integration.rs b/src-tauri/tests/starrocks_integration.rs new file mode 100644 index 00000000..e01dd64a --- /dev/null +++ b/src-tauri/tests/starrocks_integration.rs @@ -0,0 +1,120 @@ +#[path = "common/starrocks_context.rs"] +mod starrocks_context; + +use dbpaw_lib::db::drivers::mysql::MysqlDriver; +use dbpaw_lib::db::drivers::DatabaseDriver; +use std::time::{SystemTime, UNIX_EPOCH}; +use testcontainers::clients::Cli; + +fn unique_name(prefix: &str) -> String { + let millis = SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("time should be after unix epoch") + .as_millis(); + format!("{}_{}", prefix, millis) +} + +#[tokio::test] +#[ignore] +async fn test_starrocks_integration_flow() { + let docker = (!starrocks_context::should_reuse_local_db()).then(Cli::default); + let (_container, form) = starrocks_context::starrocks_form_from_test_context(docker.as_ref()); + let driver: MysqlDriver = + starrocks_context::connect_with_retry(|| MysqlDriver::connect(&form)).await; + + driver + .test_connection() + .await + .expect("test_connection failed"); + + let dbs = driver + .list_databases() + .await + .expect("list_databases failed"); + assert!(!dbs.is_empty(), "list_databases returned empty"); + + let db_name = unique_name("dbpaw_starrocks_it"); + let table_name = "events"; + let qualified = format!("`{}`.`{}`", db_name, table_name); + + driver + .execute_query(format!("CREATE DATABASE IF NOT EXISTS `{}`", db_name)) + .await + .expect("create database failed"); + + let _ = driver + .execute_query(format!("DROP TABLE IF EXISTS {}", qualified)) + .await; + + driver + .execute_query(format!("CREATE TABLE {} (id INT, name STRING)", qualified)) + .await + .expect("create table failed"); + + driver + .execute_query(format!( + "INSERT INTO {} (id, name) VALUES (1, 'hello')", + qualified + )) + .await + .expect("insert failed"); + + let tables = driver + .list_tables(Some(db_name.clone())) + .await + .expect("list_tables failed"); + assert!( + tables.iter().any(|t| t.name == table_name), + "list_tables should include {}", + table_name + ); + + let metadata = driver + .get_table_metadata(db_name.clone(), table_name.to_string()) + .await + .expect("get_table_metadata failed"); + assert!( + metadata.columns.iter().any(|c| c.name == "name"), + "metadata should include name column" + ); + + let ddl = driver + .get_table_ddl(db_name.clone(), table_name.to_string()) + .await + .expect("get_table_ddl failed"); + assert!( + ddl.to_uppercase().contains("CREATE TABLE"), + "DDL should contain CREATE TABLE" + ); + + let result = driver + .execute_query(format!("SELECT id, name FROM {} WHERE id = 1", qualified)) + .await + .expect("select failed"); + assert_eq!(result.row_count, 1); + assert_eq!( + result.data[0]["name"], + serde_json::Value::String("hello".to_string()) + ); + + let table_data = driver + .get_table_data( + db_name.clone(), + table_name.to_string(), + 1, + 100, + None, + None, + None, + None, + ) + .await + .expect("get_table_data failed"); + assert_eq!(table_data.total, 1); + assert_eq!(table_data.data.len(), 1); + + let _ = driver + .execute_query(format!("DROP DATABASE IF EXISTS `{}`", db_name)) + .await; + driver.close().await; +} diff --git a/src/App.tsx b/src/App.tsx index f3b30334..f2f5bc00 100644 --- a/src/App.tsx +++ b/src/App.tsx @@ -17,10 +17,12 @@ import { Tabs, TabsContent, TabsList, TabsTrigger } from "@/components/ui/tabs"; import { Sidebar } from "@/components/business/Sidebar/Sidebar"; import { SaveQueryDialog } from "@/components/business/Editor/SaveQueryDialog"; import { TableView } from "@/components/business/DataGrid/TableView"; +import { ErrorBoundary } from "@/components/ErrorBoundary"; import { TableMetadataView } from "@/components/business/Metadata/TableMetadataView"; import { SqlExecutionLogsDropdown } from "@/components/business/SqlLogs/SqlExecutionLogsDialog"; import { FileCode, Table, X, Settings, Sparkles } from "lucide-react"; import { Button } from "@/components/ui/button"; +import { isMysqlFamilyDriver } from "@/lib/driver-registry"; import { ContextMenu, ContextMenuContent, @@ -186,10 +188,7 @@ export default function App() { schemaOverride?: string, ) => { const isDatabaseScoped = - driver === "mysql" || - driver === "tidb" || - driver === "mariadb" || - driver === "clickhouse"; + (driver && isMysqlFamilyDriver(driver as any)) || driver === "clickhouse"; const normalizedSchemaOverride = (schemaOverride || "").trim(); return { schema: isDatabaseScoped @@ -235,6 +234,7 @@ export default function App() { const [isUnsavedConfirmOpen, setIsUnsavedConfirmOpen] = useState(false); const [isCloseSaveDialogOpen, setIsCloseSaveDialogOpen] = useState(false); const [sidebarLayout, setSidebarLayout] = useState("tabs"); + const [showColumnComments, setShowColumnComments] = useState(false); const closeSaveCompletedRef = useRef(false); const unsavedConfirmActionRef = useRef<"save" | "discard" | null>(null); const schemaOverviewRequestKeysRef = useRef>(new Map()); @@ -269,6 +269,7 @@ export default function App() { setSidebarLayout(layout === "tree" ? "tree" : "tabs"); }, ); + void getSetting("showColumnComments", false).then(setShowColumnComments); }, []); const sensors = useSensors( @@ -824,7 +825,7 @@ export default function App() { table: string; driver: string; }, - format: "csv" | "json" | "sql", + format: "csv" | "json" | "sql_dml" | "sql_ddl" | "sql_full", filePath: string, ) => { try { @@ -1616,125 +1617,131 @@ export default function App() { value={tab.id} className="h-full m-0" > - {tab.type === "editor" ? ( - - } - > - handleExecuteQuery(tab.id, sql)} - onCancel={() => - tab.connectionId && tab.activeQueryId - ? api.query.cancel( - String(tab.connectionId), - tab.activeQueryId, - ) - : Promise.resolve(false) + + {tab.type === "editor" ? ( + } - isExecuting={!!tab.activeQueryId} - queryResults={tab.queryResults} - value={tab.sqlContent} - onChange={(sql) => handleSqlChange(tab.id, sql)} - onDatabaseChange={(database) => - void handleEditorDatabaseChange(tab.id, database) + > + + handleExecuteQuery(tab.id, sql) + } + onCancel={() => + tab.connectionId && tab.activeQueryId + ? api.query.cancel( + String(tab.connectionId), + tab.activeQueryId, + ) + : Promise.resolve(false) + } + isExecuting={!!tab.activeQueryId} + queryResults={tab.queryResults} + value={tab.sqlContent} + onChange={(sql) => handleSqlChange(tab.id, sql)} + onDatabaseChange={(database) => + void handleEditorDatabaseChange( + tab.id, + database, + ) + } + connectionId={tab.connectionId} + driver={tab.driver} + schemaOverview={tab.schemaOverview} + savedQueryId={tab.savedQueryId} + initialName={ + isDefaultQueryTitle(tab.title) ? "" : tab.title + } + initialDescription={tab.savedQueryDescription} + onSaveSuccess={(savedQuery) => { + setQueriesLastUpdated(Date.now()); + setTabs((prev) => + prev.map((t) => { + if (t.id === tab.id) { + return { + ...t, + savedQueryId: savedQuery.id, + title: savedQuery.name, + savedQueryDescription: + savedQuery.description || undefined, + sqlContent: savedQuery.query, + lastSavedSql: savedQuery.query, + isDirty: false, + }; + } + return t; + }), + ); + }} + /> + + ) : tab.type === "table" ? ( + handlePageChange(tab.id, p)} + onPageSizeChange={(size) => + handlePageSizeChange(tab.id, size) } - connectionId={tab.connectionId} - driver={tab.driver} - schemaOverview={tab.schemaOverview} - savedQueryId={tab.savedQueryId} - initialName={ - isDefaultQueryTitle(tab.title) ? "" : tab.title + sortColumn={tab.sortColumn} + sortDirection={tab.sortDirection} + onSortChange={(col, dir) => + handleSortChange(tab.id, col, dir) + } + filter={tab.filter} + orderBy={tab.orderBy} + onFilterChange={(f, ob) => + handleFilterChange(tab.id, f, ob) + } + onOpenDDL={handleOpenTableDDL} + onDataRefresh={(params) => + handleTableRefresh(tab.id, params) } - initialDescription={tab.savedQueryDescription} - onSaveSuccess={(savedQuery) => { - setQueriesLastUpdated(Date.now()); - setTabs((prev) => - prev.map((t) => { - if (t.id === tab.id) { - return { - ...t, - savedQueryId: savedQuery.id, - title: savedQuery.name, - savedQueryDescription: - savedQuery.description || undefined, - sqlContent: savedQuery.query, - lastSavedSql: savedQuery.query, - isDirty: false, - }; + onCreateQuery={handleCreateQuery} + tableContext={ + tab.connectionId && + tab.database && + tab.tableName && + tab.driver + ? { + connectionId: tab.connectionId, + database: tab.database, + schema: + isMysqlFamilyDriver(tab.driver as any) || + tab.driver === "clickhouse" + ? tab.database + : tab.driver === "mssql" + ? "dbo" + : tab.driver === "duckdb" + ? "main" + : "public", + table: tab.tableName, + driver: tab.driver, } - return t; - }), - ); - }} + : undefined + } + showColumnComments={showColumnComments} + /> + ) : tab.connectionId && + tab.database && + tab.schema && + tab.tableName ? ( + - - ) : tab.type === "table" ? ( - handlePageChange(tab.id, p)} - onPageSizeChange={(size) => - handlePageSizeChange(tab.id, size) - } - sortColumn={tab.sortColumn} - sortDirection={tab.sortDirection} - onSortChange={(col, dir) => - handleSortChange(tab.id, col, dir) - } - filter={tab.filter} - orderBy={tab.orderBy} - onFilterChange={(f, ob) => - handleFilterChange(tab.id, f, ob) - } - onOpenDDL={handleOpenTableDDL} - onDataRefresh={(params) => - handleTableRefresh(tab.id, params) - } - onCreateQuery={handleCreateQuery} - tableContext={ - tab.connectionId && - tab.database && - tab.tableName && - tab.driver - ? { - connectionId: tab.connectionId, - database: tab.database, - schema: - tab.driver === "mysql" || - tab.driver === "tidb" || - tab.driver === "mariadb" || - tab.driver === "clickhouse" - ? tab.database - : tab.driver === "mssql" - ? "dbo" - : tab.driver === "duckdb" - ? "main" - : "public", - table: tab.tableName, - driver: tab.driver, - } - : undefined - } - /> - ) : tab.connectionId && - tab.database && - tab.schema && - tab.tableName ? ( - - ) : null} + ) : null} + )) )} @@ -1817,6 +1824,8 @@ export default function App() { onOpenChange={setOpenSettings} sidebarLayout={sidebarLayout} onSidebarLayoutChange={setSidebarLayout} + showColumnComments={showColumnComments} + onShowColumnCommentsChange={setShowColumnComments} /> )} diff --git a/src/components/ErrorBoundary.tsx b/src/components/ErrorBoundary.tsx new file mode 100644 index 00000000..960a3c11 --- /dev/null +++ b/src/components/ErrorBoundary.tsx @@ -0,0 +1,47 @@ +import { Component } from "react"; +import type { ReactNode, ErrorInfo } from "react"; + +interface Props { + children: ReactNode; + /** Custom fallback rendered instead of the default error UI. */ + fallback?: ReactNode; +} + +interface State { + error: Error | null; +} + +export class ErrorBoundary extends Component { + state: State = { error: null }; + + static getDerivedStateFromError(error: Error): State { + return { error }; + } + + componentDidCatch(error: Error, info: ErrorInfo) { + console.error( + "[ErrorBoundary] Uncaught render error:", + error, + info.componentStack, + ); + } + + render() { + const { error } = this.state; + if (!error) return this.props.children; + if (this.props.fallback) return this.props.fallback; + + return ( +
+
+

+ Something went wrong +

+
+            {error.message}
+          
+
+
+ ); + } +} diff --git a/src/components/business/DataGrid/ComplexValueViewer.tsx b/src/components/business/DataGrid/ComplexValueViewer.tsx index 708daf60..49456256 100644 --- a/src/components/business/DataGrid/ComplexValueViewer.tsx +++ b/src/components/business/DataGrid/ComplexValueViewer.tsx @@ -25,7 +25,8 @@ function TreeNode({ depth?: number; }) { const [expanded, setExpanded] = useState(depth < 2); - const isComplex = value !== null && value !== undefined && typeof value === "object"; + const isComplex = + value !== null && value !== undefined && typeof value === "object"; const isArr = Array.isArray(value); if (!isComplex) { @@ -38,7 +39,9 @@ function TreeNode({ className="flex items-baseline py-[2px] text-xs font-mono leading-5" style={{ paddingLeft: depth * 14 + 18 }} > - {label} + + {label} + : 0 && - arr.every((item) => item !== null && typeof item === "object" && !Array.isArray(item)); + arr.every( + (item) => + item !== null && typeof item === "object" && !Array.isArray(item), + ); const keys = allObjects ? Array.from(new Set(arr.flatMap((item) => Object.keys(item as object)))) @@ -110,21 +116,31 @@ function TableView({ value }: { value: unknown }) { {keys ? ( keys.map((k) => ( - + {k} )) ) : ( <> - # - value + + # + + + value + )} {arr.map((row, i) => ( - + {keys ? ( keys.map((k) => { const v = (row as Record)[k]; @@ -140,7 +156,9 @@ function TableView({ value }: { value: unknown }) { ) : ( <> {i} - + {cellText(row)} @@ -157,15 +175,26 @@ function TableView({ value }: { value: unknown }) { - - + + {Object.entries(value as Record).map(([k, v]) => ( - - - + + @@ -211,7 +240,9 @@ export function ComplexValueViewer({ {/* Header */}
- {columnName} + + {columnName} + {typeLabel} diff --git a/src/components/business/DataGrid/TableView.tsx b/src/components/business/DataGrid/TableView.tsx index cddbf226..1e05a301 100644 --- a/src/components/business/DataGrid/TableView.tsx +++ b/src/components/business/DataGrid/TableView.tsx @@ -76,6 +76,7 @@ import { calculateAutoColumnWidths, canMutateClickHouseTable, collectSearchMatches, + createSingleAndDoubleClickHandler, escapeSQL, cellValueToString, formatCellValue, @@ -147,6 +148,7 @@ interface TableViewProps { driver: string; }; isLoading?: boolean; + showColumnComments?: boolean; } export function TableView({ @@ -170,6 +172,7 @@ export function TableView({ onCreateQuery, tableContext, isLoading, + showColumnComments = false, }: TableViewProps) { const { t } = useTranslation(); const PAGE_SIZE_OPTIONS = ["10", "50", "100", "200", "500", "1000"] as const; @@ -178,6 +181,11 @@ export function TableView({ const [pageInput, setPageInput] = useState(String(page)); const [pageSizeInput, setPageSizeInput] = useState(String(pageSize)); const [columnWidths, setColumnWidths] = useState>({}); + const columnWidthsRef = useRef>({}); + columnWidthsRef.current = columnWidths; + const headerClickStateRef = useRef< + Record | null }> + >({}); // Reset column widths when columns definition changes (e.g. switching tables) const prevColumnsRef = useRef(""); @@ -189,19 +197,20 @@ export function TableView({ } }, [columns]); - // Auto-calculate column widths based on content + // Auto-calculate column widths based on content. + // Read columnWidths via ref to avoid re-triggering the effect on every width update. useEffect(() => { const newWidths = calculateAutoColumnWidths({ data, columns, - columnWidths, + columnWidths: columnWidthsRef.current, }); const hasChanges = Object.keys(newWidths).length > 0; if (hasChanges) { setColumnWidths((prev) => ({ ...prev, ...newWidths })); } - }, [data, columns, columnWidths]); + }, [data, columns]); useEffect(() => { setWhereInput(controlledFilter || ""); @@ -568,7 +577,9 @@ export function TableView({ // Check if there's a pending change for this cell const key = `${rowIndex}_${col}`; const pending = pendingChanges.get(key); - const value = pending ? pending.newValue : cellValueToString(currentValue); + const value = pending + ? pending.newValue + : cellValueToString(currentValue); setEditingCell({ row: rowIndex, col }); setEditValue(value); setSelectedCell({ row: rowIndex, col }); @@ -646,6 +657,26 @@ export function TableView({ }); }, []); + const handleHeaderCopy = useCallback( + (column: string) => { + void navigator.clipboard + .writeText(column) + .then(() => { + toast.success( + t("tableView.toast.columnNameCopied", { + column, + }), + ); + }) + .catch((error) => { + toast.error("Failed to copy", { + description: error instanceof Error ? error.message : String(error), + }); + }); + }, + [t], + ); + const selectSingleRow = useCallback((rowIndex: number) => { const nextSelectedRows = new Set([rowIndex]); selectedRowsRef.current = nextSelectedRows; @@ -1355,6 +1386,18 @@ export function TableView({ document.body.style.cursor = "col-resize"; }; + useEffect(() => { + const clickStates = headerClickStateRef.current; + return () => { + Object.values(clickStates).forEach((state) => { + if (state.timerId) { + clearTimeout(state.timerId); + state.timerId = null; + } + }); + }; + }, []); + useEffect(() => { return () => { document.removeEventListener("mousemove", handleMouseMove); @@ -1672,22 +1715,37 @@ export function TableView({
{tableContext && onCreateQuery && ( - + <> + + + + )} {(canInsert || canUpdateDelete) && ( <> @@ -1695,20 +1753,19 @@ export function TableView({ )} {canUpdateDelete && ( )} @@ -1789,7 +1845,9 @@ export function TableView({ JSON void handleExport("current_page", "sql")} + onClick={() => + void handleExport("current_page", "sql_dml") + } > SQL @@ -1811,7 +1869,7 @@ export function TableView({ JSON void handleExport("filtered", "sql")} + onClick={() => void handleExport("filtered", "sql_dml")} > SQL @@ -1833,7 +1891,9 @@ export function TableView({ JSON void handleExport("full_table", "sql")} + onClick={() => + void handleExport("full_table", "sql_dml") + } > SQL @@ -1841,18 +1901,6 @@ export function TableView({ - - {tableContext && ( - - )}
@@ -1932,6 +1980,17 @@ export function TableView({ const direction = isSorted ? activeSortDirection : undefined; const comment = columnComments[column]?.trim(); const headerTooltip = comment || column; + const headerActionLabel = t("tableView.header.actionHint", { + column, + }); + const headerClickState = + headerClickStateRef.current[column] ?? + (headerClickStateRef.current[column] = { timerId: null }); + const headerInteraction = createSingleAndDoubleClickHandler( + headerClickState, + () => handleHeaderCopy(column), + () => handleSortClick(column), + ); return ( {currentData.map((row, rowIndex) => { + if (!row || typeof row !== "object") return null; const isEditing = (col: string) => editingCell?.row === rowIndex && editingCell?.col === col; const isSelected = (col: string) => @@ -2087,7 +2157,11 @@ export function TableView({ {displayValue !== null && displayValue !== undefined ? ( {formatCellValue(displayValue)} @@ -2103,10 +2177,22 @@ export function TableView({ onMouseDown={(e) => e.stopPropagation()} onClick={(e) => { e.stopPropagation(); - setComplexViewer({ value: displayValue, columnName: column }); + setComplexViewer({ + value: displayValue, + columnName: column, + }); }} > - + @@ -2309,7 +2395,9 @@ export function TableView({ value={complexViewer.value} columnName={complexViewer.columnName} open={true} - onOpenChange={(open) => { if (!open) setComplexViewer(null); }} + onOpenChange={(open) => { + if (!open) setComplexViewer(null); + }} /> )} diff --git a/src/components/business/DataGrid/tableView/utils.ts b/src/components/business/DataGrid/tableView/utils.ts index 6e781f6b..3ed607f1 100644 --- a/src/components/business/DataGrid/tableView/utils.ts +++ b/src/components/business/DataGrid/tableView/utils.ts @@ -1,3 +1,5 @@ +import { isMysqlFamilyDriver } from "@/lib/driver-registry"; + export interface SearchMatch { row: number; col: string; @@ -12,6 +14,42 @@ export interface InsertColumnMeta { primaryKey?: boolean; } +export interface HeaderInteractionState { + timerId: ReturnType | null; +} + +export function createSingleAndDoubleClickHandler( + state: HeaderInteractionState, + onSingleClick: () => void, + onDoubleClick: () => void, + delayMs = 250, +) { + return { + handleClick() { + if (state.timerId) { + clearTimeout(state.timerId); + } + state.timerId = setTimeout(() => { + state.timerId = null; + onSingleClick(); + }, delayMs); + }, + handleDoubleClick() { + if (state.timerId) { + clearTimeout(state.timerId); + state.timerId = null; + } + onDoubleClick(); + }, + cancelPendingClick() { + if (state.timerId) { + clearTimeout(state.timerId); + state.timerId = null; + } + }, + }; +} + export function isInsertColumnRequired( column: Pick, ): boolean { @@ -138,9 +176,7 @@ export function escapeSQL(value: string): string { export function quoteIdent(driver: string | undefined, name: string): string { if ( - driver === "mysql" || - driver === "tidb" || - driver === "mariadb" || + (driver && isMysqlFamilyDriver(driver as any)) || driver === "clickhouse" ) { return `\`${name}\``; @@ -246,7 +282,7 @@ export function getQualifiedTableName( schema: string, table: string, ): string { - if (driver === "mysql" || driver === "tidb" || driver === "mariadb") { + if (isMysqlFamilyDriver(driver as any)) { return quoteIdent(driver, table); } diff --git a/src/components/business/DataGrid/tableView/utils.unit.test.ts b/src/components/business/DataGrid/tableView/utils.unit.test.ts index d2184a45..86449a3e 100644 --- a/src/components/business/DataGrid/tableView/utils.unit.test.ts +++ b/src/components/business/DataGrid/tableView/utils.unit.test.ts @@ -5,6 +5,7 @@ import { calculateAutoColumnWidths, canMutateClickHouseTable, collectSearchMatches, + createSingleAndDoubleClickHandler, escapeSQL, formatCellValue, formatInsertSQLValue, @@ -30,6 +31,7 @@ describe("formatSQLValue", () => { expect(formatSQLValue("false", true, "execution", "mysql")).toBe("FALSE"); expect(formatSQLValue("true", true, "execution", "tidb")).toBe("TRUE"); expect(formatSQLValue("false", true, "execution", "mariadb")).toBe("FALSE"); + expect(formatSQLValue("true", true, "execution", "starrocks")).toBe("TRUE"); }); test("throws for invalid boolean in execution mode", () => { @@ -131,6 +133,12 @@ describe("getQualifiedTableName", () => { ); }); + test("uses unqualified table with backticks for starrocks", () => { + expect(getQualifiedTableName("starrocks", "analytics", "events")).toBe( + "`events`", + ); + }); + test("does not qualify sqlite main/public schema", () => { expect(getQualifiedTableName("sqlite", "main", "users")).toBe('"users"'); expect(getQualifiedTableName("sqlite", "public", "users")).toBe('"users"'); @@ -208,6 +216,7 @@ describe("quoteIdent", () => { expect(quoteIdent("mysql", "my_table")).toBe("`my_table`"); expect(quoteIdent("tidb", "my_table")).toBe("`my_table`"); expect(quoteIdent("mariadb", "my_table")).toBe("`my_table`"); + expect(quoteIdent("starrocks", "my_table")).toBe("`my_table`"); expect(quoteIdent("clickhouse", "my_table")).toBe("`my_table`"); }); @@ -263,6 +272,60 @@ describe("sortRows", () => { }); }); +describe("createSingleAndDoubleClickHandler", () => { + test("runs single-click action after delay", async () => { + const calls: string[] = []; + const state = { timerId: null }; + const handler = createSingleAndDoubleClickHandler( + state, + () => calls.push("copy"), + () => calls.push("sort"), + 10, + ); + + handler.handleClick(); + await new Promise((resolve) => setTimeout(resolve, 20)); + + expect(calls).toEqual(["copy"]); + expect(state.timerId).toBeNull(); + }); + + test("double click cancels pending single-click action", () => { + const calls: string[] = []; + const state = { timerId: null }; + const handler = createSingleAndDoubleClickHandler( + state, + () => calls.push("copy"), + () => calls.push("sort"), + 20, + ); + + handler.handleClick(); + handler.handleDoubleClick(); + + expect(calls).toEqual(["sort"]); + expect(state.timerId).toBeNull(); + }); + + test("cancelPendingClick clears pending single-click action", async () => { + const calls: string[] = []; + const state = { timerId: null }; + const handler = createSingleAndDoubleClickHandler( + state, + () => calls.push("copy"), + () => calls.push("sort"), + 10, + ); + + handler.handleClick(); + handler.cancelPendingClick(); + await new Promise((resolve) => setTimeout(resolve, 20)); + + expect(calls).toEqual([]); + expect(state.timerId).toBeNull(); + }); +}); + describe("collectSearchMatches", () => { const data = [ { id: 1, name: "Alice" }, @@ -292,15 +355,28 @@ describe("collectSearchMatches", () => { test("skips null and undefined cell values", () => { const withNulls = [{ id: null, name: undefined }]; - const matches = collectSearchMatches(withNulls, ["id", "name"], "null", identity); + const matches = collectSearchMatches( + withNulls, + ["id", "name"], + "null", + identity, + ); expect(matches).toEqual([]); }); }); describe("calculateAutoColumnWidths", () => { test("returns empty object for empty data or columns", () => { - expect(calculateAutoColumnWidths({ data: [], columns: ["a"], columnWidths: {} })).toEqual({}); - expect(calculateAutoColumnWidths({ data: [{ a: 1 }], columns: [], columnWidths: {} })).toEqual({}); + expect( + calculateAutoColumnWidths({ data: [], columns: ["a"], columnWidths: {} }), + ).toEqual({}); + expect( + calculateAutoColumnWidths({ + data: [{ a: 1 }], + columns: [], + columnWidths: {}, + }), + ).toEqual({}); }); test("skips columns with a pre-set width", () => { @@ -415,7 +491,12 @@ describe("formatCellValue", () => { }); test("object with many keys → shows first 2 keys and remainder count", () => { - const result = formatCellValue({ id: 1, name: "x", role: "admin", score: 99 }); + const result = formatCellValue({ + id: 1, + name: "x", + role: "admin", + score: 99, + }); expect(result).toMatch(/^\{id, name, \.\.\. \+2\}$/); }); @@ -438,17 +519,19 @@ describe("formatCellValue: integration with collectSearchMatches", () => { { id: 2, meta: { role: "user", tags: [] } }, ]; const identity = (_row: number, _col: string, val: any) => val; - const matches = collectSearchMatches(data, ["id", "meta"], "admin", identity); + const matches = collectSearchMatches( + data, + ["id", "meta"], + "admin", + identity, + ); expect(matches.length).toBe(1); expect(matches[0].row).toBe(0); expect(matches[0].col).toBe("meta"); }); test("array fields are searchable by content", () => { - const data = [ - { tags: ["read", "write"] }, - { tags: ["read"] }, - ]; + const data = [{ tags: ["read", "write"] }, { tags: ["read"] }]; const identity = (_row: number, _col: string, val: any) => val; const matches = collectSearchMatches(data, ["tags"], "write", identity); expect(matches.length).toBe(1); @@ -480,7 +563,9 @@ describe("formatCellValue: PostgreSQL array column output", () => { }); test("text array displays as compact JSON string array", () => { - expect(formatCellValue(["postgres", "arrays"])).toBe('["postgres","arrays"]'); + expect(formatCellValue(["postgres", "arrays"])).toBe( + '["postgres","arrays"]', + ); }); test("bool array displays as compact JSON", () => { @@ -492,7 +577,10 @@ describe("formatCellValue: PostgreSQL array column output", () => { }); test("jsonb array (array of objects) displays as full JSON", () => { - const val = [{ source: "web", valid: true }, { source: "app", valid: false }]; + const val = [ + { source: "web", valid: true }, + { source: "app", valid: false }, + ]; expect(formatCellValue(val)).toBe(JSON.stringify(val)); }); @@ -537,7 +625,12 @@ describe("collectSearchMatches: PostgreSQL array columns are searchable", () => const identity = (_row: number, _col: string, val: any) => val; test("finds match inside text array content", () => { - const matches = collectSearchMatches(data, ["id", "tags"], "jsonb", identity); + const matches = collectSearchMatches( + data, + ["id", "tags"], + "jsonb", + identity, + ); expect(matches.length).toBe(1); expect(matches[0].row).toBe(0); expect(matches[0].col).toBe("tags"); diff --git a/src/components/business/Editor/SqlEditor.tsx b/src/components/business/Editor/SqlEditor.tsx index d7c06a2b..aaaa9202 100644 --- a/src/components/business/Editor/SqlEditor.tsx +++ b/src/components/business/Editor/SqlEditor.tsx @@ -372,6 +372,7 @@ export function SqlEditor({ mysql: "mysql", tidb: "mysql", mariadb: "mysql", + starrocks: "mysql", sqlite: "sqlite", duckdb: "sqlite", clickhouse: "sql", @@ -517,6 +518,7 @@ export function SqlEditor({ case "mysql": case "tidb": case "mariadb": + case "starrocks": return MySQL; case "sqlite": case "duckdb": @@ -872,7 +874,7 @@ export function SqlEditor({ JSON void handleExportResult("sql")} + onClick={() => void handleExportResult("sql_dml")} > SQL diff --git a/src/components/business/Sidebar/ConnectionList.tsx b/src/components/business/Sidebar/ConnectionList.tsx index c0756bfc..64755cc3 100644 --- a/src/components/business/Sidebar/ConnectionList.tsx +++ b/src/components/business/Sidebar/ConnectionList.tsx @@ -166,7 +166,6 @@ const defaultForm: ConnectionForm = { sshUsername: "", }; - const defaultCreateDatabaseForm: CreateDatabaseForm = { name: "", ifNotExists: true, @@ -178,14 +177,50 @@ const defaultCreateDatabaseForm: CreateDatabaseForm = { }; const createDbNoneOption = "__none__"; -const mysqlCharsetOptions = ["utf8mb4", "utf8", "latin1"]; -const mysqlCollationOptions = [ - "utf8mb4_general_ci", - "utf8mb4_unicode_ci", - "utf8_general_ci", - "latin1_swedish_ci", +const postgresEncodingOptions = [ + "UTF8", + "SQL_ASCII", + "BIG5", + "EUC_CN", + "EUC_JP", + "EUC_JIS_2004", + "EUC_KR", + "EUC_TW", + "GB18030", + "GBK", + "ISO_8859_5", + "ISO_8859_6", + "ISO_8859_7", + "ISO_8859_8", + "JOHAB", + "KOI8R", + "KOI8U", + "LATIN1", + "LATIN2", + "LATIN3", + "LATIN4", + "LATIN5", + "LATIN6", + "LATIN7", + "LATIN8", + "LATIN9", + "LATIN10", + "MULE_INTERNAL", + "SHIFT_JIS_2004", + "SJIS", + "UHC", + "WIN866", + "WIN874", + "WIN1250", + "WIN1251", + "WIN1252", + "WIN1253", + "WIN1254", + "WIN1255", + "WIN1256", + "WIN1257", + "WIN1258", ]; -const postgresEncodingOptions = ["UTF8", "LATIN1", "SQL_ASCII"]; const postgresLocaleOptions = [ "en_US.UTF-8", "C", @@ -195,9 +230,60 @@ const postgresLocaleOptions = [ ]; const mssqlCollationOptions = [ "SQL_Latin1_General_CP1_CI_AS", + "SQL_Latin1_General_CP1_CS_AS", + "SQL_Latin1_General_CP1_CI_AI", + "SQL_Latin1_General_CP1_CS_AI", + "Latin1_General_CI_AS", + "Latin1_General_CS_AS", + "Latin1_General_BIN", + "Latin1_General_BIN2", + "Latin1_General_100_CI_AS", + "Latin1_General_100_CS_AS", + "Latin1_General_100_CI_AI", + "Latin1_General_100_BIN2", "Latin1_General_100_CI_AS_SC", + "Latin1_General_100_CS_AS_SC", + "Latin1_General_100_CI_AI_SC", + "Latin1_General_100_BIN2_UTF8", + "Latin1_General_100_CI_AS_SC_UTF8", + "Latin1_General_100_CI_AI_SC_UTF8", + "SQL_Latin1_General_CP850_CI_AS", + "Modern_Spanish_CI_AS", + "Modern_Spanish_100_CI_AS", + "French_CI_AS", + "French_100_CI_AS", + "German_PhoneBook_CI_AS", + "German_PhoneBook_100_CI_AS", + "Turkish_CI_AS", + "Turkish_100_CI_AS", + "Cyrillic_General_CI_AS", + "Cyrillic_General_100_CI_AS", "Chinese_PRC_CI_AS", + "Chinese_PRC_CS_AS", + "Chinese_PRC_100_CI_AS", + "Chinese_PRC_100_CS_AS", + "Chinese_PRC_100_BIN2", + "Chinese_PRC_100_CI_AS_SC", + "Chinese_PRC_100_CI_AS_SC_UTF8", + "Chinese_Simplified_Pinyin_100_CI_AS", + "Chinese_Simplified_Pinyin_100_CS_AS", + "Chinese_Traditional_Stroke_Order_100_CI_AS", "Japanese_CI_AS", + "Japanese_CS_AS", + "Japanese_BIN2", + "Japanese_XJIS_100_CI_AS", + "Japanese_XJIS_100_CS_AS", + "Japanese_XJIS_100_BIN2", + "Japanese_XJIS_140_CI_AS", + "Japanese_XJIS_140_CI_AS_KS_WS", + "Japanese_Bushu_Kakusu_100_CI_AS", + "Japanese_Bushu_Kakusu_140_CI_AS", + "Korean_Wansung_CI_AS", + "Korean_Wansung_100_CI_AS", + "Korean_Wansung_140_CI_AS", + "Korean_Unicode_CI_AS", + "Korean_Unicode_100_CI_AS", + "Korean_Unicode_140_CI_AS", ]; interface ConnectionListProps { onTableSelect?: ( @@ -222,7 +308,7 @@ interface ConnectionListProps { table: string; driver: string; }, - format: "csv" | "json" | "sql", + format: "csv" | "json" | "sql_dml" | "sql_ddl" | "sql_full", filePath: string, ) => void; activeTableTarget?: { @@ -324,6 +410,9 @@ export function ConnectionList({ const [createDbForm, setCreateDbForm] = useState( defaultCreateDatabaseForm, ); + const [mysqlCharsets, setMysqlCharsets] = useState([]); + const [mysqlCollations, setMysqlCollations] = useState([]); + const [loadingMysqlOptions, setLoadingMysqlOptions] = useState(false); const [testMsg, setTestMsg] = useState<{ ok: boolean; text: string; @@ -361,13 +450,48 @@ export function ConnectionList({ [connections, createDbConnectionId], ); const createDbTargetDriver = createDbTargetConnection?.type; - const isMySqlFamilyCreateDb = - createDbTargetDriver === "mysql" || - createDbTargetDriver === "mariadb" || - createDbTargetDriver === "tidb"; + const isMySqlFamilyCreateDb = createDbTargetDriver + ? isMysqlFamilyDriver(createDbTargetDriver as any) + : false; const isPostgresCreateDb = createDbTargetDriver === "postgres"; const isMssqlCreateDb = createDbTargetDriver === "mssql"; + useEffect(() => { + if ( + !isCreateDbDialogOpen || + !isMySqlFamilyCreateDb || + !createDbConnectionId + ) + return; + setLoadingMysqlOptions(true); + api.connections + .getMysqlCharsets(Number(createDbConnectionId)) + .then(setMysqlCharsets) + .catch(() => setMysqlCharsets(["utf8mb4", "utf8", "latin1"])) + .finally(() => setLoadingMysqlOptions(false)); + }, [isCreateDbDialogOpen, isMySqlFamilyCreateDb, createDbConnectionId]); + + useEffect(() => { + if ( + !isCreateDbDialogOpen || + !isMySqlFamilyCreateDb || + !createDbConnectionId + ) + return; + api.connections + .getMysqlCollations( + Number(createDbConnectionId), + createDbForm.charset || undefined, + ) + .then(setMysqlCollations) + .catch(() => setMysqlCollations([])); + }, [ + isCreateDbDialogOpen, + isMySqlFamilyCreateDb, + createDbConnectionId, + createDbForm.charset, + ]); + const getConnectionStatusLabel = (connection: Connection) => { if (connection.connectState === "success") { return t("connection.status.connected"); @@ -1613,7 +1737,7 @@ export function ConnectionList({ connection: Connection, database: DatabaseInfo, table: TableInfo, - format: "csv" | "json" | "sql", + format: "csv" | "json" | "sql_dml" | "sql_ddl" | "sql_full", ) => { if (!onExportTable) return; if (!isTauri()) { @@ -2524,12 +2648,38 @@ export function ConnectionList({ connection, database, table, - "sql", + "sql_dml", + ) + } + > + + {t("connection.menu.exportSqlDml")} + + + void handleTableExport( + connection, + database, + table, + "sql_ddl", + ) + } + > + + {t("connection.menu.exportSqlDdl")} + + + void handleTableExport( + connection, + database, + table, + "sql_full", ) } > - {t("connection.menu.exportSql")} + {t("connection.menu.exportSqlFull")} @@ -2757,12 +2907,38 @@ export function ConnectionList({ connection, database, table, - "sql", + "sql_dml", + ) + } + > + + {t("connection.menu.exportSqlDml")} + + + void handleTableExport( + connection, + database, + table, + "sql_ddl", + ) + } + > + + {t("connection.menu.exportSqlDdl")} + + + void handleTableExport( + connection, + database, + table, + "sql_full", ) } > - {t("connection.menu.exportSql")} + {t("connection.menu.exportSqlFull")} @@ -3034,6 +3210,8 @@ export function ConnectionList({ setCreateDbConnectionId(null); setShowCreateDbAdvanced(false); setCreateDbForm(defaultCreateDatabaseForm); + setMysqlCharsets([]); + setMysqlCollations([]); } }} > @@ -3093,25 +3271,31 @@ export function ConnectionList({ px + +
+
+ +

+ {t("settings.dataGrid.showColumnCommentsDescription")} +

+
+ +
diff --git a/src/lib/connection-form/rules.ts b/src/lib/connection-form/rules.ts index 52769c32..f9023ffd 100644 --- a/src/lib/connection-form/rules.ts +++ b/src/lib/connection-form/rules.ts @@ -1,8 +1,5 @@ import type { ConnectionForm, Driver } from "@/services/api"; -import { - isMysqlFamilyDriver, - isFileBasedDriver, -} from "@/lib/driver-registry"; +import { isMysqlFamilyDriver, isFileBasedDriver } from "@/lib/driver-registry"; export { isMysqlFamilyDriver, isFileBasedDriver }; diff --git a/src/lib/connection-form/rules.unit.test.ts b/src/lib/connection-form/rules.unit.test.ts index ee3ee846..21daf857 100644 --- a/src/lib/connection-form/rules.unit.test.ts +++ b/src/lib/connection-form/rules.unit.test.ts @@ -15,6 +15,7 @@ describe("isMysqlFamilyDriver", () => { expect(isMysqlFamilyDriver("mysql")).toBe(true); expect(isMysqlFamilyDriver("mariadb")).toBe(true); expect(isMysqlFamilyDriver("tidb")).toBe(true); + expect(isMysqlFamilyDriver("starrocks")).toBe(true); }); test("rejects non-mysql drivers", () => { @@ -38,12 +39,14 @@ describe("isFileBasedDriver", () => { describe("allowsHostWithPort / requiresPasswordOnCreate", () => { test("only mysql family allows host:port notation", () => { expect(allowsHostWithPort("mysql")).toBe(true); + expect(allowsHostWithPort("starrocks")).toBe(true); expect(allowsHostWithPort("postgres")).toBe(false); }); test("non-mysql drivers require password on create", () => { expect(requiresPasswordOnCreate("postgres")).toBe(true); expect(requiresPasswordOnCreate("mysql")).toBe(false); + expect(requiresPasswordOnCreate("starrocks")).toBe(false); }); }); @@ -144,14 +147,14 @@ describe("normalizeConnectionFormInput", () => { test("uses embedded host port even when a default port is already set", () => { const normalized = normalizeConnectionFormInput({ - driver: "mysql", - host: " db:3307 ", - port: 3306, + driver: "starrocks", + host: " db:9031 ", + port: 9030, password: "", } as any); expect(normalized.host).toBe("db"); - expect(normalized.port).toBe(3307); + expect(normalized.port).toBe(9031); }); test("does not split host:port for non-mysql drivers", () => { diff --git a/src/lib/driver-registry.tsx b/src/lib/driver-registry.tsx index 71044fd5..24e25277 100644 --- a/src/lib/driver-registry.tsx +++ b/src/lib/driver-registry.tsx @@ -18,6 +18,7 @@ const DRIVER_IDS = [ "mysql", "mariadb", "tidb", + "starrocks", "sqlite", "duckdb", "clickhouse", @@ -102,6 +103,18 @@ export const DRIVER_REGISTRY: DriverConfig[] = [ importCapability: "supported", icon: () => renderSimpleIcon(siMysql), }, + { + id: "starrocks", + label: "StarRocks", + defaultPort: 9030, + isFileBased: false, + isMysqlFamily: true, + supportsSSLCA: true, + supportsSchemaBrowsing: false, + supportsCreateDatabase: true, + importCapability: "unsupported", + icon: () => , + }, { id: "sqlite", label: "SQLite", @@ -190,7 +203,9 @@ export const getConnectionIcon = ( ): ReactNode => { const config = DRIVER_REGISTRY.find((d) => d.id === driver); if (config) return config.icon(); - const normalized = String(driver || "").trim().toLowerCase(); + const normalized = String(driver || "") + .trim() + .toLowerCase(); if (normalized === "postgresql" || normalized === "pgsql") return getConnectionIcon("postgres"); if (normalized === "sqlite3") return getConnectionIcon("sqlite"); diff --git a/src/lib/driver-registry.unit.test.ts b/src/lib/driver-registry.unit.test.ts index 8b13331a..a8ef772d 100644 --- a/src/lib/driver-registry.unit.test.ts +++ b/src/lib/driver-registry.unit.test.ts @@ -15,18 +15,19 @@ import { // ─── Registry completeness ──────────────────────────────────────────────────── describe("DRIVER_REGISTRY", () => { - test("contains all 9 supported drivers", () => { + test("contains all 10 supported drivers", () => { const ids = DRIVER_REGISTRY.map((d) => d.id); expect(ids).toContain("postgres"); expect(ids).toContain("mysql"); expect(ids).toContain("mariadb"); expect(ids).toContain("tidb"); + expect(ids).toContain("starrocks"); expect(ids).toContain("sqlite"); expect(ids).toContain("duckdb"); expect(ids).toContain("clickhouse"); expect(ids).toContain("mssql"); expect(ids).toContain("oracle"); - expect(DRIVER_REGISTRY).toHaveLength(9); + expect(DRIVER_REGISTRY).toHaveLength(10); }); test("has no duplicate IDs", () => { @@ -99,6 +100,7 @@ describe("getDriverConfig", () => { test("returns the correct config for each driver", () => { expect(getDriverConfig("postgres").label).toBe("PostgreSQL"); expect(getDriverConfig("mysql").label).toBe("MySQL"); + expect(getDriverConfig("starrocks").label).toBe("StarRocks"); expect(getDriverConfig("mssql").label).toBe("SQL Server"); expect(getDriverConfig("clickhouse").label).toBe("ClickHouse"); expect(getDriverConfig("duckdb").label).toBe("DuckDB"); @@ -113,6 +115,7 @@ describe("getDefaultPort", () => { expect(getDefaultPort("mysql")).toBe(3306); expect(getDefaultPort("mariadb")).toBe(3306); expect(getDefaultPort("tidb")).toBe(4000); + expect(getDefaultPort("starrocks")).toBe(9030); expect(getDefaultPort("clickhouse")).toBe(8123); expect(getDefaultPort("mssql")).toBe(1433); }); @@ -137,6 +140,7 @@ describe("isFileBasedDriver", () => { "mysql", "mariadb", "tidb", + "starrocks", "clickhouse", "mssql", ]; @@ -153,6 +157,7 @@ describe("isMysqlFamilyDriver", () => { expect(isMysqlFamilyDriver("mysql")).toBe(true); expect(isMysqlFamilyDriver("mariadb")).toBe(true); expect(isMysqlFamilyDriver("tidb")).toBe(true); + expect(isMysqlFamilyDriver("starrocks")).toBe(true); }); test("returns false for non-MySQL drivers", () => { @@ -177,6 +182,7 @@ describe("supportsSSLCA", () => { expect(supportsSSLCA("mysql")).toBe(true); expect(supportsSSLCA("mariadb")).toBe(true); expect(supportsSSLCA("tidb")).toBe(true); + expect(supportsSSLCA("starrocks")).toBe(true); }); test("returns false for drivers without SSL CA support", () => { @@ -195,6 +201,7 @@ describe("supportsCreateDatabase", () => { expect(supportsCreateDatabase("mysql")).toBe(true); expect(supportsCreateDatabase("mariadb")).toBe(true); expect(supportsCreateDatabase("tidb")).toBe(true); + expect(supportsCreateDatabase("starrocks")).toBe(true); expect(supportsCreateDatabase("clickhouse")).toBe(true); expect(supportsCreateDatabase("mssql")).toBe(true); }); @@ -218,6 +225,7 @@ describe("supportsSchemaBrowsing", () => { "mysql", "mariadb", "tidb", + "starrocks", "sqlite", "duckdb", "clickhouse", @@ -241,6 +249,10 @@ describe("importCapability", () => { ); }); + test("starrocks import is unsupported", () => { + expect(getDriverConfig("starrocks").importCapability).toBe("unsupported"); + }); + test("all other drivers are supported", () => { const supported: Driver[] = [ "postgres", diff --git a/src/lib/i18n/locales/en.ts b/src/lib/i18n/locales/en.ts index 1308fe3b..86eac42a 100644 --- a/src/lib/i18n/locales/en.ts +++ b/src/lib/i18n/locales/en.ts @@ -101,6 +101,12 @@ export const en = { fontSizeDescription: "Adjust global text size across the app (Range: {{min}}-{{max}}px)", }, + dataGrid: { + title: "Data Grid", + showColumnComments: "Show Column Comments", + showColumnCommentsDescription: + "Display column comments in small text below the column name in table headers", + }, layout: { title: "Layout", modeTitle: "Sidebar Layout", @@ -292,7 +298,9 @@ export const en = { importSqlReadOnly: "Import SQL (Read-only, unsupported)", exportCsv: "Export as CSV", exportJson: "Export as JSON", - exportSql: "Export as SQL", + exportSqlDml: "Export DML (INSERT)", + exportSqlDdl: "Export DDL (CREATE)", + exportSqlFull: "Export DML & DDL", }, importDialog: { title: "Import SQL", @@ -427,6 +435,14 @@ export const en = { }, untitled: "Untitled", }, + tableView: { + header: { + actionHint: "Click to copy column name, double-click to sort {{column}}", + }, + toast: { + columnNameCopied: "Copied column name: {{column}}", + }, + }, } as const; type DeepStringify = { diff --git a/src/lib/i18n/locales/ja.ts b/src/lib/i18n/locales/ja.ts index 47968fab..4ce27966 100644 --- a/src/lib/i18n/locales/ja.ts +++ b/src/lib/i18n/locales/ja.ts @@ -103,6 +103,12 @@ export const ja: Translations = { fontSizeDescription: "アプリ全体の文字サイズを調整します(範囲:{{min}}-{{max}}px)", }, + dataGrid: { + title: "データグリッド", + showColumnComments: "カラムコメントを表示", + showColumnCommentsDescription: + "テーブルヘッダーの列名の下に小さなテキストでカラムコメントを表示します", + }, layout: { title: "レイアウト", modeTitle: "サイドバー配置", @@ -295,7 +301,9 @@ export const ja: Translations = { importSqlReadOnly: "SQL をインポート(読み取り専用で未対応)", exportCsv: "CSV としてエクスポート", exportJson: "JSON としてエクスポート", - exportSql: "SQL としてエクスポート", + exportSqlDml: "DML をエクスポート (INSERT)", + exportSqlDdl: "DDL をエクスポート (CREATE)", + exportSqlFull: "DML & DDL をエクスポート", }, importDialog: { title: "SQL をインポート", @@ -388,6 +396,14 @@ export const ja: Translations = { send: "送信", sendMessage: "メッセージを送信", }, + tableView: { + header: { + actionHint: "{{column}} をクリックで列名コピー、ダブルクリックで並び替え", + }, + toast: { + columnNameCopied: "列名をコピーしました: {{column}}", + }, + }, tableSelector: { emptyLabel: "テーブルスキーマを選択(データなし)", selectedLabel: "スキーマ: {{count}} 件選択中", diff --git a/src/lib/i18n/locales/zh.ts b/src/lib/i18n/locales/zh.ts index 9b7f6271..d4c21871 100644 --- a/src/lib/i18n/locales/zh.ts +++ b/src/lib/i18n/locales/zh.ts @@ -101,6 +101,12 @@ export const zh: Translations = { fontSizeTitle: "字体大小", fontSizeDescription: "调整应用全局文字大小(范围:{{min}}-{{max}}px)", }, + dataGrid: { + title: "数据表格", + showColumnComments: "显示字段注释", + showColumnCommentsDescription: + "在表头列名下方以小字体显示字段的 Comment 注释", + }, layout: { title: "布局", modeTitle: "侧边栏布局", @@ -287,7 +293,9 @@ export const zh: Translations = { importSqlReadOnly: "导入 SQL(只读,不支持)", exportCsv: "导出为 CSV", exportJson: "导出为 JSON", - exportSql: "导出为 SQL", + exportSqlDml: "导出 DML (INSERT)", + exportSqlDdl: "导出 DDL (CREATE)", + exportSqlFull: "导出 DML & DDL", }, importDialog: { title: "导入 SQL", @@ -373,6 +381,14 @@ export const zh: Translations = { send: "发送", sendMessage: "发送消息", }, + tableView: { + header: { + actionHint: "单击复制字段名,双击按 {{column}} 排序", + }, + toast: { + columnNameCopied: "已复制字段名:{{column}}", + }, + }, tableSelector: { emptyLabel: "选择表结构(不含数据)", selectedLabel: "已选结构:{{count}} 项", diff --git a/src/main.tsx b/src/main.tsx index b9556bf3..af07b995 100644 --- a/src/main.tsx +++ b/src/main.tsx @@ -5,25 +5,28 @@ import { ThemeProvider } from "./components/theme-provider"; import { Toaster } from "./components/ui/sonner"; import "./lib/i18n"; import { initI18nFromStore } from "./lib/i18n"; +import { ErrorBoundary } from "./components/ErrorBoundary"; const renderApp = async () => { await initI18nFromStore(); if (import.meta.env.PROD) { document.addEventListener("contextmenu", (event) => { const target = event.target as HTMLElement | null; - const allowNative = target?.closest( - 'input, textarea, [contenteditable="true"]', - ); + const allowNative = + event.altKey || + target?.closest('input, textarea, [contenteditable="true"]'); if (!allowNative) { event.preventDefault(); } }); } createRoot(document.getElementById("root")!).render( - - - - , + + + + + + , ); }; diff --git a/src/services/api.ts b/src/services/api.ts index bcdad2d2..b91be7c6 100644 --- a/src/services/api.ts +++ b/src/services/api.ts @@ -299,7 +299,12 @@ export interface AIChatResponse { assistantMessageId: number; } -export type TransferFormat = "csv" | "json" | "sql"; +export type TransferFormat = + | "csv" + | "json" + | "sql_dml" + | "sql_ddl" + | "sql_full"; export type ExportScope = | "current_page" | "filtered" @@ -467,6 +472,10 @@ export const api = { delete: (id: number) => invoke("delete_connection", { id }), createDatabase: (id: number, payload: CreateDatabasePayload) => invoke("create_database_by_id", { id, payload }), + getMysqlCharsets: (id: number) => + invoke("get_mysql_charsets_by_id", { id }), + getMysqlCollations: (id: number, charset?: string) => + invoke("get_mysql_collations_by_id", { id, charset }), testEphemeral: (form: ConnectionForm) => invoke("test_connection_ephemeral", { form }), listSqliteIssues: () => diff --git a/src/services/api.unit.test.ts b/src/services/api.unit.test.ts index 258fd4dd..7b236364 100644 --- a/src/services/api.unit.test.ts +++ b/src/services/api.unit.test.ts @@ -67,6 +67,7 @@ describe("normalizeImportDriver", () => { "mysql", "mariadb", "tidb", + "starrocks", "sqlite", "duckdb", "mssql", @@ -93,6 +94,10 @@ describe("getImportDriverCapability", () => { ); }); + test("starrocks import is unsupported", () => { + expect(getImportDriverCapability("starrocks")).toBe("unsupported"); + }); + test("all writable drivers are supported", () => { const supported = [ "postgres", @@ -168,13 +173,23 @@ describe("invoke: Tauri environment", () => { tauriInvokeImpl = async (cmd, args) => { capturedCmd = cmd; capturedArgs = args; - return { data: [], rowCount: 0, columns: [], timeTakenMs: 0, success: true }; + return { + data: [], + rowCount: 0, + columns: [], + timeTakenMs: 0, + success: true, + }; }; await api.query.execute(42, "SELECT 1", "mydb", "sql_editor"); expect(capturedCmd).toBe("execute_query"); - expect(capturedArgs).toMatchObject({ id: 42, query: "SELECT 1", database: "mydb" }); + expect(capturedArgs).toMatchObject({ + id: 42, + query: "SELECT 1", + database: "mydb", + }); }); test("tauriInvoke error propagates to caller", async () => { @@ -247,7 +262,10 @@ describe("api command mapping", () => { ["list_sql_execution_logs", () => api.sqlLogs.list()], ["list_tables", () => api.metadata.listTables(1)], ["get_table_ddl", () => api.metadata.getTableDDL(1, "db", "public", "t")], - ["get_table_metadata", () => api.metadata.getTableMetadata(1, "db", "public", "t")], + [ + "get_table_metadata", + () => api.metadata.getTableMetadata(1, "db", "public", "t"), + ], ["get_connections", () => api.connections.list()], ["create_connection", () => api.connections.create({ driver: "postgres" })], ["delete_connection", () => api.connections.delete(1)], @@ -257,6 +275,12 @@ describe("api command mapping", () => { ["ai_delete_provider", () => api.ai.providers.delete(1)], ["ai_list_conversations", () => api.ai.conversations.list()], ["cancel_query", () => api.query.cancel("uuid-abc", "qid-1")], + ["get_mysql_charsets_by_id", () => api.connections.getMysqlCharsets(1)], + ["get_mysql_collations_by_id", () => api.connections.getMysqlCollations(1)], + [ + "get_mysql_collations_by_id", + () => api.connections.getMysqlCollations(1, "utf8mb4"), + ], ]; for (const [expectedCmd, callFn] of commands) { diff --git a/src/services/mocks.service.test.ts b/src/services/mocks.service.test.ts index e81b6274..9e862cbf 100644 --- a/src/services/mocks.service.test.ts +++ b/src/services/mocks.service.test.ts @@ -1,5 +1,9 @@ import { describe, expect, test } from "bun:test"; -import { invokeMock } from "./mocks"; +import { + invokeMock, + mockGetMysqlCharsets, + mockGetMysqlCollations, +} from "./mocks"; describe("invokeMock service layer", () => { test("returns table list for metadata command", async () => { @@ -28,3 +32,98 @@ describe("invokeMock service layer", () => { ); }); }); + +describe("mockGetMysqlCharsets", () => { + test("returns a non-empty list", async () => { + const charsets = await mockGetMysqlCharsets(1); + expect(charsets.length).toBeGreaterThan(0); + }); + + test("contains the three most common charsets", async () => { + const charsets = await mockGetMysqlCharsets(1); + expect(charsets).toContain("utf8mb4"); + expect(charsets).toContain("utf8"); + expect(charsets).toContain("latin1"); + }); + + test("contains CJK charsets", async () => { + const charsets = await mockGetMysqlCharsets(1); + expect(charsets).toContain("gbk"); + expect(charsets).toContain("gb18030"); + expect(charsets).toContain("euckr"); + }); + + test("all entries are non-empty strings", async () => { + const charsets = await mockGetMysqlCharsets(1); + for (const cs of charsets) { + expect(typeof cs).toBe("string"); + expect(cs.trim().length).toBeGreaterThan(0); + } + }); +}); + +describe("mockGetMysqlCollations", () => { + test("returns all collations when no charset given", async () => { + const collations = await mockGetMysqlCollations(1); + expect(collations.length).toBeGreaterThan(0); + }); + + test("returns collations for utf8mb4", async () => { + const collations = await mockGetMysqlCollations(1, "utf8mb4"); + expect(collations.length).toBeGreaterThan(0); + expect(collations).toContain("utf8mb4_general_ci"); + expect(collations).toContain("utf8mb4_unicode_ci"); + for (const col of collations) { + expect(col.startsWith("utf8mb4")).toBe(true); + } + }); + + test("returns collations for utf8", async () => { + const collations = await mockGetMysqlCollations(1, "utf8"); + expect(collations).toContain("utf8_general_ci"); + }); + + test("falls back to all collations for unknown charset", async () => { + const all = await mockGetMysqlCollations(1); + const unknown = await mockGetMysqlCollations(1, "euckr"); + expect(unknown.length).toBe(all.length); + }); + + test("all entries are non-empty strings", async () => { + const collations = await mockGetMysqlCollations(1); + for (const col of collations) { + expect(typeof col).toBe("string"); + expect(col.trim().length).toBeGreaterThan(0); + } + }); +}); + +describe("invokeMock charset/collation commands", () => { + test("get_mysql_charsets_by_id returns charset array", async () => { + const charsets = await invokeMock("get_mysql_charsets_by_id", { + id: 1, + }); + expect(Array.isArray(charsets)).toBe(true); + expect(charsets).toContain("utf8mb4"); + }); + + test("get_mysql_collations_by_id without charset returns all collations", async () => { + const collations = await invokeMock( + "get_mysql_collations_by_id", + { id: 1 }, + ); + expect(Array.isArray(collations)).toBe(true); + expect(collations.length).toBeGreaterThan(0); + }); + + test("get_mysql_collations_by_id with charset filters results", async () => { + const collations = await invokeMock( + "get_mysql_collations_by_id", + { id: 1, charset: "utf8mb4" }, + ); + expect(collations).toContain("utf8mb4_general_ci"); + for (const col of collations) { + expect(col.startsWith("utf8mb4")).toBe(true); + } + }); +}); diff --git a/src/services/mocks.ts b/src/services/mocks.ts index 781bec53..d0c406a6 100644 --- a/src/services/mocks.ts +++ b/src/services/mocks.ts @@ -213,7 +213,12 @@ export const mockTableData = { created_at: "2024-01-15 10:30:00", updated_at: "2024-01-15 10:30:00", // object with 4 keys → abbreviated as {role, department, ... +2} - metadata: { role: "admin", department: "engineering", level: 5, active: true }, + metadata: { + role: "admin", + department: "engineering", + level: 5, + active: true, + }, // array with 3 items → [3 items] tags: ["vip", "beta-tester", "early-adopter"], settings: null, @@ -230,7 +235,11 @@ export const mockTableData = { // array with 1 item → inline JSON tags: ["newsletter"], // nested object → tree view shows expand/collapse - settings: { theme: "dark", lang: "zh", notifications: { email: true, sms: false } }, + settings: { + theme: "dark", + lang: "zh", + notifications: { email: true, sms: false }, + }, }, { id: 3, @@ -253,7 +262,11 @@ export const mockTableData = { created_at: "2024-01-18 09:15:00", updated_at: "2024-01-18 09:15:00", // object containing a nested array - metadata: { role: "moderator", permissions: ["read", "write", "delete"], score: 88 }, + metadata: { + role: "moderator", + permissions: ["read", "write", "delete"], + score: 88, + }, tags: ["moderator", "trusted"], settings: null, }, @@ -265,10 +278,18 @@ export const mockTableData = { created_at: "2024-01-19 16:50:00", updated_at: "2024-01-19 16:50:00", // array of objects → table view renders as multi-column table - metadata: [{ key: "plan", value: "pro" }, { key: "trial", value: false }], + metadata: [ + { key: "plan", value: "pro" }, + { key: "trial", value: false }, + ], tags: ["pro"], // object with 4 keys → tree/table view - settings: { theme: "system", lang: "ja", timezone: "Asia/Tokyo", fontSize: 14 }, + settings: { + theme: "system", + lang: "ja", + timezone: "Asia/Tokyo", + fontSize: 14, + }, }, { id: 6, @@ -451,7 +472,10 @@ export const mockArrayTypeData: QueryResult = { scores: [95, 87, 72], flags: [true, false, true], readings: [3.14, 2.72, 1.41], - metadata_list: [{ source: "web", valid: true }, { source: "app", valid: false }], + metadata_list: [ + { source: "web", valid: true }, + { source: "app", valid: false }, + ], }, { id: 2, @@ -761,9 +785,7 @@ export async function mockExecuteQuery( // Return different data based on query type if (lower.includes("select")) { // Dedicated array-type dataset: SELECT * FROM pg_arrays - const isArrayQuery = - lower.includes("pg_arrays") || - lower.includes("array"); + const isArrayQuery = lower.includes("pg_arrays") || lower.includes("array"); // Dedicated complex-type dataset: SELECT * FROM json_test const isComplexQuery = !isArrayQuery && @@ -772,7 +794,11 @@ export async function mockExecuteQuery( lower.includes("jsonb") || lower.includes("complex")); const result = { - ...(isArrayQuery ? mockArrayTypeData : isComplexQuery ? mockComplexTypeData : mockQueryResult), + ...(isArrayQuery + ? mockArrayTypeData + : isComplexQuery + ? mockComplexTypeData + : mockQueryResult), timeTakenMs: Math.floor(Math.random() * 100) + 20, }; appendSqlExecutionLog({ @@ -896,7 +922,12 @@ const mockArrayTestTableMetadata: TableMetadata = { { name: "scores", type: "int4[]", nullable: true, primaryKey: false }, { name: "flags", type: "bool[]", nullable: true, primaryKey: false }, { name: "readings", type: "float8[]", nullable: true, primaryKey: false }, - { name: "metadata_list", type: "jsonb[]", nullable: true, primaryKey: false }, + { + name: "metadata_list", + type: "jsonb[]", + nullable: true, + primaryKey: false, + }, ], indexes: [], foreignKeys: [], @@ -942,6 +973,83 @@ export async function mockListDatabasesById(_id: number): Promise { return mockDatabases; } +/** + * Mock get MySQL charsets + */ +export async function mockGetMysqlCharsets(_id: number): Promise { + await new Promise((resolve) => setTimeout(resolve, 50)); + return [ + "armscii8", + "ascii", + "big5", + "binary", + "cp1250", + "cp1251", + "cp1256", + "cp1257", + "cp850", + "cp852", + "cp866", + "cp932", + "dec8", + "eucjpms", + "euckr", + "gb18030", + "gb2312", + "gbk", + "geostd8", + "greek", + "hebrew", + "hp8", + "keybcs2", + "koi8r", + "koi8u", + "latin1", + "latin2", + "latin5", + "latin7", + "macce", + "macroman", + "sjis", + "swe7", + "tis620", + "ucs2", + "ujis", + "utf16", + "utf16le", + "utf32", + "utf8", + "utf8mb4", + ]; +} + +/** + * Mock get MySQL collations + */ +export async function mockGetMysqlCollations( + _id: number, + charset?: string, +): Promise { + await new Promise((resolve) => setTimeout(resolve, 50)); + const all: Record = { + utf8mb4: [ + "utf8mb4_0900_ai_ci", + "utf8mb4_0900_as_ci", + "utf8mb4_0900_as_cs", + "utf8mb4_bin", + "utf8mb4_general_ci", + "utf8mb4_unicode_ci", + "utf8mb4_unicode_520_ci", + ], + utf8: ["utf8_bin", "utf8_general_ci", "utf8_unicode_ci"], + latin1: ["latin1_bin", "latin1_general_ci", "latin1_swedish_ci"], + ascii: ["ascii_bin", "ascii_general_ci"], + binary: ["binary"], + }; + if (charset && all[charset]) return all[charset]; + return Object.values(all).flat().sort(); +} + /** * Mock get schema overview */ @@ -1357,6 +1465,12 @@ export async function invokeMock(cmd: string, args?: any): Promise { case "create_database_by_id": return mockCreateDatabaseById(args.id, args.payload) as Promise; + case "get_mysql_charsets_by_id": + return mockGetMysqlCharsets(args.id) as Promise; + + case "get_mysql_collations_by_id": + return mockGetMysqlCollations(args.id, args.charset) as Promise; + case "test_connection_ephemeral": return mockTestConnectionEphemeral(args.form) as Promise;
keyvalue + key + + value +
{k} +
+ {k} + {cellText(v)}