diff --git a/mssql_test.go b/mssql_test.go index 45d59b1..520849a 100644 --- a/mssql_test.go +++ b/mssql_test.go @@ -1978,3 +1978,154 @@ func TestMSSQLQueryContextCancel(t *testing.T) { t.Fatalf("Unexpected error value: should=%s, is=%s", context.Canceled, err) } } + +// https://github.com/alexbrainman/odbc/issues/178 +// verify that inserting unicode text into an NVARCHAR column +// with a specified collation preserves the original characters when +// the parameter is sent from Go. The collation on NVARCHAR only affects +// comparisons and sort order, not storage. This test reproduces +// behavior originally reported in issue #178. +func TestMSSQLNVarcharCollationPreservesUnicode(t *testing.T) { + db, sc, err := mssqlConnect() + if err != nil { + t.Fatal(err) + } + defer closeDB(t, db, sc, sc) + + type testStruct struct { + name string + collate string + poem string + } + + testCases := []testStruct{ + { + name: "zh_tw_big5", + collate: "Chinese_Taiwan_Bopomofo_CI_AS", + poem: "花間一壺酒,獨酌無相親。", + }, + { + name: "zh_cn_gbk", + collate: "Chinese_PRC_90_CI_AS", + poem: "花间一壶酒,独酌无相亲。", + }, + { + name: "ja_jp", + collate: "Japanese_CI_AS", + poem: "月日は百代の過客にして、行かふ年も又旅人也", + }, + { + name: "ko_kr", + collate: "Korean_Wansung_CI_AS", + poem: "꽃 사이 놓인 한 동이 술을 친한 이 없이 혼자 마시네.", + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + db.Exec("drop table dbo.temp") + exec(t, db, "create table dbo.temp (poem nvarchar(200) collate "+tc.collate+")") + + // when the column is VARCHAR with a specific collation, converting + // the parameter on the server side using that collation ensures the + // NVARCHAR value we send is mapped into the correct code page. if + // we simply send the parameter as NVARCHAR and insert directly into a + // VARCHAR column the conversion uses the database default collation and + // we lose characters for languages not covered by that code page. + // + // cast the incoming nvarchar parameter to varchar and give the + // cast the desired collation. this way the conversion from + // Unicode to the column code page uses the correct collation, + // not the database default. + // + // column is NVARCHAR so we can insert directly without worrying + // about code page conversions; the collation on nvarchar only + // affects comparisons and sort order, not storage. + stmt, err := db.Prepare("insert into dbo.temp (poem) values (?)") + if err != nil { + t.Fatal(err) + } + defer stmt.Close() + + if _, err := stmt.Exec(tc.poem); err != nil { + t.Fatal(err) + } + + var got string + if err := db.QueryRow("select cast(poem as nvarchar(max)) from dbo.temp").Scan(&got); err != nil { + t.Fatal(err) + } + if got != tc.poem { + t.Fatalf("poem mismatch, want=%v, got=%v", tc.poem, got) + } + + exec(t, db, "drop table dbo.temp") + }) + } +} + +// https://github.com/alexbrainman/odbc/issues/178 +// verify that an SQL statement which provides a parameter description +// causes BindValue to honor the described size rather than the length of +// the actual string. this test will fail if the size override lines in +// param.go are removed (they were added in #178). +func TestMSSQLDescribeParameterSize(t *testing.T) { + if testing.Short() { + t.Skip("skip in short mode") + } + db, _, err := mssqlConnectWithParams(newConnParams()) + if err != nil { + t.Fatal(err) + } + defer db.Close() + + // first, make sure the driver actually describes the parameter and returns + // the expected size. we use the Raw method to gain access to the + // underlying *Conn and call PrepareODBCStmt directly. + conn, err2 := db.Conn(context.Background()) + if err2 != nil { + t.Fatalf("failed to get raw connection: %v", err2) + } + defer conn.Close() + err = conn.Raw(func(dc interface{}) error { + c := dc.(*Conn) + stmt, err := c.PrepareODBCStmt("select cast(? as varchar(5))") + if err != nil { + return err + } + if len(stmt.Parameters) != 1 { + return fmt.Errorf("unexpected param count %d", len(stmt.Parameters)) + } + p := stmt.Parameters[0] + if !p.isDescribed || p.Size != 5 { + return fmt.Errorf("expected described size 5, got isDescribed=%v size=%d", p.isDescribed, p.Size) + } + return nil + }) + if err != nil { + t.Fatal(err) + } + + // intercept SQLBindParameter to capture the size argument used by BindValue + var boundSize api.SQLULEN + orig := sqlBindParameter + sqlBindParameter = func(h api.SQLHSTMT, paramNumber api.SQLUSMALLINT, + inputOutputType api.SQLSMALLINT, cType api.SQLSMALLINT, sqlType api.SQLSMALLINT, + size api.SQLULEN, decimal api.SQLSMALLINT, + buffer api.SQLPOINTER, bufferLength api.SQLLEN, strLenOrIndPtr *api.SQLLEN, + ) api.SQLRETURN { + boundSize = size + return orig(h, paramNumber, inputOutputType, cType, sqlType, + size, decimal, buffer, bufferLength, strLenOrIndPtr) + } + defer func() { sqlBindParameter = orig }() + + // execute the statement once; the cast will raise a truncation error when + // the input exceeds five characters, but that is irrelevant to capturing + // the bound size. + _, _ = db.Exec("select cast(? as varchar(5))", "abcdef") + if boundSize != 5 { + t.Fatalf("expected bound size 5, got %d", boundSize) + } +} diff --git a/param.go b/param.go index e5dd8ba..ea279f8 100644 --- a/param.go +++ b/param.go @@ -32,6 +32,9 @@ func (p *Parameter) StoreStrLen_or_IndPtr(v api.SQLLEN) *api.SQLLEN { } +// exposeable hook for binding; tests override this to capture arguments. +var sqlBindParameter = api.SQLBindParameter + func (p *Parameter) BindValue(h api.SQLHSTMT, idx int, v driver.Value, conn *Conn) error { // TODO(brainman): Reuse memory for previously bound values. If memory // is reused, we, probably, do not need to call SQLBindParameter either. @@ -70,6 +73,9 @@ func (p *Parameter) BindValue(h api.SQLHSTMT, idx int, v driver.Value, conn *Con sqltype = api.SQL_WLONGVARCHAR case p.isDescribed: sqltype = p.SQLType + if p.Size != 0 { + size = p.Size + } case size <= 1: sqltype = api.SQL_WVARCHAR default: @@ -163,7 +169,7 @@ func (p *Parameter) BindValue(h api.SQLHSTMT, idx int, v driver.Value, conn *Con default: return fmt.Errorf("unsupported type %T", v) } - ret := api.SQLBindParameter(h, api.SQLUSMALLINT(idx+1), + ret := sqlBindParameter(h, api.SQLUSMALLINT(idx+1), api.SQL_PARAM_INPUT, ctype, sqltype, size, decimal, api.SQLPOINTER(buf), buflen, plen) if IsError(ret) {