diff --git a/docs/docs/build/connectors/olap/starrocks.md b/docs/docs/build/connectors/olap/starrocks.md index ee615af73d5..caf9a1115d4 100644 --- a/docs/docs/build/connectors/olap/starrocks.md +++ b/docs/docs/build/connectors/olap/starrocks.md @@ -7,13 +7,21 @@ sidebar_position: 5 [StarRocks](https://www.starrocks.io/) is an open-source, high-performance analytical database designed for real-time, multi-dimensional analytics on large-scale data. It supports both primary key and aggregate data models, making it suitable for a variety of analytical workloads including real-time dashboards, ad-hoc queries, and complex analytical tasks. -Rill supports connecting to an existing StarRocks cluster via a "live connector" and using it as an OLAP engine built against [external tables](/build/connectors/olap#external-olap-tables) to power Rill dashboards. +:::note Supported Versions + +Rill supports connecting to StarRocks 4.0 or newer versions. + +::: + +:::info + +Rill supports connecting to an existing StarRocks cluster via a read-only OLAP connector and using it to power Rill dashboards with [external tables](/build/connectors/olap#external-olap-tables). ::: ## Connect to StarRocks -When using StarRocks for local development, you can connect via connection parameters or by using the DSN. +When using StarRocks for local development, you can connect via connection parameters or by using a DSN. After selecting "Add Data", select StarRocks and fill in your connection parameters. This will automatically create the `starrocks.yaml` file in your `connectors` directory and populate the `.env` file with `connector.starrocks.password`. @@ -34,7 +42,7 @@ ssl: false ### Connection String (DSN) -Rill can also connect to StarRocks using a DSN connection string. StarRocks uses MySQL protocol, so the connection string follows the MySQL DSN format: +Rill can also connect to StarRocks using a DSN connection string. StarRocks uses MySQL protocol, so the connection string must follow the MySQL DSN format: ```yaml type: connector @@ -43,28 +51,46 @@ driver: starrocks dsn: "{{ .env.connector.starrocks.dsn }}" ``` -The DSN format is: +#### Using default_catalog + +For `default_catalog`, you can specify database directly in the DSN path (MySQL-style): ``` -starrocks://user:password@host:port/database +user:password@tcp(host:9030)/my_database?parseTime=true ``` -Or using MySQL-style format: -``` -user:password@tcp(host:port)/database?parseTime=true +#### Using external catalogs with DSN + +For external catalogs (Iceberg, Hive, etc.), set `catalog` and `database` as separate properties (do not include database in DSN): +```yaml +type: connector +driver: starrocks + +dsn: "user:password@tcp(host:9030)/?parseTime=true" +catalog: iceberg_catalog +database: my_database ``` +If `catalog` is not specified, it defaults to `default_catalog`. + +:::warning DSN Format + +Only MySQL-style DSN format is supported. The `starrocks://` URL scheme is **not** supported. When using DSN, do not set `host`, `port`, `username`, `password` separately — these must be included in the DSN string. + +::: + ## Configuration Properties | Property | Description | Default | |----------|-------------|---------| -| `host` | StarRocks FE (Frontend) server hostname | Required | +| `host` | StarRocks FE (Frontend) server hostname | Required (if no DSN) | | `port` | MySQL protocol port of StarRocks FE | `9030` | -| `username` | Username for authentication | Required | +| `username` | Username for authentication | `root` | | `password` | Password for authentication | - | | `catalog` | StarRocks catalog name (for external catalogs like Iceberg, Hive) | `default_catalog` | | `database` | StarRocks database name | - | | `ssl` | Enable SSL/TLS encryption | `false` | -| `dsn` | Full connection string (alternative to individual parameters) | - | +| `dsn` | MySQL-format connection string (alternative to individual parameters) | - | +| `log_queries` | Enable logging of all SQL queries (useful for debugging) | `false` | ## External Catalogs @@ -95,6 +121,26 @@ StarRocks uses a three-level hierarchy: Catalog > Database > Table. In Rill's AP | `databaseSchema` | Database | `my_database` | | `table` | Table | `my_table` | +## Creating Metrics Views + +When creating metrics views against StarRocks tables, use the `table` property with `database_schema` to reference your data: + +```yaml +type: metrics_view +display_name: My Dashboard +table: my_table +database_schema: my_database +timeseries: timestamp + +dimensions: + - name: category + column: category + +measures: + - name: total_count + expression: COUNT(*) +``` + ## Troubleshooting ### Connection Issues @@ -106,10 +152,14 @@ If you encounter connection issues: 3. Ensure network connectivity to the StarRocks FE node 4. For SSL connections, verify SSL is enabled on the StarRocks server +### Timezone Handling + +All timestamp values are returned in UTC. The driver parses DATETIME values from StarRocks as UTC time. ## Known Limitations -- **Model execution**: Model creation and execution is not yet supported. This feature is under development. +- **Read-only connector**: StarRocks is a read-only OLAP connector. Model creation and execution is not supported. +- **Direct table reference**: Use the `table` property in metrics views instead of `model` to reference StarRocks tables directly. :::info Need help connecting to StarRocks? diff --git a/runtime/drivers/olap.go b/runtime/drivers/olap.go index d91e3503f1b..a255055806e 100644 --- a/runtime/drivers/olap.go +++ b/runtime/drivers/olap.go @@ -493,7 +493,7 @@ func (d Dialect) OrderByExpression(name string, desc bool) string { if desc { res += " DESC" } - if d == DialectDuckDB { + if d == DialectDuckDB || d == DialectStarRocks { res += " NULLS LAST" } return res @@ -619,9 +619,12 @@ func (d Dialect) DateTruncExpr(dim *runtimev1.MetricsViewSpec_Dimension, grain r } return fmt.Sprintf("CAST(date_trunc('%s', %s, 'MILLISECONDS', '%s') AS TIMESTAMP)", specifier, expr, tz), nil case DialectStarRocks: - // StarRocks supports date_trunc similar to DuckDB but does not support timezone parameter - // NOTE: Timezone and time shift parameters are validated in runtime/metricsview/executor/executor_validate.go - return fmt.Sprintf("date_trunc('%s', %s)", specifier, expr), nil + // StarRocks supports date_trunc and CONVERT_TZ for timezone handling + if tz == "" { + return fmt.Sprintf("date_trunc('%s', %s)", specifier, expr), nil + } + // Convert to target timezone, truncate, then convert back to UTC + return fmt.Sprintf("CONVERT_TZ(date_trunc('%s', CONVERT_TZ(%s, 'UTC', '%s')), '%s', 'UTC')", specifier, expr, tz, tz), nil default: return "", fmt.Errorf("unsupported dialect %q", d) } diff --git a/runtime/drivers/starrocks/olap.go b/runtime/drivers/starrocks/olap.go index 7e7248d53bf..996b54d9f87 100644 --- a/runtime/drivers/starrocks/olap.go +++ b/runtime/drivers/starrocks/olap.go @@ -2,6 +2,7 @@ package starrocks import ( "context" + "database/sql" "errors" "fmt" "strings" @@ -86,8 +87,20 @@ func (c *connection) Query(ctx context.Context, stmt *drivers.Statement) (*drive return nil, err } + cts, err := rows.ColumnTypes() + if err != nil { + rows.Close() + return nil, err + } + + starrocksRows := &starrocksRows{ + Rows: rows, + scanDest: prepareScanDest(schema), + colTypes: cts, + } + return &drivers.Result{ - Rows: rows, + Rows: starrocksRows, Schema: schema, }, nil } @@ -150,8 +163,10 @@ func (c *connection) databaseTypeToRuntimeType(dbType string) (*runtimev1.Type, } switch dbType { - case "BOOLEAN", "BOOL", "TINYINT": + case "BOOLEAN", "BOOL": return &runtimev1.Type{Code: runtimev1.Type_CODE_BOOL}, nil + case "TINYINT": + return &runtimev1.Type{Code: runtimev1.Type_CODE_INT8}, nil case "SMALLINT": return &runtimev1.Type{Code: runtimev1.Type_CODE_INT16}, nil case "INT", "INTEGER": @@ -164,8 +179,8 @@ func (c *connection) databaseTypeToRuntimeType(dbType string) (*runtimev1.Type, return &runtimev1.Type{Code: runtimev1.Type_CODE_FLOAT32}, nil case "DOUBLE": return &runtimev1.Type{Code: runtimev1.Type_CODE_FLOAT64}, nil - case "DECIMAL", "DECIMALV2", "DECIMAL32", "DECIMAL64", "DECIMAL128": - return &runtimev1.Type{Code: runtimev1.Type_CODE_DECIMAL}, nil + case "DECIMAL": + return &runtimev1.Type{Code: runtimev1.Type_CODE_STRING}, nil case "CHAR", "VARCHAR", "STRING", "TEXT": return &runtimev1.Type{Code: runtimev1.Type_CODE_STRING}, nil case "DATE": @@ -182,8 +197,122 @@ func (c *connection) databaseTypeToRuntimeType(dbType string) (*runtimev1.Type, return &runtimev1.Type{Code: runtimev1.Type_CODE_STRUCT}, nil case "BINARY", "VARBINARY", "BLOB": // Note: StarRocks doesn't have BLOB type, but MySQL driver may report VARBINARY as BLOB - return &runtimev1.Type{Code: runtimev1.Type_CODE_BYTES}, nil + // Use CODE_STRING like MySQL driver for consistency + return &runtimev1.Type{Code: runtimev1.Type_CODE_STRING}, nil default: return nil, errUnsupportedType } } + +// starrocksRows wraps sqlx.Rows to provide MapScan method. +// This is required because if the correct type is not provided to Scan +// mysql driver just returns byte arrays. +type starrocksRows struct { + *sqlx.Rows + scanDest []any + colTypes []*sql.ColumnType +} + +func (r *starrocksRows) MapScan(dest map[string]any) error { + err := r.Rows.Scan(r.scanDest...) + if err != nil { + return err + } + for i, ct := range r.colTypes { + fieldName := ct.Name() + valPtr := r.scanDest[i] + // Safety guard: prepareScanDest always allocates, but check anyway + if valPtr == nil { + dest[fieldName] = nil + continue + } + switch valPtr := valPtr.(type) { + case *sql.NullBool: + if valPtr.Valid { + dest[fieldName] = valPtr.Bool + } else { + dest[fieldName] = nil + } + case *sql.NullInt16: + if valPtr.Valid { + dest[fieldName] = valPtr.Int16 + } else { + dest[fieldName] = nil + } + case *sql.NullInt32: + if valPtr.Valid { + dest[fieldName] = valPtr.Int32 + } else { + dest[fieldName] = nil + } + case *sql.NullInt64: + if valPtr.Valid { + dest[fieldName] = valPtr.Int64 + } else { + dest[fieldName] = nil + } + case *sql.NullFloat64: + if valPtr.Valid { + dest[fieldName] = valPtr.Float64 + } else { + dest[fieldName] = nil + } + case *sql.NullString: + if valPtr.Valid { + dest[fieldName] = valPtr.String + } else { + dest[fieldName] = nil + } + case *sql.NullTime: + if valPtr.Valid { + dest[fieldName] = valPtr.Time + } else { + dest[fieldName] = nil + } + default: + // Handle ARRAY, MAP, STRUCT, BYTES and other complex types + // These are scanned into *any in prepareScanDest + if ptr, ok := valPtr.(*any); ok { + dest[fieldName] = *ptr + } else { + // Fallback: store the pointer's underlying value directly + dest[fieldName] = valPtr + } + } + } + return nil +} + +func prepareScanDest(schema *runtimev1.StructType) []any { + scanList := make([]any, len(schema.Fields)) + for i, field := range schema.Fields { + var dest any + switch field.Type.Code { + case runtimev1.Type_CODE_BOOL: + dest = &sql.NullBool{} + case runtimev1.Type_CODE_INT8: + dest = &sql.NullInt16{} + case runtimev1.Type_CODE_INT16: + dest = &sql.NullInt16{} + case runtimev1.Type_CODE_INT32: + dest = &sql.NullInt32{} + case runtimev1.Type_CODE_INT64, runtimev1.Type_CODE_INT128: + dest = &sql.NullInt64{} + case runtimev1.Type_CODE_FLOAT32, runtimev1.Type_CODE_FLOAT64: + dest = &sql.NullFloat64{} + case runtimev1.Type_CODE_STRING: + dest = &sql.NullString{} + case runtimev1.Type_CODE_DATE, runtimev1.Type_CODE_TIME: + dest = &sql.NullString{} + case runtimev1.Type_CODE_TIMESTAMP: + // MySQL driver returns DATETIME as time.Time when parseTime=true in DSN + dest = &sql.NullTime{} + case runtimev1.Type_CODE_JSON: + dest = &sql.NullString{} + default: + dest = new(any) + } + scanList[i] = dest + } + return scanList +} diff --git a/runtime/drivers/starrocks/olap_test.go b/runtime/drivers/starrocks/olap_test.go new file mode 100644 index 00000000000..9cf07fe83a7 --- /dev/null +++ b/runtime/drivers/starrocks/olap_test.go @@ -0,0 +1,1384 @@ +package starrocks + +import ( + "context" + "fmt" + "testing" + + runtimev1 "github.com/rilldata/rill/proto/gen/rill/runtime/v1" + "github.com/rilldata/rill/runtime/drivers" + "github.com/rilldata/rill/runtime/drivers/starrocks/teststarrocks" + "github.com/rilldata/rill/runtime/pkg/activity" + "github.com/rilldata/rill/runtime/storage" + "github.com/stretchr/testify/require" + "go.uber.org/zap" +) + +func TestStarRocksOLAP(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test") + } + + dsn := teststarrocks.StartWithData(t) + + conn, err := driver{}.Open("default", map[string]any{ + "dsn": dsn, + }, storage.MustNew(t.TempDir(), nil), activity.NewNoopClient(), zap.NewNop()) + require.NoError(t, err) + defer conn.Close() + + olap, ok := conn.AsOLAP("default") + require.True(t, ok) + + // Basic type tests + t.Run("VarcharNotBinary", func(t *testing.T) { + testVarcharNotBinary(t, olap) + }) + + t.Run("NullHandling", func(t *testing.T) { + testNullHandling(t, olap) + }) + + t.Run("NumericTypes", func(t *testing.T) { + testNumericTypes(t, olap) + }) + + // All types tests + t.Run("AllBasicTypes", func(t *testing.T) { + testAllBasicTypes(t, olap) + }) + + t.Run("DateTimeTypes", func(t *testing.T) { + testDateTimeTypes(t, olap) + }) + + t.Run("StringTypes", func(t *testing.T) { + testStringTypes(t, olap) + }) + + t.Run("BinaryTypes", func(t *testing.T) { + testBinaryTypes(t, olap) + }) + + t.Run("AggregateTypes", func(t *testing.T) { + testAggregateTypes(t, olap) + }) + + t.Run("UnicodeStrings", func(t *testing.T) { + testUnicodeStrings(t, olap) + }) + + t.Run("JSONType", func(t *testing.T) { + testJSONType(t, olap) + }) + + // API tests + t.Run("DryRun", func(t *testing.T) { + testDryRun(t, olap) + }) + + t.Run("Exec", func(t *testing.T) { + testExec(t, olap) + }) + + t.Run("QuerySchema", func(t *testing.T) { + testQuerySchema(t, olap) + }) + + // Result set tests + t.Run("EmptyResultSet", func(t *testing.T) { + testEmptyResultSet(t, olap) + }) + + t.Run("MultipleRows", func(t *testing.T) { + testMultipleRows(t, olap) + }) + + // Boundary and special cases + t.Run("BoundaryValues", func(t *testing.T) { + testBoundaryValues(t, olap) + }) + + t.Run("NullHandlingDetailed", func(t *testing.T) { + testNullHandlingDetailed(t, olap) + }) + + t.Run("NegativeValues", func(t *testing.T) { + testNegativeValues(t, olap) + }) + + t.Run("SpecialCharacters", func(t *testing.T) { + testSpecialCharacters(t, olap) + }) + + // Complex types + t.Run("ComplexTypes", func(t *testing.T) { + testComplexTypes(t, olap) + }) + + // Error cases + t.Run("ErrorCases", func(t *testing.T) { + testErrorCases(t, olap) + }) + + // Other tests + t.Run("ParameterBinding", func(t *testing.T) { + testParameterBinding(t, olap) + }) + + t.Run("SchemaValidation", func(t *testing.T) { + testSchemaValidation(t, olap) + }) + + t.Run("AggregateFunctions", func(t *testing.T) { + testAggregateFunctions(t, olap) + }) + + // Output all types and values + t.Run("AllTypesOutput", func(t *testing.T) { + testAllTypesOutput(t, olap) + }) + + // High-precision DECIMAL test (DECIMAL32, DECIMAL64, DECIMAL128) + t.Run("DecimalPrecision", func(t *testing.T) { + testDecimalPrecision(t, olap) + }) +} + +func testVarcharNotBinary(t *testing.T, olap drivers.OLAPStore) { + ctx := context.Background() + + res, err := olap.Query(ctx, &drivers.Statement{ + Query: "SELECT 'hello' AS str_col, 'world' AS str_col2", + }) + require.NoError(t, err) + defer res.Close() + + require.True(t, res.Next()) + + row := make(map[string]any) + err = res.MapScan(row) + require.NoError(t, err) + + // VARCHAR should be string, not []byte + strVal, ok := row["str_col"].(string) + require.True(t, ok, "expected string type, got %T", row["str_col"]) + require.Equal(t, "hello", strVal) +} + +func testNullHandling(t *testing.T, olap drivers.OLAPStore) { + ctx := context.Background() + + res, err := olap.Query(ctx, &drivers.Statement{ + Query: "SELECT NULL AS null_col, 'value' AS str_col", + }) + require.NoError(t, err) + defer res.Close() + + require.True(t, res.Next()) + + row := make(map[string]any) + err = res.MapScan(row) + require.NoError(t, err) + + require.Nil(t, row["null_col"]) + require.Equal(t, "value", row["str_col"]) +} + +func testNumericTypes(t *testing.T, olap drivers.OLAPStore) { + ctx := context.Background() + + res, err := olap.Query(ctx, &drivers.Statement{ + Query: "SELECT 42 AS int_col, 3.14 AS float_col, TRUE AS bool_col", + }) + require.NoError(t, err) + defer res.Close() + + require.True(t, res.Next()) + + row := make(map[string]any) + err = res.MapScan(row) + require.NoError(t, err) + + // Check types are correct (not []byte) + // StarRocks returns small integers as TINYINT/SMALLINT via MySQL protocol + intVal := row["int_col"] + _, isByte := intVal.([]byte) + require.False(t, isByte, "int_col should not be []byte, got %T", intVal) + + // Accept any integer type (int16, int32, int64) + switch intVal.(type) { + case int16, int32, int64: + // OK - valid integer type + default: + t.Errorf("expected int type, got %T", intVal) + } +} + +func testAllBasicTypes(t *testing.T, olap drivers.OLAPStore) { + ctx := context.Background() + + res, err := olap.Query(ctx, &drivers.Statement{ + Query: "SELECT * FROM test_db.all_types WHERE id = 1", + }) + require.NoError(t, err) + defer res.Close() + + require.True(t, res.Next()) + + row := make(map[string]any) + err = res.MapScan(row) + require.NoError(t, err) + + // Verify each type is correctly converted + require.Equal(t, int32(1), row["id"]) + + // Boolean - MySQL protocol returns BOOLEAN as TINYINT + // So it might be bool or int16 depending on driver behavior + boolVal := row["bool_col"] + _, isBoolByte := boolVal.([]byte) + require.False(t, isBoolByte, "bool_col should not be []byte, got %T", boolVal) + switch v := boolVal.(type) { + case bool: + require.True(t, v) + case int16: + require.Equal(t, int16(1), v) // 1 = true + default: + t.Errorf("expected bool or int16 for BOOLEAN, got %T", boolVal) + } + + // Integer types + _, ok := row["tinyint_col"].(int16) // TINYINT maps to int16 (via NullInt16) + require.True(t, ok, "expected int16 for tinyint, got %T", row["tinyint_col"]) + + _, ok = row["smallint_col"].(int16) + require.True(t, ok, "expected int16 type, got %T", row["smallint_col"]) + + _, ok = row["int_col"].(int32) + require.True(t, ok, "expected int32 type, got %T", row["int_col"]) + + _, ok = row["bigint_col"].(int64) + require.True(t, ok, "expected int64 type, got %T", row["bigint_col"]) + + // Float types + _, ok = row["float_col"].(float64) + require.True(t, ok, "expected float64 type, got %T", row["float_col"]) + + _, ok = row["double_col"].(float64) + require.True(t, ok, "expected float64 type, got %T", row["double_col"]) + + // Decimal type - stored as string to preserve precision (same as MySQL driver) + decimalVal, ok := row["decimal_col"].(string) + require.True(t, ok, "expected string type for DECIMAL, got %T", row["decimal_col"]) + require.Equal(t, "12345.6789", decimalVal, "decimal value mismatch") + + // String types - should NOT be []byte + charVal, ok := row["char_col"].(string) + require.True(t, ok, "expected string type for CHAR, got %T", row["char_col"]) + require.Contains(t, charVal, "char_val") + + varcharVal, ok := row["varchar_col"].(string) + require.True(t, ok, "expected string type for VARCHAR, got %T", row["varchar_col"]) + require.Equal(t, "varchar_value", varcharVal) + + stringVal, ok := row["string_col"].(string) + require.True(t, ok, "expected string type for STRING, got %T", row["string_col"]) + require.Equal(t, "string_value", stringVal) +} + +func testDateTimeTypes(t *testing.T, olap drivers.OLAPStore) { + ctx := context.Background() + + res, err := olap.Query(ctx, &drivers.Statement{ + Query: "SELECT date_col, datetime_col FROM test_db.all_types WHERE id = 1", + }) + require.NoError(t, err) + defer res.Close() + + require.True(t, res.Next()) + + row := make(map[string]any) + err = res.MapScan(row) + require.NoError(t, err) + + // DATE - should be string (StarRocks returns as string via MySQL protocol) + dateVal, ok := row["date_col"].(string) + require.True(t, ok, "expected string type for DATE, got %T", row["date_col"]) + require.Contains(t, dateVal, "2024-01-15") + + // DATETIME - should be string (MySQL driver returns as string without parseTime=true) + datetimeVal, ok := row["datetime_col"].(string) + require.True(t, ok, "expected string type for DATETIME, got %T", row["datetime_col"]) + require.Contains(t, datetimeVal, "2024-01-15") + require.Contains(t, datetimeVal, "10:30:00") +} + +func testStringTypes(t *testing.T, olap drivers.OLAPStore) { + ctx := context.Background() + + res, err := olap.Query(ctx, &drivers.Statement{ + Query: "SELECT char_col, varchar_col, string_col FROM test_db.all_types WHERE id = 1", + }) + require.NoError(t, err) + defer res.Close() + + require.True(t, res.Next()) + + row := make(map[string]any) + err = res.MapScan(row) + require.NoError(t, err) + + // All string types should be Go string, not []byte + for _, col := range []string{"char_col", "varchar_col", "string_col"} { + val := row[col] + _, isByte := val.([]byte) + require.False(t, isByte, "%s should not be []byte, got %T", col, val) + + _, isString := val.(string) + require.True(t, isString, "%s should be string, got %T", col, val) + } +} + +func testBinaryTypes(t *testing.T, olap drivers.OLAPStore) { + ctx := context.Background() + + res, err := olap.Query(ctx, &drivers.Statement{ + Query: "SELECT id, binary_col, blob_col FROM test_db.binary_types WHERE id = 1", + }) + require.NoError(t, err) + defer res.Close() + + require.True(t, res.Next()) + + row := make(map[string]any) + err = res.MapScan(row) + require.NoError(t, err) + + // Binary types might be []byte or base64 string depending on driver + // Just verify they're not nil for non-null values + require.NotNil(t, row["binary_col"]) + require.NotNil(t, row["blob_col"]) +} + +func testAggregateTypes(t *testing.T, olap drivers.OLAPStore) { + ctx := context.Background() + + // Query aggregate table with HLL and BITMAP + res, err := olap.Query(ctx, &drivers.Statement{ + Query: "SELECT id, dt, hll_cardinality(hll_col) as hll_count, bitmap_count(bitmap_col) as bitmap_count, count_col FROM test_db.aggregate_types WHERE id = 1", + }) + require.NoError(t, err) + defer res.Close() + + require.True(t, res.Next()) + + row := make(map[string]any) + err = res.MapScan(row) + require.NoError(t, err) + + // Verify aggregate results + require.Equal(t, int32(1), row["id"]) + require.NotNil(t, row["hll_count"]) + require.NotNil(t, row["bitmap_count"]) +} + +func testUnicodeStrings(t *testing.T, olap drivers.OLAPStore) { + ctx := context.Background() + + res, err := olap.Query(ctx, &drivers.Statement{ + Query: "SELECT ascii_col, unicode_col, emoji_col, korean_col, chinese_col, japanese_col FROM test_db.string_encoding_test WHERE id = 1", + }) + require.NoError(t, err) + defer res.Close() + + require.True(t, res.Next()) + + row := make(map[string]any) + err = res.MapScan(row) + require.NoError(t, err) + + // Verify Unicode strings are correctly handled + asciiVal, ok := row["ascii_col"].(string) + require.True(t, ok, "expected string type, got %T", row["ascii_col"]) + require.Equal(t, "Hello World", asciiVal) + + unicodeVal, ok := row["unicode_col"].(string) + require.True(t, ok, "expected string type, got %T", row["unicode_col"]) + require.Equal(t, "Héllo Wörld", unicodeVal) + + emojiVal, ok := row["emoji_col"].(string) + require.True(t, ok, "expected string type, got %T", row["emoji_col"]) + require.Equal(t, "😀🎉🚀", emojiVal) + + koreanVal, ok := row["korean_col"].(string) + require.True(t, ok, "expected string type, got %T", row["korean_col"]) + require.Equal(t, "안녕하세요", koreanVal) + + chineseVal, ok := row["chinese_col"].(string) + require.True(t, ok, "expected string type, got %T", row["chinese_col"]) + require.Equal(t, "你好世界", chineseVal) + + japaneseVal, ok := row["japanese_col"].(string) + require.True(t, ok, "expected string type, got %T", row["japanese_col"]) + require.Equal(t, "こんにちは", japaneseVal) +} + +func testJSONType(t *testing.T, olap drivers.OLAPStore) { + ctx := context.Background() + + res, err := olap.Query(ctx, &drivers.Statement{ + Query: "SELECT json_col FROM test_db.all_types WHERE id = 1", + }) + require.NoError(t, err) + defer res.Close() + + require.True(t, res.Next()) + + row := make(map[string]any) + err = res.MapScan(row) + require.NoError(t, err) + + // JSON should be returned as string + jsonVal, ok := row["json_col"].(string) + require.True(t, ok, "expected string type for JSON, got %T", row["json_col"]) + require.Contains(t, jsonVal, "key") + require.Contains(t, jsonVal, "value") +} + +// ============================================================ +// DryRun tests +// ============================================================ +func testDryRun(t *testing.T, olap drivers.OLAPStore) { + ctx := context.Background() + + // Query DryRun - valid query + t.Run("QueryDryRunValid", func(t *testing.T) { + res, err := olap.Query(ctx, &drivers.Statement{ + Query: "SELECT * FROM test_db.all_types", + DryRun: true, + }) + require.NoError(t, err) + require.Nil(t, res, "DryRun should return nil result") + }) + + // Query DryRun - invalid query + t.Run("QueryDryRunInvalid", func(t *testing.T) { + _, err := olap.Query(ctx, &drivers.Statement{ + Query: "SELECT * FROM nonexistent_table", + DryRun: true, + }) + require.Error(t, err) + }) + + // Exec DryRun + t.Run("ExecDryRun", func(t *testing.T) { + err := olap.Exec(ctx, &drivers.Statement{ + Query: "SELECT 1", + DryRun: true, + }) + require.NoError(t, err) + }) +} + +// ============================================================ +// Exec tests +// ============================================================ +func testExec(t *testing.T, olap drivers.OLAPStore) { + ctx := context.Background() + + // Basic exec + t.Run("BasicExec", func(t *testing.T) { + err := olap.Exec(ctx, &drivers.Statement{ + Query: "SELECT 1", + }) + require.NoError(t, err) + }) + + // Error case + t.Run("ExecError", func(t *testing.T) { + err := olap.Exec(ctx, &drivers.Statement{ + Query: "INVALID SQL SYNTAX", + }) + require.Error(t, err) + }) +} + +// ============================================================ +// QuerySchema tests +// ============================================================ +func testQuerySchema(t *testing.T, olap drivers.OLAPStore) { + ctx := context.Background() + + schema, err := olap.QuerySchema(ctx, "SELECT id, varchar_col, int_col FROM test_db.all_types", nil) + require.NoError(t, err) + require.NotNil(t, schema) + require.Len(t, schema.Fields, 3) + + // Verify schema types + require.Equal(t, "id", schema.Fields[0].Name) + require.Equal(t, runtimev1.Type_CODE_INT32, schema.Fields[0].Type.Code) + + require.Equal(t, "varchar_col", schema.Fields[1].Name) + require.Equal(t, runtimev1.Type_CODE_STRING, schema.Fields[1].Type.Code) + + require.Equal(t, "int_col", schema.Fields[2].Name) + require.Equal(t, runtimev1.Type_CODE_INT32, schema.Fields[2].Type.Code) +} + +// ============================================================ +// Empty result set test +// ============================================================ +func testEmptyResultSet(t *testing.T, olap drivers.OLAPStore) { + ctx := context.Background() + + res, err := olap.Query(ctx, &drivers.Statement{ + Query: "SELECT * FROM test_db.all_types WHERE id = -999", + }) + require.NoError(t, err) + defer res.Close() + + // Schema should exist + require.NotNil(t, res.Schema) + require.Greater(t, len(res.Schema.Fields), 0) + + // No rows + require.False(t, res.Next()) + require.NoError(t, res.Err()) +} + +// ============================================================ +// Multiple rows test +// ============================================================ +func testMultipleRows(t *testing.T, olap drivers.OLAPStore) { + ctx := context.Background() + + res, err := olap.Query(ctx, &drivers.Statement{ + Query: "SELECT id, varchar_col FROM test_db.all_types ORDER BY id", + }) + require.NoError(t, err) + defer res.Close() + + var rows []map[string]any + for res.Next() { + row := make(map[string]any) + err := res.MapScan(row) + require.NoError(t, err) + rows = append(rows, row) + } + require.NoError(t, res.Err()) + + require.Len(t, rows, 3, "expected 3 rows") + require.Equal(t, int32(1), rows[0]["id"]) + require.Equal(t, int32(2), rows[1]["id"]) + require.Equal(t, int32(3), rows[2]["id"]) +} + +// ============================================================ +// Boundary values test +// ============================================================ +func testBoundaryValues(t *testing.T, olap drivers.OLAPStore) { + ctx := context.Background() + + res, err := olap.Query(ctx, &drivers.Statement{ + Query: "SELECT * FROM test_db.boundary_values WHERE id = 1", + }) + require.NoError(t, err) + defer res.Close() + + require.True(t, res.Next()) + row := make(map[string]any) + err = res.MapScan(row) + require.NoError(t, err) + + // TINYINT boundary + require.Equal(t, int16(-128), row["tinyint_min"]) + require.Equal(t, int16(127), row["tinyint_max"]) + + // SMALLINT boundary + require.Equal(t, int16(-32768), row["smallint_min"]) + require.Equal(t, int16(32767), row["smallint_max"]) + + // INT boundary + require.Equal(t, int32(-2147483648), row["int_min"]) + require.Equal(t, int32(2147483647), row["int_max"]) + + // BIGINT boundary + require.Equal(t, int64(-9223372036854775808), row["bigint_min"]) + require.Equal(t, int64(9223372036854775807), row["bigint_max"]) + + // Empty string + require.Equal(t, "", row["empty_string"]) + + // Whitespace string + require.Equal(t, " ", row["whitespace_string"]) +} + +// ============================================================ +// Detailed NULL handling test +// ============================================================ +func testNullHandlingDetailed(t *testing.T, olap drivers.OLAPStore) { + ctx := context.Background() + + // Row with all columns NULL (id=3) + res, err := olap.Query(ctx, &drivers.Statement{ + Query: "SELECT * FROM test_db.all_types WHERE id = 3", + }) + require.NoError(t, err) + defer res.Close() + + require.True(t, res.Next()) + row := make(map[string]any) + err = res.MapScan(row) + require.NoError(t, err) + + // Only id is non-null + require.Equal(t, int32(3), row["id"]) + + // All other columns should be NULL + nullColumns := []string{ + "bool_col", "tinyint_col", "smallint_col", "int_col", "bigint_col", + "float_col", "double_col", "decimal_col", + "char_col", "varchar_col", "string_col", "date_col", "datetime_col", "json_col", + } + for _, col := range nullColumns { + require.Nil(t, row[col], "expected %s to be nil", col) + } +} + +// ============================================================ +// Negative values test +// ============================================================ +func testNegativeValues(t *testing.T, olap drivers.OLAPStore) { + ctx := context.Background() + + res, err := olap.Query(ctx, &drivers.Statement{ + Query: "SELECT * FROM test_db.all_types WHERE id = 2", + }) + require.NoError(t, err) + defer res.Close() + + require.True(t, res.Next()) + row := make(map[string]any) + err = res.MapScan(row) + require.NoError(t, err) + + // Negative integers + require.Equal(t, int16(-128), row["tinyint_col"]) + require.Equal(t, int16(-32768), row["smallint_col"]) + require.Equal(t, int32(-2147483648), row["int_col"]) + require.Equal(t, int64(-9223372036854775808), row["bigint_col"]) + + // Negative floats + floatVal, _ := row["float_col"].(float64) + require.Less(t, floatVal, float64(0)) + + doubleVal, _ := row["double_col"].(float64) + require.Less(t, doubleVal, float64(0)) + + // DECIMAL is returned as string to preserve precision + decimalVal, ok := row["decimal_col"].(string) + require.True(t, ok, "expected string type for DECIMAL, got %T", row["decimal_col"]) + require.True(t, len(decimalVal) > 0 && decimalVal[0] == '-', "expected negative decimal value") + + // false boolean + boolVal := row["bool_col"] + switch v := boolVal.(type) { + case bool: + require.False(t, v) + case int16: + require.Equal(t, int16(0), v) + } +} + +// ============================================================ +// Special characters test +// ============================================================ +func testSpecialCharacters(t *testing.T, olap drivers.OLAPStore) { + ctx := context.Background() + + res, err := olap.Query(ctx, &drivers.Statement{ + Query: "SELECT * FROM test_db.special_chars WHERE id = 1", + }) + require.NoError(t, err) + defer res.Close() + + require.True(t, res.Next()) + row := make(map[string]any) + err = res.MapScan(row) + require.NoError(t, err) + + // Quotes + quoteVal, ok := row["quote_col"].(string) + require.True(t, ok) + require.Contains(t, quoteVal, "'") + require.Contains(t, quoteVal, "\"") + + // Emoji + emojiVal, ok := row["emoji_col"].(string) + require.True(t, ok) + require.Contains(t, emojiVal, "😀") + + // SQL injection string (should be stored as-is) + sqlVal, ok := row["sql_injection_col"].(string) + require.True(t, ok) + require.Contains(t, sqlVal, "DROP TABLE") +} + +// ============================================================ +// Complex types test (ARRAY, MAP, STRUCT) +// ============================================================ +func testComplexTypes(t *testing.T, olap drivers.OLAPStore) { + ctx := context.Background() + + res, err := olap.Query(ctx, &drivers.Statement{ + Query: "SELECT id, array_col, map_col, struct_col FROM test_db.complex_types WHERE id = 1", + }) + require.NoError(t, err) + defer res.Close() + + require.True(t, res.Next()) + row := make(map[string]any) + err = res.MapScan(row) + require.NoError(t, err) + + // ARRAY - handled by default case (*any) + require.NotNil(t, row["array_col"]) + t.Logf("array_col type: %T, value: %v", row["array_col"], row["array_col"]) + + // MAP - handled by default case + require.NotNil(t, row["map_col"]) + t.Logf("map_col type: %T, value: %v", row["map_col"], row["map_col"]) + + // STRUCT - handled by default case + require.NotNil(t, row["struct_col"]) + t.Logf("struct_col type: %T, value: %v", row["struct_col"], row["struct_col"]) +} + +// ============================================================ +// Error cases test +// ============================================================ +func testErrorCases(t *testing.T, olap drivers.OLAPStore) { + ctx := context.Background() + + t.Run("NonexistentTable", func(t *testing.T) { + _, err := olap.Query(ctx, &drivers.Statement{ + Query: "SELECT * FROM nonexistent_db.nonexistent_table", + }) + require.Error(t, err) + }) + + t.Run("SyntaxError", func(t *testing.T) { + _, err := olap.Query(ctx, &drivers.Statement{ + Query: "SELEC * FROM test_db.all_types", + }) + require.Error(t, err) + }) + + t.Run("InvalidColumn", func(t *testing.T) { + _, err := olap.Query(ctx, &drivers.Statement{ + Query: "SELECT nonexistent_column FROM test_db.all_types", + }) + require.Error(t, err) + }) +} + +// ============================================================ +// Parameter binding test +// ============================================================ +func testParameterBinding(t *testing.T, olap drivers.OLAPStore) { + ctx := context.Background() + + res, err := olap.Query(ctx, &drivers.Statement{ + Query: "SELECT id, varchar_col FROM test_db.all_types WHERE id = ?", + Args: []any{1}, + }) + require.NoError(t, err) + defer res.Close() + + require.True(t, res.Next()) + row := make(map[string]any) + err = res.MapScan(row) + require.NoError(t, err) + + require.Equal(t, int32(1), row["id"]) + require.Equal(t, "varchar_value", row["varchar_col"]) +} + +// ============================================================ +// Schema validation test +// ============================================================ +func testSchemaValidation(t *testing.T, olap drivers.OLAPStore) { + ctx := context.Background() + + res, err := olap.Query(ctx, &drivers.Statement{ + Query: "SELECT id, bool_col, tinyint_col, smallint_col, int_col, bigint_col, float_col, double_col, decimal_col, varchar_col, date_col, datetime_col, json_col FROM test_db.all_types WHERE id = 1", + }) + require.NoError(t, err) + defer res.Close() + + schema := res.Schema + require.NotNil(t, schema) + + // Debug: Log actual schema types returned by StarRocks + t.Log("=== Actual Schema Types from StarRocks ===") + for _, field := range schema.Fields { + t.Logf("Column: %-15s Type.Code: %v", field.Name, field.Type.Code) + } + + // Also log the raw DatabaseTypeName from the driver + t.Log("=== Raw DatabaseTypeName from MySQL Driver ===") + if starrocksRes, ok := res.Rows.(*starrocksRows); ok { + for _, ct := range starrocksRes.colTypes { + t.Logf("Column: %-15s DatabaseTypeName: %s", ct.Name(), ct.DatabaseTypeName()) + } + } + + expectedTypes := map[string]runtimev1.Type_Code{ + "id": runtimev1.Type_CODE_INT32, + "bool_col": runtimev1.Type_CODE_BOOL, // Note: MySQL protocol may report as TINYINT + "tinyint_col": runtimev1.Type_CODE_INT8, + "smallint_col": runtimev1.Type_CODE_INT16, + "int_col": runtimev1.Type_CODE_INT32, + "bigint_col": runtimev1.Type_CODE_INT64, + "float_col": runtimev1.Type_CODE_FLOAT32, + "double_col": runtimev1.Type_CODE_FLOAT64, + "decimal_col": runtimev1.Type_CODE_STRING, // DECIMAL returns as string to preserve precision + "varchar_col": runtimev1.Type_CODE_STRING, + "date_col": runtimev1.Type_CODE_DATE, + "datetime_col": runtimev1.Type_CODE_TIMESTAMP, + "json_col": runtimev1.Type_CODE_JSON, + } + + for _, field := range schema.Fields { + expectedCode, exists := expectedTypes[field.Name] + if exists { + // bool_col may be reported as TINYINT (CODE_INT8) by MySQL protocol + if field.Name == "bool_col" { + if field.Type.Code != runtimev1.Type_CODE_BOOL && field.Type.Code != runtimev1.Type_CODE_INT8 { + t.Errorf("expected BOOL or INT8 for bool_col, got %v", field.Type.Code) + } + continue + } + // json_col may be reported as STRING by MySQL protocol + if field.Name == "json_col" { + if field.Type.Code != runtimev1.Type_CODE_JSON && field.Type.Code != runtimev1.Type_CODE_STRING { + t.Errorf("expected JSON or STRING for json_col, got %v", field.Type.Code) + } + continue + } + require.Equal(t, expectedCode, field.Type.Code, "type mismatch for %s", field.Name) + } + } +} + +// ============================================================ +// Aggregate functions test +// ============================================================ +func testAggregateFunctions(t *testing.T, olap drivers.OLAPStore) { + ctx := context.Background() + + t.Run("COUNT", func(t *testing.T) { + res, err := olap.Query(ctx, &drivers.Statement{ + Query: "SELECT COUNT(*) as cnt FROM test_db.all_types", + }) + require.NoError(t, err) + defer res.Close() + + require.True(t, res.Next()) + row := make(map[string]any) + err = res.MapScan(row) + require.NoError(t, err) + + cnt, ok := row["cnt"].(int64) + require.True(t, ok, "expected int64 for COUNT, got %T", row["cnt"]) + require.Equal(t, int64(3), cnt) + }) + + t.Run("SUM", func(t *testing.T) { + res, err := olap.Query(ctx, &drivers.Statement{ + Query: "SELECT SUM(tinyint_col) as sum_val FROM test_db.all_types", + }) + require.NoError(t, err) + defer res.Close() + + require.True(t, res.Next()) + row := make(map[string]any) + err = res.MapScan(row) + require.NoError(t, err) + require.NotNil(t, row["sum_val"]) + }) + + t.Run("AVG", func(t *testing.T) { + res, err := olap.Query(ctx, &drivers.Statement{ + Query: "SELECT AVG(int_col) as avg_val FROM test_db.all_types WHERE int_col IS NOT NULL", + }) + require.NoError(t, err) + defer res.Close() + + require.True(t, res.Next()) + row := make(map[string]any) + err = res.MapScan(row) + require.NoError(t, err) + + _, ok := row["avg_val"].(float64) + require.True(t, ok, "expected float64 for AVG, got %T", row["avg_val"]) + }) + + t.Run("MIN_MAX", func(t *testing.T) { + res, err := olap.Query(ctx, &drivers.Statement{ + Query: "SELECT MIN(int_col) as min_val, MAX(int_col) as max_val FROM test_db.all_types", + }) + require.NoError(t, err) + defer res.Close() + + require.True(t, res.Next()) + row := make(map[string]any) + err = res.MapScan(row) + require.NoError(t, err) + + minVal, ok := row["min_val"].(int32) + require.True(t, ok) + require.Equal(t, int32(-2147483648), minVal) + + maxVal, ok := row["max_val"].(int32) + require.True(t, ok) + require.Equal(t, int32(2147483647), maxVal) + }) +} + +// ============================================================ +// All Types Output - prints all types and actual values +// ============================================================ +func testAllTypesOutput(t *testing.T, olap drivers.OLAPStore) { + ctx := context.Background() + + t.Log("================================================================================") + t.Log(" StarRocks Type → Go Return Value Mapping") + t.Log("================================================================================") + + // 1. Basic types from all_types table + t.Log("") + t.Log("=== 1. Basic Types (test_db.all_types, id=1) ===") + t.Log("--------------------------------------------------------------------------------") + t.Logf("%-20s | %-20s | %-15s | %s", "Column", "StarRocks Type", "Go Type", "Value") + t.Log("--------------------------------------------------------------------------------") + + res, err := olap.Query(ctx, &drivers.Statement{ + Query: "SELECT * FROM test_db.all_types WHERE id = 1", + }) + require.NoError(t, err) + + // Get schema + schema := res.Schema + schemaMap := make(map[string]string) + for _, field := range schema.Fields { + schemaMap[field.Name] = field.Type.Code.String() + } + + require.True(t, res.Next()) + row := make(map[string]any) + err = res.MapScan(row) + require.NoError(t, err) + res.Close() + + // Define column order and StarRocks types + basicColumns := []struct { + name string + srType string + }{ + {"id", "INT"}, + {"bool_col", "BOOLEAN"}, + {"tinyint_col", "TINYINT"}, + {"smallint_col", "SMALLINT"}, + {"int_col", "INT"}, + {"bigint_col", "BIGINT"}, + {"largeint_col", "LARGEINT"}, + {"float_col", "FLOAT"}, + {"double_col", "DOUBLE"}, + {"decimal_col", "DECIMAL(18,4)"}, + {"char_col", "CHAR(10)"}, + {"varchar_col", "VARCHAR(255)"}, + {"string_col", "STRING"}, + {"date_col", "DATE"}, + {"datetime_col", "DATETIME"}, + {"json_col", "JSON"}, + } + + for _, col := range basicColumns { + val := row[col.name] + goType := fmt.Sprintf("%T", val) + valStr := formatValue(val) + t.Logf("%-20s | %-20s | %-15s | %s", col.name, col.srType, goType, valStr) + } + + // 2. Complex types + t.Log("") + t.Log("=== 2. Complex Types (test_db.complex_types, id=1) ===") + t.Log("--------------------------------------------------------------------------------") + t.Logf("%-20s | %-20s | %-15s | %s", "Column", "StarRocks Type", "Go Type", "Value") + t.Log("--------------------------------------------------------------------------------") + + res2, err := olap.Query(ctx, &drivers.Statement{ + Query: "SELECT id, array_col, map_col, struct_col FROM test_db.complex_types WHERE id = 1", + }) + require.NoError(t, err) + require.True(t, res2.Next()) + row2 := make(map[string]any) + err = res2.MapScan(row2) + require.NoError(t, err) + res2.Close() + + complexColumns := []struct { + name string + srType string + }{ + {"array_col", "ARRAY"}, + {"map_col", "MAP"}, + {"struct_col", "STRUCT<...>"}, + } + + for _, col := range complexColumns { + val := row2[col.name] + goType := fmt.Sprintf("%T", val) + valStr := formatValue(val) + t.Logf("%-20s | %-20s | %-15s | %s", col.name, col.srType, goType, valStr) + } + + // 3. Binary types + t.Log("") + t.Log("=== 3. Binary Types (test_db.binary_types, id=1) ===") + t.Log("--------------------------------------------------------------------------------") + t.Logf("%-20s | %-20s | %-15s | %s", "Column", "StarRocks Type", "Go Type", "Value") + t.Log("--------------------------------------------------------------------------------") + + res3, err := olap.Query(ctx, &drivers.Statement{ + Query: "SELECT id, binary_col, blob_col FROM test_db.binary_types WHERE id = 1", + }) + require.NoError(t, err) + require.True(t, res3.Next()) + row3 := make(map[string]any) + err = res3.MapScan(row3) + require.NoError(t, err) + res3.Close() + + binaryColumns := []struct { + name string + srType string + }{ + {"binary_col", "VARBINARY(255)"}, + {"blob_col", "VARBINARY(65535)"}, + } + + for _, col := range binaryColumns { + val := row3[col.name] + goType := fmt.Sprintf("%T", val) + valStr := formatValue(val) + t.Logf("%-20s | %-20s | %-15s | %s", col.name, col.srType, goType, valStr) + } + + // 4. Aggregate types (with functions) + t.Log("") + t.Log("=== 4. Aggregate Types (test_db.aggregate_types) ===") + t.Log("--------------------------------------------------------------------------------") + t.Logf("%-20s | %-20s | %-15s | %s", "Column", "StarRocks Type", "Go Type", "Value") + t.Log("--------------------------------------------------------------------------------") + + res4, err := olap.Query(ctx, &drivers.Statement{ + Query: "SELECT id, dt, hll_cardinality(hll_col) as hll_result, bitmap_count(bitmap_col) as bitmap_result, count_col FROM test_db.aggregate_types WHERE id = 1", + }) + require.NoError(t, err) + require.True(t, res4.Next()) + row4 := make(map[string]any) + err = res4.MapScan(row4) + require.NoError(t, err) + res4.Close() + + aggColumns := []struct { + name string + srType string + }{ + {"id", "INT"}, + {"dt", "DATE"}, + {"hll_result", "HLL→BIGINT"}, + {"bitmap_result", "BITMAP→BIGINT"}, + {"count_col", "BIGINT SUM"}, + } + + for _, col := range aggColumns { + val := row4[col.name] + goType := fmt.Sprintf("%T", val) + valStr := formatValue(val) + t.Logf("%-20s | %-20s | %-15s | %s", col.name, col.srType, goType, valStr) + } + + // 5. NULL values + t.Log("") + t.Log("=== 5. NULL Values (test_db.all_types, id=3) ===") + t.Log("--------------------------------------------------------------------------------") + t.Logf("%-20s | %-20s | %-15s | %s", "Column", "StarRocks Type", "Go Type", "Value") + t.Log("--------------------------------------------------------------------------------") + + res5, err := olap.Query(ctx, &drivers.Statement{ + Query: "SELECT * FROM test_db.all_types WHERE id = 3", + }) + require.NoError(t, err) + require.True(t, res5.Next()) + row5 := make(map[string]any) + err = res5.MapScan(row5) + require.NoError(t, err) + res5.Close() + + nullColumns := []struct { + name string + srType string + }{ + {"bool_col", "BOOLEAN"}, + {"int_col", "INT"}, + {"varchar_col", "VARCHAR"}, + {"date_col", "DATE"}, + {"json_col", "JSON"}, + } + + for _, col := range nullColumns { + val := row5[col.name] + goType := fmt.Sprintf("%T", val) + valStr := formatValue(val) + t.Logf("%-20s | %-20s | %-15s | %s", col.name, col.srType, goType, valStr) + } + + // 6. Boundary values + t.Log("") + t.Log("=== 6. Boundary Values (test_db.boundary_values) ===") + t.Log("--------------------------------------------------------------------------------") + t.Logf("%-20s | %-20s | %-15s | %s", "Column", "StarRocks Type", "Go Type", "Value") + t.Log("--------------------------------------------------------------------------------") + + res6, err := olap.Query(ctx, &drivers.Statement{ + Query: "SELECT * FROM test_db.boundary_values WHERE id = 1", + }) + require.NoError(t, err) + require.True(t, res6.Next()) + row6 := make(map[string]any) + err = res6.MapScan(row6) + require.NoError(t, err) + res6.Close() + + boundaryColumns := []struct { + name string + srType string + }{ + {"tinyint_min", "TINYINT"}, + {"tinyint_max", "TINYINT"}, + {"smallint_min", "SMALLINT"}, + {"smallint_max", "SMALLINT"}, + {"int_min", "INT"}, + {"int_max", "INT"}, + {"bigint_min", "BIGINT"}, + {"bigint_max", "BIGINT"}, + {"empty_string", "VARCHAR"}, + {"whitespace_string", "VARCHAR"}, + } + + for _, col := range boundaryColumns { + val := row6[col.name] + goType := fmt.Sprintf("%T", val) + valStr := formatValue(val) + t.Logf("%-20s | %-20s | %-15s | %s", col.name, col.srType, goType, valStr) + } + + // 7. Unicode strings + t.Log("") + t.Log("=== 7. Unicode/Encoding (test_db.string_encoding_test) ===") + t.Log("--------------------------------------------------------------------------------") + t.Logf("%-20s | %-20s | %-15s | %s", "Column", "StarRocks Type", "Go Type", "Value") + t.Log("--------------------------------------------------------------------------------") + + res7, err := olap.Query(ctx, &drivers.Statement{ + Query: "SELECT * FROM test_db.string_encoding_test WHERE id = 1", + }) + require.NoError(t, err) + require.True(t, res7.Next()) + row7 := make(map[string]any) + err = res7.MapScan(row7) + require.NoError(t, err) + res7.Close() + + unicodeColumns := []struct { + name string + srType string + }{ + {"ascii_col", "VARCHAR (ASCII)"}, + {"unicode_col", "VARCHAR (Unicode)"}, + {"emoji_col", "VARCHAR (Emoji)"}, + {"korean_col", "VARCHAR (Korean)"}, + {"chinese_col", "VARCHAR (Chinese)"}, + {"japanese_col", "VARCHAR (Japanese)"}, + } + + for _, col := range unicodeColumns { + val := row7[col.name] + goType := fmt.Sprintf("%T", val) + valStr := formatValue(val) + t.Logf("%-20s | %-20s | %-15s | %s", col.name, col.srType, goType, valStr) + } + + // 8. Special characters + t.Log("") + t.Log("=== 8. Special Characters (test_db.special_chars) ===") + t.Log("--------------------------------------------------------------------------------") + t.Logf("%-20s | %-20s | %-15s | %s", "Column", "StarRocks Type", "Go Type", "Value") + t.Log("--------------------------------------------------------------------------------") + + res8, err := olap.Query(ctx, &drivers.Statement{ + Query: "SELECT * FROM test_db.special_chars WHERE id = 1", + }) + require.NoError(t, err) + require.True(t, res8.Next()) + row8 := make(map[string]any) + err = res8.MapScan(row8) + require.NoError(t, err) + res8.Close() + + specialColumns := []struct { + name string + srType string + }{ + {"newline_col", "VARCHAR"}, + {"tab_col", "VARCHAR"}, + {"quote_col", "VARCHAR"}, + {"emoji_col", "VARCHAR"}, + {"sql_injection_col", "VARCHAR"}, + } + + for _, col := range specialColumns { + val := row8[col.name] + goType := fmt.Sprintf("%T", val) + valStr := formatValue(val) + t.Logf("%-20s | %-20s | %-15s | %s", col.name, col.srType, goType, valStr) + } + + // Summary table + t.Log("") + t.Log("================================================================================") + t.Log(" Type Mapping Summary") + t.Log("================================================================================") + t.Logf("%-20s | %-20s | %-15s", "StarRocks Type", "Schema Code", "Go Return Type") + t.Log("--------------------------------------------------------------------------------") + + summaryTypes := []struct { + srType string + schemaCode string + goType string + }{ + {"BOOLEAN", "CODE_INT8", "int16"}, + {"TINYINT", "CODE_INT8", "int16"}, + {"SMALLINT", "CODE_INT16", "int16"}, + {"INT", "CODE_INT32", "int32"}, + {"BIGINT", "CODE_INT64", "int64"}, + {"LARGEINT", "CODE_INT128", "string (>64bit auto)"}, + {"FLOAT", "CODE_FLOAT32", "float64"}, + {"DOUBLE", "CODE_FLOAT64", "float64"}, + {"DECIMAL", "CODE_STRING", "string (precision)"}, + {"CHAR/VARCHAR/STRING", "CODE_STRING", "string"}, + {"DATE", "CODE_DATE", "string"}, + {"DATETIME", "CODE_TIMESTAMP", "string"}, + {"JSON", "CODE_STRING", "string"}, + {"ARRAY", "CODE_ARRAY", "string (JSON)"}, + {"MAP", "CODE_MAP", "string (JSON)"}, + {"STRUCT<...>", "CODE_STRUCT", "string (JSON)"}, + {"VARBINARY", "CODE_STRING", "string"}, + {"HLL", "N/A", "use hll_cardinality()"}, + {"BITMAP", "N/A", "use bitmap_count()"}, + {"NULL", "N/A", ""}, + } + + for _, s := range summaryTypes { + t.Logf("%-20s | %-20s | %-15s", s.srType, s.schemaCode, s.goType) + } + + t.Log("================================================================================") +} + +// formatValue formats a value for display, truncating long strings +func formatValue(val any) string { + if val == nil { + return "" + } + s := fmt.Sprintf("%v", val) + if len(s) > 50 { + return s[:47] + "..." + } + return s +} + +// ============================================================ +// High-precision DECIMAL test (DECIMAL32, DECIMAL64, DECIMAL128) +// ============================================================ +func testDecimalPrecision(t *testing.T, olap drivers.OLAPStore) { + ctx := context.Background() + + t.Log("=== Testing High-Precision DECIMAL Types ===") + t.Log("StarRocks internal types based on precision:") + t.Log(" DECIMAL(1-9, S) → DECIMAL32") + t.Log(" DECIMAL(10-18, S) → DECIMAL64") + t.Log(" DECIMAL(19-38, S) → DECIMAL128") + t.Log("") + + res, err := olap.Query(ctx, &drivers.Statement{ + Query: "SELECT * FROM test_db.decimal_precision_test ORDER BY id", + }) + require.NoError(t, err) + defer res.Close() + + // Log the DatabaseTypeName for each column + if starrocksRes, ok := res.Rows.(*starrocksRows); ok { + t.Log("=== Raw DatabaseTypeName from MySQL Driver ===") + for _, ct := range starrocksRes.colTypes { + t.Logf("Column: %-15s DatabaseTypeName: %s", ct.Name(), ct.DatabaseTypeName()) + } + } + + t.Log("") + t.Log("=== Values (returned as string to preserve precision) ===") + t.Logf("%-5s | %-20s | %-25s | %s", "ID", "DECIMAL32 (9,4)", "DECIMAL64 (18,6)", "DECIMAL128 (38,10)") + t.Log("------+----------------------+---------------------------+------------------------------------------") + + rowNum := 0 + for res.Next() { + row := make(map[string]any) + err := res.MapScan(row) + require.NoError(t, err) + rowNum++ + + id := row["id"] + d32 := row["decimal32_col"] + d64 := row["decimal64_col"] + d128 := row["decimal128_col"] + + // All DECIMAL types should return as string + d32Str, ok := d32.(string) + require.True(t, ok, "DECIMAL32 should be string, got %T", d32) + + d64Str, ok := d64.(string) + require.True(t, ok, "DECIMAL64 should be string, got %T", d64) + + d128Str, ok := d128.(string) + require.True(t, ok, "DECIMAL128 should be string, got %T", d128) + + t.Logf("%-5v | %-20s | %-25s | %s", id, d32Str, d64Str, d128Str) + + // Verify high-precision DECIMAL128 preserves all digits + if rowNum == 1 { + // 12345678901234567890123456.7890123456 - 26 digits before decimal, 10 after + require.Contains(t, d128Str, "12345678901234567890123456") + require.Contains(t, d128Str, "7890123456") + } + } + + t.Log("") + t.Log("=== Precision Preservation Test ===") + t.Log("DECIMAL128 value: 12345678901234567890123456.7890123456") + t.Log("If this were float64, precision would be lost:") + t.Log(" float64 max precision: ~15-17 significant digits") + t.Log(" DECIMAL128 has 36 significant digits → string preserves all") +} diff --git a/runtime/drivers/starrocks/starrocks_test.go b/runtime/drivers/starrocks/starrocks_test.go index 3e842680549..8781fff2095 100644 --- a/runtime/drivers/starrocks/starrocks_test.go +++ b/runtime/drivers/starrocks/starrocks_test.go @@ -136,17 +136,51 @@ func TestDatabaseTypeToRuntimeType(t *testing.T) { expected string expectErr bool }{ + // Boolean types {"BOOLEAN", "CODE_BOOL", false}, + {"BOOL", "CODE_BOOL", false}, + + // Integer types + {"TINYINT", "CODE_INT8", false}, + {"SMALLINT", "CODE_INT16", false}, {"INT", "CODE_INT32", false}, + {"INTEGER", "CODE_INT32", false}, {"BIGINT", "CODE_INT64", false}, + {"LARGEINT", "CODE_INT128", false}, + + // Floating point types + {"FLOAT", "CODE_FLOAT32", false}, {"DOUBLE", "CODE_FLOAT64", false}, + {"DECIMAL(10,2)", "CODE_STRING", false}, // DECIMAL returns string for precision + + // String types + {"CHAR(10)", "CODE_STRING", false}, {"VARCHAR(255)", "CODE_STRING", false}, - {"DATETIME", "CODE_TIMESTAMP", false}, + {"STRING", "CODE_STRING", false}, + {"TEXT", "CODE_STRING", false}, + + // Binary types (same as MySQL - returns CODE_STRING) + {"BINARY", "CODE_STRING", false}, + {"VARBINARY", "CODE_STRING", false}, + {"BLOB", "CODE_STRING", false}, + + // Date/Time types {"DATE", "CODE_DATE", false}, + {"DATETIME", "CODE_TIMESTAMP", false}, + {"TIMESTAMP", "CODE_TIMESTAMP", false}, + + // Semi-structured types {"JSON", "CODE_JSON", false}, - {"DECIMAL(10,2)", "CODE_DECIMAL", false}, + {"JSONB", "CODE_JSON", false}, {"ARRAY", "CODE_ARRAY", false}, - {"UNKNOWN_TYPE", "", true}, // unsupported type returns error + {"MAP", "CODE_MAP", false}, + {"STRUCT", "CODE_STRUCT", false}, + + // Unsupported types (aggregate-only types) + {"HLL", "", true}, + {"BITMAP", "", true}, + {"PERCENTILE", "", true}, + {"UNKNOWN_TYPE", "", true}, } for _, tt := range tests { diff --git a/runtime/drivers/starrocks/teststarrocks/testdata/init.sql b/runtime/drivers/starrocks/teststarrocks/testdata/init.sql new file mode 100644 index 00000000000..15d41a8e897 --- /dev/null +++ b/runtime/drivers/starrocks/teststarrocks/testdata/init.sql @@ -0,0 +1,212 @@ +-- StarRocks Test Database Initialization +-- Version: 4.0.3 +-- This script creates test tables with all supported StarRocks data types + +CREATE DATABASE IF NOT EXISTS test_db; +USE test_db; + +-- Table 1: All basic data types (Duplicate Key Model) +CREATE TABLE IF NOT EXISTS all_types ( + -- Primary key + id INT NOT NULL, + + -- Numeric types + bool_col BOOLEAN, + tinyint_col TINYINT, + smallint_col SMALLINT, + int_col INT, + bigint_col BIGINT, + largeint_col LARGEINT, + float_col FLOAT, + double_col DOUBLE, + decimal_col DECIMAL(18, 4), + + -- String types + char_col CHAR(10), + varchar_col VARCHAR(255), + string_col STRING, + + -- Date/Time types + date_col DATE, + datetime_col DATETIME, + + -- Semi-structured types + json_col JSON, + array_col ARRAY, + map_col MAP, + struct_col STRUCT +) +DUPLICATE KEY(id) +DISTRIBUTED BY HASH(id) BUCKETS 1 +PROPERTIES ("replication_num" = "1"); + +-- Table 2: Aggregate types (Aggregate Key Model) +-- HLL, BITMAP, PERCENTILE are only available in aggregate tables +CREATE TABLE IF NOT EXISTS aggregate_types ( + -- Key column + id INT NOT NULL, + dt DATE NOT NULL, + + -- Aggregate columns with special types + hll_col HLL HLL_UNION, + bitmap_col BITMAP BITMAP_UNION, + count_col BIGINT SUM DEFAULT "0" +) +AGGREGATE KEY(id, dt) +DISTRIBUTED BY HASH(id) BUCKETS 1 +PROPERTIES ("replication_num" = "1"); + +-- Table 3: Binary type test (StarRocks 4.0+) +CREATE TABLE IF NOT EXISTS binary_types ( + id INT NOT NULL, + binary_col VARBINARY(255), + blob_col VARBINARY(65535) +) +DUPLICATE KEY(id) +DISTRIBUTED BY HASH(id) BUCKETS 1 +PROPERTIES ("replication_num" = "1"); + +-- Insert test data into all_types +INSERT INTO all_types VALUES +(1, true, 127, 32767, 2147483647, 9223372036854775807, 170141183460469231731687303715884105727, + 3.14, 3.141592653589793, 12345.6789, + 'char_val', 'varchar_value', 'string_value', + '2024-01-15', '2024-01-15 10:30:00', + '{"key": "value", "num": 123}', + [1, 2, 3, 4, 5], + map{"key1": 1, "key2": 2}, + named_struct("name", "John", "age", 30)), +(2, false, -128, -32768, -2147483648, -9223372036854775808, -170141183460469231731687303715884105728, + -3.14, -3.141592653589793, -12345.6789, + 'char_2', 'varchar_2', 'string_2', + '2024-06-20', '2024-06-20 15:45:30', + '{"nested": {"array": [1,2,3]}}', + [10, 20, 30], + map{"a": 100, "b": 200}, + named_struct("name", "Jane", "age", 25)), +(3, NULL, NULL, NULL, NULL, NULL, NULL, + NULL, NULL, NULL, + NULL, NULL, NULL, + NULL, NULL, + NULL, + NULL, + NULL, + NULL); + +-- Insert test data into aggregate_types +INSERT INTO aggregate_types VALUES +(1, '2024-01-01', hll_hash('user1'), to_bitmap(100), 10), +(1, '2024-01-01', hll_hash('user2'), to_bitmap(101), 20), +(2, '2024-01-02', hll_hash('user3'), to_bitmap(200), 30); + +-- Insert test data into binary_types +INSERT INTO binary_types VALUES +(1, x'48454C4C4F', x'576F726C64'), +(2, x'0102030405', x'AABBCCDDEE'), +(3, NULL, NULL); + +-- Table 4: String encoding test (UTF-8, special characters) +CREATE TABLE IF NOT EXISTS string_encoding_test ( + id INT NOT NULL, + ascii_col VARCHAR(255), + unicode_col VARCHAR(255), + emoji_col VARCHAR(255), + korean_col VARCHAR(255), + chinese_col VARCHAR(255), + japanese_col VARCHAR(255) +) +DUPLICATE KEY(id) +DISTRIBUTED BY HASH(id) BUCKETS 1 +PROPERTIES ("replication_num" = "1"); + +INSERT INTO string_encoding_test VALUES +(1, 'Hello World', 'Héllo Wörld', '😀🎉🚀', '안녕하세요', '你好世界', 'こんにちは'), +(2, 'Test 123', 'Tëst 456', '👍👎', '테스트', '测试', 'テスト'); + +-- Table 5: Complex types (ARRAY, MAP, STRUCT) +CREATE TABLE IF NOT EXISTS complex_types ( + id INT NOT NULL, + array_col ARRAY, + map_col MAP, + struct_col STRUCT +) +DUPLICATE KEY(id) +DISTRIBUTED BY HASH(id) BUCKETS 1 +PROPERTIES ("replication_num" = "1"); + +INSERT INTO complex_types VALUES +(1, [1, 2, 3], map{'a': 1, 'b': 2}, row('John', 30)), +(2, [4, 5], map{'c': 3}, row('Jane', 25)), +(3, NULL, NULL, NULL); + +-- Table 6: Boundary values test +CREATE TABLE IF NOT EXISTS boundary_values ( + id INT NOT NULL, + tinyint_min TINYINT, + tinyint_max TINYINT, + smallint_min SMALLINT, + smallint_max SMALLINT, + int_min INT, + int_max INT, + bigint_min BIGINT, + bigint_max BIGINT, + empty_string VARCHAR(255), + whitespace_string VARCHAR(255) +) +DUPLICATE KEY(id) +DISTRIBUTED BY HASH(id) BUCKETS 1 +PROPERTIES ("replication_num" = "1"); + +INSERT INTO boundary_values VALUES +(1, -128, 127, -32768, 32767, -2147483648, 2147483647, -9223372036854775808, 9223372036854775807, '', ' '); + +-- Table 7: Special characters test +CREATE TABLE IF NOT EXISTS special_chars ( + id INT NOT NULL, + newline_col VARCHAR(255), + tab_col VARCHAR(255), + quote_col VARCHAR(255), + emoji_col VARCHAR(255), + sql_injection_col VARCHAR(255) +) +DUPLICATE KEY(id) +DISTRIBUTED BY HASH(id) BUCKETS 1 +PROPERTIES ("replication_num" = "1"); + +INSERT INTO special_chars VALUES +(1, 'line1\nline2', 'col1\tcol2', 'it''s a "test"', '😀🎉', 'SELECT * FROM users; DROP TABLE--'); + +-- Table 8: High-precision DECIMAL test (DECIMAL32, DECIMAL64, DECIMAL128) +-- StarRocks uses different internal types based on precision: +-- - DECIMAL(1-9, S) → DECIMAL32 +-- - DECIMAL(10-18, S) → DECIMAL64 +-- - DECIMAL(19-38, S) → DECIMAL128 +CREATE TABLE IF NOT EXISTS decimal_precision_test ( + id INT NOT NULL, + decimal32_col DECIMAL(9, 4), + decimal64_col DECIMAL(18, 6), + decimal128_col DECIMAL(38, 10) +) +DUPLICATE KEY(id) +DISTRIBUTED BY HASH(id) BUCKETS 1 +PROPERTIES ("replication_num" = "1"); + +INSERT INTO decimal_precision_test VALUES +(1, 12345.6789, 123456789012.345678, 12345678901234567890123456.7890123456), +(2, -99999.9999, -999999999999.999999, -99999999999999999999999999.9999999999), +(3, 0.0001, 0.000001, 0.0000000001); + +-- Table 9: Ad Bids table for metricsview tests +-- This table mirrors the structure used in other OLAP driver tests (ClickHouse, DuckDB) +CREATE TABLE IF NOT EXISTS ad_bids ( + id INT NOT NULL, + timestamp DATETIME NOT NULL, + publisher VARCHAR(255), + domain VARCHAR(255), + bid_price DOUBLE +) +DUPLICATE KEY(id) +DISTRIBUTED BY HASH(id) BUCKETS 1 +PROPERTIES ("replication_num" = "1"); + +-- Ad bids data is loaded from AdBids.csv.gz via LOAD DATA LOCAL INFILE in teststarrocks.go diff --git a/runtime/drivers/starrocks/teststarrocks/teststarrocks.go b/runtime/drivers/starrocks/teststarrocks/teststarrocks.go new file mode 100644 index 00000000000..a1b2529cd24 --- /dev/null +++ b/runtime/drivers/starrocks/teststarrocks/teststarrocks.go @@ -0,0 +1,281 @@ +package teststarrocks + +import ( + "compress/gzip" + "context" + "database/sql" + "encoding/json" + "fmt" + "io" + "net/http" + "os" + "path/filepath" + "runtime" + "strings" + "time" + + "github.com/docker/go-connections/nat" + "github.com/stretchr/testify/require" + "github.com/testcontainers/testcontainers-go" + tcwait "github.com/testcontainers/testcontainers-go/wait" + + _ "github.com/go-sql-driver/mysql" // MySQL driver for database/sql +) + +const ( + // StarRocksVersion is the StarRocks version used for testing + StarRocksVersion = "4.0.3" + // StarRocksImage is the Docker image for StarRocks all-in-one container + StarRocksImage = "starrocks/allin1-ubuntu:" + StarRocksVersion +) + +// TestingT satisfies both *testing.T and *testing.B. +type TestingT interface { + Name() string + TempDir() string + FailNow() + Errorf(format string, args ...interface{}) + Cleanup(f func()) +} + +// StarRocksInfo contains connection info for a StarRocks container +type StarRocksInfo struct { + DSN string // MySQL protocol DSN (port 9030) + FEHTTPAddr string // FE HTTP address for Stream Load (port 8030) + BEHTTPAddr string // BE HTTP address for Stream Load redirect (port 8040) +} + +// Start starts a StarRocks all-in-one container for testing. +// It returns connection info for the container. +// The container is automatically terminated when the test ends. +func Start(t TestingT) StarRocksInfo { + ctx := context.Background() + + req := testcontainers.ContainerRequest{ + Image: StarRocksImage, + ExposedPorts: []string{"9030/tcp", "8030/tcp", "8040/tcp"}, + WaitingFor: tcwait.ForAll( + tcwait.ForListeningPort("9030/tcp"), + tcwait.ForListeningPort("8030/tcp"), + tcwait.ForListeningPort("8040/tcp"), + tcwait.ForLog("Enjoy the journey to StarRocks"), + ).WithDeadline(5 * time.Minute), + } + + container, err := testcontainers.GenericContainer(ctx, testcontainers.GenericContainerRequest{ + ContainerRequest: req, + Started: true, + }) + require.NoError(t, err) + + t.Cleanup(func() { + err := container.Terminate(ctx) + require.NoError(t, err) + }) + + host, err := container.Host(ctx) + require.NoError(t, err) + + mysqlPort, err := container.MappedPort(ctx, nat.Port("9030/tcp")) + require.NoError(t, err) + + feHTTPPort, err := container.MappedPort(ctx, nat.Port("8030/tcp")) + require.NoError(t, err) + + beHTTPPort, err := container.MappedPort(ctx, nat.Port("8040/tcp")) + require.NoError(t, err) + + return StarRocksInfo{ + DSN: fmt.Sprintf("root:@tcp(%s:%s)/?parseTime=true&loc=UTC", host, mysqlPort.Port()), + FEHTTPAddr: fmt.Sprintf("%s:%s", host, feHTTPPort.Port()), + BEHTTPAddr: fmt.Sprintf("%s:%s", host, beHTTPPort.Port()), + } +} + +// StartWithData starts a StarRocks container and initializes it with test tables. +// Returns DSN for connecting to the container. +func StartWithData(t TestingT) string { + info := Start(t) + + // Wait for StarRocks to be fully ready + waitForStarRocks(t, info.DSN) + + // Initialize test database and tables from init.sql + initTestData(t, info.DSN) + + // Load ad_bids data from CSV via Stream Load + loadAdBidsData(t, info.FEHTTPAddr, info.BEHTTPAddr) + + return info.DSN +} + +// waitForStarRocks waits for StarRocks to be ready to accept queries +func waitForStarRocks(t TestingT, dsn string) { + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) + defer cancel() + + db, err := sql.Open("mysql", dsn) + require.NoError(t, err) + defer db.Close() + + // Wait until we can execute a simple query + ticker := time.NewTicker(2 * time.Second) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + require.Fail(t, "timeout waiting for StarRocks to be ready") + return + case <-ticker.C: + _, err := db.ExecContext(ctx, "SELECT 1") + if err == nil { + return + } + } + } +} + +// initTestData initializes test database and tables from init.sql +func initTestData(t TestingT, dsn string) { + db, err := sql.Open("mysql", dsn) + require.NoError(t, err) + defer db.Close() + + // Read init.sql from testdata + _, currentFile, _, _ := runtime.Caller(0) + initSQLPath := filepath.Join(filepath.Dir(currentFile), "testdata", "init.sql") + + content, err := os.ReadFile(initSQLPath) + require.NoError(t, err, "failed to read init.sql") + + // Parse and execute SQL statements + statements := parseSQLStatements(string(content)) + for _, stmt := range statements { + _, err := db.Exec(stmt) + if err != nil { + // DDL with "IF NOT EXISTS" may have benign failures + // DML (INSERT) failures are more serious but may happen on re-run + isDDL := strings.HasPrefix(strings.ToUpper(stmt), "CREATE") || + strings.HasPrefix(strings.ToUpper(stmt), "USE") + if isDDL { + // DDL errors are usually benign (already exists, etc.) + continue + } + // Log DML errors but continue - data may already exist + stmtPreview := stmt + if len(stmtPreview) > 100 { + stmtPreview = stmtPreview[:100] + "..." + } + t.Errorf("Warning executing statement: %v\nStatement: %s", err, stmtPreview) + } + } +} + +// parseSQLStatements parses SQL file content into individual statements. +// Handles comments and multi-line statements. +func parseSQLStatements(content string) []string { + var statements []string + var current strings.Builder + + lines := strings.Split(content, "\n") + for _, line := range lines { + trimmed := strings.TrimSpace(line) + + // Skip empty lines and comments + if trimmed == "" || strings.HasPrefix(trimmed, "--") { + continue + } + + current.WriteString(line) + current.WriteString("\n") + + // Check if statement ends with semicolon + if strings.HasSuffix(trimmed, ";") { + stmt := strings.TrimSpace(current.String()) + // Remove trailing semicolon for execution + stmt = strings.TrimSuffix(stmt, ";") + stmt = strings.TrimSpace(stmt) + if stmt != "" { + statements = append(statements, stmt) + } + current.Reset() + } + } + + // Handle any remaining statement without semicolon + if remaining := strings.TrimSpace(current.String()); remaining != "" { + statements = append(statements, remaining) + } + + return statements +} + +// streamLoadResponse represents the JSON response from StarRocks Stream Load +type streamLoadResponse struct { + Status string `json:"Status"` + Msg string `json:"Message"` +} + +// loadAdBidsData loads ad_bids data from CSV file using StarRocks Stream Load API +func loadAdBidsData(t TestingT, feHTTPAddr, beHTTPAddr string) { + // Find the AdBids.csv.gz file + _, currentFile, _, _ := runtime.Caller(0) + // Go up from teststarrocks -> starrocks -> drivers -> runtime -> testruntime/testdata/ad_bids/data + csvGzPath := filepath.Join(filepath.Dir(currentFile), "..", "..", "..", "testruntime", "testdata", "ad_bids", "data", "AdBids.csv.gz") + + // Open and decompress the gzip file + gzFile, err := os.Open(csvGzPath) + require.NoError(t, err, "failed to open AdBids.csv.gz") + defer gzFile.Close() + + gzReader, err := gzip.NewReader(gzFile) + require.NoError(t, err, "failed to create gzip reader") + defer gzReader.Close() + + // Read decompressed CSV data into memory + csvData, err := io.ReadAll(gzReader) + require.NoError(t, err, "failed to read CSV data") + + // Create HTTP request for Stream Load + url := fmt.Sprintf("http://%s/api/test_db/ad_bids/_stream_load", feHTTPAddr) + req, err := http.NewRequest(http.MethodPut, url, strings.NewReader(string(csvData))) + require.NoError(t, err, "failed to create Stream Load request") + + // Set required headers for Stream Load + req.Header.Set("Expect", "100-continue") + req.Header.Set("column_separator", ",") + req.Header.Set("skip_header", "1") // Skip CSV header row + // Use NULLIF to convert empty strings to NULL for publisher and domain columns (matches DuckDB behavior) + req.Header.Set("columns", "id, timestamp, tmp_publisher, tmp_domain, bid_price, publisher=NULLIF(tmp_publisher, ''), domain=NULLIF(tmp_domain, '')") + req.SetBasicAuth("root", "") + + // Create HTTP client with custom redirect policy + // StarRocks FE redirects to BE for Stream Load, but the redirect URL contains + // the internal container address. We need to rewrite it to the mapped host port. + client := &http.Client{ + Timeout: 2 * time.Minute, + CheckRedirect: func(req *http.Request, via []*http.Request) error { + // Rewrite the redirect URL to use the correct BE address + req.URL.Host = beHTTPAddr + // Preserve auth header on redirect (like curl --location-trusted) + if len(via) > 0 { + req.SetBasicAuth("root", "") + } + return nil + }, + } + resp, err := client.Do(req) + require.NoError(t, err, "failed to execute Stream Load request") + defer resp.Body.Close() + + // Parse response + body, err := io.ReadAll(resp.Body) + require.NoError(t, err, "failed to read Stream Load response") + + var result streamLoadResponse + err = json.Unmarshal(body, &result) + require.NoError(t, err, "failed to parse Stream Load response: %s", string(body)) + + require.Equal(t, "Success", result.Status, "Stream Load failed: %s", result.Msg) +} diff --git a/runtime/metricsview/ast.go b/runtime/metricsview/ast.go index 759e12ac48f..28cec6e18bb 100644 --- a/runtime/metricsview/ast.go +++ b/runtime/metricsview/ast.go @@ -504,9 +504,16 @@ func (a *AST) ResolveMeasure(qm Measure, visible bool) (*runtimev1.MetricsViewSp return nil, err } + // StarRocks returns DECIMAL for division, which gets mapped to string. + // Cast to DOUBLE for consistent numeric handling across all dialects. + expr := fmt.Sprintf("%s/%#f", a.Dialect.EscapeIdentifier(m.Name), *qm.Compute.PercentOfTotal.Total) + if a.Dialect == drivers.DialectStarRocks { + expr = fmt.Sprintf("CAST(%s AS DOUBLE)", expr) + } + return &runtimev1.MetricsViewSpec_Measure{ Name: qm.Name, - Expression: fmt.Sprintf("%s/%#f", a.Dialect.EscapeIdentifier(m.Name), *qm.Compute.PercentOfTotal.Total), + Expression: expr, Type: runtimev1.MetricsViewSpec_MEASURE_TYPE_DERIVED, ReferencedMeasures: []string{qm.Compute.PercentOfTotal.Measure}, DisplayName: fmt.Sprintf("%s (Σ%%)", m.DisplayName), diff --git a/runtime/queries/column_time_range.go b/runtime/queries/column_time_range.go index dc5caa54485..9e9a0d4bed0 100644 --- a/runtime/queries/column_time_range.go +++ b/runtime/queries/column_time_range.go @@ -121,6 +121,15 @@ func (q *ColumnTimeRange) resolveDuckDBAndClickhouse(ctx context.Context, olap d } func (q *ColumnTimeRange) resolveStarRocks(ctx context.Context, olap drivers.OLAPStore, priority int) error { + // If schema is not provided, look up the table to get the correct schema + if q.DatabaseSchema == "" { + table, err := olap.InformationSchema().Lookup(ctx, q.Database, "", q.TableName) + if err != nil { + return fmt.Errorf("failed to lookup table %q: %w", q.TableName, err) + } + q.DatabaseSchema = table.DatabaseSchema + } + rangeSQL := fmt.Sprintf( "SELECT min(%[1]s) as \"min\", max(%[1]s) as \"max\" FROM %[2]s", olap.Dialect().EscapeIdentifier(q.ColumnName), diff --git a/runtime/queries/metricsview_aggregation_test.go b/runtime/queries/metricsview_aggregation_test.go index 42a11751746..00459312c89 100644 --- a/runtime/queries/metricsview_aggregation_test.go +++ b/runtime/queries/metricsview_aggregation_test.go @@ -65,6 +65,46 @@ func TestMetricViewAggregationAgainstClickHouse(t *testing.T) { }) } +func TestMetricViewAggregationAgainstStarRocks(t *testing.T) { + testmode.Expensive(t) + + rt, instanceID := testruntime.NewInstanceWithStarRocksProject(t) + t.Run("testMetricsViewsAggregation", func(t *testing.T) { testMetricsViewsAggregation(t, rt, instanceID) }) + t.Run("testMetricsViewsAggregation_no_limit", func(t *testing.T) { testMetricsViewsAggregation_no_limit(t, rt, instanceID) }) + t.Run("testMetricsViewAggregation_measure_filters", func(t *testing.T) { testMetricsViewAggregation_measure_filters(t, rt, instanceID) }) + t.Run("testMetricsViewsAggregation_timezone", func(t *testing.T) { testMetricsViewsAggregation_timezone(t, rt, instanceID) }) + t.Run("testMetricsViewsAggregation_filter", func(t *testing.T) { testMetricsViewsAggregation_filter(t, rt, instanceID) }) + t.Run("testMetricsViewsAggregation_filter_with_timestamp", func(t *testing.T) { testMetricsViewsAggregation_filter_with_timestamp(t, rt, instanceID) }) + t.Run("testMetricsViewsAggregation_filter_2dims", func(t *testing.T) { testMetricsViewsAggregation_filter_2dims(t, rt, instanceID) }) + t.Run("testMetricsViewsAggregation_having_gt", func(t *testing.T) { testMetricsViewsAggregation_having_gt(t, rt, instanceID) }) + t.Run("testMetricsViewsAggregation_having_same_name", func(t *testing.T) { testMetricsViewsAggregation_having_same_name(t, rt, instanceID) }) + t.Run("testMetricsViewsAggregation_having", func(t *testing.T) { testMetricsViewsAggregation_having(t, rt, instanceID) }) + t.Run("testMetricsViewsAggregation_where", func(t *testing.T) { testMetricsViewsAggregation_where(t, rt, instanceID) }) + t.Run("testMetricsViewsAggregation_whereAndSQLBoth", func(t *testing.T) { testMetricsViewsAggregation_whereAndSQLBoth(t, rt, instanceID) }) + t.Run("testMetricsViewsAggregation_filter_having_measure", func(t *testing.T) { testMetricsViewsAggregation_filter_having_measure(t, rt, instanceID) }) + t.Run("testMetricsViewsAggregation_filter_with_where_and_having_measure", func(t *testing.T) { + testMetricsViewsAggregation_filter_with_where_and_having_measure(t, rt, instanceID) + }) + t.Run("testMetricsViewsAggregation_2time_aggregations", func(t *testing.T) { testMetricsViewsAggregation_2time_aggregations(t, rt, instanceID) }) + t.Run("testMetricsViewsAggregation_comparison_no_time_dim", func(t *testing.T) { testMetricsViewsAggregation_comparison_no_time_dim(t, rt, instanceID) }) + t.Run("testMetricsViewsAggregation_comparison_no_dims", func(t *testing.T) { testMetricsViewsAggregation_comparison_no_dims(t, rt, instanceID) }) + t.Run("TestMetricsViewsAggregation_comparison_measure_filter_with_a_single_derivative_measure", func(t *testing.T) { + testMetricsViewsAggregation_comparison_measure_filter_with_a_single_derivative_measure(t, rt, instanceID) + }) + t.Run("testMetricsViewsAggregation_comparison_measure_filter_no_duplicates", func(t *testing.T) { + testMetricsViewsAggregation_comparison_measure_filter_no_duplicates(t, rt, instanceID) + }) + t.Run("testMetricsViewsAggregation_comparison_measure_filter_with_totals", func(t *testing.T) { + testMetricsViewsAggregation_comparison_measure_filter_with_totals(t, rt, instanceID) + }) + t.Run("testMetricsViewsAggregation_comparison_with_offset", func(t *testing.T) { testMetricsViewsAggregation_comparison_with_offset(t, rt, instanceID) }) + t.Run("testMetricsViewAggregation_percent_of_totals", func(t *testing.T) { testMetricsViewAggregation_percent_of_totals(t, rt, instanceID) }) + t.Run("testMetricsViewAggregation_percent_of_totals_with_limit", func(t *testing.T) { testMetricsViewAggregation_percent_of_totals_with_limit(t, rt, instanceID) }) + t.Run("testMetricsViewsAggregation_comparison_with_offset_and_limit_and_delta", func(t *testing.T) { + testMetricsViewsAggregation_comparison_with_offset_and_limit_and_delta(t, rt, instanceID) + }) +} + func TestMetricViewAggregationAgainstDuckDB(t *testing.T) { rt, instanceID := testruntime.NewInstanceForProject(t, "ad_bids") t.Run("testMetricsViewsAggregation", func(t *testing.T) { testMetricsViewsAggregation(t, rt, instanceID) }) diff --git a/runtime/queries/metricsview_comparison_toplist_test.go b/runtime/queries/metricsview_comparison_toplist_test.go index 23346c64cb2..d4c2bfee62e 100644 --- a/runtime/queries/metricsview_comparison_toplist_test.go +++ b/runtime/queries/metricsview_comparison_toplist_test.go @@ -59,6 +59,154 @@ func TestMetricsViewsComparisonAgainstClickHouse(t *testing.T) { t.Run("TestServer_MetricsViewTimeseries_export_csv", func(t *testing.T) { TestServer_MetricsViewTimeseries_export_csv(t) }) } +func TestMetricsViewsComparisonAgainstStarRocks(t *testing.T) { + testmode.Expensive(t) + + rt, instanceID := testruntime.NewInstanceWithStarRocksProject(t) + t.Run("testMetricsViewsComparison_dim_order", func(t *testing.T) { + testMetricsViewsComparison_dim_order(t, rt, instanceID) + }) + t.Run("testMetricsViewsComparison_measure_order", func(t *testing.T) { + testMetricsViewsComparison_measure_order(t, rt, instanceID) + }) + t.Run("testMetricsViewsComparison_measure_filters", func(t *testing.T) { + testMetricsViewsComparison_measure_filters(t, rt, instanceID) + }) +} + +func testMetricsViewsComparison_dim_order(t *testing.T, rt *runtime.Runtime, instanceID string) { + ctr := &queries.ColumnTimeRange{ + TableName: "ad_bids", + ColumnName: "timestamp", + } + err := ctr.Resolve(context.Background(), rt, instanceID, 0) + require.NoError(t, err) + diff := ctr.Result.Max.AsTime().Sub(ctr.Result.Min.AsTime()) + maxTime := ctr.Result.Min.AsTime().Add(diff / 2) + + q := &queries.MetricsViewComparison{ + MetricsViewName: "ad_bids_metrics", + DimensionName: "dom", + Measures: []*runtimev1.MetricsViewAggregationMeasure{ + { + Name: "measure_1", + }, + }, + TimeRange: &runtimev1.TimeRange{ + Start: ctr.Result.Min, + End: timestamppb.New(maxTime), + }, + Sort: []*runtimev1.MetricsViewComparisonSort{ + { + Name: "dom", + SortType: runtimev1.MetricsViewComparisonMeasureType_METRICS_VIEW_COMPARISON_MEASURE_TYPE_BASE_VALUE, + Desc: false, + }, + }, + Limit: 10, + SecurityClaims: testClaims(), + } + + err = q.Resolve(context.Background(), rt, instanceID, 0) + require.NoError(t, err) + require.NotEmpty(t, q.Result) +} + +func testMetricsViewsComparison_measure_order(t *testing.T, rt *runtime.Runtime, instanceID string) { + ctr := &queries.ColumnTimeRange{ + TableName: "ad_bids", + ColumnName: "timestamp", + } + err := ctr.Resolve(context.Background(), rt, instanceID, 0) + require.NoError(t, err) + diff := ctr.Result.Max.AsTime().Sub(ctr.Result.Min.AsTime()) + maxTime := ctr.Result.Min.AsTime().Add(diff / 2) + + q := &queries.MetricsViewComparison{ + MetricsViewName: "ad_bids_metrics", + DimensionName: "dom", + Measures: []*runtimev1.MetricsViewAggregationMeasure{ + { + Name: "measure_1", + }, + }, + TimeRange: &runtimev1.TimeRange{ + Start: ctr.Result.Min, + End: timestamppb.New(maxTime), + }, + Sort: []*runtimev1.MetricsViewComparisonSort{ + { + Name: "measure_1", + SortType: runtimev1.MetricsViewComparisonMeasureType_METRICS_VIEW_COMPARISON_MEASURE_TYPE_BASE_VALUE, + Desc: true, + }, + }, + Limit: 10, + SecurityClaims: testClaims(), + } + + err = q.Resolve(context.Background(), rt, instanceID, 0) + require.NoError(t, err) + require.NotEmpty(t, q.Result) +} + +func testMetricsViewsComparison_measure_filters(t *testing.T, rt *runtime.Runtime, instanceID string) { + ctr := &queries.ColumnTimeRange{ + TableName: "ad_bids", + ColumnName: "timestamp", + } + err := ctr.Resolve(context.Background(), rt, instanceID, 0) + require.NoError(t, err) + diff := ctr.Result.Max.AsTime().Sub(ctr.Result.Min.AsTime()) + maxTime := ctr.Result.Min.AsTime().Add(diff / 2) + + q := &queries.MetricsViewComparison{ + MetricsViewName: "ad_bids_metrics", + DimensionName: "dom", + Measures: []*runtimev1.MetricsViewAggregationMeasure{ + { + Name: "measure_1", + }, + }, + TimeRange: &runtimev1.TimeRange{ + Start: ctr.Result.Min, + End: timestamppb.New(maxTime), + }, + Having: &runtimev1.Expression{ + Expression: &runtimev1.Expression_Cond{ + Cond: &runtimev1.Condition{ + Op: runtimev1.Operation_OPERATION_GT, + Exprs: []*runtimev1.Expression{ + { + Expression: &runtimev1.Expression_Ident{ + Ident: "measure_1", + }, + }, + { + Expression: &runtimev1.Expression_Val{ + Val: structpb.NewNumberValue(1), + }, + }, + }, + }, + }, + }, + Sort: []*runtimev1.MetricsViewComparisonSort{ + { + Name: "measure_1", + SortType: runtimev1.MetricsViewComparisonMeasureType_METRICS_VIEW_COMPARISON_MEASURE_TYPE_BASE_VALUE, + Desc: true, + }, + }, + Limit: 10, + SecurityClaims: testClaims(), + } + + err = q.Resolve(context.Background(), rt, instanceID, 0) + require.NoError(t, err) + require.NotEmpty(t, q.Result) +} + func TestMetricsViewsComparison_dim_order_comparison_toplist_vs_general_toplist(t *testing.T) { rt, instanceID := testruntime.NewInstanceForProject(t, "ad_bids") diff --git a/runtime/queries/metricsview_toplist_test.go b/runtime/queries/metricsview_toplist_test.go index 6a47394b387..b18ad48f311 100644 --- a/runtime/queries/metricsview_toplist_test.go +++ b/runtime/queries/metricsview_toplist_test.go @@ -6,6 +6,7 @@ import ( "testing" runtimev1 "github.com/rilldata/rill/proto/gen/rill/runtime/v1" + "github.com/rilldata/rill/runtime" "github.com/rilldata/rill/runtime/queries" "github.com/rilldata/rill/runtime/testruntime" "github.com/rilldata/rill/runtime/testruntime/testmode" @@ -42,6 +43,60 @@ func TestMetricsViewsToplistAgainstClickHouse(t *testing.T) { t.Run("TestMetricsViewsToplist_measure_filters", func(t *testing.T) { TestMetricsViewsToplist_measure_filters(t) }) } +func TestMetricsViewsToplistAgainstStarRocks(t *testing.T) { + testmode.Expensive(t) + + rt, instanceID := testruntime.NewInstanceWithStarRocksProject(t) + t.Run("testMetricsViewsToplist_measure_filters", func(t *testing.T) { + testMetricsViewsToplist_measure_filters(t, rt, instanceID) + }) +} + +func testMetricsViewsToplist_measure_filters(t *testing.T, rt *runtime.Runtime, instanceID string) { + ctr := &queries.ColumnTimeRange{ + DatabaseSchema: "test_db", + TableName: "ad_bids", + ColumnName: "timestamp", + } + err := ctr.Resolve(context.Background(), rt, instanceID, 0) + require.NoError(t, err) + diff := ctr.Result.Max.AsTime().Sub(ctr.Result.Min.AsTime()) + maxTime := ctr.Result.Min.AsTime().Add(diff / 2) + + lmt := int64(250) + q := &queries.MetricsViewToplist{ + MetricsViewName: "ad_bids_metrics", + DimensionName: "dom", + MeasureNames: []string{"measure_1"}, + TimeStart: ctr.Result.Min, + TimeEnd: timestamppb.New(maxTime), + Having: &runtimev1.Expression{ + Expression: &runtimev1.Expression_Cond{ + Cond: &runtimev1.Condition{ + Op: runtimev1.Operation_OPERATION_GT, + Exprs: []*runtimev1.Expression{ + { + Expression: &runtimev1.Expression_Ident{ + Ident: "measure_1", + }, + }, + { + Expression: &runtimev1.Expression_Val{ + Val: structpb.NewNumberValue(1), + }, + }, + }, + }, + }, + }, + Limit: &lmt, + SecurityClaims: testClaims(), + } + err = q.Resolve(context.Background(), rt, instanceID, 0) + require.NoError(t, err) + require.NotEmpty(t, q.Result) +} + func TestMetricsViewsToplist_measure_filters(t *testing.T) { rt, instanceID := testruntime.NewInstanceForProject(t, "ad_bids") diff --git a/runtime/testruntime/testdata/ad_bids_starrocks/dashboards/ad_bids_metrics.yaml b/runtime/testruntime/testdata/ad_bids_starrocks/dashboards/ad_bids_metrics.yaml new file mode 100644 index 00000000000..fe9c309a71b --- /dev/null +++ b/runtime/testruntime/testdata/ad_bids_starrocks/dashboards/ad_bids_metrics.yaml @@ -0,0 +1,45 @@ +table: ad_bids +database_schema: test_db +display_name: Ad bids +description: + +timeseries: timestamp +smallest_time_grain: "" + +dimensions: + - name: pub + display_name: Publisher + property: publisher + description: "" + - name: dom + display_name: Domain + property: domain + description: "" + - name: nolabel_pub + display_name: nolabel_pub + property: publisher + - name: space_label + display_name: Space Label + expression: "publisher" + - name: null_publisher + display_name: Null Publisher + expression: case when publisher is null then true else false end + +measures: + - display_name: "Number of bids" + expression: count(*) + description: "" + format_preset: humanize + - display_name: "Average bid price" + expression: avg(bid_price) + description: "" + format_preset: humanize + - name: m1 + expression: avg(bid_price) + description: "" + format_preset: humanize + - name: bid_price + expression: avg(bid_price) + description: "" + format_preset: humanize + diff --git a/runtime/testruntime/testdata/ad_bids_starrocks/metrics/ad_bids_metrics_view.yaml b/runtime/testruntime/testdata/ad_bids_starrocks/metrics/ad_bids_metrics_view.yaml new file mode 100644 index 00000000000..05af8715a1c --- /dev/null +++ b/runtime/testruntime/testdata/ad_bids_starrocks/metrics/ad_bids_metrics_view.yaml @@ -0,0 +1,30 @@ +# Metrics view YAML +# Reference documentation: https://docs.rilldata.com/reference/project-files/metrics-views + +version: 1 +type: metrics_view + +display_name: Ad Bids +table: ad_bids +database_schema: test_db +timeseries: timestamp + +dimensions: + - name: publisher + display_name: Publisher + column: publisher + - name: domain + display_name: Domain + column: domain + +measures: + - name: total_records + display_name: Total records + expression: COUNT(*) + description: "" + format_preset: humanize + - name: bid_price_sum + display_name: Sum of Bid Price + expression: SUM(bid_price) + description: "" + format_preset: humanize diff --git a/runtime/testruntime/testdata/ad_bids_starrocks/rill.yaml b/runtime/testruntime/testdata/ad_bids_starrocks/rill.yaml new file mode 100644 index 00000000000..ec73f556c13 --- /dev/null +++ b/runtime/testruntime/testdata/ad_bids_starrocks/rill.yaml @@ -0,0 +1 @@ +olap_connector: starrocks diff --git a/runtime/testruntime/testruntime.go b/runtime/testruntime/testruntime.go index 6b231ba6756..efedd24633c 100644 --- a/runtime/testruntime/testruntime.go +++ b/runtime/testruntime/testruntime.go @@ -17,6 +17,7 @@ import ( "github.com/rilldata/rill/runtime" "github.com/rilldata/rill/runtime/drivers" "github.com/rilldata/rill/runtime/drivers/clickhouse/testclickhouse" + "github.com/rilldata/rill/runtime/drivers/starrocks/teststarrocks" "github.com/rilldata/rill/runtime/pkg/activity" "github.com/rilldata/rill/runtime/pkg/email" "github.com/rilldata/rill/runtime/storage" @@ -42,6 +43,7 @@ import ( _ "github.com/rilldata/rill/runtime/drivers/s3" _ "github.com/rilldata/rill/runtime/drivers/snowflake" _ "github.com/rilldata/rill/runtime/drivers/sqlite" + _ "github.com/rilldata/rill/runtime/drivers/starrocks" _ "github.com/rilldata/rill/runtime/reconcilers" ) @@ -419,3 +421,55 @@ func NewInstanceWithClickhouseProject(t TestingT, withCluster bool) (*runtime.Ru return rt, inst.ID } + +func NewInstanceWithStarRocksProject(t TestingT) (*runtime.Runtime, string) { + dsn := teststarrocks.StartWithData(t) + + rt := New(t, true) + ctx := t.Context() + + _, currentFile, _, _ := goruntime.Caller(0) + projectPath := filepath.Join(currentFile, "..", "testdata", "ad_bids_starrocks") + + inst := &drivers.Instance{ + Environment: "test", + OLAPConnector: "starrocks", + RepoConnector: "repo", + CatalogConnector: "catalog", + Connectors: []*runtimev1.Connector{ + { + Type: "file", + Name: "repo", + Config: Must(structpb.NewStruct(map[string]any{"dsn": projectPath})), + }, + { + Type: "starrocks", + Name: "starrocks", + Config: Must(structpb.NewStruct(map[string]any{"dsn": dsn, "database": "test_db"})), + }, + { + Type: "sqlite", + Name: "catalog", + // Setting a test-specific name ensures a unique connection when "cache=shared" is enabled. + // "cache=shared" is needed to prevent threading problems. + Config: Must(structpb.NewStruct(map[string]any{"dsn": fmt.Sprintf("file:%s?mode=memory&cache=shared", t.Name())})), + }, + }, + Variables: map[string]string{"rill.stage_changes": "false"}, + } + + err := rt.CreateInstance(ctx, inst) + require.NoError(t, err) + require.NotEmpty(t, inst.ID) + + ctrl, err := rt.Controller(ctx, inst.ID) + require.NoError(t, err) + + _, err = ctrl.Get(ctx, runtime.GlobalProjectParserName, false) + require.NoError(t, err) + + err = ctrl.WaitUntilIdle(ctx, false) + require.NoError(t, err) + + return rt, inst.ID +}