Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 68 additions & 2 deletions examples/query_tags_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,23 @@
Query Tags are key-value pairs that can be attached to SQL executions and will appear
in the system.query.history table for analytical purposes.

Format: "key1:value1,key2:value2,key3:value3"
There are two ways to set query tags:
1. Session-level: Set in session_configuration (applies to all queries in the session)
2. Per-query level: Pass query_tags parameter to execute() or execute_async() (applies to specific query)

Format: Dictionary with string keys and optional string values
Example: {"team": "engineering", "application": "etl", "priority": "high"}

Special cases:
- If a value is None, only the key is included (no colon or value)
- Special characters (:, ,, \\) in values are automatically escaped
- Keys are not escaped (should be controlled identifiers)
"""

print("=== Query Tags Example ===\n")

# Example 1: Session-level query tags (old approach)
print("Example 1: Session-level query tags")
with sql.connect(
server_hostname=os.getenv("DATABRICKS_SERVER_HOSTNAME"),
http_path=os.getenv("DATABRICKS_HTTP_PATH"),
Expand All @@ -21,10 +33,64 @@
'ansi_mode': False
}
) as connection:

with connection.cursor() as cursor:
cursor.execute("SELECT 1")
result = cursor.fetchone()
print(f" Result: {result[0]}")

print()

# Example 2: Per-query query tags (new approach)
print("Example 2: Per-query query tags")
with sql.connect(
server_hostname=os.getenv("DATABRICKS_SERVER_HOSTNAME"),
http_path=os.getenv("DATABRICKS_HTTP_PATH"),
access_token=os.getenv("DATABRICKS_TOKEN"),
) as connection:

with connection.cursor() as cursor:
# Query 1: Tags for a critical ETL job
cursor.execute(
"SELECT 1",
query_tags={"team": "data-eng", "application": "etl", "priority": "high"}
)
result = cursor.fetchone()
print(f" ETL Query Result: {result[0]}")

# Query 2: Tags with None value (key-only tag)
cursor.execute(
"SELECT 2",
query_tags={"team": "analytics", "experimental": None}
)
result = cursor.fetchone()
print(f" Experimental Query Result: {result[0]}")

# Query 3: Tags with special characters (automatically escaped)
cursor.execute(
"SELECT 3",
query_tags={"description": "test:with:colons,and,commas"}
)
result = cursor.fetchone()
print(f" Special Chars Query Result: {result[0]}")

print()

# Example 3: Async execution with query tags
print("Example 3: Async execution with query tags")
with sql.connect(
server_hostname=os.getenv("DATABRICKS_SERVER_HOSTNAME"),
http_path=os.getenv("DATABRICKS_HTTP_PATH"),
access_token=os.getenv("DATABRICKS_TOKEN"),
) as connection:

with connection.cursor() as cursor:
cursor.execute_async(
"SELECT 4",
query_tags={"team": "data-eng", "mode": "async"}
)
cursor.get_async_execution_result()
result = cursor.fetchone()
print(f" Async Query Result: {result[0]}")

print("\n=== Query Tags Example Complete ===")
2 changes: 2 additions & 0 deletions src/databricks/sql/backend/databricks_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ def execute_command(
async_op: bool,
enforce_embedded_schema_correctness: bool,
row_limit: Optional[int] = None,
query_tags: Optional[Dict[str, Optional[str]]] = None,
) -> Union[ResultSet, None]:
"""
Executes a SQL command or query within the specified session.
Expand All @@ -102,6 +103,7 @@ def execute_command(
async_op: Whether to execute the command asynchronously
enforce_embedded_schema_correctness: Whether to enforce schema correctness
row_limit: Maximum number of rows in the response.
query_tags: Optional dictionary of query tags to apply for this query only.

Returns:
If async_op is False, returns a ResultSet object containing the
Expand Down
22 changes: 17 additions & 5 deletions src/databricks/sql/backend/thrift_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import math
import time
import threading
from typing import List, Optional, Union, Any, TYPE_CHECKING
from typing import Dict, List, Optional, Union, Any, TYPE_CHECKING
from uuid import UUID

from databricks.sql.common.unified_http_client import UnifiedHttpClient
Expand Down Expand Up @@ -53,6 +53,7 @@
convert_arrow_based_set_to_arrow_table,
convert_decimals_in_arrow_table,
convert_column_based_set_to_arrow_table,
serialize_query_tags,
)
from databricks.sql.types import SSLOptions
from databricks.sql.backend.databricks_client import DatabricksClient
Expand Down Expand Up @@ -1003,6 +1004,7 @@ def execute_command(
async_op=False,
enforce_embedded_schema_correctness=False,
row_limit: Optional[int] = None,
query_tags: Optional[Dict[str, Optional[str]]] = None,
) -> Union["ResultSet", None]:
thrift_handle = session_id.to_thrift_handle()
if not thrift_handle:
Expand All @@ -1022,6 +1024,19 @@ def execute_command(
# DBR should be changed to use month_day_nano_interval
intervalTypesAsArrow=False,
)

# Build confOverlay with default configs and query_tags
merged_conf_overlay = {
# We want to receive proper Timestamp arrow types.
"spark.thriftserver.arrowBasedRowSet.timestampAsString": "false"
}

# Serialize and add query_tags to confOverlay if provided
if query_tags:
serialized_tags = serialize_query_tags(query_tags)
if serialized_tags:
merged_conf_overlay["query_tags"] = serialized_tags

req = ttypes.TExecuteStatementReq(
sessionHandle=thrift_handle,
statement=operation,
Expand All @@ -1036,10 +1051,7 @@ def execute_command(
canReadArrowResult=True if pyarrow else False,
canDecompressLZ4Result=lz4_compression,
canDownloadResult=use_cloud_fetch,
confOverlay={
# We want to receive proper Timestamp arrow types.
"spark.thriftserver.arrowBasedRowSet.timestampAsString": "false"
},
confOverlay=merged_conf_overlay,
useArrowNativeTypes=spark_arrow_types,
parameters=parameters,
enforceEmbeddedSchemaCorrectness=enforce_embedded_schema_correctness,
Expand Down
11 changes: 11 additions & 0 deletions src/databricks/sql/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1263,6 +1263,7 @@ def execute(
parameters: Optional[TParameterCollection] = None,
enforce_embedded_schema_correctness=False,
input_stream: Optional[BinaryIO] = None,
query_tags: Optional[Dict[str, Optional[str]]] = None,
) -> "Cursor":
"""
Execute a query and wait for execution to complete.
Expand Down Expand Up @@ -1293,6 +1294,10 @@ def execute(
Both will result in the query equivalent to "SELECT * FROM table WHERE field = 'foo'
being sent to the server

:param query_tags: Optional dictionary of query tags to apply for this query only.
Tags are key-value pairs that can be used to identify and categorize queries.
Example: {"team": "data-eng", "application": "etl"}

:returns self
"""

Expand Down Expand Up @@ -1333,6 +1338,7 @@ def execute(
async_op=False,
enforce_embedded_schema_correctness=enforce_embedded_schema_correctness,
row_limit=self.row_limit,
query_tags=query_tags,
)

if self.active_result_set and self.active_result_set.is_staging_operation:
Expand All @@ -1349,13 +1355,17 @@ def execute_async(
operation: str,
parameters: Optional[TParameterCollection] = None,
enforce_embedded_schema_correctness=False,
query_tags: Optional[Dict[str, Optional[str]]] = None,
) -> "Cursor":
"""

Execute a query and do not wait for it to complete and just move ahead

:param operation:
:param parameters:
:param query_tags: Optional dictionary of query tags to apply for this query only.
Tags are key-value pairs that can be used to identify and categorize queries.
Example: {"team": "data-eng", "application": "etl"}
:return:
"""

Expand Down Expand Up @@ -1392,6 +1402,7 @@ def execute_async(
async_op=True,
enforce_embedded_schema_correctness=enforce_embedded_schema_correctness,
row_limit=self.row_limit,
query_tags=query_tags,
)

return self
Expand Down
40 changes: 40 additions & 0 deletions src/databricks/sql/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -898,6 +898,46 @@ def concat_table_chunks(
return pyarrow.concat_tables(table_chunks)


def serialize_query_tags(query_tags: Optional[Dict[str, Optional[str]]]) -> Optional[str]:
"""
Serialize query_tags dictionary to a string format.

Format: "key1:value1,key2:value2"
Special cases:
- If value is None, omit the colon and value (e.g., "key1:value1,key2,key3:value3")
- Escape special characters (:, ,, \\) in values with a leading backslash
- Keys are not escaped (assumed to be controlled identifiers)

Args:
query_tags: Dictionary of query tags where keys are strings and values are optional strings

Returns:
Serialized string or None if query_tags is None or empty
"""
if not query_tags:
return None

def escape_value(value: str) -> str:
"""Escape special characters in tag values."""
# Escape backslash first to avoid double-escaping
value = value.replace("\\", "\\\\")
# Escape colon and comma
value = value.replace(":", "\\:")
value = value.replace(",", "\\,")
return value

serialized_parts = []
for key, value in query_tags.items():
if value is None:
# No colon or value when value is None
serialized_parts.append(key)
else:
escaped_value = escape_value(value)
serialized_parts.append(f"{key}:{escaped_value}")

return ",".join(serialized_parts)


def build_client_context(server_hostname: str, version: str, **kwargs):
"""Build ClientContext for HTTP client configuration."""
from databricks.sql.auth.common import ClientContext
Expand Down
63 changes: 63 additions & 0 deletions tests/unit/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
convert_to_assigned_datatypes_in_column_table,
ColumnTable,
concat_table_chunks,
serialize_query_tags,
)

try:
Expand Down Expand Up @@ -161,3 +162,65 @@ def test_concat_table_chunks__incorrect_column_names_error(self):

with pytest.raises(ValueError):
concat_table_chunks([column_table1, column_table2])

def test_serialize_query_tags_basic(self):
"""Test basic query tags serialization"""
query_tags = {"team": "data-eng", "application": "etl"}
result = serialize_query_tags(query_tags)
assert result == "team:data-eng,application:etl"

def test_serialize_query_tags_with_none_value(self):
"""Test query tags with None value (should omit colon and value)"""
query_tags = {"key1": "value1", "key2": None, "key3": "value3"}
result = serialize_query_tags(query_tags)
assert result == "key1:value1,key2,key3:value3"

def test_serialize_query_tags_with_special_chars(self):
"""Test query tags with special characters (colon, comma, backslash)"""
query_tags = {
"key1": "value:with:colons",
"key2": "value,with,commas",
"key3": "value\\with\\backslashes",
}
result = serialize_query_tags(query_tags)
assert (
result
== "key1:value\\:with\\:colons,key2:value\\,with\\,commas,key3:value\\\\with\\\\backslashes"
)

def test_serialize_query_tags_with_mixed_special_chars(self):
"""Test query tags with mixed special characters"""
query_tags = {"key1": "a:b,c\\d"}
result = serialize_query_tags(query_tags)
assert result == "key1:a\\:b\\,c\\\\d"

def test_serialize_query_tags_empty_dict(self):
"""Test serialization with empty dictionary"""
query_tags = {}
result = serialize_query_tags(query_tags)
assert result is None

def test_serialize_query_tags_none(self):
"""Test serialization with None input"""
result = serialize_query_tags(None)
assert result is None

def test_serialize_query_tags_with_special_chars_in_key(self):
"""Test query tags with special characters in keys (keys are not escaped)"""
query_tags = {
"key:with:colons": "value1",
"key,with,commas": "value2",
"key\\with\\backslashes": "value3",
}
result = serialize_query_tags(query_tags)
# Keys are not escaped, only values are
assert (
result
== "key:with:colons:value1,key,with,commas:value2,key\\with\\backslashes:value3"
)

def test_serialize_query_tags_all_none_values(self):
"""Test query tags where all values are None"""
query_tags = {"key1": None, "key2": None, "key3": None}
result = serialize_query_tags(query_tags)
assert result == "key1,key2,key3"
Loading