Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 55 additions & 29 deletions mssql_python/pybind/ddbc_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2924,7 +2924,9 @@ SQLRETURN SQLGetData_wrap(SqlHandlePtr StatementHandle, SQLUSMALLINT colCount, p
row.append(
FetchLobColumnData(hStmt, i, SQL_C_CHAR, false, false, charEncoding));
} else {
uint64_t fetchBufferSize = columnSize + 1 /* null-termination */;
// Multiply by 4 because utf8 conversion by the driver might
// turn varchar(x) into up to 3*x (maybe 4*x?) bytes.
uint64_t fetchBufferSize = 4 * columnSize + 1 /* null-termination */;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any specific reason of this number or is it random?

Copy link
Author

@ffelixg ffelixg Jan 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we are fetching varchar(x), the driver gives us x in the columnSize variable. The driver (msodbcsql18) however does not give us which collation the varchar(x) column uses. The Non-Windows drivers also convert the data to utf-8 no matter which collation is used by the column. So if it's a utf-8 collation, x equals the maximum number of bytes, so we're fine. If varchar(x) uses any other single byte collation, we must allocate the buffer large enough such that the result of a conversion to utf-8 fits. From my understanding, utf-8 tries to encode most characters of other collations with 2 bytes. Some later additions, like "€" to latin1 may require 3 bytes.

I've also verified this by iterating over all possible SQL Server collations and converting the result to utf-8. If there was a single byte collation with a character that requires more than 3 bytes, the following script would error. (posted this also in some copilot review comment)

SET NOCOUNT ON;
DECLARE @collation_name NVARCHAR(128);
DECLARE collation_cursor CURSOR FOR SELECT name FROM fn_helpcollations();
OPEN collation_cursor;
FETCH NEXT FROM collation_cursor INTO @collation_name;

WHILE @@FETCH_STATUS = 0
BEGIN
    IF @collation_name NOT LIKE N'%_UTF8%'
    BEGIN
        DECLARE @sql NVARCHAR(MAX) = N'
        declare @t1 table (a varchar(1) collate ' + @collation_name + N')
        declare @t2 table (a varchar(4) collate Latin1_General_100_CI_AI_SC_UTF8)
        insert into @t1 select top 256 cast(row_number() over(order by (select 1)) - 1 as binary(1)) a from sys.objects
        insert into @t2 select cast(a as nvarchar(10)) from @t1
        if (select max(datalength(a)) from @t2) > 3
            throw 50000, ''datalength too big'', 1
        ';
        EXEC sp_executesql @sql;
    END
    
    FETCH NEXT FROM collation_cursor INTO @collation_name;
END
CLOSE collation_cursor;
DEALLOCATE collation_cursor;

Therefore x needs to be multiplied by 3 at least. +1 for the null terminator is fine, since the null terminator takes 1 byte no matter the encoding/collation.

std::vector<SQLCHAR> dataBuffer(fetchBufferSize);
SQLLEN dataLen;
ret = SQLGetData_ptr(hStmt, i, SQL_C_CHAR, dataBuffer.data(), dataBuffer.size(),
Expand Down Expand Up @@ -2953,12 +2955,15 @@ SQLRETURN SQLGetData_wrap(SqlHandlePtr StatementHandle, SQLUSMALLINT colCount, p
row.append(raw_bytes);
}
} else {
// Buffer too small, fallback to streaming
LOG("SQLGetData: CHAR column %d data truncated "
"(buffer_size=%zu), using streaming LOB",
i, dataBuffer.size());
row.append(FetchLobColumnData(hStmt, i, SQL_C_CHAR, false, false,
charEncoding));
// Reaching this case indicates an error in mssql_python.
// Theoretically, we could still compensate by calling SQLGetData or
// FetchLobColumnData more often, but then we would still have to process
// the data we already got from the above call to SQLGetData.
// Better to throw an exception and fix the code than to risk returning corrupted data.
ThrowStdException(
"Internal error: SQLGetData returned data "
"larger than expected for CHAR column"
);
}
} else if (dataLen == SQL_NULL_DATA) {
LOG("SQLGetData: Column %d is NULL (CHAR)", i);
Expand Down Expand Up @@ -2995,7 +3000,7 @@ SQLRETURN SQLGetData_wrap(SqlHandlePtr StatementHandle, SQLUSMALLINT colCount, p
case SQL_WCHAR:
case SQL_WVARCHAR:
case SQL_WLONGVARCHAR: {
if (columnSize == SQL_NO_TOTAL || columnSize > 4000) {
if (columnSize == SQL_NO_TOTAL || columnSize == 0 || columnSize > 4000) {
LOG("SQLGetData: Streaming LOB for column %d (SQL_C_WCHAR) "
"- columnSize=%lu",
i, (unsigned long)columnSize);
Expand Down Expand Up @@ -3024,12 +3029,15 @@ SQLRETURN SQLGetData_wrap(SqlHandlePtr StatementHandle, SQLUSMALLINT colCount, p
"length=%lu for column %d",
(unsigned long)numCharsInData, i);
} else {
// Buffer too small, fallback to streaming
LOG("SQLGetData: NVARCHAR column %d data "
"truncated, using streaming LOB",
i);
row.append(FetchLobColumnData(hStmt, i, SQL_C_WCHAR, true, false,
"utf-16le"));
// Reaching this case indicates an error in mssql_python.
// Theoretically, we could still compensate by calling SQLGetData or
// FetchLobColumnData more often, but then we would still have to process
// the data we already got from the above call to SQLGetData.
// Better to throw an exception and fix the code than to risk returning corrupted data.
ThrowStdException(
"Internal error: SQLGetData returned data "
"larger than expected for WCHAR column"
);
}
} else if (dataLen == SQL_NULL_DATA) {
LOG("SQLGetData: Column %d is NULL (NVARCHAR)", i);
Expand Down Expand Up @@ -3291,8 +3299,15 @@ SQLRETURN SQLGetData_wrap(SqlHandlePtr StatementHandle, SQLUSMALLINT colCount, p
row.append(py::bytes(
reinterpret_cast<const char*>(dataBuffer.data()), dataLen));
} else {
row.append(
FetchLobColumnData(hStmt, i, SQL_C_BINARY, false, true, ""));
// Reaching this case indicates an error in mssql_python.
// Theoretically, we could still compensate by calling SQLGetData or
// FetchLobColumnData more often, but then we would still have to process
// the data we already got from the above call to SQLGetData.
// Better to throw an exception and fix the code than to risk returning corrupted data.
ThrowStdException(
"Internal error: SQLGetData returned data "
"larger than expected for BINARY column"
);
}
} else if (dataLen == SQL_NULL_DATA) {
row.append(py::none());
Expand Down Expand Up @@ -3434,7 +3449,9 @@ SQLRETURN SQLBindColums(SQLHSTMT hStmt, ColumnBuffers& buffers, py::list& column
// TODO: handle variable length data correctly. This logic wont
// suffice
HandleZeroColumnSizeAtFetch(columnSize);
uint64_t fetchBufferSize = columnSize + 1 /*null-terminator*/;
// Multiply by 4 because utf8 conversion by the driver might
// turn varchar(x) into up to 3*x (maybe 4*x?) bytes.
uint64_t fetchBufferSize = 4 * columnSize + 1 /*null-terminator*/;
// TODO: For LONGVARCHAR/BINARY types, columnSize is returned as
// 2GB-1 by SQLDescribeCol. So fetchBufferSize = 2GB.
// fetchSize=1 if columnSize>1GB. So we'll allocate a vector of
Expand Down Expand Up @@ -3580,8 +3597,7 @@ SQLRETURN SQLBindColums(SQLHSTMT hStmt, ColumnBuffers& buffers, py::list& column
// Fetch rows in batches
// TODO: Move to anonymous namespace, since it is not used outside this file
SQLRETURN FetchBatchData(SQLHSTMT hStmt, ColumnBuffers& buffers, py::list& columnNames,
py::list& rows, SQLUSMALLINT numCols, SQLULEN& numRowsFetched,
const std::vector<SQLUSMALLINT>& lobColumns) {
py::list& rows, SQLUSMALLINT numCols, SQLULEN& numRowsFetched) {
LOG("FetchBatchData: Fetching data in batches");
SQLRETURN ret = SQLFetchScroll_ptr(hStmt, SQL_FETCH_NEXT, 0);
if (ret == SQL_NO_DATA) {
Expand All @@ -3600,19 +3616,28 @@ SQLRETURN FetchBatchData(SQLHSTMT hStmt, ColumnBuffers& buffers, py::list& colum
SQLULEN columnSize;
SQLULEN processedColumnSize;
uint64_t fetchBufferSize;
bool isLob;
};
std::vector<ColumnInfo> columnInfos(numCols);
for (SQLUSMALLINT col = 0; col < numCols; col++) {
const auto& columnMeta = columnNames[col].cast<py::dict>();
columnInfos[col].dataType = columnMeta["DataType"].cast<SQLSMALLINT>();
columnInfos[col].columnSize = columnMeta["ColumnSize"].cast<SQLULEN>();
columnInfos[col].isLob =
std::find(lobColumns.begin(), lobColumns.end(), col + 1) != lobColumns.end();
columnInfos[col].processedColumnSize = columnInfos[col].columnSize;
HandleZeroColumnSizeAtFetch(columnInfos[col].processedColumnSize);
columnInfos[col].fetchBufferSize =
columnInfos[col].processedColumnSize + 1; // +1 for null terminator
switch (columnInfos[col].dataType) {
case SQL_CHAR:
case SQL_VARCHAR:
case SQL_LONGVARCHAR:
// Multiply by 4 because utf8 conversion by the driver might
// turn varchar(x) into up to 3*x (maybe 4*x?) bytes.
columnInfos[col].fetchBufferSize =
4 * columnInfos[col].processedColumnSize + 1; // +1 for null terminator
break;
default:
columnInfos[col].fetchBufferSize =
columnInfos[col].processedColumnSize + 1; // +1 for null terminator
break;
}
}

std::string decimalSeparator = GetDecimalSeparator(); // Cache decimal separator
Expand All @@ -3630,7 +3655,6 @@ SQLRETURN FetchBatchData(SQLHSTMT hStmt, ColumnBuffers& buffers, py::list& colum
columnInfosExt[col].columnSize = columnInfos[col].columnSize;
columnInfosExt[col].processedColumnSize = columnInfos[col].processedColumnSize;
columnInfosExt[col].fetchBufferSize = columnInfos[col].fetchBufferSize;
columnInfosExt[col].isLob = columnInfos[col].isLob;

// Map data type to processor function (switch executed once per column,
// not per cell)
Expand Down Expand Up @@ -3739,7 +3763,7 @@ SQLRETURN FetchBatchData(SQLHSTMT hStmt, ColumnBuffers& buffers, py::list& colum
// types) to just 10 (setup only) Note: Processor functions no
// longer need to check for NULL since we do it above
if (columnProcessors[col - 1] != nullptr) {
columnProcessors[col - 1](row, buffers, &columnInfosExt[col - 1], col, i, hStmt);
columnProcessors[col - 1](row, buffers, &columnInfosExt[col - 1], col, i);
continue;
}

Expand Down Expand Up @@ -3916,7 +3940,9 @@ size_t calculateRowSize(py::list& columnNames, SQLUSMALLINT numCols) {
case SQL_CHAR:
case SQL_VARCHAR:
case SQL_LONGVARCHAR:
rowSize += columnSize;
// Multiply by 4 because utf8 conversion by the driver might
// turn varchar(x) into up to 3*x (maybe 4*x?) bytes.
rowSize += 4 * columnSize;
break;
case SQL_SS_XML:
case SQL_WCHAR:
Expand Down Expand Up @@ -4070,7 +4096,7 @@ SQLRETURN FetchMany_wrap(SqlHandlePtr StatementHandle, py::list& rows, int fetch
SQLSetStmtAttr_ptr(hStmt, SQL_ATTR_ROW_ARRAY_SIZE, (SQLPOINTER)(intptr_t)fetchSize, 0);
SQLSetStmtAttr_ptr(hStmt, SQL_ATTR_ROWS_FETCHED_PTR, &numRowsFetched, 0);

ret = FetchBatchData(hStmt, buffers, columnNames, rows, numCols, numRowsFetched, lobColumns);
ret = FetchBatchData(hStmt, buffers, columnNames, rows, numCols, numRowsFetched);
if (!SQL_SUCCEEDED(ret) && ret != SQL_NO_DATA) {
LOG("FetchMany_wrap: Error when fetching data - SQLRETURN=%d", ret);
return ret;
Expand Down Expand Up @@ -4203,7 +4229,7 @@ SQLRETURN FetchAll_wrap(SqlHandlePtr StatementHandle, py::list& rows,

while (ret != SQL_NO_DATA) {
ret =
FetchBatchData(hStmt, buffers, columnNames, rows, numCols, numRowsFetched, lobColumns);
FetchBatchData(hStmt, buffers, columnNames, rows, numCols, numRowsFetched);
if (!SQL_SUCCEEDED(ret) && ret != SQL_NO_DATA) {
LOG("FetchAll_wrap: Error when fetching data - SQLRETURN=%d", ret);
return ret;
Expand Down
Loading
Loading