diff --git a/driver/columns.go b/driver/columns.go index 42b59a9..15e9cd6 100644 --- a/driver/columns.go +++ b/driver/columns.go @@ -14,6 +14,21 @@ import ( // BigQueryTimeLayout represents the TIME format: HH:MM:SS[.SSSSSS] const BigQueryTimeLayout = "15:04:05.000000" +// DateTime wraps time.Time to preserve BigQuery DATETIME +// (timezone-naive) semantics. When passed as a query parameter, +// the driver converts it back to civil.DateTime so BigQuery maps +// it to DATETIME, not TIMESTAMP. +type DateTime struct { + time.Time +} + +// Value implements driver.Valuer. Returns the underlying time.Time +// so that database/sql.Scan can transparently assign DateTime to +// *time.Time targets. +func (dt DateTime) Value() (driver.Value, error) { + return dt.Time, nil +} + type bigQuerySchema interface { ColumnNames() []string ConvertColumnValue(index int, value bigquery.Value) (driver.Value, error) @@ -64,14 +79,24 @@ func (column bigQueryColumn) ConvertValue(value bigquery.Value) (driver.Value, e } } - // Handle DATETIME type conversion from civil.DateTime to time.Time + // Handle DATETIME type conversion from civil.DateTime to + // DateTime wrapper. The wrapper preserves time.Time semantics + // for consumers while enabling correct round-trip back to + // BigQuery DATETIME via buildParameter. if column.FieldType == bigquery.DateTimeFieldType { if value != nil { if civilDateTime, ok := value.(civil.DateTime); ok { - converted := time.Date(civilDateTime.Date.Year, civilDateTime.Date.Month, civilDateTime.Date.Day, - civilDateTime.Time.Hour, civilDateTime.Time.Minute, civilDateTime.Time.Second, - civilDateTime.Time.Nanosecond, time.UTC) - return converted, nil + t := time.Date( + civilDateTime.Date.Year, + civilDateTime.Date.Month, + civilDateTime.Date.Day, + civilDateTime.Time.Hour, + civilDateTime.Time.Minute, + civilDateTime.Time.Second, + civilDateTime.Time.Nanosecond, + time.UTC, + ) + return DateTime{t}, nil } } } diff --git a/driver/columns_test.go b/driver/columns_test.go new file mode 100644 index 0000000..d2ec1ae --- /dev/null +++ b/driver/columns_test.go @@ -0,0 +1,163 @@ +package driver + +import ( + "database/sql/driver" + "testing" + "time" + + "cloud.google.com/go/bigquery" + "cloud.google.com/go/civil" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" +) + +type ColumnsTestSuite struct { + suite.Suite +} + +func TestColumnsTestSuite(t *testing.T) { + suite.Run(t, new(ColumnsTestSuite)) +} + +// TestDateTime_ImplementsValuer is a compile-time check that +// DateTime satisfies the driver.Valuer interface. +func (s *ColumnsTestSuite) TestDateTime_ImplementsValuer() { + var _ driver.Valuer = DateTime{} +} + +func (s *ColumnsTestSuite) TestDateTimeValue() { + testCases := map[string]struct { + input DateTime + expected time.Time + }{ + "basic_datetime": { + input: DateTime{ + time.Date(2024, 6, 15, 10, 30, 45, 0, time.UTC), + }, + expected: time.Date(2024, 6, 15, 10, 30, 45, 0, time.UTC), + }, + "zero_time": { + input: DateTime{time.Time{}}, + expected: time.Time{}, + }, + "with_nanoseconds": { + input: DateTime{ + time.Date(2024, 1, 1, 0, 0, 0, 123456000, time.UTC), + }, + expected: time.Date(2024, 1, 1, 0, 0, 0, 123456000, time.UTC), + }, + } + + for name, tc := range testCases { + s.Run(name, func() { + val, err := tc.input.Value() + assert.NoError(s.T(), err) + assert.Equal(s.T(), tc.expected, val) + }) + } +} + +func (s *ColumnsTestSuite) TestConvertValue() { + testCases := map[string]struct { + column bigQueryColumn + input bigquery.Value + expectedType string // "DateTime", "time.Time", "string", "nil" + expected interface{} + }{ + "datetime_column": { + column: bigQueryColumn{ + Name: "created_at", + FieldType: bigquery.DateTimeFieldType, + }, + input: civil.DateTime{ + Date: civil.Date{Year: 2024, Month: 6, Day: 15}, + Time: civil.Time{Hour: 10, Minute: 30, Second: 45}, + }, + expectedType: "DateTime", + expected: DateTime{ + time.Date(2024, 6, 15, 10, 30, 45, 0, time.UTC), + }, + }, + "datetime_column_with_nanoseconds": { + column: bigQueryColumn{ + Name: "event_time", + FieldType: bigquery.DateTimeFieldType, + }, + input: civil.DateTime{ + Date: civil.Date{Year: 2024, Month: 1, Day: 1}, + Time: civil.Time{ + Hour: 12, Minute: 0, Second: 0, + Nanosecond: 500000000, + }, + }, + expectedType: "DateTime", + expected: DateTime{ + time.Date(2024, 1, 1, 12, 0, 0, 500000000, time.UTC), + }, + }, + "datetime_column_nil_value": { + column: bigQueryColumn{ + Name: "created_at", + FieldType: bigquery.DateTimeFieldType, + }, + input: nil, + expectedType: "nil", + expected: nil, + }, + "date_column": { + column: bigQueryColumn{ + Name: "event_date", + FieldType: bigquery.DateFieldType, + }, + input: civil.Date{Year: 2024, Month: 3, Day: 1}, + expectedType: "time.Time", + expected: time.Date(2024, 3, 1, 0, 0, 0, 0, time.UTC), + }, + "time_column": { + column: bigQueryColumn{ + Name: "event_time", + FieldType: bigquery.TimeFieldType, + }, + input: civil.Time{Hour: 14, Minute: 30, Second: 0}, + expectedType: "string", + expected: "14:30:00.000000", + }, + "time_column_with_nanoseconds": { + column: bigQueryColumn{ + Name: "event_time", + FieldType: bigquery.TimeFieldType, + }, + input: civil.Time{ + Hour: 9, Minute: 5, Second: 3, + Nanosecond: 123000000, + }, + expectedType: "string", + expected: "09:05:03.123000", + }, + } + + for name, tc := range testCases { + s.Run(name, func() { + result, err := tc.column.ConvertValue(tc.input) + assert.NoError(s.T(), err) + + switch tc.expectedType { + case "DateTime": + dt, ok := result.(DateTime) + assert.True(s.T(), ok, "expected DateTime, got %T", result) + assert.Equal(s.T(), tc.expected, dt) + case "time.Time": + ts, ok := result.(time.Time) + assert.True(s.T(), ok, "expected time.Time, got %T", result) + assert.Equal(s.T(), tc.expected, ts) + case "string": + str, ok := result.(string) + assert.True(s.T(), ok, "expected string, got %T", result) + assert.Equal(s.T(), tc.expected, str) + case "nil": + assert.Nil(s.T(), result) + } + }) + } +} diff --git a/driver/statement.go b/driver/statement.go index fe4515b..33d164d 100644 --- a/driver/statement.go +++ b/driver/statement.go @@ -6,6 +6,7 @@ import ( "errors" "cloud.google.com/go/bigquery" + "cloud.google.com/go/civil" "github.com/sirupsen/logrus" "github.com/scaledata/bigquery/adaptor" @@ -172,6 +173,14 @@ func (statement bigQueryStatement) buildParameters(args []driver.Value) ([]bigqu } func buildParameter(arg driver.Value, parameters []bigquery.QueryParameter) []bigquery.QueryParameter { + // Convert DateTime back to civil.DateTime so the BigQuery + // client maps it to DATETIME parameter type (not TIMESTAMP). + if dt, ok := arg.(DateTime); ok { + return append(parameters, bigquery.QueryParameter{ + Value: civil.DateTimeOf(dt.Time), + }) + } + namedValue, ok := arg.(driver.NamedValue) if ok { return buildParameterFromNamedValue(namedValue, parameters) @@ -187,14 +196,21 @@ func buildParameter(arg driver.Value, parameters []bigquery.QueryParameter) []bi func buildParameterFromNamedValue(namedValue driver.NamedValue, parameters []bigquery.QueryParameter) []bigquery.QueryParameter { logrus.Debugf("-param:%s=%s", namedValue.Name, namedValue.Value) + // Convert DateTime back to civil.DateTime for DATETIME + // parameter type. + value := namedValue.Value + if dt, ok := value.(DateTime); ok { + value = civil.DateTimeOf(dt.Time) + } + if namedValue.Name == "" { return append(parameters, bigquery.QueryParameter{ - Value: namedValue.Value, + Value: value, }) } else { return append(parameters, bigquery.QueryParameter{ Name: namedValue.Name, - Value: namedValue.Value, + Value: value, }) } } diff --git a/driver/statement_test.go b/driver/statement_test.go new file mode 100644 index 0000000..bdf0e79 --- /dev/null +++ b/driver/statement_test.go @@ -0,0 +1,127 @@ +package driver + +import ( + "database/sql/driver" + "testing" + "time" + + "cloud.google.com/go/bigquery" + "cloud.google.com/go/civil" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" +) + +type StatementTestSuite struct { + suite.Suite +} + +func TestStatementTestSuite(t *testing.T) { + suite.Run(t, new(StatementTestSuite)) +} + +func (s *StatementTestSuite) TestBuildParameter() { + ts := time.Date(2024, 6, 15, 10, 30, 45, 0, time.UTC) + + testCases := map[string]struct { + arg driver.Value + existing []bigquery.QueryParameter + expectedLen int + expectedValue interface{} + expectedName string + }{ + "datetime_converts_to_civil": { + arg: DateTime{ts}, + expectedLen: 1, + expectedValue: civil.DateTimeOf(ts), + }, + "regular_string_passes_through": { + arg: "hello", + expectedLen: 1, + expectedValue: "hello", + }, + "regular_int64_passes_through": { + arg: int64(42), + expectedLen: 1, + expectedValue: int64(42), + }, + "named_value_with_name": { + arg: driver.NamedValue{ + Name: "param1", + Value: int64(42), + }, + expectedLen: 1, + expectedValue: int64(42), + expectedName: "param1", + }, + "accumulates_with_existing": { + arg: "second", + existing: []bigquery.QueryParameter{{Value: "first"}}, + expectedLen: 2, + expectedValue: "second", + }, + } + + for name, tc := range testCases { + s.Run(name, func() { + params := buildParameter(tc.arg, tc.existing) + assert.Equal(s.T(), tc.expectedLen, len(params)) + + last := params[len(params)-1] + assert.Equal(s.T(), tc.expectedValue, last.Value) + if tc.expectedName != "" { + assert.Equal(s.T(), tc.expectedName, last.Name) + } + }) + } +} + +func (s *StatementTestSuite) TestBuildParameterFromNamedValue() { + ts := time.Date(2024, 6, 15, 10, 30, 45, 0, time.UTC) + + testCases := map[string]struct { + namedValue driver.NamedValue + expectedValue interface{} + expectedName string + }{ + "datetime_named_converts_to_civil": { + namedValue: driver.NamedValue{ + Name: "created_at", + Value: DateTime{ts}, + }, + expectedValue: civil.DateTimeOf(ts), + expectedName: "created_at", + }, + "datetime_unnamed_converts_to_civil": { + namedValue: driver.NamedValue{ + Value: DateTime{ts}, + }, + expectedValue: civil.DateTimeOf(ts), + }, + "regular_value_named": { + namedValue: driver.NamedValue{ + Name: "count", + Value: int64(99), + }, + expectedValue: int64(99), + expectedName: "count", + }, + "regular_value_unnamed": { + namedValue: driver.NamedValue{ + Value: "hello", + }, + expectedValue: "hello", + }, + } + + for name, tc := range testCases { + s.Run(name, func() { + params := buildParameterFromNamedValue( + tc.namedValue, nil, + ) + assert.Equal(s.T(), 1, len(params)) + assert.Equal(s.T(), tc.expectedValue, params[0].Value) + assert.Equal(s.T(), tc.expectedName, params[0].Name) + }) + } +}