From a7e220f51c05c19eedbad6b49e2be3595f9ae730 Mon Sep 17 00:00:00 2001 From: ffelixg <142172984+ffelixg@users.noreply.github.com> Date: Sun, 30 Nov 2025 21:06:07 +0100 Subject: [PATCH 01/15] Add arrow fetch support --- mssql_python/cursor.py | 86 +++ mssql_python/pybind/ddbc_bindings.cpp | 1028 +++++++++++++++++++++++++ requirements.txt | 1 + tests/test_004_cursor.py | 246 ++++++ 4 files changed, 1361 insertions(+) diff --git a/mssql_python/cursor.py b/mssql_python/cursor.py index 4e8815bd..cbd66cb3 100644 --- a/mssql_python/cursor.py +++ b/mssql_python/cursor.py @@ -31,7 +31,10 @@ from mssql_python import get_settings if TYPE_CHECKING: + import pyarrow # type: ignore from mssql_python.connection import Connection +else: + pyarrow = None # Constants for string handling MAX_INLINE_CHAR: int = ( @@ -2317,6 +2320,89 @@ def fetchall(self) -> List[Row]: # On error, don't increment rownumber - rethrow the error raise e + def arrow_batch(self, batch_size: int = 8192) -> "pyarrow.RecordBatch": + """ + Fetch a single pyarrow Record Batch of the specified size from the + query result set. + + Args: + batch_size: Maximum number of rows to fetch in the Record Batch. + + Returns: + A pyarrow RecordBatch object containing up to batch_size rows. + """ + self._check_closed() # Check if the cursor is closed + if not self._has_result_set and self.description: + self._reset_rownumber() + + try: + import pyarrow + except ImportError as e: + raise ImportError( + "pyarrow is required for arrow_batch(). Please install pyarrow." + ) from e + + capsules = [] + ret = ddbc_bindings.DDBCSQLFetchArrowBatch(self.hstmt, capsules, max(batch_size, 0)) + check_error(ddbc_sql_const.SQL_HANDLE_STMT.value, self.hstmt, ret) + + batch = pyarrow.RecordBatch._import_from_c_capsule(*capsules) + return batch + + def arrow(self, batch_size: int = 8192) -> "pyarrow.Table": + """ + Fetch the entire result as a pyarrow Table. + + Args: + batch_size: Size of the Record Batches which make up the Table. + + Returns: + A pyarrow Table containing all remaining rows from the result set. + """ + try: + import pyarrow + except ImportError as e: + raise ImportError("pyarrow is required for arrow(). Please install pyarrow.") from e + + batches: list["pyarrow.RecordBatch"] = [] + while True: + batch = self.arrow_batch(batch_size) + if batch.num_rows < batch_size or batch_size <= 0: + if not batches or batch.num_rows > 0: + batches.append(batch) + break + batches.append(batch) + return pyarrow.Table.from_batches(batches, schema=batches[0].schema) + + def arrow_reader(self, batch_size: int = 8192) -> "pyarrow.RecordBatchReader": + """ + Fetch the result as a pyarrow RecordBatchReader, which yields Record + Batches of the specified size until the current result set is + exhausted. + + Args: + batch_size: Size of the Record Batches produced by the reader. + + Returns: + A pyarrow RecordBatchReader for the result set. + """ + try: + import pyarrow + except ImportError as e: + raise ImportError( + "pyarrow is required for arrow_reader(). Please install pyarrow." + ) from e + + # Fetch schema without advancing cursor + schema_batch = self.arrow_batch(0) + schema = schema_batch.schema + + def batch_generator(): + while (batch := self.arrow_batch(batch_size)).num_rows > 0: + yield batch + + return pyarrow.RecordBatchReader.from_batches(schema, batch_generator()) + def nextset(self) -> Union[bool, None]: """ Skip to the next available result set. diff --git a/mssql_python/pybind/ddbc_bindings.cpp b/mssql_python/pybind/ddbc_bindings.cpp index 63696f91..a9c385d1 100644 --- a/mssql_python/pybind/ddbc_bindings.cpp +++ b/mssql_python/pybind/ddbc_bindings.cpp @@ -157,6 +157,83 @@ struct NumericData { } }; +// Struct to hold data buffers and indicators for each column +struct ColumnBuffersArrow { + std::vector> uint8; + std::vector> int16; + std::vector> int32; + std::vector> int64; + std::vector> float64; + std::vector> bit; + std::vector> var; + std::vector> date; + std::vector> ts_micro; + std::vector> time_second; + std::vector> decimal; + + std::vector> valid; + std::vector> var_data; + + ColumnBuffersArrow(SQLSMALLINT numCols) + : + uint8(numCols), + int16(numCols), + int32(numCols), + int64(numCols), + float64(numCols), + bit(numCols), + var(numCols), + date(numCols), + ts_micro(numCols), + time_second(numCols), + decimal(numCols), + + valid(numCols), + var_data(numCols) {} +}; + +#ifndef ARROW_C_DATA_INTERFACE +#define ARROW_C_DATA_INTERFACE + +#define ARROW_FLAG_DICTIONARY_ORDERED 1 +#define ARROW_FLAG_NULLABLE 2 +#define ARROW_FLAG_MAP_KEYS_SORTED 4 + +struct ArrowSchema { + // Array type description + const char* format; + const char* name; + const char* metadata; + int64_t flags; + int64_t n_children; + struct ArrowSchema** children; + struct ArrowSchema* dictionary; + + // Release callback + void (*release)(struct ArrowSchema*); + // Opaque producer-specific data + void* private_data; +}; + +struct ArrowArray { + // Array data description + int64_t length; + int64_t null_count; + int64_t offset; + int64_t n_buffers; + int64_t n_children; + const void** buffers; + struct ArrowArray** children; + struct ArrowArray* dictionary; + + // Release callback + void (*release)(struct ArrowArray*); + // Opaque producer-specific data + void* private_data; +}; + +#endif // ARROW_C_DATA_INTERFACE + //------------------------------------------------------------------------------------------------- // Function pointer initialization //------------------------------------------------------------------------------------------------- @@ -4087,6 +4164,956 @@ SQLRETURN FetchMany_wrap(SqlHandlePtr StatementHandle, py::list& rows, int fetch return ret; } +// GetDataVar - Progressively fetches variable-length column data using SQLGetData. +// +// Calls SQLGetData repeatedly, reallocating the buffer as needed, until all data is retrieved. +// Handles both fixed-size and unknown-size (SQL_NO_TOTAL) responses from the driver. +// +// @param hStmt: Statement handle +// @param colNumber: 1-based column index +// @param cType: SQL C data type (SQL_C_CHAR, SQL_C_WCHAR, or SQL_C_BINARY) +// @param dataVec: Reference to vector that will hold the fetched data (will be resized as needed) +// @param indicator: Pointer to indicator value (SQL_NULL_DATA for NULL, or data length) +// +// @return SQLRETURN: SQL_SUCCESS on success, or error code on failure +template +SQLRETURN GetDataVar(SQLHSTMT hStmt, + SQLUSMALLINT colNumber, + SQLSMALLINT cType, + std::vector& dataVec, + SQLLEN* indicator) { + if (!SQLGetData_ptr) { + ThrowStdException("SQLGetData function not loaded"); + } + + size_t start = 0; + size_t end = 0; + + // Determine null terminator size based on data type + size_t sizeNullTerminator = 0; + switch (cType) { + case SQL_C_WCHAR: + case SQL_C_CHAR: + sizeNullTerminator = 1; + break; + case SQL_C_BINARY: + sizeNullTerminator = 0; + break; + default: + ThrowStdException("GetDataVar only supports SQL_C_CHAR, SQL_C_WCHAR, and SQL_C_BINARY"); + } + + // Ensure initial buffer has space for at least the null terminator + if (dataVec.size() < sizeNullTerminator) { + dataVec.resize(sizeNullTerminator); + } + + while (true) { + SQLLEN localInd = 0; + SQLRETURN ret = SQLGetData_ptr( + hStmt, + colNumber, + cType, + reinterpret_cast(dataVec.data() + start), + sizeof(T) * (dataVec.size() - start), // Available buffer size from start position + &localInd + ); + + // Handle NULL data + if (localInd == SQL_NULL_DATA) { + *indicator = SQL_NULL_DATA; + return SQL_SUCCESS; + } + + // Check for errors (excluding SQL_SUCCESS_WITH_INFO which means more data available) + if (ret == SQL_ERROR || ret == SQL_INVALID_HANDLE) { + return ret; + } + + // SQL_SUCCESS or SQL_NO_DATA means we got all the data + if (ret == SQL_SUCCESS || ret == SQL_NO_DATA) { + if (localInd >= 0) { + *indicator = static_cast(start) * sizeof(T) + localInd; + } else { + *indicator = localInd; // Preserve SQL_NO_TOTAL or other negative values + } + break; + } + + // SQL_SUCCESS_WITH_INFO means buffer was too small, need to continue fetching + if (ret == SQL_SUCCESS_WITH_INFO) { + // Determine how much more space we need + if (localInd < 0) { + // SQL_NO_TOTAL: driver doesn't know total size, double the buffer + end = dataVec.size() * 2; + } else { + // Driver returned total size: allocate exactly what we need + assert(localInd % sizeof(T) == 0); + end = start + static_cast(localInd) / sizeof(T) + sizeNullTerminator; + } + + // The next read starts where the null terminator would have been placed + start = dataVec.size() - sizeNullTerminator; + + // Resize buffer for next iteration + dataVec.resize(end); + } else { + // Unexpected return code + return ret; + } + } + + return SQL_SUCCESS; +} + +void ArrowSchema_release(struct ArrowSchema* schema) { + assert (schema != nullptr); + assert (schema->release != nullptr); + schema->release = nullptr; + delete[] schema->name; + for (int i = 0; i < schema->n_children; i++) { + assert (schema->children != nullptr); + if (schema->children[i]) { + schema->children[i]->release(schema->children[i]); + delete schema->children[i]; + } + } + delete[] schema->children; + delete[] schema->format; +} + +void ArrowArray_release(struct ArrowArray* array) { + assert (array != nullptr); + assert (array->release != nullptr); + array->release = nullptr; + + uint32_t buffers_freed = 0; + uint32_t current_buffer = 0; + while (buffers_freed < array->n_buffers) { + if (array->buffers[current_buffer]) { + free((void*)array->buffers[current_buffer]); + buffers_freed++; + } + current_buffer++; + assert (current_buffer <= 3); + } + delete[] array->buffers; + + for (int i = 0; i < array->n_children; i++) { + assert (array->children != nullptr); + assert (array->children[i] != nullptr); + array->children[i]->release(array->children[i]); + delete array->children[i]; + } + delete[] array->children; + +} + +int32_t dateAsDayCount(SQLUSMALLINT year, SQLUSMALLINT month, SQLUSMALLINT day) { + // Convert SQL_DATE_STRUCT to Arrow Date32 (days since epoch) + std::tm tm_date = {}; + tm_date.tm_year = year - 1900; // tm_year is years since 1900 + tm_date.tm_mon = month - 1; // tm_mon is 0-11 + tm_date.tm_mday = day; + + std::time_t time_since_epoch = std::mktime(&tm_date); + if (time_since_epoch == -1) { + LOG("Failed to convert SQL_DATE_STRUCT to time_t"); + ThrowStdException("Date conversion error"); + } + // Calculate days since epoch + return time_since_epoch / 86400; +} + +SQLRETURN FetchArrowBatch_wrap( + SqlHandlePtr StatementHandle, + py::list& capsules, + ssize_t arrowBatchSize +) { + ssize_t fetchSize = arrowBatchSize; + SQLRETURN ret; + SQLHSTMT hStmt = StatementHandle->get(); + // Retrieve column count + SQLSMALLINT numCols = SQLNumResultCols_wrap(StatementHandle); + if (numCols <= 0) { + ThrowStdException("No active result set. Cannot fetch Arrow batch."); + } + + // Retrieve column metadata + py::list columnNames; + ret = SQLDescribeCol_wrap(StatementHandle, columnNames); + if (!SQL_SUCCEEDED(ret)) { + LOG("Failed to get column descriptions"); + return ret; + } + + bool hasLobColumns = false; + + std::vector dataTypes(numCols); + std::vector columnSizes(numCols); + std::vector columnNullable(numCols); + std::vector> columnFormats(numCols); + std::vector> columnNamesCStr(numCols); + + ColumnBuffersArrow buffersArrow(numCols); + for (SQLSMALLINT i = 0; i < numCols; i++) { + auto colMeta = columnNames[i].cast(); + SQLSMALLINT dataType = colMeta["DataType"].cast(); + SQLULEN columnSize = colMeta["ColumnSize"].cast(); + SQLSMALLINT nullable = colMeta["Nullable"].cast(); + dataTypes[i] = dataType; + columnSizes[i] = columnSize; + columnNullable[i] = (nullable != SQL_NO_NULLS); + + if ((dataType == SQL_WVARCHAR || dataType == SQL_WLONGVARCHAR || + dataType == SQL_VARCHAR || dataType == SQL_LONGVARCHAR || + dataType == SQL_VARBINARY || dataType == SQL_LONGVARBINARY || dataType == SQL_SS_XML) && + (columnSize == 0 || columnSize == SQL_NO_TOTAL || columnSize > SQL_MAX_LOB_SIZE)) { + hasLobColumns = true; + if (fetchSize > 1) { + fetchSize = 1; // LOBs require row-by-row fetch + } + } + + std::string columnName = colMeta["ColumnName"].cast(); + size_t nameLen = columnName.length() + 1; + columnNamesCStr[i] = std::make_unique(nameLen); + std::memcpy(columnNamesCStr[i].get(), columnName.c_str(), nameLen); + + const char* format = nullptr; + switch(dataType) { + case SQL_CHAR: + case SQL_VARCHAR: + case SQL_LONGVARCHAR: + case SQL_SS_XML: + case SQL_WCHAR: + case SQL_WVARCHAR: + case SQL_WLONGVARCHAR: + case SQL_GUID: + format = "u"; + buffersArrow.var[i] = std::make_unique(arrowBatchSize + 1); + buffersArrow.var_data[i].resize(arrowBatchSize * 42); + // start at offset 0 + buffersArrow.var[i][0] = 0; + break; + case SQL_BINARY: + case SQL_VARBINARY: + case SQL_LONGVARBINARY: + format = "z"; + buffersArrow.var[i] = std::make_unique(arrowBatchSize + 1); + buffersArrow.var_data[i].resize(arrowBatchSize * 42); + // start at offset 0 + buffersArrow.var[i][0] = 0; + break; + case SQL_TINYINT: + format = "C"; + buffersArrow.uint8[i] = std::make_unique(arrowBatchSize); + break; + case SQL_SMALLINT: + format = "s"; + buffersArrow.int16[i] = std::make_unique(arrowBatchSize); + break; + case SQL_INTEGER: + format = "i"; + buffersArrow.int32[i] = std::make_unique(arrowBatchSize); + break; + case SQL_BIGINT: + format = "l"; + buffersArrow.int64[i] = std::make_unique(arrowBatchSize); + break; + case SQL_REAL: + case SQL_FLOAT: + case SQL_DOUBLE: + format = "g"; + buffersArrow.float64[i] = std::make_unique(arrowBatchSize); + break; + case SQL_DECIMAL: + case SQL_NUMERIC: { + std::ostringstream formatStream; + formatStream << "d:" << columnSize << "," << colMeta["DecimalDigits"].cast(); + std::string formatStr = formatStream.str(); + size_t formatLen = formatStr.length() + 1; + columnFormats[i] = std::make_unique(formatLen); + std::memcpy(columnFormats[i].get(), formatStr.c_str(), formatLen); + format = columnFormats[i].get(); + buffersArrow.decimal[i] = std::make_unique<__int128_t[]>(arrowBatchSize); + break; + } + case SQL_TIMESTAMP: + case SQL_TYPE_TIMESTAMP: + case SQL_DATETIME: + format = "tsu:"; + buffersArrow.ts_micro[i] = std::make_unique(arrowBatchSize); + break; + case SQL_SS_TIMESTAMPOFFSET: + format = "tsu:+00:00"; + buffersArrow.ts_micro[i] = std::make_unique(arrowBatchSize); + break; + case SQL_TYPE_DATE: + format = "tdD"; + buffersArrow.date[i] = std::make_unique(arrowBatchSize); + break; + case SQL_TIME: + case SQL_TYPE_TIME: + case SQL_SS_TIME2: + format = "tts"; + buffersArrow.time_second[i] = std::make_unique(arrowBatchSize); + break; + case SQL_BIT: + format = "b"; + buffersArrow.bit[i] = std::make_unique((arrowBatchSize + 7) / 8); + break; + default: + std::wstring columnName = colMeta["ColumnName"].cast(); + std::ostringstream errorString; + errorString << "Unsupported data type for Arrow batch fetch for column - " << columnName.c_str() + << ", Type - " << dataType << ", column ID - " << (i + 1); + LOG(errorString.str().c_str()); + ThrowStdException(errorString.str()); + break; + } + + // Store format string if not already stored (for non-decimal types) + if (!columnFormats[i]) { + size_t formatLen = std::strlen(format) + 1; + columnFormats[i] = std::make_unique(formatLen); + std::memcpy(columnFormats[i].get(), format, formatLen); + } + + buffersArrow.valid[i] = std::make_unique((arrowBatchSize + 7) / 8); + // Initialize validity bitmap to all valid + std::memset(buffersArrow.valid[i].get(), 0xFF, (arrowBatchSize + 7) / 8); + } + + if (fetchSize > 1) { + // An overly large fetch size doesn't seem to help performance + SQLSMALLINT searchStart = 64; + if (arrowBatchSize < 64) { + searchStart = static_cast(arrowBatchSize); + } + for (SQLSMALLINT maybeNewSize = searchStart; maybeNewSize >= 1; maybeNewSize -= 1) { + if (arrowBatchSize % maybeNewSize == 0) { + fetchSize = maybeNewSize; + break; + } + } + } + + // Initialize column buffers + ColumnBuffers buffers(numCols, fetchSize); + + if (!hasLobColumns && fetchSize > 0) { + // Bind columns + ret = SQLBindColums(hStmt, buffers, columnNames, numCols, fetchSize); + if (!SQL_SUCCEEDED(ret)) { + LOG("Error when binding columns"); + return ret; + } + } + + SQLULEN numRowsFetched; + SQLSetStmtAttr_ptr(hStmt, SQL_ATTR_ROW_ARRAY_SIZE, (SQLPOINTER)(intptr_t)fetchSize, 0); + SQLSetStmtAttr_ptr(hStmt, SQL_ATTR_ROWS_FETCHED_PTR, &numRowsFetched, 0); + + + size_t idxRowArrow = 0; + // arrowBatchSize % fetchSize == 0 ensures that any followup (even non-arrow) fetches + // start with a fresh batch + assert(fetchSize == 0 || arrowBatchSize % fetchSize == 0); + assert(fetchSize <= arrowBatchSize); + + while (idxRowArrow < arrowBatchSize) { + ret = SQLFetch_ptr(hStmt); + if (ret == SQL_NO_DATA) { + ret = SQL_SUCCESS; // Normal completion + break; + } + if (!SQL_SUCCEEDED(ret)) { + LOG("Error while fetching rows in batches"); + return ret; + } + // numRowsFetched is the SQL_ATTR_ROWS_FETCHED_PTR attribute. + // It'll be populated by SQLFetch + assert(numRowsFetched + idxRowArrow <= static_cast(arrowBatchSize)); + for (SQLULEN idxRowSql = 0; idxRowSql < numRowsFetched; idxRowSql++) { + for (SQLUSMALLINT col = 1; col <= numCols; col++) { + auto dataType = dataTypes[col - 1]; + auto columnSize = columnSizes[col - 1]; + + if (hasLobColumns) { + assert(idxRowSql == 0 && "GetData only works one row at a time"); + + switch(dataType) { + case SQL_BINARY: + case SQL_VARBINARY: + case SQL_LONGVARBINARY: { + GetDataVar( + hStmt, + col, + SQL_C_BINARY, + buffers.charBuffers[col - 1], + buffers.indicators[col - 1].data() + ); + break; + } + case SQL_CHAR: + case SQL_VARCHAR: + case SQL_LONGVARCHAR: { + GetDataVar( + hStmt, + col, + SQL_C_CHAR, + buffers.charBuffers[col - 1], + buffers.indicators[col - 1].data() + ); + break; + } + case SQL_SS_XML: + case SQL_WCHAR: + case SQL_WVARCHAR: + case SQL_WLONGVARCHAR: { + GetDataVar( + hStmt, + col, + SQL_C_WCHAR, + buffers.wcharBuffers[col - 1], + buffers.indicators[col - 1].data() + ); + break; + } + case SQL_INTEGER: { + buffers.intBuffers[col - 1].resize(1); + SQLGetData_ptr( + hStmt, col, SQL_C_SLONG, + buffers.intBuffers[col - 1].data(), + sizeof(SQLINTEGER), + buffers.indicators[col - 1].data() + ); + break; + } + case SQL_SMALLINT: { + buffers.smallIntBuffers[col - 1].resize(1); + SQLGetData_ptr( + hStmt, col, SQL_C_SSHORT, + buffers.smallIntBuffers[col - 1].data(), + sizeof(SQLSMALLINT), + buffers.indicators[col - 1].data() + ); + break; + } + case SQL_TINYINT: { + buffers.charBuffers[col - 1].resize(1); + SQLGetData_ptr( + hStmt, col, SQL_C_TINYINT, + buffers.charBuffers[col - 1].data(), + sizeof(SQLCHAR), + buffers.indicators[col - 1].data() + ); + break; + } + case SQL_BIT: { + buffers.charBuffers[col - 1].resize(1); + SQLGetData_ptr( + hStmt, col, SQL_C_BIT, + buffers.charBuffers[col - 1].data(), + sizeof(SQLCHAR), + buffers.indicators[col - 1].data() + ); + break; + } + case SQL_REAL: { + buffers.realBuffers[col - 1].resize(1); + SQLGetData_ptr( + hStmt, col, SQL_C_FLOAT, + buffers.realBuffers[col - 1].data(), + sizeof(SQLREAL), + buffers.indicators[col - 1].data() + ); + break; + } + case SQL_DECIMAL: + case SQL_NUMERIC: { + buffers.charBuffers[col - 1].resize(MAX_DIGITS_IN_NUMERIC); + SQLGetData_ptr( + hStmt, col, SQL_C_CHAR, + buffers.charBuffers[col - 1].data(), + MAX_DIGITS_IN_NUMERIC * sizeof(SQLCHAR), + buffers.indicators[col - 1].data() + ); + break; + } + case SQL_DOUBLE: + case SQL_FLOAT: { + buffers.doubleBuffers[col - 1].resize(1); + SQLGetData_ptr( + hStmt, col, SQL_C_DOUBLE, + buffers.doubleBuffers[col - 1].data(), + sizeof(SQLDOUBLE), + buffers.indicators[col - 1].data() + ); + break; + } + case SQL_TIMESTAMP: + case SQL_TYPE_TIMESTAMP: + case SQL_DATETIME: { + buffers.timestampBuffers[col - 1].resize(1); + SQLGetData_ptr( + hStmt, col, SQL_C_TYPE_TIMESTAMP, + buffers.timestampBuffers[col - 1].data(), + sizeof(SQL_TIMESTAMP_STRUCT), + buffers.indicators[col - 1].data() + ); + break; + } + case SQL_BIGINT: { + buffers.bigIntBuffers[col - 1].resize(1); + SQLGetData_ptr( + hStmt, col, SQL_C_SBIGINT, + buffers.bigIntBuffers[col - 1].data(), + sizeof(SQLBIGINT), + buffers.indicators[col - 1].data() + ); + break; + } + case SQL_TYPE_DATE: { + buffers.dateBuffers[col - 1].resize(1); + SQLGetData_ptr( + hStmt, col, SQL_C_TYPE_DATE, + buffers.dateBuffers[col - 1].data(), + sizeof(SQL_DATE_STRUCT), + buffers.indicators[col - 1].data() + ); + break; + } + case SQL_TIME: + case SQL_TYPE_TIME: + case SQL_SS_TIME2: { + buffers.timeBuffers[col - 1].resize(1); + SQLGetData_ptr( + hStmt, col, SQL_C_TYPE_TIME, + buffers.timeBuffers[col - 1].data(), + sizeof(SQL_TIME_STRUCT), + buffers.indicators[col - 1].data() + ); + break; + } + case SQL_GUID: { + buffers.guidBuffers[col - 1].resize(1); + SQLGetData_ptr( + hStmt, col, SQL_C_GUID, + buffers.guidBuffers[col - 1].data(), + sizeof(SQLGUID), + buffers.indicators[col - 1].data() + ); + break; + } + case SQL_SS_TIMESTAMPOFFSET: { + buffers.datetimeoffsetBuffers[col - 1].resize(1); + SQLGetData_ptr( + hStmt, col, SQL_C_SS_TIMESTAMPOFFSET, + buffers.datetimeoffsetBuffers[col - 1].data(), + sizeof(DateTimeOffset), + buffers.indicators[col - 1].data() + ); + break; + } + default: { + std::ostringstream errorString; + errorString << "Unsupported data type for column ID - " << col + << ", Type - " << dataType; + LOG("SQLGetData: %s", errorString.str().c_str()); + ThrowStdException(errorString.str()); + break; + } + } + } + + SQLLEN dataLen = buffers.indicators[col - 1][idxRowSql]; + + if (dataLen == SQL_NULL_DATA) { + // Mark as null in validity bitmap + size_t bytePos = idxRowArrow / 8; + size_t bitPos = idxRowArrow % 8; + buffersArrow.valid[col - 1][bytePos] &= ~(1 << bitPos); + + // Value buffer for variable length data types needs to be set appropriately + // as it will be used by the next non null value + switch (dataType) + { + case SQL_CHAR: + case SQL_VARCHAR: + case SQL_LONGVARCHAR: + case SQL_SS_XML: + case SQL_WCHAR: + case SQL_WVARCHAR: + case SQL_WLONGVARCHAR: + case SQL_GUID: + case SQL_BINARY: + case SQL_VARBINARY: + case SQL_LONGVARBINARY: + buffersArrow.var[col - 1][idxRowArrow + 1] = buffersArrow.var[col - 1][idxRowArrow]; + break; + default: + break; + } + continue; + } else if (dataLen < 0) { + // Negative value is unexpected, log column index, SQL type & raise exception + LOG("Unexpected negative data length. Column ID - {}, SQL Type - {}, Data Length - {}", col, dataType, dataLen); + ThrowStdException("Unexpected negative data length."); + } + + switch (dataType) { + case SQL_BINARY: + case SQL_VARBINARY: + case SQL_LONGVARBINARY: { + uint64_t fetchBufferSize = columnSize /* bytes are not null terminated */; + auto target_vec = &buffersArrow.var_data[col - 1]; + auto start = buffersArrow.var[col - 1][idxRowArrow]; + while (target_vec->size() < start + dataLen) { + target_vec->resize(target_vec->size() * 2); + } + + std::memcpy(&(*target_vec)[start], &buffers.charBuffers[col - 1][idxRowSql * fetchBufferSize], dataLen); + buffersArrow.var[col - 1][idxRowArrow + 1] = start + dataLen; + break; + } + case SQL_CHAR: + case SQL_VARCHAR: + case SQL_LONGVARCHAR: { + uint64_t fetchBufferSize = columnSize + 1 /* null-termination */; + auto target_vec = &buffersArrow.var_data[col - 1]; + auto start = buffersArrow.var[col - 1][idxRowArrow]; + while (target_vec->size() < start + dataLen) { + target_vec->resize(target_vec->size() * 2); + } + + std::memcpy(&(*target_vec)[start], &buffers.charBuffers[col - 1][idxRowSql * fetchBufferSize], dataLen); + buffersArrow.var[col - 1][idxRowArrow + 1] = start + dataLen; + break; + } + case SQL_SS_XML: + case SQL_WCHAR: + case SQL_WVARCHAR: + case SQL_WLONGVARCHAR: { + assert(dataLen % sizeof(SQLWCHAR) == 0); + auto dataLenW = dataLen / sizeof(SQLWCHAR); + auto wcharSource = &buffers.wcharBuffers[col - 1][idxRowSql * (columnSize + 1)]; + auto start = buffersArrow.var[col - 1][idxRowArrow]; + auto target_vec = &buffersArrow.var_data[col - 1]; +#if defined(_WIN32) + // Convert wide string + int dataLenConverted = WideCharToMultiByte(CP_UTF8, 0, wcharSource, dataLenW, NULL, 0, NULL, NULL); + while (target_vec->size() < start + dataLenConverted) { + target_vec->resize(target_vec->size() * 2); + } + WideCharToMultiByte(CP_UTF8, 0, wcharSource, dataLenW, &(*target_vec)[start], dataLenConverted, NULL, NULL); + buffersArrow.var[col - 1][idxRowArrow + 1] = start + dataLenConverted; +#else + // On Unix, use the SQLWCHARToWString utility and then convert to UTF-8 + std::string utf8str = WideToUTF8(SQLWCHARToWString(wcharSource, dataLenW)); + std::memcpy(&(*target_vec)[start], utf8str.data(), utf8str.size()); + buffersArrow.var[col - 1][idxRowArrow + 1] = start + utf8str.size(); +#endif + break; + } + case SQL_GUID: { + // GUID is stored as a 36-character string in Arrow (e.g., "550e8400-e29b-41d4-a716-446655440000") + // Each GUID is exactly 36 bytes in UTF-8 + auto target_vec = &buffersArrow.var_data[col - 1]; + auto start = buffersArrow.var[col - 1][idxRowArrow]; + + // Ensure buffer has space for the GUID string + null terminator + while (target_vec->size() < start + 37) { + target_vec->resize(target_vec->size() * 2); + } + + // Get the GUID from the buffer + const SQLGUID& guidValue = buffers.guidBuffers[col - 1][idxRowSql]; + + // Convert GUID to string format: xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx + snprintf(reinterpret_cast(&target_vec->data()[start]), 37, + "%08x-%04x-%04x-%02x%02x-%02x%02x%02x%02x%02x%02x", + guidValue.Data1, + guidValue.Data2, + guidValue.Data3, + guidValue.Data4[0], guidValue.Data4[1], + guidValue.Data4[2], guidValue.Data4[3], + guidValue.Data4[4], guidValue.Data4[5], + guidValue.Data4[6], guidValue.Data4[7]); + + // Update offset for next row, ignoring null terminator + buffersArrow.var[col - 1][idxRowArrow + 1] = start + 36; + break; + } + case SQL_TINYINT: + buffersArrow.uint8[col - 1][idxRowArrow] = buffers.charBuffers[col - 1][idxRowSql]; + break; + case SQL_SMALLINT: + buffersArrow.int16[col - 1][idxRowArrow] = buffers.smallIntBuffers[col - 1][idxRowSql]; + break; + case SQL_INTEGER: + buffersArrow.int32[col - 1][idxRowArrow] = buffers.intBuffers[col - 1][idxRowSql]; + break; + case SQL_BIGINT: + buffersArrow.int64[col - 1][idxRowArrow] = buffers.bigIntBuffers[col - 1][idxRowSql]; + break; + case SQL_REAL: + case SQL_FLOAT: + case SQL_DOUBLE: + buffersArrow.float64[col - 1][idxRowArrow] = buffers.doubleBuffers[col - 1][idxRowSql]; + break; + case SQL_DECIMAL: + case SQL_NUMERIC: { + assert(dataLen <= MAX_DIGITS_IN_NUMERIC); + __int128_t decimalValue = 0; + auto start = idxRowSql * MAX_DIGITS_IN_NUMERIC; + int sign = 1; + for (SQLULEN idx = start; idx < start + dataLen; idx++) { + char digitChar = buffers.charBuffers[col - 1][idx]; + if (digitChar == '-') { + sign = -1; + } else if (digitChar >= '0' && digitChar <= '9') { + decimalValue = decimalValue * 10 + (digitChar - '0'); + } + } + buffersArrow.decimal[col - 1][idxRowArrow] = decimalValue * sign; + break; + } + case SQL_TIMESTAMP: + case SQL_TYPE_TIMESTAMP: + case SQL_DATETIME: { + SQL_TIMESTAMP_STRUCT sql_value = buffers.timestampBuffers[col - 1][idxRowSql]; + int64_t days = dateAsDayCount( + sql_value.year, + sql_value.month, + sql_value.day + ); + buffersArrow.ts_micro[col - 1][idxRowArrow] = + days * 86400 * 1000000 + + static_cast(sql_value.hour) * 3600 * 1000000 + + static_cast(sql_value.minute) * 60 * 1000000 + + static_cast(sql_value.second) * 1000000 + + static_cast(sql_value.fraction) / 1000; + break; + } + case SQL_SS_TIMESTAMPOFFSET: { + DateTimeOffset sql_value = buffers.datetimeoffsetBuffers[col - 1][idxRowSql]; + int64_t days = dateAsDayCount( + sql_value.year, + sql_value.month, + sql_value.day + ); + buffersArrow.ts_micro[col - 1][idxRowArrow] = + days * 86400 * 1000000 + + (static_cast(sql_value.hour) - static_cast(sql_value.timezone_hour)) * 3600 * 1000000 + + (static_cast(sql_value.minute) - static_cast(sql_value.timezone_minute)) * 60 * 1000000 + + static_cast(sql_value.second) * 1000000 + + static_cast(sql_value.fraction) / 1000; + break; + } + case SQL_TYPE_DATE: + buffersArrow.date[col - 1][idxRowArrow] = dateAsDayCount( + buffers.dateBuffers[col - 1][idxRowSql].year, + buffers.dateBuffers[col - 1][idxRowSql].month, + buffers.dateBuffers[col - 1][idxRowSql].day + ); + break; + case SQL_TIME: + case SQL_TYPE_TIME: + case SQL_SS_TIME2: { + // NOTE: SQL_SS_TIME2 supports fractional seconds, but SQL_C_TYPE_TIME does not. + // To fully support SQL_SS_TIME2, the corresponding c-type should be used. + const SQL_TIME_STRUCT& timeValue = buffers.timeBuffers[col - 1][idxRowSql]; + buffersArrow.time_second[col - 1][idxRowArrow] = + static_cast(timeValue.hour) * 3600 + + static_cast(timeValue.minute) * 60 + + static_cast(timeValue.second); + break; + } + case SQL_BIT: { + // SQL_BIT is stored as a single bit in Arrow's bitmap format + // Get the boolean value from the buffer + bool bitValue = buffers.charBuffers[col - 1][idxRowSql] != 0; + + // Set the bit in the Arrow bitmap + size_t byteIndex = idxRowArrow / 8; + size_t bitIndex = idxRowArrow % 8; + + if (bitValue) { + // Set bit to 1 + buffersArrow.bit[col - 1][byteIndex] |= (1 << bitIndex); + } else { + // Clear bit to 0 + buffersArrow.bit[col - 1][byteIndex] &= ~(1 << bitIndex); + } + break; + } + default: { + std::ostringstream errorString; + errorString << "Unsupported data type for column ID - " << col + << ", Type - " << dataType; + LOG(errorString.str().c_str()); + ThrowStdException(errorString.str()); + break; + } + } + } + idxRowArrow++; + } + } + + // Reset attributes before returning to avoid using stack pointers later + SQLSetStmtAttr_ptr(hStmt, SQL_ATTR_ROW_ARRAY_SIZE, (SQLPOINTER)1, 0); + SQLSetStmtAttr_ptr(hStmt, SQL_ATTR_ROWS_FETCHED_PTR, NULL, 0); + + // Transfer ownerhip of buffers to Arrow structures + // Exceptions beyond this point would cause memory leaks + auto batch_children = new ArrowSchema* [numCols]; + for (SQLSMALLINT i = 0; i < numCols; i++) { + auto arrow_schema = new ArrowSchema({ + .format = columnFormats[i].release(), + .name = columnNamesCStr[i].release(), + .flags = columnNullable[i] ? 2 : 0, // ARROW_FLAG_NULLABLE + .release = ArrowSchema_release, + }); + batch_children[i] = arrow_schema; + } + + auto arrow_schema_batch = new ArrowSchema({ + .format = strdup("+s"), + .name = strdup(""), + .n_children = numCols, + .children = batch_children, + .release = ArrowSchema_release, + }); + auto caps = py::capsule((void*)arrow_schema_batch, "arrow_schema", [](void* ptr) { + auto arrow_schema = static_cast(ptr); + if (arrow_schema->release) { + arrow_schema->release(arrow_schema); + } + delete arrow_schema; + }); + capsules.append(caps); + + auto arrow_array_batch_buffers = new const void* [3]; + memset(arrow_array_batch_buffers, 0, sizeof(const void*) * 3); + auto arrow_array_batch = new ArrowArray({ + .length = static_cast(idxRowArrow), + .n_buffers = 1, + .n_children = numCols, + .buffers = arrow_array_batch_buffers, + .children = new ArrowArray* [numCols], + .release = ArrowArray_release, + }); + // Necessary dummy buffer + arrow_array_batch->buffers[1] = new int[1]; + + for (SQLUSMALLINT col = 0; col < numCols; col++) { + auto dataType = dataTypes[col]; + auto arrow_array_col_buffers = new const void* [3]; + memset(arrow_array_col_buffers, 0, sizeof(const void*) * 3); + // Allocate new memory and copy the data + switch (dataType) { + case SQL_CHAR: + case SQL_VARCHAR: + case SQL_LONGVARCHAR: + case SQL_SS_XML: + case SQL_WCHAR: + case SQL_WVARCHAR: + case SQL_WLONGVARCHAR: + case SQL_GUID: + case SQL_BINARY: + case SQL_VARBINARY: + case SQL_LONGVARBINARY: { + assert(buffersArrow.var[col][0] == 0); + // length of string at index i is the difference between values at i and i+1 + // so total length is value at index idxRowArrow + auto data_buf_len_total = buffersArrow.var[col][idxRowArrow]; + uint8_t* dataBuffer = new uint8_t[data_buf_len_total]; + std::memcpy(dataBuffer, buffersArrow.var_data[col].data(), data_buf_len_total); + arrow_array_col_buffers[2] = dataBuffer; + arrow_array_col_buffers[1] = buffersArrow.var[col].release(); + } + break; + case SQL_TINYINT: + arrow_array_col_buffers[1] = buffersArrow.uint8[col].release(); + break; + case SQL_SMALLINT: + arrow_array_col_buffers[1] = buffersArrow.int16[col].release(); + break; + case SQL_INTEGER: + arrow_array_col_buffers[1] = buffersArrow.int32[col].release(); + break; + case SQL_BIGINT: + arrow_array_col_buffers[1] = buffersArrow.int64[col].release(); + break; + case SQL_REAL: + case SQL_FLOAT: + case SQL_DOUBLE: + arrow_array_col_buffers[1] = buffersArrow.float64[col].release(); + break; + case SQL_DECIMAL: + case SQL_NUMERIC: { + arrow_array_col_buffers[1] = buffersArrow.decimal[col].release(); + break; + } + case SQL_TIMESTAMP: + case SQL_TYPE_TIMESTAMP: + case SQL_DATETIME: + arrow_array_col_buffers[1] = buffersArrow.ts_micro[col].release(); + break; + case SQL_SS_TIMESTAMPOFFSET: + arrow_array_col_buffers[1] = buffersArrow.ts_micro[col].release(); + break; + case SQL_TYPE_DATE: + arrow_array_col_buffers[1] = buffersArrow.date[col].release(); + break; + case SQL_TIME: + case SQL_TYPE_TIME: + case SQL_SS_TIME2: + arrow_array_col_buffers[1] = buffersArrow.time_second[col].release(); + break; + case SQL_BIT: + arrow_array_col_buffers[1] = buffersArrow.bit[col].release(); + break; + default: { + std::ostringstream errorString; + errorString << "Unsupported data type for column ID - " << (col + 1) + << ", Type - " << dataType; + LOG(errorString.str().c_str()); + ThrowStdException(errorString.str()); + break; + } + } + + auto arrow_array_col = new ArrowArray({ + .length = static_cast(idxRowArrow), + .null_count = 0, + .offset = 0, + .n_buffers = arrow_array_col_buffers[2] ? 3 : 2, + .n_children = 0, + .buffers = arrow_array_col_buffers, + .children = nullptr, + .release = ArrowArray_release, + }); + + arrow_array_col->buffers[0] = buffersArrow.valid[col].release(); + arrow_array_batch->children[col] = arrow_array_col; + } + + capsules.append(py::capsule((void*)arrow_array_batch, "arrow_array", [](void* ptr) { + auto arrow_array = static_cast(ptr); + if (arrow_array->release) { + arrow_array->release(arrow_array); + } + delete arrow_array; + })); + + return ret; +} + + // FetchAll_wrap - Fetches all rows of data from the result set. // // @param StatementHandle: Handle to the statement from which data is to be @@ -4406,6 +5433,7 @@ PYBIND11_MODULE(ddbc_bindings, m) { m.def("DDBCSQLFetchAll", &FetchAll_wrap, "Fetch all rows from the result set", py::arg("StatementHandle"), py::arg("rows"), py::arg("charEncoding") = "utf-8", py::arg("wcharEncoding") = "utf-16le"); + m.def("DDBCSQLFetchArrowBatch", &FetchArrowBatch_wrap, "Fetch an arrow batch of given length from the result set"); m.def("DDBCSQLFreeHandle", &SQLFreeHandle_wrap, "Free a handle"); m.def("DDBCSQLCheckError", &SQLCheckError_Wrap, "Check for driver errors"); m.def("DDBCSQLGetAllDiagRecords", &SQLGetAllDiagRecords, diff --git a/requirements.txt b/requirements.txt index 0951f7d0..4cd60771 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,6 +4,7 @@ pytest-cov coverage unittest-xml-reporting psutil +pyarrow # Build dependencies pybind11 diff --git a/tests/test_004_cursor.py b/tests/test_004_cursor.py index a54cffda..c2058c63 100644 --- a/tests/test_004_cursor.py +++ b/tests/test_004_cursor.py @@ -18,6 +18,11 @@ import re from conftest import is_azure_sql_connection +try: + import pyarrow as pa +except ImportError: + pa = None + # Setup test table TEST_TABLE = """ @@ -15018,3 +15023,244 @@ def test_close(db_connection): pytest.fail(f"Cursor close test failed: {e}") finally: cursor = db_connection.cursor() + + +def get_arrow_test_data(include_lobs: bool, batch_length: int): + arrow_test_data = [ + (pa.uint8(), "tinyint", [1, 2, None, 4, 5, 0, 2**8 - 1]), + (pa.int16(), "smallint", [1, 2, None, 4, 5, -(2**15), 2**15 - 1]), + (pa.int32(), "int", [1, 2, None, 4, 5, 0, -(2**31), 2**31 - 1]), + (pa.int64(), "bigint", [1, 2, None, 4, 5, 0, -(2**63), 2**63 - 1]), + (pa.float64(), "float", [1.0, 2.5, None, 4.25, 5.125]), + ( + pa.decimal128(precision=10, scale=2), + "decimal(10, 2)", + [ + decimal.Decimal("1.23"), + None, + decimal.Decimal("0.25"), + decimal.Decimal("-99999999.99"), + decimal.Decimal("99999999.99"), + ], + ), + ( + pa.decimal128(precision=38, scale=10), + "decimal(38, 10)", + [ + decimal.Decimal("1.1234567890"), + None, + decimal.Decimal("0"), + decimal.Decimal("1.0000000001"), + decimal.Decimal("-9999999999999999999999999999.9999999999"), + decimal.Decimal("9999999999999999999999999999.9999999999"), + ], + ), + (pa.bool_(), "bit", [True, None, False]), + (pa.binary(), "binary(9)", [b"asdfghjkl", None, b"lkjhgfdsa"]), + (pa.string(), "varchar(100)", ["asdfghjkl", None, "lkjhgfdsa"]), + (pa.string(), "nvarchar(100)", ["asdfghjkl", None, "lkjhgfdsa"]), + (pa.date32(), "date", [date(1, 1, 1), None, date(2345, 12, 31), date(9999, 12, 31)]), + ( + pa.time32("s"), + "time(0)", + [time(12, 0, 5, 0), None, time(23, 59, 59, 0), time(0, 0, 0, 0)], + ), + ( + pa.time32("s"), + "time(7)", + [time(12, 0, 5, 0), None, time(23, 59, 59, 0), time(0, 0, 0, 0)], + ), + ( + pa.timestamp("us"), + "datetime2(0)", + [datetime(2025, 1, 1, 12, 0, 5, 0), None, datetime(2345, 12, 31, 23, 59, 59, 0)], + ), + ( + pa.timestamp("us"), + "datetime2(3)", + [datetime(2025, 1, 1, 12, 0, 5, 123_000), None, datetime(2345, 12, 31, 23, 59, 59, 0)], + ), + ( + pa.timestamp("us"), + "datetime2(6)", + [datetime(2025, 1, 1, 12, 0, 5, 123_456), None, datetime(2345, 12, 31, 23, 59, 59, 0)], + ), + ( + pa.timestamp("us"), + "datetime2(7)", + [datetime(2025, 1, 1, 12, 0, 5, 123_456), None, datetime(2145, 12, 31, 23, 59, 59, 0)], + ), + ( + pa.timestamp("us"), + "datetime2(2)", + [datetime(2025, 1, 1, 12, 0, 5, 0), None, datetime(2145, 12, 31, 23, 59, 59, 0)], + ), + ] + + if include_lobs: + arrow_test_data += [ + (pa.string(), "nvarchar(max)", ["hey", None, "ho"]), + (pa.string(), "varchar(max)", ["hey", None, "ho"]), + (pa.binary(), "varbinary(max)", [b"hey", None, b"ho"]), + ] + + for ix in range(len(arrow_test_data)): + while True: + T, sql_type, vals = arrow_test_data[ix] + if len(vals) >= batch_length: + arrow_test_data[ix] = (T, sql_type, vals[:batch_length]) + break + arrow_test_data[ix] = (T, sql_type, vals + vals) + + return arrow_test_data + + +def _test_arrow_test_data(cursor: mssql_python.Cursor, arrow_test_data, fetch_length=500): + cols = [] + for i_col, (pa_type, sql_type, values) in enumerate(arrow_test_data): + rows = [] + for value in values: + if type(value) is bool: + value = int(value) + if type(value) is bytes: + value = value.decode() + if value is None: + value = "null" + else: + value = f"'{value}'" + rows.append(f"col_{i_col} = cast({value} as {sql_type})") + cols.append(rows) + + selects = [] + for row in zip(*cols): + selects.append(f"select {', '.join(col for col in row)}") + full_query = "\nunion all\n".join(selects) + ret = cursor.execute(full_query).arrow_batch(fetch_length) + for i_col, col in enumerate(ret): + for i_row, (v_expected, v_actual) in enumerate( + zip(arrow_test_data[i_col][2][:fetch_length], col.to_pylist(), strict=True) + ): + assert ( + v_expected == v_actual + ), f"Mismatch in column {i_col}, row {i_row}: expected {v_expected}, got {v_actual}" + for i_col, (pa_type, sql_type, values) in enumerate(arrow_test_data): + field = ret.schema.field(i_col) + assert ( + field.name == f"col_{i_col}" + ), f"Column {i_col} name mismatch: expected col_{i_col}, got {field.name}" + assert field.type.equals( + pa_type + ), f"Column {i_col} type mismatch: expected {pa_type}, got {field.type}" + + +@pytest.mark.skipif(pa is None, reason="pyarrow is not installed") +def test_arrow_lob_wide(cursor: mssql_python.Cursor): + "Take the SQLGetData branch for a wide table." + arrow_test_data = get_arrow_test_data(include_lobs=True, batch_length=123) + _test_arrow_test_data(cursor, arrow_test_data) + + +@pytest.mark.skipif(pa is None, reason="pyarrow is not installed") +def test_arrow_nolob_wide(cursor: mssql_python.Cursor): + "Test the SQLBindData branch for a wide table." + arrow_test_data = get_arrow_test_data(include_lobs=False, batch_length=123) + _test_arrow_test_data(cursor, arrow_test_data) + + +@pytest.mark.skipif(pa is None, reason="pyarrow is not installed") +def test_arrow_single_column(cursor: mssql_python.Cursor): + "Test each datatype as a single column fetch." + arrow_test_data = get_arrow_test_data(include_lobs=True, batch_length=123) + for col_data in arrow_test_data: + _test_arrow_test_data(cursor, [col_data]) + + +@pytest.mark.skipif(pa is None, reason="pyarrow is not installed") +def test_arrow_empty_fetch(cursor: mssql_python.Cursor): + "Test each datatype as a single column fetch of length 0." + arrow_test_data = get_arrow_test_data(include_lobs=True, batch_length=123) + for col_data in arrow_test_data: + _test_arrow_test_data(cursor, [col_data], fetch_length=0) + + +@pytest.mark.skipif(pa is None, reason="pyarrow is not installed") +def test_arrow_table_batchsize_negative(cursor: mssql_python.Cursor): + tbl = cursor.execute("select 1 a").arrow(batch_size=-42) + assert type(tbl) is pa.Table + assert tbl.num_rows == 0 + assert tbl.num_columns == 1 + assert cursor.fetchone()[0] == 1 + + +@pytest.mark.skipif(pa is None, reason="pyarrow is not installed") +def test_arrow_empty_result_set(cursor: mssql_python.Cursor): + "Test fetching from an empty result set." + cursor.execute("select 1 where 1 = 0") + batch = cursor.arrow_batch(10) + assert batch.num_rows == 0 + assert batch.num_columns == 1 + cursor.execute("select cast(N'' as nvarchar(max)) where 1 = 0") + batch = cursor.arrow_batch(10) + assert batch.num_rows == 0 + assert batch.num_columns == 1 + cursor.execute("select 1, cast(N'' as nvarchar(max)) where 1 = 0") + batch = cursor.arrow_batch(10) + assert batch.num_rows == 0 + assert batch.num_columns == 2 + + +@pytest.mark.skipif(pa is None, reason="pyarrow is not installed") +def test_arrow_no_result_set(cursor: mssql_python.Cursor): + "Test fetching when there is no result set." + cursor.execute("declare @a int") + with pytest.raises(Exception, match=".*No active result set.*"): + cursor.arrow_batch(10) + + +@pytest.mark.skipif(pa is None, reason="pyarrow is not installed") +def test_arrow_datetimeoffset(cursor: mssql_python.Cursor): + "Datetimeoffset converts correctly to utc" + cursor.execute( + "declare @dt datetimeoffset(0) = '2345-02-03 12:34:56 +00:00';\n" + "select @dt, @dt at time zone 'Pacific Standard Time';\n" + ) + batch = cursor.arrow_batch(10) + assert batch.num_rows == 1 + assert batch.num_columns == 2 + for col in batch.columns: + assert pa.types.is_timestamp(col.type) + assert col.type.tz == "+00:00", col.type.tz + assert col.to_pylist() == [ + datetime(2345, 2, 3, 12, 34, 56, tzinfo=timezone.utc), + ] + + +@pytest.mark.skipif(pa is None, reason="pyarrow is not installed") +def test_arrow_schema_nullable(cursor: mssql_python.Cursor): + "Test that the schema is nullable." + cursor.execute("select 1 a, null b") + batch = cursor.arrow_batch(10) + assert batch.num_rows == 1 + assert batch.num_columns == 2 + assert not batch.schema.field(0).nullable + assert batch.schema.field(1).nullable + assert batch.schema.field(0).name == "a" + assert batch.schema.field(1).name == "b" + + +@pytest.mark.skipif(pa is None, reason="pyarrow is not installed") +def test_arrow_table(cursor: mssql_python.Cursor): + tbl = cursor.execute("select top 11 1 a from sys.objects").arrow(batch_size=5) + assert type(tbl) is pa.Table + assert tbl.num_rows == 11 + assert tbl.num_columns == 1 + assert [len(b) for b in tbl.to_batches()] == [5, 5, 1] + + +@pytest.mark.skipif(pa is None, reason="pyarrow is not installed") +def test_arrow_reader(cursor: mssql_python.Cursor): + reader = cursor.execute("select top 11 1 a from sys.objects").arrow_reader(batch_size=4) + assert type(reader) is pa.RecordBatchReader + batches = list(reader) + assert [len(b) for b in batches] == [4, 4, 3] + assert sum(len(b) for b in batches) == 11 From 68482fb22adb21e8ce8f2530dce5045ceaae2fe4 Mon Sep 17 00:00:00 2001 From: ffelixg <142172984+ffelixg@users.noreply.github.com> Date: Sun, 30 Nov 2025 22:32:04 +0100 Subject: [PATCH 02/15] Copilot suggestion: Fix typo Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- mssql_python/pybind/ddbc_bindings.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mssql_python/pybind/ddbc_bindings.cpp b/mssql_python/pybind/ddbc_bindings.cpp index a9c385d1..6771c82d 100644 --- a/mssql_python/pybind/ddbc_bindings.cpp +++ b/mssql_python/pybind/ddbc_bindings.cpp @@ -4967,7 +4967,7 @@ SQLRETURN FetchArrowBatch_wrap( SQLSetStmtAttr_ptr(hStmt, SQL_ATTR_ROW_ARRAY_SIZE, (SQLPOINTER)1, 0); SQLSetStmtAttr_ptr(hStmt, SQL_ATTR_ROWS_FETCHED_PTR, NULL, 0); - // Transfer ownerhip of buffers to Arrow structures + // Transfer ownership of buffers to Arrow structures // Exceptions beyond this point would cause memory leaks auto batch_children = new ArrowSchema* [numCols]; for (SQLSMALLINT i = 0; i < numCols; i++) { From 36455784b226d669576a5ef739d6fe236147a11e Mon Sep 17 00:00:00 2001 From: ffelixg <142172984+ffelixg@users.noreply.github.com> Date: Sun, 30 Nov 2025 22:34:37 +0100 Subject: [PATCH 03/15] Copilot suggestion: Fix missing buffer resize Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- mssql_python/pybind/ddbc_bindings.cpp | 3 +++ 1 file changed, 3 insertions(+) diff --git a/mssql_python/pybind/ddbc_bindings.cpp b/mssql_python/pybind/ddbc_bindings.cpp index 6771c82d..b6b28b5f 100644 --- a/mssql_python/pybind/ddbc_bindings.cpp +++ b/mssql_python/pybind/ddbc_bindings.cpp @@ -4812,6 +4812,9 @@ SQLRETURN FetchArrowBatch_wrap( #else // On Unix, use the SQLWCHARToWString utility and then convert to UTF-8 std::string utf8str = WideToUTF8(SQLWCHARToWString(wcharSource, dataLenW)); + while (target_vec->size() < start + utf8str.size()) { + target_vec->resize(target_vec->size() * 2); + } std::memcpy(&(*target_vec)[start], utf8str.data(), utf8str.size()); buffersArrow.var[col - 1][idxRowArrow + 1] = start + utf8str.size(); #endif From 9c8c3e8152723818e1a5f552f2c502d87db41856 Mon Sep 17 00:00:00 2001 From: ffelixg <142172984+ffelixg@users.noreply.github.com> Date: Sun, 30 Nov 2025 22:35:57 +0100 Subject: [PATCH 04/15] Copilot suggestion: Initialize bool value buffer Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- mssql_python/pybind/ddbc_bindings.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/mssql_python/pybind/ddbc_bindings.cpp b/mssql_python/pybind/ddbc_bindings.cpp index b6b28b5f..7c56406c 100644 --- a/mssql_python/pybind/ddbc_bindings.cpp +++ b/mssql_python/pybind/ddbc_bindings.cpp @@ -4462,6 +4462,7 @@ SQLRETURN FetchArrowBatch_wrap( case SQL_BIT: format = "b"; buffersArrow.bit[i] = std::make_unique((arrowBatchSize + 7) / 8); + std::memset(buffersArrow.bit[i].get(), 0, (arrowBatchSize + 7) / 8); break; default: std::wstring columnName = colMeta["ColumnName"].cast(); From 5267a339b731785a5a4d1cd5cb4c2c34328509d9 Mon Sep 17 00:00:00 2001 From: ffelixg <142172984+ffelixg@users.noreply.github.com> Date: Sun, 30 Nov 2025 22:39:48 +0100 Subject: [PATCH 05/15] Add test for long data --- tests/test_004_cursor.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/tests/test_004_cursor.py b/tests/test_004_cursor.py index c2058c63..deece655 100644 --- a/tests/test_004_cursor.py +++ b/tests/test_004_cursor.py @@ -15264,3 +15264,14 @@ def test_arrow_reader(cursor: mssql_python.Cursor): batches = list(reader) assert [len(b) for b in batches] == [4, 4, 3] assert sum(len(b) for b in batches) == 11 + + +@pytest.mark.skipif(pa is None, reason="pyarrow is not installed") +def test_arrow_long_string(cursor: mssql_python.Cursor): + "Make sure resizing the data buffer works" + long_string = "A" * 100000 # 100k characters + cursor.execute("select cast(? as nvarchar(max))", (long_string,)) + batch = cursor.arrow_batch(10) + assert batch.num_rows == 1 + assert batch.num_columns == 1 + assert batch.column(0).to_pylist() == [long_string] From b81f245fb2edaffb676cec7b15ad17c2a79a2276 Mon Sep 17 00:00:00 2001 From: ffelixg <142172984+ffelixg@users.noreply.github.com> Date: Sun, 30 Nov 2025 22:45:32 +0100 Subject: [PATCH 06/15] Copilot suggestion: Uppercase uuids --- mssql_python/pybind/ddbc_bindings.cpp | 2 +- tests/test_004_cursor.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/mssql_python/pybind/ddbc_bindings.cpp b/mssql_python/pybind/ddbc_bindings.cpp index 7c56406c..26766b5f 100644 --- a/mssql_python/pybind/ddbc_bindings.cpp +++ b/mssql_python/pybind/ddbc_bindings.cpp @@ -4837,7 +4837,7 @@ SQLRETURN FetchArrowBatch_wrap( // Convert GUID to string format: xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx snprintf(reinterpret_cast(&target_vec->data()[start]), 37, - "%08x-%04x-%04x-%02x%02x-%02x%02x%02x%02x%02x%02x", + "%08X-%04X-%04X-%02X%02X-%02X%02X%02X%02X%02X%02X", guidValue.Data1, guidValue.Data2, guidValue.Data3, diff --git a/tests/test_004_cursor.py b/tests/test_004_cursor.py index deece655..f8a5daa9 100644 --- a/tests/test_004_cursor.py +++ b/tests/test_004_cursor.py @@ -15059,6 +15059,7 @@ def get_arrow_test_data(include_lobs: bool, batch_length: int): (pa.binary(), "binary(9)", [b"asdfghjkl", None, b"lkjhgfdsa"]), (pa.string(), "varchar(100)", ["asdfghjkl", None, "lkjhgfdsa"]), (pa.string(), "nvarchar(100)", ["asdfghjkl", None, "lkjhgfdsa"]), + (pa.string(), "uniqueidentifier", ["58185E0D-3A91-44D8-BC46-7107217E0A6D", None]), (pa.date32(), "date", [date(1, 1, 1), None, date(2345, 12, 31), date(9999, 12, 31)]), ( pa.time32("s"), From 532672c15d6337a26574961f33772dabaf75a5ca Mon Sep 17 00:00:00 2001 From: ffelixg <142172984+ffelixg@users.noreply.github.com> Date: Sun, 30 Nov 2025 22:52:35 +0100 Subject: [PATCH 07/15] Copilot suggestion: use new for batch schema format/name Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- mssql_python/pybind/ddbc_bindings.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mssql_python/pybind/ddbc_bindings.cpp b/mssql_python/pybind/ddbc_bindings.cpp index 26766b5f..f1002afb 100644 --- a/mssql_python/pybind/ddbc_bindings.cpp +++ b/mssql_python/pybind/ddbc_bindings.cpp @@ -4985,8 +4985,8 @@ SQLRETURN FetchArrowBatch_wrap( } auto arrow_schema_batch = new ArrowSchema({ - .format = strdup("+s"), - .name = strdup(""), + .format = []{ char* f = new char[3]; std::strcpy(f, "+s"); return f; }(), + .name = []{ char* n = new char[1]; n[0] = '\0'; return n; }(), .n_children = numCols, .children = batch_children, .release = ArrowSchema_release, From 590fdf68747a27425a41d471fc46d1f22b85cb14 Mon Sep 17 00:00:00 2001 From: ffelixg <142172984+ffelixg@users.noreply.github.com> Date: Tue, 2 Dec 2025 00:46:30 +0100 Subject: [PATCH 08/15] Replace free calls in release callbacks with unique pointers tracked by private_data --- mssql_python/pybind/ddbc_bindings.cpp | 231 +++++++++++++++++--------- 1 file changed, 151 insertions(+), 80 deletions(-) diff --git a/mssql_python/pybind/ddbc_bindings.cpp b/mssql_python/pybind/ddbc_bindings.cpp index f1002afb..8a88afd2 100644 --- a/mssql_python/pybind/ddbc_bindings.cpp +++ b/mssql_python/pybind/ddbc_bindings.cpp @@ -192,6 +192,28 @@ struct ColumnBuffersArrow { var_data(numCols) {} }; +struct ArrowArrayPrivateData { + std::unique_ptr buffer_uint8; + std::unique_ptr buffer_int16; + std::unique_ptr buffer_int32; + std::unique_ptr buffer_int64; + std::unique_ptr buffer_float64; + std::unique_ptr buffer_bit; + std::unique_ptr buffer_var; + std::unique_ptr buffer_date; + std::unique_ptr buffer_ts_micro; + std::unique_ptr buffer_time_second; + std::unique_ptr<__int128_t[]> buffer_decimal; + + std::unique_ptr buffer_valid; + std::unique_ptr buffer_var_data; +}; + +struct ArrowSchemaPrivateData { + std::unique_ptr name; + std::unique_ptr format; +}; + #ifndef ARROW_C_DATA_INTERFACE #define ARROW_C_DATA_INTERFACE @@ -212,7 +234,8 @@ struct ArrowSchema { // Release callback void (*release)(struct ArrowSchema*); // Opaque producer-specific data - void* private_data; + // Only our child-arrays will set this, so we can give it the correct type + ArrowSchemaPrivateData* private_data; }; struct ArrowArray { @@ -229,7 +252,8 @@ struct ArrowArray { // Release callback void (*release)(struct ArrowArray*); // Opaque producer-specific data - void* private_data; + // Only our child-arrays will set this, so we can give it the correct type + ArrowArrayPrivateData* private_data; }; #endif // ARROW_C_DATA_INTERFACE @@ -4182,10 +4206,6 @@ SQLRETURN GetDataVar(SQLHSTMT hStmt, SQLSMALLINT cType, std::vector& dataVec, SQLLEN* indicator) { - if (!SQLGetData_ptr) { - ThrowStdException("SQLGetData function not loaded"); - } - size_t start = 0; size_t end = 0; @@ -4266,49 +4286,6 @@ SQLRETURN GetDataVar(SQLHSTMT hStmt, return SQL_SUCCESS; } -void ArrowSchema_release(struct ArrowSchema* schema) { - assert (schema != nullptr); - assert (schema->release != nullptr); - schema->release = nullptr; - delete[] schema->name; - for (int i = 0; i < schema->n_children; i++) { - assert (schema->children != nullptr); - if (schema->children[i]) { - schema->children[i]->release(schema->children[i]); - delete schema->children[i]; - } - } - delete[] schema->children; - delete[] schema->format; -} - -void ArrowArray_release(struct ArrowArray* array) { - assert (array != nullptr); - assert (array->release != nullptr); - array->release = nullptr; - - uint32_t buffers_freed = 0; - uint32_t current_buffer = 0; - while (buffers_freed < array->n_buffers) { - if (array->buffers[current_buffer]) { - free((void*)array->buffers[current_buffer]); - buffers_freed++; - } - current_buffer++; - assert (current_buffer <= 3); - } - delete[] array->buffers; - - for (int i = 0; i < array->n_children; i++) { - assert (array->children != nullptr); - assert (array->children[i] != nullptr); - array->children[i]->release(array->children[i]); - delete array->children[i]; - } - delete[] array->children; - -} - int32_t dateAsDayCount(SQLUSMALLINT year, SQLUSMALLINT month, SQLUSMALLINT day) { // Convert SQL_DATE_STRUCT to Arrow Date32 (days since epoch) std::tm tm_date = {}; @@ -4321,6 +4298,8 @@ int32_t dateAsDayCount(SQLUSMALLINT year, SQLUSMALLINT month, SQLUSMALLINT day) LOG("Failed to convert SQL_DATE_STRUCT to time_t"); ThrowStdException("Date conversion error"); } + // Sanity check against timezone issues. Since we only provide the date, this has to be true + assert(time_since_epoch % 86400 == 0); // Calculate days since epoch return time_since_epoch / 86400; } @@ -4380,7 +4359,7 @@ SQLRETURN FetchArrowBatch_wrap( columnNamesCStr[i] = std::make_unique(nameLen); std::memcpy(columnNamesCStr[i].get(), columnName.c_str(), nameLen); - const char* format = nullptr; + std::string format = ""; switch(dataType) { case SQL_CHAR: case SQL_VARCHAR: @@ -4476,9 +4455,9 @@ SQLRETURN FetchArrowBatch_wrap( // Store format string if not already stored (for non-decimal types) if (!columnFormats[i]) { - size_t formatLen = std::strlen(format) + 1; + size_t formatLen = format.length() + 1; columnFormats[i] = std::make_unique(formatLen); - std::memcpy(columnFormats[i].get(), format, formatLen); + std::memcpy(columnFormats[i].get(), format.c_str(), formatLen); } buffersArrow.valid[i] = std::make_unique((arrowBatchSize + 7) / 8); @@ -4973,24 +4952,64 @@ SQLRETURN FetchArrowBatch_wrap( // Transfer ownership of buffers to Arrow structures // Exceptions beyond this point would cause memory leaks - auto batch_children = new ArrowSchema* [numCols]; + + auto batch_children = new ArrowSchema*[numCols]; + for (SQLSMALLINT i = 0; i < numCols; i++) { + auto col_private_data = new ArrowSchemaPrivateData(); + col_private_data->format = std::move(columnFormats[i]); + col_private_data->name = std::move(columnNamesCStr[i]); + auto arrow_schema = new ArrowSchema({ - .format = columnFormats[i].release(), - .name = columnNamesCStr[i].release(), - .flags = columnNullable[i] ? 2 : 0, // ARROW_FLAG_NULLABLE - .release = ArrowSchema_release, + .format = col_private_data->format.get(), + .name = col_private_data->name.get(), + .metadata = nullptr, + .flags = static_cast(columnNullable[i] ? ARROW_FLAG_NULLABLE : 0), + .n_children = 0, + .children = nullptr, + .dictionary = nullptr, + .release = [](ArrowSchema* schema) { + assert(schema != nullptr); + assert(schema->release != nullptr); + assert(schema->private_data != nullptr); + assert(schema->children == nullptr && schema->n_children == 0); + delete schema->private_data; // Frees format and name + schema->release = nullptr; + }, + .private_data = col_private_data, }); batch_children[i] = arrow_schema; } auto arrow_schema_batch = new ArrowSchema({ - .format = []{ char* f = new char[3]; std::strcpy(f, "+s"); return f; }(), - .name = []{ char* n = new char[1]; n[0] = '\0'; return n; }(), + .format = "+s", + .name = "", + .metadata = nullptr, + .flags = 0, .n_children = numCols, .children = batch_children, - .release = ArrowSchema_release, + .dictionary = nullptr, + .release = [](ArrowSchema* schema) { + // format and name are string literals, no need to free + assert(schema != nullptr); + assert(schema->release != nullptr); + assert(schema->private_data == nullptr); + assert(schema->children != nullptr); + assert(schema->n_children > 0); + for (int64_t i = 0; i < schema->n_children; ++i) { + if (schema->children[i]) { + if (schema->children[i]->release) { + schema->children[i]->release(schema->children[i]); + } + delete schema->children[i]; + } + } + delete[] schema->children; + schema->release = nullptr; + }, + .private_data = nullptr, }); + auto caps = py::capsule((void*)arrow_schema_batch, "arrow_schema", [](void* ptr) { auto arrow_schema = static_cast(ptr); if (arrow_schema->release) { @@ -5002,21 +5021,48 @@ SQLRETURN FetchArrowBatch_wrap( auto arrow_array_batch_buffers = new const void* [3]; memset(arrow_array_batch_buffers, 0, sizeof(const void*) * 3); + // Necessary dummy buffer, pyarrow will error without it + arrow_array_batch_buffers[1] = new uint8_t[1]{0}; auto arrow_array_batch = new ArrowArray({ .length = static_cast(idxRowArrow), + // only the non null dummy buffer counts .n_buffers = 1, .n_children = numCols, .buffers = arrow_array_batch_buffers, .children = new ArrowArray* [numCols], - .release = ArrowArray_release, + .release = [](ArrowArray* array) { + assert(array != nullptr); + assert(array->private_data == nullptr); + assert(array->release != nullptr); + assert(array->children != nullptr); + assert(array->n_children > 0); + for (int64_t i = 0; i < array->n_children; ++i) { + if (array->children[i]) { + if (array->children[i]->release) { + array->children[i]->release(array->children[i]); + } + delete array->children[i]; + } + } + delete[] array->children; + assert(array->buffers != nullptr); + assert(array->n_buffers == 1); + assert(array->buffers[0] == nullptr); + assert(array->buffers[1] != nullptr); + assert(array->buffers[2] == nullptr); + // Delete dummy buffer + delete[] const_cast(static_cast(array->buffers[1])); + + delete[] array->buffers; + array->release = nullptr; + }, }); - // Necessary dummy buffer - arrow_array_batch->buffers[1] = new int[1]; for (SQLUSMALLINT col = 0; col < numCols; col++) { auto dataType = dataTypes[col]; auto arrow_array_col_buffers = new const void* [3]; memset(arrow_array_col_buffers, 0, sizeof(const void*) * 3); + auto private_data = new ArrowArrayPrivateData(); // Allocate new memory and copy the data switch (dataType) { case SQL_CHAR: @@ -5034,52 +5080,65 @@ SQLRETURN FetchArrowBatch_wrap( // length of string at index i is the difference between values at i and i+1 // so total length is value at index idxRowArrow auto data_buf_len_total = buffersArrow.var[col][idxRowArrow]; - uint8_t* dataBuffer = new uint8_t[data_buf_len_total]; - std::memcpy(dataBuffer, buffersArrow.var_data[col].data(), data_buf_len_total); - arrow_array_col_buffers[2] = dataBuffer; - arrow_array_col_buffers[1] = buffersArrow.var[col].release(); + auto dataBuffer = std::make_unique(data_buf_len_total); + std::memcpy(dataBuffer.get(), buffersArrow.var_data[col].data(), data_buf_len_total); + private_data->buffer_var_data = std::move(dataBuffer); + arrow_array_col_buffers[2] = private_data->buffer_var_data.get(); + private_data->buffer_var = std::move(buffersArrow.var[col]); + arrow_array_col_buffers[1] = private_data->buffer_var.get(); } break; case SQL_TINYINT: - arrow_array_col_buffers[1] = buffersArrow.uint8[col].release(); + private_data->buffer_uint8 = std::move(buffersArrow.uint8[col]); + arrow_array_col_buffers[1] = private_data->buffer_uint8.get(); break; case SQL_SMALLINT: - arrow_array_col_buffers[1] = buffersArrow.int16[col].release(); + private_data->buffer_int16 = std::move(buffersArrow.int16[col]); + arrow_array_col_buffers[1] = private_data->buffer_int16.get(); break; case SQL_INTEGER: - arrow_array_col_buffers[1] = buffersArrow.int32[col].release(); + private_data->buffer_int32 = std::move(buffersArrow.int32[col]); + arrow_array_col_buffers[1] = private_data->buffer_int32.get(); break; case SQL_BIGINT: - arrow_array_col_buffers[1] = buffersArrow.int64[col].release(); + private_data->buffer_int64 = std::move(buffersArrow.int64[col]); + arrow_array_col_buffers[1] = private_data->buffer_int64.get(); break; case SQL_REAL: case SQL_FLOAT: case SQL_DOUBLE: - arrow_array_col_buffers[1] = buffersArrow.float64[col].release(); + private_data->buffer_float64 = std::move(buffersArrow.float64[col]); + arrow_array_col_buffers[1] = private_data->buffer_float64.get(); break; case SQL_DECIMAL: case SQL_NUMERIC: { - arrow_array_col_buffers[1] = buffersArrow.decimal[col].release(); + private_data->buffer_decimal = std::move(buffersArrow.decimal[col]); + arrow_array_col_buffers[1] = private_data->buffer_decimal.get(); break; } case SQL_TIMESTAMP: case SQL_TYPE_TIMESTAMP: case SQL_DATETIME: - arrow_array_col_buffers[1] = buffersArrow.ts_micro[col].release(); + private_data->buffer_ts_micro = std::move(buffersArrow.ts_micro[col]); + arrow_array_col_buffers[1] = private_data->buffer_ts_micro.get(); break; case SQL_SS_TIMESTAMPOFFSET: - arrow_array_col_buffers[1] = buffersArrow.ts_micro[col].release(); + private_data->buffer_ts_micro = std::move(buffersArrow.ts_micro[col]); + arrow_array_col_buffers[1] = private_data->buffer_ts_micro.get(); break; case SQL_TYPE_DATE: - arrow_array_col_buffers[1] = buffersArrow.date[col].release(); + private_data->buffer_date = std::move(buffersArrow.date[col]); + arrow_array_col_buffers[1] = private_data->buffer_date.get(); break; case SQL_TIME: case SQL_TYPE_TIME: case SQL_SS_TIME2: - arrow_array_col_buffers[1] = buffersArrow.time_second[col].release(); + private_data->buffer_time_second = std::move(buffersArrow.time_second[col]); + arrow_array_col_buffers[1] = private_data->buffer_time_second.get(); break; case SQL_BIT: - arrow_array_col_buffers[1] = buffersArrow.bit[col].release(); + private_data->buffer_bit = std::move(buffersArrow.bit[col]); + arrow_array_col_buffers[1] = private_data->buffer_bit.get(); break; default: { std::ostringstream errorString; @@ -5099,10 +5158,22 @@ SQLRETURN FetchArrowBatch_wrap( .n_children = 0, .buffers = arrow_array_col_buffers, .children = nullptr, - .release = ArrowArray_release, + .release = [](ArrowArray* array) { + assert(array != nullptr); + assert(array->private_data != nullptr); + assert(array->release != nullptr); + assert(array->children == nullptr); + assert(array->n_children == 0); + delete array->private_data; // Frees all buffer entries + assert(array->buffers != nullptr); + delete[] array->buffers; + array->release = nullptr; + }, + .private_data = private_data, }); - arrow_array_col->buffers[0] = buffersArrow.valid[col].release(); + private_data->buffer_valid = std::move(buffersArrow.valid[col]); + arrow_array_col->buffers[0] = private_data->buffer_valid.get(); arrow_array_batch->children[col] = arrow_array_col; } From ad188bd9fecdaba9450aedd5d488e2f4231cd950 Mon Sep 17 00:00:00 2001 From: ffelixg <142172984+ffelixg@users.noreply.github.com> Date: Sat, 6 Dec 2025 20:08:10 +0100 Subject: [PATCH 09/15] Eliminate potential memory leaks on allocation failures when transferring ownership to arrow --- mssql_python/pybind/ddbc_bindings.cpp | 640 ++++++++++++-------------- 1 file changed, 289 insertions(+), 351 deletions(-) diff --git a/mssql_python/pybind/ddbc_bindings.cpp b/mssql_python/pybind/ddbc_bindings.cpp index 8a88afd2..c706bb2b 100644 --- a/mssql_python/pybind/ddbc_bindings.cpp +++ b/mssql_python/pybind/ddbc_bindings.cpp @@ -157,56 +157,31 @@ struct NumericData { } }; -// Struct to hold data buffers and indicators for each column -struct ColumnBuffersArrow { - std::vector> uint8; - std::vector> int16; - std::vector> int32; - std::vector> int64; - std::vector> float64; - std::vector> bit; - std::vector> var; - std::vector> date; - std::vector> ts_micro; - std::vector> time_second; - std::vector> decimal; - - std::vector> valid; - std::vector> var_data; - - ColumnBuffersArrow(SQLSMALLINT numCols) - : - uint8(numCols), - int16(numCols), - int32(numCols), - int64(numCols), - float64(numCols), - bit(numCols), - var(numCols), - date(numCols), - ts_micro(numCols), - time_second(numCols), - decimal(numCols), - - valid(numCols), - var_data(numCols) {} -}; - struct ArrowArrayPrivateData { - std::unique_ptr buffer_uint8; - std::unique_ptr buffer_int16; - std::unique_ptr buffer_int32; - std::unique_ptr buffer_int64; - std::unique_ptr buffer_float64; - std::unique_ptr buffer_bit; - std::unique_ptr buffer_var; - std::unique_ptr buffer_date; - std::unique_ptr buffer_ts_micro; - std::unique_ptr buffer_time_second; - std::unique_ptr<__int128_t[]> buffer_decimal; - - std::unique_ptr buffer_valid; - std::unique_ptr buffer_var_data; + std::unique_ptr valid; + + std::unique_ptr uint8Val; + std::unique_ptr int16Val; + std::unique_ptr int32Val; + std::unique_ptr int64Val; + std::unique_ptr float64Val; + std::unique_ptr bitVal; + std::unique_ptr varVal; + std::unique_ptr dateVal; + std::unique_ptr tsMicroVal; + std::unique_ptr timeSecondVal; + std::unique_ptr<__int128_t[]> decimalVal; + + std::vector varData; + + // first buffer will be the valid bitmap + // second buffer will be one of the value buffers above + // third buffer will be the varData buffer for variable length types + std::array buffers; + + // Points to one of the typed *Val buffers above. Since the buffer pointers + // don't change, this can be set once during batch initialization. + void* ptrValueBuffer; }; struct ArrowSchemaPrivateData { @@ -4331,15 +4306,20 @@ SQLRETURN FetchArrowBatch_wrap( std::vector dataTypes(numCols); std::vector columnSizes(numCols); std::vector columnNullable(numCols); - std::vector> columnFormats(numCols); - std::vector> columnNamesCStr(numCols); + std::vector columnVarLen(numCols, false); - ColumnBuffersArrow buffersArrow(numCols); + std::vector> arrowArrayPrivateData(numCols); + std::vector> arrowSchemaPrivateData(numCols); for (SQLSMALLINT i = 0; i < numCols; i++) { + arrowArrayPrivateData[i] = std::make_unique(); + auto& arrowColumnProducer = arrowArrayPrivateData[i]; + arrowSchemaPrivateData[i] = std::make_unique(); + auto colMeta = columnNames[i].cast(); SQLSMALLINT dataType = colMeta["DataType"].cast(); SQLULEN columnSize = colMeta["ColumnSize"].cast(); SQLSMALLINT nullable = colMeta["Nullable"].cast(); + dataTypes[i] = dataType; columnSizes[i] = columnSize; columnNullable[i] = (nullable != SQL_NO_NULLS); @@ -4356,8 +4336,8 @@ SQLRETURN FetchArrowBatch_wrap( std::string columnName = colMeta["ColumnName"].cast(); size_t nameLen = columnName.length() + 1; - columnNamesCStr[i] = std::make_unique(nameLen); - std::memcpy(columnNamesCStr[i].get(), columnName.c_str(), nameLen); + arrowSchemaPrivateData[i]->name = std::make_unique(nameLen); + std::memcpy(arrowSchemaPrivateData[i]->name.get(), columnName.c_str(), nameLen); std::string format = ""; switch(dataType) { @@ -4370,41 +4350,50 @@ SQLRETURN FetchArrowBatch_wrap( case SQL_WLONGVARCHAR: case SQL_GUID: format = "u"; - buffersArrow.var[i] = std::make_unique(arrowBatchSize + 1); - buffersArrow.var_data[i].resize(arrowBatchSize * 42); + arrowColumnProducer->varVal = std::make_unique(arrowBatchSize + 1); + arrowColumnProducer->varData.resize(arrowBatchSize * 42); + columnVarLen[i] = true; // start at offset 0 - buffersArrow.var[i][0] = 0; + arrowColumnProducer->varVal[0] = 0; + arrowColumnProducer->ptrValueBuffer = arrowColumnProducer->varVal.get(); break; case SQL_BINARY: case SQL_VARBINARY: case SQL_LONGVARBINARY: format = "z"; - buffersArrow.var[i] = std::make_unique(arrowBatchSize + 1); - buffersArrow.var_data[i].resize(arrowBatchSize * 42); + arrowColumnProducer->varVal = std::make_unique(arrowBatchSize + 1); + arrowColumnProducer->varData.resize(arrowBatchSize * 42); + columnVarLen[i] = true; // start at offset 0 - buffersArrow.var[i][0] = 0; + arrowColumnProducer->varVal[0] = 0; + arrowColumnProducer->ptrValueBuffer = arrowColumnProducer->varVal.get(); break; case SQL_TINYINT: format = "C"; - buffersArrow.uint8[i] = std::make_unique(arrowBatchSize); + arrowColumnProducer->uint8Val = std::make_unique(arrowBatchSize); + arrowColumnProducer->ptrValueBuffer = arrowColumnProducer->uint8Val.get(); break; case SQL_SMALLINT: format = "s"; - buffersArrow.int16[i] = std::make_unique(arrowBatchSize); + arrowColumnProducer->int16Val = std::make_unique(arrowBatchSize); + arrowColumnProducer->ptrValueBuffer = arrowColumnProducer->int16Val.get(); break; case SQL_INTEGER: format = "i"; - buffersArrow.int32[i] = std::make_unique(arrowBatchSize); + arrowColumnProducer->int32Val = std::make_unique(arrowBatchSize); + arrowColumnProducer->ptrValueBuffer = arrowColumnProducer->int32Val.get(); break; case SQL_BIGINT: format = "l"; - buffersArrow.int64[i] = std::make_unique(arrowBatchSize); + arrowColumnProducer->int64Val = std::make_unique(arrowBatchSize); + arrowColumnProducer->ptrValueBuffer = arrowColumnProducer->int64Val.get(); break; case SQL_REAL: case SQL_FLOAT: case SQL_DOUBLE: format = "g"; - buffersArrow.float64[i] = std::make_unique(arrowBatchSize); + arrowColumnProducer->float64Val = std::make_unique(arrowBatchSize); + arrowColumnProducer->ptrValueBuffer = arrowColumnProducer->float64Val.get(); break; case SQL_DECIMAL: case SQL_NUMERIC: { @@ -4412,36 +4401,42 @@ SQLRETURN FetchArrowBatch_wrap( formatStream << "d:" << columnSize << "," << colMeta["DecimalDigits"].cast(); std::string formatStr = formatStream.str(); size_t formatLen = formatStr.length() + 1; - columnFormats[i] = std::make_unique(formatLen); - std::memcpy(columnFormats[i].get(), formatStr.c_str(), formatLen); - format = columnFormats[i].get(); - buffersArrow.decimal[i] = std::make_unique<__int128_t[]>(arrowBatchSize); + arrowSchemaPrivateData[i]->format = std::make_unique(formatLen); + std::memcpy(arrowSchemaPrivateData[i]->format.get(), formatStr.c_str(), formatLen); + format = arrowSchemaPrivateData[i]->format.get(); + arrowColumnProducer->decimalVal = std::make_unique<__int128_t[]>(arrowBatchSize); + arrowColumnProducer->ptrValueBuffer = arrowColumnProducer->decimalVal.get(); break; } case SQL_TIMESTAMP: case SQL_TYPE_TIMESTAMP: case SQL_DATETIME: format = "tsu:"; - buffersArrow.ts_micro[i] = std::make_unique(arrowBatchSize); + arrowColumnProducer->tsMicroVal = std::make_unique(arrowBatchSize); + arrowColumnProducer->ptrValueBuffer = arrowColumnProducer->tsMicroVal.get(); break; case SQL_SS_TIMESTAMPOFFSET: format = "tsu:+00:00"; - buffersArrow.ts_micro[i] = std::make_unique(arrowBatchSize); + arrowColumnProducer->tsMicroVal = std::make_unique(arrowBatchSize); + arrowColumnProducer->ptrValueBuffer = arrowColumnProducer->tsMicroVal.get(); break; case SQL_TYPE_DATE: format = "tdD"; - buffersArrow.date[i] = std::make_unique(arrowBatchSize); + arrowColumnProducer->dateVal = std::make_unique(arrowBatchSize); + arrowColumnProducer->ptrValueBuffer = arrowColumnProducer->dateVal.get(); break; case SQL_TIME: case SQL_TYPE_TIME: case SQL_SS_TIME2: format = "tts"; - buffersArrow.time_second[i] = std::make_unique(arrowBatchSize); + arrowColumnProducer->timeSecondVal = std::make_unique(arrowBatchSize); + arrowColumnProducer->ptrValueBuffer = arrowColumnProducer->timeSecondVal.get(); break; case SQL_BIT: format = "b"; - buffersArrow.bit[i] = std::make_unique((arrowBatchSize + 7) / 8); - std::memset(buffersArrow.bit[i].get(), 0, (arrowBatchSize + 7) / 8); + arrowColumnProducer->bitVal = std::make_unique((arrowBatchSize + 7) / 8); + std::memset(arrowColumnProducer->bitVal.get(), 0, (arrowBatchSize + 7) / 8); + arrowColumnProducer->ptrValueBuffer = arrowColumnProducer->bitVal.get(); break; default: std::wstring columnName = colMeta["ColumnName"].cast(); @@ -4453,16 +4448,17 @@ SQLRETURN FetchArrowBatch_wrap( break; } - // Store format string if not already stored (for non-decimal types) - if (!columnFormats[i]) { + // Store format string if not already stored. + // For non-decimal types, format is now a static string. + if (!arrowSchemaPrivateData[i]->format) { size_t formatLen = format.length() + 1; - columnFormats[i] = std::make_unique(formatLen); - std::memcpy(columnFormats[i].get(), format.c_str(), formatLen); + arrowSchemaPrivateData[i]->format = std::make_unique(formatLen); + std::memcpy(arrowSchemaPrivateData[i]->format.get(), format.c_str(), formatLen); } - buffersArrow.valid[i] = std::make_unique((arrowBatchSize + 7) / 8); + arrowColumnProducer->valid = std::make_unique((arrowBatchSize + 7) / 8); // Initialize validity bitmap to all valid - std::memset(buffersArrow.valid[i].get(), 0xFF, (arrowBatchSize + 7) / 8); + std::memset(arrowColumnProducer->valid.get(), 0xFF, (arrowBatchSize + 7) / 8); } if (fetchSize > 1) { @@ -4495,7 +4491,6 @@ SQLRETURN FetchArrowBatch_wrap( SQLSetStmtAttr_ptr(hStmt, SQL_ATTR_ROW_ARRAY_SIZE, (SQLPOINTER)(intptr_t)fetchSize, 0); SQLSetStmtAttr_ptr(hStmt, SQL_ATTR_ROWS_FETCHED_PTR, &numRowsFetched, 0); - size_t idxRowArrow = 0; // arrowBatchSize % fetchSize == 0 ensures that any followup (even non-arrow) fetches // start with a fresh batch @@ -4516,9 +4511,10 @@ SQLRETURN FetchArrowBatch_wrap( // It'll be populated by SQLFetch assert(numRowsFetched + idxRowArrow <= static_cast(arrowBatchSize)); for (SQLULEN idxRowSql = 0; idxRowSql < numRowsFetched; idxRowSql++) { - for (SQLUSMALLINT col = 1; col <= numCols; col++) { - auto dataType = dataTypes[col - 1]; - auto columnSize = columnSizes[col - 1]; + for (SQLUSMALLINT idxCol = 0; idxCol < numCols; idxCol++) { + auto& arrowColumnProducer = arrowArrayPrivateData[idxCol]; + auto dataType = dataTypes[idxCol]; + auto columnSize = columnSizes[idxCol]; if (hasLobColumns) { assert(idxRowSql == 0 && "GetData only works one row at a time"); @@ -4529,10 +4525,10 @@ SQLRETURN FetchArrowBatch_wrap( case SQL_LONGVARBINARY: { GetDataVar( hStmt, - col, + idxCol + 1, SQL_C_BINARY, - buffers.charBuffers[col - 1], - buffers.indicators[col - 1].data() + buffers.charBuffers[idxCol], + buffers.indicators[idxCol].data() ); break; } @@ -4541,10 +4537,10 @@ SQLRETURN FetchArrowBatch_wrap( case SQL_LONGVARCHAR: { GetDataVar( hStmt, - col, + idxCol + 1, SQL_C_CHAR, - buffers.charBuffers[col - 1], - buffers.indicators[col - 1].data() + buffers.charBuffers[idxCol], + buffers.indicators[idxCol].data() ); break; } @@ -4554,152 +4550,152 @@ SQLRETURN FetchArrowBatch_wrap( case SQL_WLONGVARCHAR: { GetDataVar( hStmt, - col, + idxCol + 1, SQL_C_WCHAR, - buffers.wcharBuffers[col - 1], - buffers.indicators[col - 1].data() + buffers.wcharBuffers[idxCol], + buffers.indicators[idxCol].data() ); break; } case SQL_INTEGER: { - buffers.intBuffers[col - 1].resize(1); + buffers.intBuffers[idxCol].resize(1); SQLGetData_ptr( - hStmt, col, SQL_C_SLONG, - buffers.intBuffers[col - 1].data(), + hStmt, idxCol + 1, SQL_C_SLONG, + buffers.intBuffers[idxCol].data(), sizeof(SQLINTEGER), - buffers.indicators[col - 1].data() + buffers.indicators[idxCol].data() ); break; } case SQL_SMALLINT: { - buffers.smallIntBuffers[col - 1].resize(1); + buffers.smallIntBuffers[idxCol].resize(1); SQLGetData_ptr( - hStmt, col, SQL_C_SSHORT, - buffers.smallIntBuffers[col - 1].data(), + hStmt, idxCol + 1, SQL_C_SSHORT, + buffers.smallIntBuffers[idxCol].data(), sizeof(SQLSMALLINT), - buffers.indicators[col - 1].data() + buffers.indicators[idxCol].data() ); break; } case SQL_TINYINT: { - buffers.charBuffers[col - 1].resize(1); + buffers.charBuffers[idxCol].resize(1); SQLGetData_ptr( - hStmt, col, SQL_C_TINYINT, - buffers.charBuffers[col - 1].data(), + hStmt, idxCol + 1, SQL_C_TINYINT, + buffers.charBuffers[idxCol].data(), sizeof(SQLCHAR), - buffers.indicators[col - 1].data() + buffers.indicators[idxCol].data() ); break; } case SQL_BIT: { - buffers.charBuffers[col - 1].resize(1); + buffers.charBuffers[idxCol].resize(1); SQLGetData_ptr( - hStmt, col, SQL_C_BIT, - buffers.charBuffers[col - 1].data(), + hStmt, idxCol + 1, SQL_C_BIT, + buffers.charBuffers[idxCol].data(), sizeof(SQLCHAR), - buffers.indicators[col - 1].data() + buffers.indicators[idxCol].data() ); break; } case SQL_REAL: { - buffers.realBuffers[col - 1].resize(1); + buffers.realBuffers[idxCol].resize(1); SQLGetData_ptr( - hStmt, col, SQL_C_FLOAT, - buffers.realBuffers[col - 1].data(), + hStmt, idxCol + 1, SQL_C_FLOAT, + buffers.realBuffers[idxCol].data(), sizeof(SQLREAL), - buffers.indicators[col - 1].data() + buffers.indicators[idxCol].data() ); break; } case SQL_DECIMAL: case SQL_NUMERIC: { - buffers.charBuffers[col - 1].resize(MAX_DIGITS_IN_NUMERIC); + buffers.charBuffers[idxCol].resize(MAX_DIGITS_IN_NUMERIC); SQLGetData_ptr( - hStmt, col, SQL_C_CHAR, - buffers.charBuffers[col - 1].data(), + hStmt, idxCol + 1, SQL_C_CHAR, + buffers.charBuffers[idxCol].data(), MAX_DIGITS_IN_NUMERIC * sizeof(SQLCHAR), - buffers.indicators[col - 1].data() + buffers.indicators[idxCol].data() ); break; } case SQL_DOUBLE: case SQL_FLOAT: { - buffers.doubleBuffers[col - 1].resize(1); + buffers.doubleBuffers[idxCol].resize(1); SQLGetData_ptr( - hStmt, col, SQL_C_DOUBLE, - buffers.doubleBuffers[col - 1].data(), + hStmt, idxCol + 1, SQL_C_DOUBLE, + buffers.doubleBuffers[idxCol].data(), sizeof(SQLDOUBLE), - buffers.indicators[col - 1].data() + buffers.indicators[idxCol].data() ); break; } case SQL_TIMESTAMP: case SQL_TYPE_TIMESTAMP: case SQL_DATETIME: { - buffers.timestampBuffers[col - 1].resize(1); + buffers.timestampBuffers[idxCol].resize(1); SQLGetData_ptr( - hStmt, col, SQL_C_TYPE_TIMESTAMP, - buffers.timestampBuffers[col - 1].data(), + hStmt, idxCol + 1, SQL_C_TYPE_TIMESTAMP, + buffers.timestampBuffers[idxCol].data(), sizeof(SQL_TIMESTAMP_STRUCT), - buffers.indicators[col - 1].data() + buffers.indicators[idxCol].data() ); break; } case SQL_BIGINT: { - buffers.bigIntBuffers[col - 1].resize(1); + buffers.bigIntBuffers[idxCol].resize(1); SQLGetData_ptr( - hStmt, col, SQL_C_SBIGINT, - buffers.bigIntBuffers[col - 1].data(), + hStmt, idxCol + 1, SQL_C_SBIGINT, + buffers.bigIntBuffers[idxCol].data(), sizeof(SQLBIGINT), - buffers.indicators[col - 1].data() + buffers.indicators[idxCol].data() ); break; } case SQL_TYPE_DATE: { - buffers.dateBuffers[col - 1].resize(1); + buffers.dateBuffers[idxCol].resize(1); SQLGetData_ptr( - hStmt, col, SQL_C_TYPE_DATE, - buffers.dateBuffers[col - 1].data(), + hStmt, idxCol + 1, SQL_C_TYPE_DATE, + buffers.dateBuffers[idxCol].data(), sizeof(SQL_DATE_STRUCT), - buffers.indicators[col - 1].data() + buffers.indicators[idxCol].data() ); break; } case SQL_TIME: case SQL_TYPE_TIME: case SQL_SS_TIME2: { - buffers.timeBuffers[col - 1].resize(1); + buffers.timeBuffers[idxCol].resize(1); SQLGetData_ptr( - hStmt, col, SQL_C_TYPE_TIME, - buffers.timeBuffers[col - 1].data(), + hStmt, idxCol + 1, SQL_C_TYPE_TIME, + buffers.timeBuffers[idxCol].data(), sizeof(SQL_TIME_STRUCT), - buffers.indicators[col - 1].data() + buffers.indicators[idxCol].data() ); break; } case SQL_GUID: { - buffers.guidBuffers[col - 1].resize(1); + buffers.guidBuffers[idxCol].resize(1); SQLGetData_ptr( - hStmt, col, SQL_C_GUID, - buffers.guidBuffers[col - 1].data(), + hStmt, idxCol + 1, SQL_C_GUID, + buffers.guidBuffers[idxCol].data(), sizeof(SQLGUID), - buffers.indicators[col - 1].data() + buffers.indicators[idxCol].data() ); break; } case SQL_SS_TIMESTAMPOFFSET: { - buffers.datetimeoffsetBuffers[col - 1].resize(1); + buffers.datetimeoffsetBuffers[idxCol].resize(1); SQLGetData_ptr( - hStmt, col, SQL_C_SS_TIMESTAMPOFFSET, - buffers.datetimeoffsetBuffers[col - 1].data(), + hStmt, idxCol + 1, SQL_C_SS_TIMESTAMPOFFSET, + buffers.datetimeoffsetBuffers[idxCol].data(), sizeof(DateTimeOffset), - buffers.indicators[col - 1].data() + buffers.indicators[idxCol].data() ); break; } default: { std::ostringstream errorString; - errorString << "Unsupported data type for column ID - " << col + errorString << "Unsupported data type for column ID - " << (idxCol + 1) << ", Type - " << dataType; LOG("SQLGetData: %s", errorString.str().c_str()); ThrowStdException(errorString.str()); @@ -4708,13 +4704,13 @@ SQLRETURN FetchArrowBatch_wrap( } } - SQLLEN dataLen = buffers.indicators[col - 1][idxRowSql]; + SQLLEN dataLen = buffers.indicators[idxCol][idxRowSql]; if (dataLen == SQL_NULL_DATA) { // Mark as null in validity bitmap size_t bytePos = idxRowArrow / 8; size_t bitPos = idxRowArrow % 8; - buffersArrow.valid[col - 1][bytePos] &= ~(1 << bitPos); + arrowColumnProducer->valid[bytePos] &= ~(1 << bitPos); // Value buffer for variable length data types needs to be set appropriately // as it will be used by the next non null value @@ -4731,7 +4727,7 @@ SQLRETURN FetchArrowBatch_wrap( case SQL_BINARY: case SQL_VARBINARY: case SQL_LONGVARBINARY: - buffersArrow.var[col - 1][idxRowArrow + 1] = buffersArrow.var[col - 1][idxRowArrow]; + arrowColumnProducer->varVal[idxRowArrow + 1] = arrowColumnProducer->varVal[idxRowArrow]; break; default: break; @@ -4739,7 +4735,7 @@ SQLRETURN FetchArrowBatch_wrap( continue; } else if (dataLen < 0) { // Negative value is unexpected, log column index, SQL type & raise exception - LOG("Unexpected negative data length. Column ID - {}, SQL Type - {}, Data Length - {}", col, dataType, dataLen); + LOG("Unexpected negative data length. Column ID - {}, SQL Type - {}, Data Length - {}", idxCol + 1, dataType, dataLen); ThrowStdException("Unexpected negative data length."); } @@ -4748,28 +4744,28 @@ SQLRETURN FetchArrowBatch_wrap( case SQL_VARBINARY: case SQL_LONGVARBINARY: { uint64_t fetchBufferSize = columnSize /* bytes are not null terminated */; - auto target_vec = &buffersArrow.var_data[col - 1]; - auto start = buffersArrow.var[col - 1][idxRowArrow]; + auto target_vec = &arrowColumnProducer->varData; + auto start = arrowColumnProducer->varVal[idxRowArrow]; while (target_vec->size() < start + dataLen) { target_vec->resize(target_vec->size() * 2); } - std::memcpy(&(*target_vec)[start], &buffers.charBuffers[col - 1][idxRowSql * fetchBufferSize], dataLen); - buffersArrow.var[col - 1][idxRowArrow + 1] = start + dataLen; + std::memcpy(&(*target_vec)[start], &buffers.charBuffers[idxCol][idxRowSql * fetchBufferSize], dataLen); + arrowColumnProducer->varVal[idxRowArrow + 1] = start + dataLen; break; } case SQL_CHAR: case SQL_VARCHAR: case SQL_LONGVARCHAR: { uint64_t fetchBufferSize = columnSize + 1 /* null-termination */; - auto target_vec = &buffersArrow.var_data[col - 1]; - auto start = buffersArrow.var[col - 1][idxRowArrow]; + auto target_vec = &arrowColumnProducer->varData; + auto start = arrowColumnProducer->varVal[idxRowArrow]; while (target_vec->size() < start + dataLen) { target_vec->resize(target_vec->size() * 2); } - std::memcpy(&(*target_vec)[start], &buffers.charBuffers[col - 1][idxRowSql * fetchBufferSize], dataLen); - buffersArrow.var[col - 1][idxRowArrow + 1] = start + dataLen; + std::memcpy(&(*target_vec)[start], &buffers.charBuffers[idxCol][idxRowSql * fetchBufferSize], dataLen); + arrowColumnProducer->varVal[idxRowArrow + 1] = start + dataLen; break; } case SQL_SS_XML: @@ -4778,9 +4774,9 @@ SQLRETURN FetchArrowBatch_wrap( case SQL_WLONGVARCHAR: { assert(dataLen % sizeof(SQLWCHAR) == 0); auto dataLenW = dataLen / sizeof(SQLWCHAR); - auto wcharSource = &buffers.wcharBuffers[col - 1][idxRowSql * (columnSize + 1)]; - auto start = buffersArrow.var[col - 1][idxRowArrow]; - auto target_vec = &buffersArrow.var_data[col - 1]; + auto wcharSource = &buffers.wcharBuffers[idxCol][idxRowSql * (columnSize + 1)]; + auto start = arrowColumnProducer->varVal[idxRowArrow]; + auto target_vec = &arrowColumnProducer->varData; #if defined(_WIN32) // Convert wide string int dataLenConverted = WideCharToMultiByte(CP_UTF8, 0, wcharSource, dataLenW, NULL, 0, NULL, NULL); @@ -4788,7 +4784,7 @@ SQLRETURN FetchArrowBatch_wrap( target_vec->resize(target_vec->size() * 2); } WideCharToMultiByte(CP_UTF8, 0, wcharSource, dataLenW, &(*target_vec)[start], dataLenConverted, NULL, NULL); - buffersArrow.var[col - 1][idxRowArrow + 1] = start + dataLenConverted; + arrowColumnProducer->varVal[idxRowArrow + 1] = start + dataLenConverted; #else // On Unix, use the SQLWCHARToWString utility and then convert to UTF-8 std::string utf8str = WideToUTF8(SQLWCHARToWString(wcharSource, dataLenW)); @@ -4796,15 +4792,15 @@ SQLRETURN FetchArrowBatch_wrap( target_vec->resize(target_vec->size() * 2); } std::memcpy(&(*target_vec)[start], utf8str.data(), utf8str.size()); - buffersArrow.var[col - 1][idxRowArrow + 1] = start + utf8str.size(); + arrowColumnProducer->varVal[idxRowArrow + 1] = start + utf8str.size(); #endif break; } case SQL_GUID: { // GUID is stored as a 36-character string in Arrow (e.g., "550e8400-e29b-41d4-a716-446655440000") // Each GUID is exactly 36 bytes in UTF-8 - auto target_vec = &buffersArrow.var_data[col - 1]; - auto start = buffersArrow.var[col - 1][idxRowArrow]; + auto target_vec = &arrowColumnProducer->varData; + auto start = arrowColumnProducer->varVal[idxRowArrow]; // Ensure buffer has space for the GUID string + null terminator while (target_vec->size() < start + 37) { @@ -4812,7 +4808,7 @@ SQLRETURN FetchArrowBatch_wrap( } // Get the GUID from the buffer - const SQLGUID& guidValue = buffers.guidBuffers[col - 1][idxRowSql]; + const SQLGUID& guidValue = buffers.guidBuffers[idxCol][idxRowSql]; // Convert GUID to string format: xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx snprintf(reinterpret_cast(&target_vec->data()[start]), 37, @@ -4826,25 +4822,25 @@ SQLRETURN FetchArrowBatch_wrap( guidValue.Data4[6], guidValue.Data4[7]); // Update offset for next row, ignoring null terminator - buffersArrow.var[col - 1][idxRowArrow + 1] = start + 36; + arrowColumnProducer->varVal[idxRowArrow + 1] = start + 36; break; } case SQL_TINYINT: - buffersArrow.uint8[col - 1][idxRowArrow] = buffers.charBuffers[col - 1][idxRowSql]; + arrowColumnProducer->uint8Val[idxRowArrow] = buffers.charBuffers[idxCol][idxRowSql]; break; case SQL_SMALLINT: - buffersArrow.int16[col - 1][idxRowArrow] = buffers.smallIntBuffers[col - 1][idxRowSql]; + arrowColumnProducer->int16Val[idxRowArrow] = buffers.smallIntBuffers[idxCol][idxRowSql]; break; case SQL_INTEGER: - buffersArrow.int32[col - 1][idxRowArrow] = buffers.intBuffers[col - 1][idxRowSql]; + arrowColumnProducer->int32Val[idxRowArrow] = buffers.intBuffers[idxCol][idxRowSql]; break; case SQL_BIGINT: - buffersArrow.int64[col - 1][idxRowArrow] = buffers.bigIntBuffers[col - 1][idxRowSql]; + arrowColumnProducer->int64Val[idxRowArrow] = buffers.bigIntBuffers[idxCol][idxRowSql]; break; case SQL_REAL: case SQL_FLOAT: case SQL_DOUBLE: - buffersArrow.float64[col - 1][idxRowArrow] = buffers.doubleBuffers[col - 1][idxRowSql]; + arrowColumnProducer->float64Val[idxRowArrow] = buffers.doubleBuffers[idxCol][idxRowSql]; break; case SQL_DECIMAL: case SQL_NUMERIC: { @@ -4853,26 +4849,26 @@ SQLRETURN FetchArrowBatch_wrap( auto start = idxRowSql * MAX_DIGITS_IN_NUMERIC; int sign = 1; for (SQLULEN idx = start; idx < start + dataLen; idx++) { - char digitChar = buffers.charBuffers[col - 1][idx]; + char digitChar = buffers.charBuffers[idxCol][idx]; if (digitChar == '-') { sign = -1; } else if (digitChar >= '0' && digitChar <= '9') { decimalValue = decimalValue * 10 + (digitChar - '0'); } } - buffersArrow.decimal[col - 1][idxRowArrow] = decimalValue * sign; + arrowColumnProducer->decimalVal[idxRowArrow] = decimalValue * sign; break; } case SQL_TIMESTAMP: case SQL_TYPE_TIMESTAMP: case SQL_DATETIME: { - SQL_TIMESTAMP_STRUCT sql_value = buffers.timestampBuffers[col - 1][idxRowSql]; + SQL_TIMESTAMP_STRUCT sql_value = buffers.timestampBuffers[idxCol][idxRowSql]; int64_t days = dateAsDayCount( sql_value.year, sql_value.month, sql_value.day ); - buffersArrow.ts_micro[col - 1][idxRowArrow] = + arrowColumnProducer->tsMicroVal[idxRowArrow] = days * 86400 * 1000000 + static_cast(sql_value.hour) * 3600 * 1000000 + static_cast(sql_value.minute) * 60 * 1000000 + @@ -4881,13 +4877,13 @@ SQLRETURN FetchArrowBatch_wrap( break; } case SQL_SS_TIMESTAMPOFFSET: { - DateTimeOffset sql_value = buffers.datetimeoffsetBuffers[col - 1][idxRowSql]; + DateTimeOffset sql_value = buffers.datetimeoffsetBuffers[idxCol][idxRowSql]; int64_t days = dateAsDayCount( sql_value.year, sql_value.month, sql_value.day ); - buffersArrow.ts_micro[col - 1][idxRowArrow] = + arrowColumnProducer->tsMicroVal[idxRowArrow] = days * 86400 * 1000000 + (static_cast(sql_value.hour) - static_cast(sql_value.timezone_hour)) * 3600 * 1000000 + (static_cast(sql_value.minute) - static_cast(sql_value.timezone_minute)) * 60 * 1000000 + @@ -4896,10 +4892,10 @@ SQLRETURN FetchArrowBatch_wrap( break; } case SQL_TYPE_DATE: - buffersArrow.date[col - 1][idxRowArrow] = dateAsDayCount( - buffers.dateBuffers[col - 1][idxRowSql].year, - buffers.dateBuffers[col - 1][idxRowSql].month, - buffers.dateBuffers[col - 1][idxRowSql].day + arrowColumnProducer->dateVal[idxRowArrow] = dateAsDayCount( + buffers.dateBuffers[idxCol][idxRowSql].year, + buffers.dateBuffers[idxCol][idxRowSql].month, + buffers.dateBuffers[idxCol][idxRowSql].day ); break; case SQL_TIME: @@ -4907,8 +4903,8 @@ SQLRETURN FetchArrowBatch_wrap( case SQL_SS_TIME2: { // NOTE: SQL_SS_TIME2 supports fractional seconds, but SQL_C_TYPE_TIME does not. // To fully support SQL_SS_TIME2, the corresponding c-type should be used. - const SQL_TIME_STRUCT& timeValue = buffers.timeBuffers[col - 1][idxRowSql]; - buffersArrow.time_second[col - 1][idxRowArrow] = + const SQL_TIME_STRUCT& timeValue = buffers.timeBuffers[idxCol][idxRowSql]; + arrowColumnProducer->timeSecondVal[idxRowArrow] = static_cast(timeValue.hour) * 3600 + static_cast(timeValue.minute) * 60 + static_cast(timeValue.second); @@ -4917,7 +4913,7 @@ SQLRETURN FetchArrowBatch_wrap( case SQL_BIT: { // SQL_BIT is stored as a single bit in Arrow's bitmap format // Get the boolean value from the buffer - bool bitValue = buffers.charBuffers[col - 1][idxRowSql] != 0; + bool bitValue = buffers.charBuffers[idxCol][idxRowSql] != 0; // Set the bit in the Arrow bitmap size_t byteIndex = idxRowArrow / 8; @@ -4925,16 +4921,16 @@ SQLRETURN FetchArrowBatch_wrap( if (bitValue) { // Set bit to 1 - buffersArrow.bit[col - 1][byteIndex] |= (1 << bitIndex); + arrowColumnProducer->bitVal[byteIndex] |= (1 << bitIndex); } else { // Clear bit to 0 - buffersArrow.bit[col - 1][byteIndex] &= ~(1 << bitIndex); + arrowColumnProducer->bitVal[byteIndex] &= ~(1 << bitIndex); } break; } default: { std::ostringstream errorString; - errorString << "Unsupported data type for column ID - " << col + errorString << "Unsupported data type for column ID - " << (idxCol + 1) << ", Type - " << dataType; LOG(errorString.str().c_str()); ThrowStdException(errorString.str()); @@ -4950,19 +4946,23 @@ SQLRETURN FetchArrowBatch_wrap( SQLSetStmtAttr_ptr(hStmt, SQL_ATTR_ROW_ARRAY_SIZE, (SQLPOINTER)1, 0); SQLSetStmtAttr_ptr(hStmt, SQL_ATTR_ROWS_FETCHED_PTR, NULL, 0); - // Transfer ownership of buffers to Arrow structures - // Exceptions beyond this point would cause memory leaks - - auto batch_children = new ArrowSchema*[numCols]; + // Transfer ownership of buffers to batch ArrowSchema + // First, allocate memory for the necessary structures + auto arrowSchemaBatch = std::make_unique(); + auto arrowSchemaBatchChildren = std::make_unique(numCols); + auto arrowSchemaBatchChildPointers = std::make_unique[]>(numCols); for (SQLSMALLINT i = 0; i < numCols; i++) { - auto col_private_data = new ArrowSchemaPrivateData(); - col_private_data->format = std::move(columnFormats[i]); - col_private_data->name = std::move(columnNamesCStr[i]); + arrowSchemaBatchChildPointers[i] = std::make_unique(); + } - auto arrow_schema = new ArrowSchema({ - .format = col_private_data->format.get(), - .name = col_private_data->name.get(), + // Second, transfer ownership to arrowSchemaBatch + // No unhandled exceptions until the pycapsule owns the arrowSchemaBatch to avoid memory leaks + + for (SQLSMALLINT i = 0; i < numCols; i++) { + *arrowSchemaBatchChildPointers[i] = { + .format = arrowSchemaPrivateData[i]->format.get(), + .name = arrowSchemaPrivateData[i]->name.get(), .metadata = nullptr, .flags = static_cast(columnNullable[i] ? ARROW_FLAG_NULLABLE : 0), .n_children = 0, @@ -4976,18 +4976,21 @@ SQLRETURN FetchArrowBatch_wrap( delete schema->private_data; // Frees format and name schema->release = nullptr; }, - .private_data = col_private_data, - }); - batch_children[i] = arrow_schema; + .private_data = arrowSchemaPrivateData[i].release(), + }; + } + + for (SQLSMALLINT i = 0; i < numCols; i++) { + arrowSchemaBatchChildren[i] = arrowSchemaBatchChildPointers[i].release(); } - auto arrow_schema_batch = new ArrowSchema({ + *arrowSchemaBatch = ArrowSchema{ .format = "+s", .name = "", .metadata = nullptr, .flags = 0, .n_children = numCols, - .children = batch_children, + .children = arrowSchemaBatchChildren.release(), .dictionary = nullptr, .release = [](ArrowSchema* schema) { // format and name are string literals, no need to free @@ -5008,28 +5011,79 @@ SQLRETURN FetchArrowBatch_wrap( schema->release = nullptr; }, .private_data = nullptr, - }); + }; - auto caps = py::capsule((void*)arrow_schema_batch, "arrow_schema", [](void* ptr) { - auto arrow_schema = static_cast(ptr); - if (arrow_schema->release) { - arrow_schema->release(arrow_schema); - } - delete arrow_schema; - }); - capsules.append(caps); + // Finally, transfer ownership of arrowSchemaBatch and its pointer to pycapsule + py::capsule arrowSchemaBatchCapsule; + try { + arrowSchemaBatchCapsule = py::capsule(arrowSchemaBatch.get(), "arrow_schema", [](void* ptr) { + auto arrowSchema = static_cast(ptr); + if (arrowSchema->release) { + arrowSchema->release(arrowSchema); + } + delete arrowSchema; + }); + } catch (...) { + arrowSchemaBatch->release(arrowSchemaBatch.get()); + throw; + } + arrowSchemaBatch.release(); + capsules.append(arrowSchemaBatchCapsule); + + // Transfer ownership of buffers to batch ArrowArray + // First, allocate memory for the necessary structures + auto arrowArrayBatch = std::make_unique(); + + auto arrowArrayBatchBuffers = std::make_unique(1); + arrowArrayBatchBuffers[0] = nullptr; - auto arrow_array_batch_buffers = new const void* [3]; - memset(arrow_array_batch_buffers, 0, sizeof(const void*) * 3); - // Necessary dummy buffer, pyarrow will error without it - arrow_array_batch_buffers[1] = new uint8_t[1]{0}; - auto arrow_array_batch = new ArrowArray({ + auto arrowArrayBatchChildren = std::make_unique(numCols); + auto arrowArrayBatchChildPointers = std::make_unique[]>(numCols); + for (SQLSMALLINT i = 0; i < numCols; i++) { + arrowArrayBatchChildPointers[i] = std::make_unique(); + } + + // Second, transfer ownership to arrowArrayBatch + // No unhandled exceptions until the pycapsule owns the arrowArrayBatch to avoid memory leaks + + for (SQLUSMALLINT col = 0; col < numCols; col++) { + auto dataType = dataTypes[col]; + arrowArrayPrivateData[col]->buffers[0] = arrowArrayPrivateData[col]->valid.get(); + arrowArrayPrivateData[col]->buffers[1] = arrowArrayPrivateData[col]->ptrValueBuffer; + arrowArrayPrivateData[col]->buffers[2] = arrowArrayPrivateData[col]->varData.data(); + + *arrowArrayBatchChildPointers[col] = { + .length = static_cast(idxRowArrow), + .null_count = 0, + .offset = 0, + .n_buffers = columnVarLen[col] ? 3 : 2, + .n_children = 0, + .buffers = (const void**)arrowArrayPrivateData[col]->buffers.data(), + .children = nullptr, + .release = [](ArrowArray* array) { + assert(array != nullptr); + assert(array->private_data != nullptr); + assert(array->release != nullptr); + assert(array->children == nullptr); + assert(array->n_children == 0); + delete array->private_data; // Frees all buffer entries + assert(array->buffers != nullptr); + array->release = nullptr; + }, + .private_data = arrowArrayPrivateData[col].release(), + }; + } + + for (SQLSMALLINT i = 0; i < numCols; i++) { + arrowArrayBatchChildren[i] = arrowArrayBatchChildPointers[i].release(); + } + + *arrowArrayBatch = ArrowArray{ .length = static_cast(idxRowArrow), - // only the non null dummy buffer counts .n_buffers = 1, .n_children = numCols, - .buffers = arrow_array_batch_buffers, - .children = new ArrowArray* [numCols], + .buffers = arrowArrayBatchBuffers.release(), + .children = arrowArrayBatchChildren.release(), .release = [](ArrowArray* array) { assert(array != nullptr); assert(array->private_data == nullptr); @@ -5048,147 +5102,31 @@ SQLRETURN FetchArrowBatch_wrap( assert(array->buffers != nullptr); assert(array->n_buffers == 1); assert(array->buffers[0] == nullptr); - assert(array->buffers[1] != nullptr); - assert(array->buffers[2] == nullptr); - // Delete dummy buffer - delete[] const_cast(static_cast(array->buffers[1])); - delete[] array->buffers; array->release = nullptr; }, - }); + }; - for (SQLUSMALLINT col = 0; col < numCols; col++) { - auto dataType = dataTypes[col]; - auto arrow_array_col_buffers = new const void* [3]; - memset(arrow_array_col_buffers, 0, sizeof(const void*) * 3); - auto private_data = new ArrowArrayPrivateData(); - // Allocate new memory and copy the data - switch (dataType) { - case SQL_CHAR: - case SQL_VARCHAR: - case SQL_LONGVARCHAR: - case SQL_SS_XML: - case SQL_WCHAR: - case SQL_WVARCHAR: - case SQL_WLONGVARCHAR: - case SQL_GUID: - case SQL_BINARY: - case SQL_VARBINARY: - case SQL_LONGVARBINARY: { - assert(buffersArrow.var[col][0] == 0); - // length of string at index i is the difference between values at i and i+1 - // so total length is value at index idxRowArrow - auto data_buf_len_total = buffersArrow.var[col][idxRowArrow]; - auto dataBuffer = std::make_unique(data_buf_len_total); - std::memcpy(dataBuffer.get(), buffersArrow.var_data[col].data(), data_buf_len_total); - private_data->buffer_var_data = std::move(dataBuffer); - arrow_array_col_buffers[2] = private_data->buffer_var_data.get(); - private_data->buffer_var = std::move(buffersArrow.var[col]); - arrow_array_col_buffers[1] = private_data->buffer_var.get(); - } - break; - case SQL_TINYINT: - private_data->buffer_uint8 = std::move(buffersArrow.uint8[col]); - arrow_array_col_buffers[1] = private_data->buffer_uint8.get(); - break; - case SQL_SMALLINT: - private_data->buffer_int16 = std::move(buffersArrow.int16[col]); - arrow_array_col_buffers[1] = private_data->buffer_int16.get(); - break; - case SQL_INTEGER: - private_data->buffer_int32 = std::move(buffersArrow.int32[col]); - arrow_array_col_buffers[1] = private_data->buffer_int32.get(); - break; - case SQL_BIGINT: - private_data->buffer_int64 = std::move(buffersArrow.int64[col]); - arrow_array_col_buffers[1] = private_data->buffer_int64.get(); - break; - case SQL_REAL: - case SQL_FLOAT: - case SQL_DOUBLE: - private_data->buffer_float64 = std::move(buffersArrow.float64[col]); - arrow_array_col_buffers[1] = private_data->buffer_float64.get(); - break; - case SQL_DECIMAL: - case SQL_NUMERIC: { - private_data->buffer_decimal = std::move(buffersArrow.decimal[col]); - arrow_array_col_buffers[1] = private_data->buffer_decimal.get(); - break; - } - case SQL_TIMESTAMP: - case SQL_TYPE_TIMESTAMP: - case SQL_DATETIME: - private_data->buffer_ts_micro = std::move(buffersArrow.ts_micro[col]); - arrow_array_col_buffers[1] = private_data->buffer_ts_micro.get(); - break; - case SQL_SS_TIMESTAMPOFFSET: - private_data->buffer_ts_micro = std::move(buffersArrow.ts_micro[col]); - arrow_array_col_buffers[1] = private_data->buffer_ts_micro.get(); - break; - case SQL_TYPE_DATE: - private_data->buffer_date = std::move(buffersArrow.date[col]); - arrow_array_col_buffers[1] = private_data->buffer_date.get(); - break; - case SQL_TIME: - case SQL_TYPE_TIME: - case SQL_SS_TIME2: - private_data->buffer_time_second = std::move(buffersArrow.time_second[col]); - arrow_array_col_buffers[1] = private_data->buffer_time_second.get(); - break; - case SQL_BIT: - private_data->buffer_bit = std::move(buffersArrow.bit[col]); - arrow_array_col_buffers[1] = private_data->buffer_bit.get(); - break; - default: { - std::ostringstream errorString; - errorString << "Unsupported data type for column ID - " << (col + 1) - << ", Type - " << dataType; - LOG(errorString.str().c_str()); - ThrowStdException(errorString.str()); - break; + // Finally, transfer ownership of arrowArrayBatch and its pointer to pycapsule + py::capsule arrowArrayBatchCapsule; + try { + arrowArrayBatchCapsule = py::capsule(arrowArrayBatch.get(), "arrow_array", [](void* ptr) { + auto arrowArray = static_cast(ptr); + if (arrowArray->release) { + arrowArray->release(arrowArray); } - } - - auto arrow_array_col = new ArrowArray({ - .length = static_cast(idxRowArrow), - .null_count = 0, - .offset = 0, - .n_buffers = arrow_array_col_buffers[2] ? 3 : 2, - .n_children = 0, - .buffers = arrow_array_col_buffers, - .children = nullptr, - .release = [](ArrowArray* array) { - assert(array != nullptr); - assert(array->private_data != nullptr); - assert(array->release != nullptr); - assert(array->children == nullptr); - assert(array->n_children == 0); - delete array->private_data; // Frees all buffer entries - assert(array->buffers != nullptr); - delete[] array->buffers; - array->release = nullptr; - }, - .private_data = private_data, + delete arrowArray; }); - - private_data->buffer_valid = std::move(buffersArrow.valid[col]); - arrow_array_col->buffers[0] = private_data->buffer_valid.get(); - arrow_array_batch->children[col] = arrow_array_col; + } catch (...) { + arrowArrayBatch->release(arrowArrayBatch.get()); + throw; } - - capsules.append(py::capsule((void*)arrow_array_batch, "arrow_array", [](void* ptr) { - auto arrow_array = static_cast(ptr); - if (arrow_array->release) { - arrow_array->release(arrow_array); - } - delete arrow_array; - })); + arrowArrayBatch.release(); + capsules.append(arrowArrayBatchCapsule); return ret; } - // FetchAll_wrap - Fetches all rows of data from the result set. // // @param StatementHandle: Handle to the statement from which data is to be From d2c488153ae7f9b4fe7e5277774742c2b4c0c745 Mon Sep 17 00:00:00 2001 From: ffelixg <142172984+ffelixg@users.noreply.github.com> Date: Sat, 6 Dec 2025 20:12:34 +0100 Subject: [PATCH 10/15] Check returncode for SQLGetData --- mssql_python/pybind/ddbc_bindings.cpp | 96 ++++++++++++++++++++++----- 1 file changed, 80 insertions(+), 16 deletions(-) diff --git a/mssql_python/pybind/ddbc_bindings.cpp b/mssql_python/pybind/ddbc_bindings.cpp index c706bb2b..06b06b6f 100644 --- a/mssql_python/pybind/ddbc_bindings.cpp +++ b/mssql_python/pybind/ddbc_bindings.cpp @@ -4523,174 +4523,238 @@ SQLRETURN FetchArrowBatch_wrap( case SQL_BINARY: case SQL_VARBINARY: case SQL_LONGVARBINARY: { - GetDataVar( + ret = GetDataVar( hStmt, idxCol + 1, SQL_C_BINARY, buffers.charBuffers[idxCol], buffers.indicators[idxCol].data() ); + if (!SQL_SUCCEEDED(ret)) { + LOG("Error fetching LOB for column %d", idxCol + 1); + return ret; + } break; } case SQL_CHAR: case SQL_VARCHAR: case SQL_LONGVARCHAR: { - GetDataVar( + ret = GetDataVar( hStmt, idxCol + 1, SQL_C_CHAR, buffers.charBuffers[idxCol], buffers.indicators[idxCol].data() ); + if (!SQL_SUCCEEDED(ret)) { + LOG("Error fetching LOB for column %d", idxCol + 1); + return ret; + } break; } case SQL_SS_XML: case SQL_WCHAR: case SQL_WVARCHAR: case SQL_WLONGVARCHAR: { - GetDataVar( + ret = GetDataVar( hStmt, idxCol + 1, SQL_C_WCHAR, buffers.wcharBuffers[idxCol], buffers.indicators[idxCol].data() ); + if (!SQL_SUCCEEDED(ret)) { + LOG("Error fetching binary data for column %d", idxCol + 1); + return ret; + } break; } case SQL_INTEGER: { buffers.intBuffers[idxCol].resize(1); - SQLGetData_ptr( + ret = SQLGetData_ptr( hStmt, idxCol + 1, SQL_C_SLONG, buffers.intBuffers[idxCol].data(), sizeof(SQLINTEGER), buffers.indicators[idxCol].data() ); + if (!SQL_SUCCEEDED(ret)) { + LOG("Error fetching SLONG data for column %d", idxCol + 1); + return ret; + } break; } case SQL_SMALLINT: { buffers.smallIntBuffers[idxCol].resize(1); - SQLGetData_ptr( + ret = SQLGetData_ptr( hStmt, idxCol + 1, SQL_C_SSHORT, buffers.smallIntBuffers[idxCol].data(), sizeof(SQLSMALLINT), buffers.indicators[idxCol].data() ); + if (!SQL_SUCCEEDED(ret)) { + LOG("Error fetching SSHORT data for column %d", idxCol + 1); + return ret; + } break; } case SQL_TINYINT: { buffers.charBuffers[idxCol].resize(1); - SQLGetData_ptr( + ret = SQLGetData_ptr( hStmt, idxCol + 1, SQL_C_TINYINT, buffers.charBuffers[idxCol].data(), sizeof(SQLCHAR), buffers.indicators[idxCol].data() ); + if (!SQL_SUCCEEDED(ret)) { + LOG("Error fetching TINYINT data for column %d", idxCol + 1); + return ret; + } break; } case SQL_BIT: { buffers.charBuffers[idxCol].resize(1); - SQLGetData_ptr( + ret = SQLGetData_ptr( hStmt, idxCol + 1, SQL_C_BIT, buffers.charBuffers[idxCol].data(), sizeof(SQLCHAR), buffers.indicators[idxCol].data() ); + if (!SQL_SUCCEEDED(ret)) { + LOG("Error fetching BIT data for column %d", idxCol + 1); + return ret; + } break; } case SQL_REAL: { buffers.realBuffers[idxCol].resize(1); - SQLGetData_ptr( + ret = SQLGetData_ptr( hStmt, idxCol + 1, SQL_C_FLOAT, buffers.realBuffers[idxCol].data(), sizeof(SQLREAL), buffers.indicators[idxCol].data() ); + if (!SQL_SUCCEEDED(ret)) { + LOG("Error fetching FLOAT data for column %d", idxCol + 1); + return ret; + } break; } case SQL_DECIMAL: case SQL_NUMERIC: { buffers.charBuffers[idxCol].resize(MAX_DIGITS_IN_NUMERIC); - SQLGetData_ptr( + ret = SQLGetData_ptr( hStmt, idxCol + 1, SQL_C_CHAR, buffers.charBuffers[idxCol].data(), MAX_DIGITS_IN_NUMERIC * sizeof(SQLCHAR), buffers.indicators[idxCol].data() ); + if (!SQL_SUCCEEDED(ret)) { + LOG("Error fetching CHAR data for column %d", idxCol + 1); + return ret; + } break; } case SQL_DOUBLE: case SQL_FLOAT: { buffers.doubleBuffers[idxCol].resize(1); - SQLGetData_ptr( + ret = SQLGetData_ptr( hStmt, idxCol + 1, SQL_C_DOUBLE, buffers.doubleBuffers[idxCol].data(), sizeof(SQLDOUBLE), buffers.indicators[idxCol].data() ); + if (!SQL_SUCCEEDED(ret)) { + LOG("Error fetching DOUBLE data for column %d", idxCol + 1); + return ret; + } break; } case SQL_TIMESTAMP: case SQL_TYPE_TIMESTAMP: case SQL_DATETIME: { buffers.timestampBuffers[idxCol].resize(1); - SQLGetData_ptr( + ret = SQLGetData_ptr( hStmt, idxCol + 1, SQL_C_TYPE_TIMESTAMP, buffers.timestampBuffers[idxCol].data(), sizeof(SQL_TIMESTAMP_STRUCT), buffers.indicators[idxCol].data() ); + if (!SQL_SUCCEEDED(ret)) { + LOG("Error fetching TYPE_TIMESTAMP data for column %d", idxCol + 1); + return ret; + } break; } case SQL_BIGINT: { buffers.bigIntBuffers[idxCol].resize(1); - SQLGetData_ptr( + ret = SQLGetData_ptr( hStmt, idxCol + 1, SQL_C_SBIGINT, buffers.bigIntBuffers[idxCol].data(), sizeof(SQLBIGINT), buffers.indicators[idxCol].data() ); + if (!SQL_SUCCEEDED(ret)) { + LOG("Error fetching SBIGINT data for column %d", idxCol + 1); + return ret; + } break; } case SQL_TYPE_DATE: { buffers.dateBuffers[idxCol].resize(1); - SQLGetData_ptr( + ret = SQLGetData_ptr( hStmt, idxCol + 1, SQL_C_TYPE_DATE, buffers.dateBuffers[idxCol].data(), sizeof(SQL_DATE_STRUCT), buffers.indicators[idxCol].data() ); + if (!SQL_SUCCEEDED(ret)) { + LOG("Error fetching TYPE_DATE data for column %d", idxCol + 1); + return ret; + } break; } case SQL_TIME: case SQL_TYPE_TIME: case SQL_SS_TIME2: { buffers.timeBuffers[idxCol].resize(1); - SQLGetData_ptr( + ret = SQLGetData_ptr( hStmt, idxCol + 1, SQL_C_TYPE_TIME, buffers.timeBuffers[idxCol].data(), sizeof(SQL_TIME_STRUCT), buffers.indicators[idxCol].data() ); + if (!SQL_SUCCEEDED(ret)) { + LOG("Error fetching TYPE_TIME data for column %d", idxCol + 1); + return ret; + } break; } case SQL_GUID: { buffers.guidBuffers[idxCol].resize(1); - SQLGetData_ptr( + ret = SQLGetData_ptr( hStmt, idxCol + 1, SQL_C_GUID, buffers.guidBuffers[idxCol].data(), sizeof(SQLGUID), buffers.indicators[idxCol].data() ); + if (!SQL_SUCCEEDED(ret)) { + LOG("Error fetching GUID data for column %d", idxCol + 1); + return ret; + } break; } case SQL_SS_TIMESTAMPOFFSET: { buffers.datetimeoffsetBuffers[idxCol].resize(1); - SQLGetData_ptr( + ret = SQLGetData_ptr( hStmt, idxCol + 1, SQL_C_SS_TIMESTAMPOFFSET, buffers.datetimeoffsetBuffers[idxCol].data(), sizeof(DateTimeOffset), buffers.indicators[idxCol].data() ); + if (!SQL_SUCCEEDED(ret)) { + LOG("Error fetching SS_TIMESTAMPOFFSET data for column %d", idxCol + 1); + return ret; + } break; } default: { From 05c204aabc406c40c737482cd7a2816c30896175 Mon Sep 17 00:00:00 2001 From: ffelixg <142172984+ffelixg@users.noreply.github.com> Date: Sat, 6 Dec 2025 20:31:21 +0100 Subject: [PATCH 11/15] Fix null count array attribute --- mssql_python/pybind/ddbc_bindings.cpp | 5 ++++- tests/test_004_cursor.py | 26 +++++++++++++++++++++++++- 2 files changed, 29 insertions(+), 2 deletions(-) diff --git a/mssql_python/pybind/ddbc_bindings.cpp b/mssql_python/pybind/ddbc_bindings.cpp index 06b06b6f..ad960e73 100644 --- a/mssql_python/pybind/ddbc_bindings.cpp +++ b/mssql_python/pybind/ddbc_bindings.cpp @@ -4307,6 +4307,7 @@ SQLRETURN FetchArrowBatch_wrap( std::vector columnSizes(numCols); std::vector columnNullable(numCols); std::vector columnVarLen(numCols, false); + std::vector nullCounts(numCols, 0); std::vector> arrowArrayPrivateData(numCols); std::vector> arrowSchemaPrivateData(numCols); @@ -4796,6 +4797,8 @@ SQLRETURN FetchArrowBatch_wrap( default: break; } + + nullCounts[idxCol] += 1; continue; } else if (dataLen < 0) { // Negative value is unexpected, log column index, SQL type & raise exception @@ -5118,7 +5121,7 @@ SQLRETURN FetchArrowBatch_wrap( *arrowArrayBatchChildPointers[col] = { .length = static_cast(idxRowArrow), - .null_count = 0, + .null_count = nullCounts[col], .offset = 0, .n_buffers = columnVarLen[col] ? 3 : 2, .n_children = 0, diff --git a/tests/test_004_cursor.py b/tests/test_004_cursor.py index f8a5daa9..829806f8 100644 --- a/tests/test_004_cursor.py +++ b/tests/test_004_cursor.py @@ -20,8 +20,11 @@ try: import pyarrow as pa + import pyarrow.parquet as pq + import io except ImportError: pa = None + pq = None # Setup test table @@ -15138,12 +15141,17 @@ def _test_arrow_test_data(cursor: mssql_python.Cursor, arrow_test_data, fetch_le full_query = "\nunion all\n".join(selects) ret = cursor.execute(full_query).arrow_batch(fetch_length) for i_col, col in enumerate(ret): + expected_data = arrow_test_data[i_col][2][:fetch_length] for i_row, (v_expected, v_actual) in enumerate( - zip(arrow_test_data[i_col][2][:fetch_length], col.to_pylist(), strict=True) + zip(expected_data, col.to_pylist(), strict=True) ): assert ( v_expected == v_actual ), f"Mismatch in column {i_col}, row {i_row}: expected {v_expected}, got {v_actual}" + # check that null counts match + expected_null_count = sum(1 for v in expected_data if v is None) + actual_null_count = col.null_count + assert expected_null_count == actual_null_count, (expected_null_count, actual_null_count) for i_col, (pa_type, sql_type, values) in enumerate(arrow_test_data): field = ret.schema.field(i_col) assert ( @@ -15153,6 +15161,22 @@ def _test_arrow_test_data(cursor: mssql_python.Cursor, arrow_test_data, fetch_le pa_type ), f"Column {i_col} type mismatch: expected {pa_type}, got {field.type}" + # Validate that Parquet serialization/deserialization does not detect any issues + tbl = pa.Table.from_batches([ret]) + # for some reason parquet converts seconds to milliseconds in time32 + for i_col, col in enumerate(tbl.columns): + if col.type == pa.time32("s"): + tbl = tbl.set_column( + i_col, + tbl.schema.field(i_col).name, + col.cast(pa.time32("ms")), + ) + buffer = io.BytesIO() + pq.write_table(tbl, buffer) + buffer.seek(0) + read_table = pq.read_table(buffer) + assert read_table.equals(tbl) + @pytest.mark.skipif(pa is None, reason="pyarrow is not installed") def test_arrow_lob_wide(cursor: mssql_python.Cursor): From 582f366cb0cbb32d5ab3a08d2bacf70d6528d4bd Mon Sep 17 00:00:00 2001 From: ffelixg <142172984+ffelixg@users.noreply.github.com> Date: Tue, 9 Dec 2025 21:51:51 +0100 Subject: [PATCH 12/15] Replace __int128_t by custom Int128_t for compatibility --- mssql_python/pybind/ddbc_bindings.cpp | 55 ++++++++++++++++++++++++--- tests/test_004_cursor.py | 16 ++++++++ 2 files changed, 66 insertions(+), 5 deletions(-) diff --git a/mssql_python/pybind/ddbc_bindings.cpp b/mssql_python/pybind/ddbc_bindings.cpp index ad960e73..c03dd1cf 100644 --- a/mssql_python/pybind/ddbc_bindings.cpp +++ b/mssql_python/pybind/ddbc_bindings.cpp @@ -157,6 +157,50 @@ struct NumericData { } }; +struct Int128_t { + uint64_t low; + int64_t high; + + Int128_t() : low(0), high(0) {} + Int128_t(uint64_t l, int64_t h) : low(l), high(h) {} + + Int128_t multiply_by_10() const { + // value * 10 = (value * 8) + (value * 2) + Int128_t shift3 = *this << 3; + Int128_t shift1 = *this << 1; + return shift3 + shift1; + } + + Int128_t operator<<(int shift) const { + // These would require special cases. We only shift by 1 and 3 for multiply_by_10. + assert(shift > 0); + assert(shift < 64); + uint64_t new_low = low << shift; + uint64_t new_high = (static_cast(high) << shift) | (low >> (64 - shift)); + return {new_low, static_cast(new_high)}; + } + + Int128_t operator+(const Int128_t& other) const { + uint64_t sum_low = low + other.low; + uint64_t carry = (sum_low < low) ? 1 : 0; + int64_t sum_high = high + other.high + carry; + return {sum_low, sum_high}; + } + + Int128_t operator+(uint64_t digit) const { + uint64_t sum_low = low + digit; + uint64_t carry = (sum_low < low) ? 1 : 0; + int64_t sum_high = high + carry; + return {sum_low, sum_high}; + } + + Int128_t operator-() const { + uint64_t new_low = ~low + 1; + uint64_t new_high = ~high + (new_low == 0 ? 1 : 0); + return {new_low, static_cast(new_high)}; + } +}; + struct ArrowArrayPrivateData { std::unique_ptr valid; @@ -170,7 +214,7 @@ struct ArrowArrayPrivateData { std::unique_ptr dateVal; std::unique_ptr tsMicroVal; std::unique_ptr timeSecondVal; - std::unique_ptr<__int128_t[]> decimalVal; + std::unique_ptr decimalVal; std::vector varData; @@ -4405,7 +4449,7 @@ SQLRETURN FetchArrowBatch_wrap( arrowSchemaPrivateData[i]->format = std::make_unique(formatLen); std::memcpy(arrowSchemaPrivateData[i]->format.get(), formatStr.c_str(), formatLen); format = arrowSchemaPrivateData[i]->format.get(); - arrowColumnProducer->decimalVal = std::make_unique<__int128_t[]>(arrowBatchSize); + arrowColumnProducer->decimalVal = std::make_unique(arrowBatchSize); arrowColumnProducer->ptrValueBuffer = arrowColumnProducer->decimalVal.get(); break; } @@ -4911,8 +4955,9 @@ SQLRETURN FetchArrowBatch_wrap( break; case SQL_DECIMAL: case SQL_NUMERIC: { + // Relies on overloaded operators defined in Int128_t struct assert(dataLen <= MAX_DIGITS_IN_NUMERIC); - __int128_t decimalValue = 0; + Int128_t decimalValue(0, 0); auto start = idxRowSql * MAX_DIGITS_IN_NUMERIC; int sign = 1; for (SQLULEN idx = start; idx < start + dataLen; idx++) { @@ -4920,10 +4965,10 @@ SQLRETURN FetchArrowBatch_wrap( if (digitChar == '-') { sign = -1; } else if (digitChar >= '0' && digitChar <= '9') { - decimalValue = decimalValue * 10 + (digitChar - '0'); + decimalValue = decimalValue.multiply_by_10() + (uint64_t)(digitChar - '0'); } } - arrowColumnProducer->decimalVal[idxRowArrow] = decimalValue * sign; + arrowColumnProducer->decimalVal[idxRowArrow] = (sign > 0) ? decimalValue : -decimalValue; break; } case SQL_TIMESTAMP: diff --git a/tests/test_004_cursor.py b/tests/test_004_cursor.py index 829806f8..567a8cc2 100644 --- a/tests/test_004_cursor.py +++ b/tests/test_004_cursor.py @@ -15058,6 +15058,22 @@ def get_arrow_test_data(include_lobs: bool, batch_length: int): decimal.Decimal("9999999999999999999999999999.9999999999"), ], ), + ( + pa.decimal128(precision=38, scale=0), + "decimal(38, 0)", + [ + decimal.Decimal(str(2**63)), + decimal.Decimal(str(-(2**63))), + decimal.Decimal(str(2**64)), + decimal.Decimal(str(-(2**64))), + decimal.Decimal(str(2**64 - 1)), + decimal.Decimal(str(-(2**64 - 1))), + decimal.Decimal(str(2**64 + 1)), + decimal.Decimal(str(-(2**64 + 1))), + decimal.Decimal(str(2**96)), + decimal.Decimal(str(-(2**96))), + ], + ), (pa.bool_(), "bit", [True, None, False]), (pa.binary(), "binary(9)", [b"asdfghjkl", None, b"lkjhgfdsa"]), (pa.string(), "varchar(100)", ["asdfghjkl", None, "lkjhgfdsa"]), From 65c66be44d4e53633533e3c5091aa9df028871fc Mon Sep 17 00:00:00 2001 From: ffelixg <142172984+ffelixg@users.noreply.github.com> Date: Tue, 9 Dec 2025 22:12:12 +0100 Subject: [PATCH 13/15] Vendor days_from_civil to replace std::mktime --- mssql_python/pybind/ddbc_bindings.cpp | 31 +++++++++++---------------- tests/test_004_cursor.py | 15 ++++++++++++- 2 files changed, 26 insertions(+), 20 deletions(-) diff --git a/mssql_python/pybind/ddbc_bindings.cpp b/mssql_python/pybind/ddbc_bindings.cpp index c03dd1cf..07bee848 100644 --- a/mssql_python/pybind/ddbc_bindings.cpp +++ b/mssql_python/pybind/ddbc_bindings.cpp @@ -4305,22 +4305,15 @@ SQLRETURN GetDataVar(SQLHSTMT hStmt, return SQL_SUCCESS; } -int32_t dateAsDayCount(SQLUSMALLINT year, SQLUSMALLINT month, SQLUSMALLINT day) { - // Convert SQL_DATE_STRUCT to Arrow Date32 (days since epoch) - std::tm tm_date = {}; - tm_date.tm_year = year - 1900; // tm_year is years since 1900 - tm_date.tm_mon = month - 1; // tm_mon is 0-11 - tm_date.tm_mday = day; - - std::time_t time_since_epoch = std::mktime(&tm_date); - if (time_since_epoch == -1) { - LOG("Failed to convert SQL_DATE_STRUCT to time_t"); - ThrowStdException("Date conversion error"); - } - // Sanity check against timezone issues. Since we only provide the date, this has to be true - assert(time_since_epoch % 86400 == 0); - // Calculate days since epoch - return time_since_epoch / 86400; +int32_t days_from_civil(int y, int m, int d) { + // Implements the "days_from_civil" algorithm by Howard Hinnant + // Returns number of days since Unix epoch (1970-01-01) + y -= m <= 2; + const int era = (y >= 0 ? y : y - 399) / 400; + const unsigned yoe = static_cast(y - era * 400); // [0, 399] + const unsigned doy = (153 * (m + (m > 2 ? -3 : 9)) + 2) / 5 + d - 1; // [0, 365] + const unsigned doe = yoe * 365 + yoe / 4 - yoe / 100 + doy; // [0, 146096] + return era * 146097 + static_cast(doe) - 719468; } SQLRETURN FetchArrowBatch_wrap( @@ -4975,7 +4968,7 @@ SQLRETURN FetchArrowBatch_wrap( case SQL_TYPE_TIMESTAMP: case SQL_DATETIME: { SQL_TIMESTAMP_STRUCT sql_value = buffers.timestampBuffers[idxCol][idxRowSql]; - int64_t days = dateAsDayCount( + int64_t days = days_from_civil( sql_value.year, sql_value.month, sql_value.day @@ -4990,7 +4983,7 @@ SQLRETURN FetchArrowBatch_wrap( } case SQL_SS_TIMESTAMPOFFSET: { DateTimeOffset sql_value = buffers.datetimeoffsetBuffers[idxCol][idxRowSql]; - int64_t days = dateAsDayCount( + int64_t days = days_from_civil( sql_value.year, sql_value.month, sql_value.day @@ -5004,7 +4997,7 @@ SQLRETURN FetchArrowBatch_wrap( break; } case SQL_TYPE_DATE: - arrowColumnProducer->dateVal[idxRowArrow] = dateAsDayCount( + arrowColumnProducer->dateVal[idxRowArrow] = days_from_civil( buffers.dateBuffers[idxCol][idxRowSql].year, buffers.dateBuffers[idxCol][idxRowSql].month, buffers.dateBuffers[idxCol][idxRowSql].day diff --git a/tests/test_004_cursor.py b/tests/test_004_cursor.py index 567a8cc2..85ea93aa 100644 --- a/tests/test_004_cursor.py +++ b/tests/test_004_cursor.py @@ -15079,7 +15079,20 @@ def get_arrow_test_data(include_lobs: bool, batch_length: int): (pa.string(), "varchar(100)", ["asdfghjkl", None, "lkjhgfdsa"]), (pa.string(), "nvarchar(100)", ["asdfghjkl", None, "lkjhgfdsa"]), (pa.string(), "uniqueidentifier", ["58185E0D-3A91-44D8-BC46-7107217E0A6D", None]), - (pa.date32(), "date", [date(1, 1, 1), None, date(2345, 12, 31), date(9999, 12, 31)]), + ( + pa.date32(), + "date", + [ + date(1, 1, 1), + None, + date(2345, 12, 31), + date(9999, 12, 31), + date(1970, 1, 1), + date(1969, 12, 31), + date(2000, 2, 29), + date(2001, 2, 28), + ], + ), ( pa.time32("s"), "time(0)", From 62056c8f3748b278a9ce025ab1f2cc24236258dc Mon Sep 17 00:00:00 2001 From: ffelixg <142172984+ffelixg@users.noreply.github.com> Date: Tue, 9 Dec 2025 22:14:59 +0100 Subject: [PATCH 14/15] Expand test to make sure datetimeoffset via SQLGetData is covered --- tests/test_004_cursor.py | 28 +++++++++++++++------------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/tests/test_004_cursor.py b/tests/test_004_cursor.py index 85ea93aa..9431c0a2 100644 --- a/tests/test_004_cursor.py +++ b/tests/test_004_cursor.py @@ -15274,19 +15274,21 @@ def test_arrow_no_result_set(cursor: mssql_python.Cursor): @pytest.mark.skipif(pa is None, reason="pyarrow is not installed") def test_arrow_datetimeoffset(cursor: mssql_python.Cursor): "Datetimeoffset converts correctly to utc" - cursor.execute( - "declare @dt datetimeoffset(0) = '2345-02-03 12:34:56 +00:00';\n" - "select @dt, @dt at time zone 'Pacific Standard Time';\n" - ) - batch = cursor.arrow_batch(10) - assert batch.num_rows == 1 - assert batch.num_columns == 2 - for col in batch.columns: - assert pa.types.is_timestamp(col.type) - assert col.type.tz == "+00:00", col.type.tz - assert col.to_pylist() == [ - datetime(2345, 2, 3, 12, 34, 56, tzinfo=timezone.utc), - ] + for force_sqlgetdata in (False, True): + str_val = "cast('asdf' as nvarchar(max))" if force_sqlgetdata else "'asdf'" + cursor.execute( + "declare @dt datetimeoffset(0) = '2345-02-03 12:34:56 +00:00';\n" + f"select {str_val}, @dt, @dt at time zone 'Pacific Standard Time';\n" + ) + batch = cursor.arrow_batch(10) + assert batch.num_rows == 1 + assert batch.num_columns == 3 + for col in batch.columns[1:]: + assert pa.types.is_timestamp(col.type) + assert col.type.tz == "+00:00", col.type.tz + assert col.to_pylist() == [ + datetime(2345, 2, 3, 12, 34, 56, tzinfo=timezone.utc), + ] @pytest.mark.skipif(pa is None, reason="pyarrow is not installed") From b4c44fdddace4036565872fc3071d716616a0b48 Mon Sep 17 00:00:00 2001 From: ffelixg <142172984+ffelixg@users.noreply.github.com> Date: Wed, 7 Jan 2026 00:27:42 +0100 Subject: [PATCH 15/15] Fix compilation on windows --- mssql_python/pybind/ddbc_bindings.cpp | 94 ++++++++++++++------------- 1 file changed, 49 insertions(+), 45 deletions(-) diff --git a/mssql_python/pybind/ddbc_bindings.cpp b/mssql_python/pybind/ddbc_bindings.cpp index 07bee848..cad1509e 100644 --- a/mssql_python/pybind/ddbc_bindings.cpp +++ b/mssql_python/pybind/ddbc_bindings.cpp @@ -4319,9 +4319,9 @@ int32_t days_from_civil(int y, int m, int d) { SQLRETURN FetchArrowBatch_wrap( SqlHandlePtr StatementHandle, py::list& capsules, - ssize_t arrowBatchSize + int arrowBatchSize ) { - ssize_t fetchSize = arrowBatchSize; + int fetchSize = arrowBatchSize; SQLRETURN ret; SQLHSTMT hStmt = StatementHandle->get(); // Retrieve column count @@ -4477,7 +4477,6 @@ SQLRETURN FetchArrowBatch_wrap( arrowColumnProducer->ptrValueBuffer = arrowColumnProducer->bitVal.get(); break; default: - std::wstring columnName = colMeta["ColumnName"].cast(); std::ostringstream errorString; errorString << "Unsupported data type for Arrow batch fetch for column - " << columnName.c_str() << ", Type - " << dataType << ", column ID - " << (i + 1); @@ -4806,9 +4805,9 @@ SQLRETURN FetchArrowBatch_wrap( } } - SQLLEN dataLen = buffers.indicators[idxCol][idxRowSql]; + SQLLEN indicator = buffers.indicators[idxCol][idxRowSql]; - if (dataLen == SQL_NULL_DATA) { + if (indicator == SQL_NULL_DATA) { // Mark as null in validity bitmap size_t bytePos = idxRowArrow / 8; size_t bitPos = idxRowArrow % 8; @@ -4837,11 +4836,12 @@ SQLRETURN FetchArrowBatch_wrap( nullCounts[idxCol] += 1; continue; - } else if (dataLen < 0) { + } else if (indicator < 0) { // Negative value is unexpected, log column index, SQL type & raise exception - LOG("Unexpected negative data length. Column ID - {}, SQL Type - {}, Data Length - {}", idxCol + 1, dataType, dataLen); + LOG("Unexpected negative data length. Column ID - {}, SQL Type - {}, Data Length - {}", idxCol + 1, dataType, indicator); ThrowStdException("Unexpected negative data length."); } + auto dataLen = static_cast(indicator); switch (dataType) { case SQL_BINARY: @@ -4883,11 +4883,11 @@ SQLRETURN FetchArrowBatch_wrap( auto target_vec = &arrowColumnProducer->varData; #if defined(_WIN32) // Convert wide string - int dataLenConverted = WideCharToMultiByte(CP_UTF8, 0, wcharSource, dataLenW, NULL, 0, NULL, NULL); + int dataLenConverted = WideCharToMultiByte(CP_UTF8, 0, wcharSource, static_cast(dataLenW), NULL, 0, NULL, NULL); while (target_vec->size() < start + dataLenConverted) { target_vec->resize(target_vec->size() * 2); } - WideCharToMultiByte(CP_UTF8, 0, wcharSource, dataLenW, &(*target_vec)[start], dataLenConverted, NULL, NULL); + WideCharToMultiByte(CP_UTF8, 0, wcharSource, static_cast(dataLenW), reinterpret_cast(&(*target_vec)[start]), dataLenConverted, NULL, NULL); arrowColumnProducer->varVal[idxRowArrow + 1] = start + dataLenConverted; #else // On Unix, use the SQLWCHARToWString utility and then convert to UTF-8 @@ -5066,14 +5066,14 @@ SQLRETURN FetchArrowBatch_wrap( for (SQLSMALLINT i = 0; i < numCols; i++) { *arrowSchemaBatchChildPointers[i] = { - .format = arrowSchemaPrivateData[i]->format.get(), - .name = arrowSchemaPrivateData[i]->name.get(), - .metadata = nullptr, - .flags = static_cast(columnNullable[i] ? ARROW_FLAG_NULLABLE : 0), - .n_children = 0, - .children = nullptr, - .dictionary = nullptr, - .release = [](ArrowSchema* schema) { + arrowSchemaPrivateData[i]->format.get(), + arrowSchemaPrivateData[i]->name.get(), + nullptr, + static_cast(columnNullable[i] ? ARROW_FLAG_NULLABLE : 0), + 0, + nullptr, + nullptr, + [](ArrowSchema* schema) { assert(schema != nullptr); assert(schema->release != nullptr); assert(schema->private_data != nullptr); @@ -5081,7 +5081,7 @@ SQLRETURN FetchArrowBatch_wrap( delete schema->private_data; // Frees format and name schema->release = nullptr; }, - .private_data = arrowSchemaPrivateData[i].release(), + arrowSchemaPrivateData[i].release(), }; } @@ -5089,15 +5089,15 @@ SQLRETURN FetchArrowBatch_wrap( arrowSchemaBatchChildren[i] = arrowSchemaBatchChildPointers[i].release(); } - *arrowSchemaBatch = ArrowSchema{ - .format = "+s", - .name = "", - .metadata = nullptr, - .flags = 0, - .n_children = numCols, - .children = arrowSchemaBatchChildren.release(), - .dictionary = nullptr, - .release = [](ArrowSchema* schema) { + *arrowSchemaBatch = { + "+s", + "", + nullptr, + 0, + numCols, + arrowSchemaBatchChildren.release(), + nullptr, + [](ArrowSchema* schema) { // format and name are string literals, no need to free assert(schema != nullptr); assert(schema->release != nullptr); @@ -5115,7 +5115,7 @@ SQLRETURN FetchArrowBatch_wrap( delete[] schema->children; schema->release = nullptr; }, - .private_data = nullptr, + nullptr, }; // Finally, transfer ownership of arrowSchemaBatch and its pointer to pycapsule @@ -5152,20 +5152,20 @@ SQLRETURN FetchArrowBatch_wrap( // No unhandled exceptions until the pycapsule owns the arrowArrayBatch to avoid memory leaks for (SQLUSMALLINT col = 0; col < numCols; col++) { - auto dataType = dataTypes[col]; arrowArrayPrivateData[col]->buffers[0] = arrowArrayPrivateData[col]->valid.get(); arrowArrayPrivateData[col]->buffers[1] = arrowArrayPrivateData[col]->ptrValueBuffer; arrowArrayPrivateData[col]->buffers[2] = arrowArrayPrivateData[col]->varData.data(); *arrowArrayBatchChildPointers[col] = { - .length = static_cast(idxRowArrow), - .null_count = nullCounts[col], - .offset = 0, - .n_buffers = columnVarLen[col] ? 3 : 2, - .n_children = 0, - .buffers = (const void**)arrowArrayPrivateData[col]->buffers.data(), - .children = nullptr, - .release = [](ArrowArray* array) { + static_cast(idxRowArrow), + nullCounts[col], + 0, + columnVarLen[col] ? 3 : 2, + 0, + (const void**)arrowArrayPrivateData[col]->buffers.data(), + nullptr, + nullptr, + [](ArrowArray* array) { assert(array != nullptr); assert(array->private_data != nullptr); assert(array->release != nullptr); @@ -5175,7 +5175,7 @@ SQLRETURN FetchArrowBatch_wrap( assert(array->buffers != nullptr); array->release = nullptr; }, - .private_data = arrowArrayPrivateData[col].release(), + arrowArrayPrivateData[col].release(), }; } @@ -5183,13 +5183,16 @@ SQLRETURN FetchArrowBatch_wrap( arrowArrayBatchChildren[i] = arrowArrayBatchChildPointers[i].release(); } - *arrowArrayBatch = ArrowArray{ - .length = static_cast(idxRowArrow), - .n_buffers = 1, - .n_children = numCols, - .buffers = arrowArrayBatchBuffers.release(), - .children = arrowArrayBatchChildren.release(), - .release = [](ArrowArray* array) { + *arrowArrayBatch = { + static_cast(idxRowArrow), + 0, + 0, + 1, + numCols, + arrowArrayBatchBuffers.release(), + arrowArrayBatchChildren.release(), + nullptr, + [](ArrowArray* array) { assert(array != nullptr); assert(array->private_data == nullptr); assert(array->release != nullptr); @@ -5210,6 +5213,7 @@ SQLRETURN FetchArrowBatch_wrap( delete[] array->buffers; array->release = nullptr; }, + nullptr, }; // Finally, transfer ownership of arrowArrayBatch and its pointer to pycapsule