diff --git a/py/src/braintrust/wrappers/cassettes/test_embed_content.yaml b/py/src/braintrust/wrappers/cassettes/test_embed_content.yaml new file mode 100644 index 00000000..a3fd3a8d --- /dev/null +++ b/py/src/braintrust/wrappers/cassettes/test_embed_content.yaml @@ -0,0 +1,127 @@ +interactions: +- request: + body: '{"requests": [{"content": {"parts": [{"text": "This is a test"}], "role": + "user"}, "taskType": "RETRIEVAL_DOCUMENT", "outputDimensionality": 32, "model": + "models/text-embedding-004"}, {"content": {"parts": [{"text": "This is another + test"}], "role": "user"}, "taskType": "RETRIEVAL_DOCUMENT", "outputDimensionality": + 32, "model": "models/text-embedding-004"}]}' + headers: + Accept: + - '*/*' + Accept-Encoding: + - gzip, deflate, zstd + Connection: + - keep-alive + Content-Length: + - '360' + Content-Type: + - application/json + Host: + - generativelanguage.googleapis.com + user-agent: + - google-genai-sdk/1.66.0 gl-python/3.13.3 + x-goog-api-client: + - google-genai-sdk/1.66.0 gl-python/3.13.3 + method: POST + uri: https://generativelanguage.googleapis.com/v1beta/models/text-embedding-004:batchEmbedContents + response: + body: + string: !!binary | + H4sIAAAAAAAC/01Pu4oCQRDM9yuKidVbYSMz8TgQ7tTAi2X2ptWB2WmZ7l0E8d+vdQ1MGupBVdet + AhyVwsUtcDNg8I8DGWrqZjISHYn404NznWlJPpSuOqWupRBiPk3rukEUZFYcuc/BbsFyt8ZARSJn + DPOW1E9g9Msn/eXCRWn0PqNWnJWyzrDyKeE7iv4826AMIYKeCclY8BF+8DH5NhHGj+Ct1QyxvCV3 + pGcOMnOvHaJee3nM2Gz3h6/t7+bTmXKv7tU/iyzygQgBAAA= + headers: + Alt-Svc: + - h3=":443"; ma=2592000,h3-29=":443"; ma=2592000 + Content-Encoding: + - gzip + Content-Type: + - application/json; charset=UTF-8 + Date: + - Fri, 20 Mar 2026 17:24:34 GMT + Server: + - scaffolding on HTTPServer2 + Server-Timing: + - gfet4t7; dur=59 + Transfer-Encoding: + - chunked + Vary: + - Origin + - X-Origin + - Referer + X-Content-Type-Options: + - nosniff + X-Frame-Options: + - SAMEORIGIN + X-XSS-Protection: + - '0' + status: + code: 404 + message: Not Found +- request: + body: '{"requests": [{"content": {"parts": [{"text": "This is a test"}], "role": + "user"}, "taskType": "RETRIEVAL_DOCUMENT", "outputDimensionality": 32, "model": + "models/gemini-embedding-001"}, {"content": {"parts": [{"text": "This is another + test"}], "role": "user"}, "taskType": "RETRIEVAL_DOCUMENT", "outputDimensionality": + 32, "model": "models/gemini-embedding-001"}]}' + headers: + Accept: + - '*/*' + Accept-Encoding: + - gzip, deflate, zstd + Connection: + - keep-alive + Content-Length: + - '364' + Content-Type: + - application/json + Host: + - generativelanguage.googleapis.com + user-agent: + - google-genai-sdk/1.66.0 gl-python/3.13.3 + x-goog-api-client: + - google-genai-sdk/1.66.0 gl-python/3.13.3 + method: POST + uri: https://generativelanguage.googleapis.com/v1beta/models/gemini-embedding-001:batchEmbedContents + response: + body: + string: !!binary | + H4sIAAAAAAAC/41Uy24TQRC8+ytWOduo3w9+BeUAioWQgAuCS5R/T60VBU8vB/Zgr2p6q7urq+f5 + tG0P1x9frk9P335+/fXwcfsEZNueb784+/P5++/rX3x/LvSBRFU8us7vKECyiLBSX1Bmt7TM80LQ + ns0sK0hsGZLeKy1V7RwzWNhdsgcxk0iZDQpnAY2OWCRMUcPBwE2lcDiYucHCszh1d/M+UDBlcOp9 + NAd1NK+qsZjIAt6+DyWweqzZsivUD0qYq5MNNMMr6p7h8jakVAmdzWW5Ua6ldahGHdKVOEXaWhoq + qDIfsaUdaG8JFY22nDZRQ3Oso4Bs+sfkLahSpmT4XEhi0LKxar9hj7f/l/P/mRwuVc61TmKHDXhN + EsisQ+hObqqpvieL1zQtNTub2kgV7uGH1WHo3DF96GIKnuMySBYNXoSW6EFWeEV7dSxqaIFrLSYx + U4+5QphWaL1KYwJpDqsEBEMZZZk2fDENy74Xa3O/oG33fWOxd1C4lq4Xmi6E33bTrS4krIHbFAGj + RG09zd1Omof7w1xcRsfYLSjs42ps3a28rhdGE5bHYtEy9ZzkHpxsw9x4nHM192l/ezm9AuHl5j/b + BQAA + headers: + Alt-Svc: + - h3=":443"; ma=2592000,h3-29=":443"; ma=2592000 + Content-Encoding: + - gzip + Content-Type: + - application/json; charset=UTF-8 + Date: + - Fri, 20 Mar 2026 17:28:26 GMT + Server: + - scaffolding on HTTPServer2 + Server-Timing: + - gfet4t7; dur=145 + Transfer-Encoding: + - chunked + Vary: + - Origin + - X-Origin + - Referer + X-Content-Type-Options: + - nosniff + X-Frame-Options: + - SAMEORIGIN + X-XSS-Protection: + - '0' + status: + code: 200 + message: OK +version: 1 diff --git a/py/src/braintrust/wrappers/cassettes/test_embed_content_async.yaml b/py/src/braintrust/wrappers/cassettes/test_embed_content_async.yaml new file mode 100644 index 00000000..c1818fc0 --- /dev/null +++ b/py/src/braintrust/wrappers/cassettes/test_embed_content_async.yaml @@ -0,0 +1,110 @@ +interactions: +- request: + body: '{"requests": [{"content": {"parts": [{"text": "This is a test"}], "role": + "user"}, "taskType": "RETRIEVAL_DOCUMENT", "outputDimensionality": 32, "model": + "models/text-embedding-004"}, {"content": {"parts": [{"text": "This is another + test"}], "role": "user"}, "taskType": "RETRIEVAL_DOCUMENT", "outputDimensionality": + 32, "model": "models/text-embedding-004"}]}' + headers: + Content-Type: + - application/json + user-agent: + - google-genai-sdk/1.66.0 gl-python/3.13.3 + x-goog-api-client: + - google-genai-sdk/1.66.0 gl-python/3.13.3 + method: POST + uri: https://generativelanguage.googleapis.com/v1beta/models/text-embedding-004:batchEmbedContents + response: + body: + string: "{\n \"error\": {\n \"code\": 404,\n \"message\": \"models/text-embedding-004 + is not found for API version v1beta, or is not supported for embedContent. + Call ListModels to see the list of available models and their supported methods.\",\n + \ \"status\": \"NOT_FOUND\"\n }\n}\n" + headers: + Alt-Svc: + - h3=":443"; ma=2592000,h3-29=":443"; ma=2592000 + Content-Type: + - application/json; charset=UTF-8 + Date: + - Fri, 20 Mar 2026 17:24:34 GMT + Server: + - scaffolding on HTTPServer2 + Server-Timing: + - gfet4t7; dur=61 + Transfer-Encoding: + - chunked + Vary: + - Origin + - X-Origin + - Referer + X-Content-Type-Options: + - nosniff + X-Frame-Options: + - SAMEORIGIN + X-XSS-Protection: + - '0' + status: + code: 404 + message: Not Found +- request: + body: '{"requests": [{"content": {"parts": [{"text": "This is a test"}], "role": + "user"}, "taskType": "RETRIEVAL_DOCUMENT", "outputDimensionality": 32, "model": + "models/gemini-embedding-001"}, {"content": {"parts": [{"text": "This is another + test"}], "role": "user"}, "taskType": "RETRIEVAL_DOCUMENT", "outputDimensionality": + 32, "model": "models/gemini-embedding-001"}]}' + headers: + Content-Type: + - application/json + user-agent: + - google-genai-sdk/1.66.0 gl-python/3.13.3 + x-goog-api-client: + - google-genai-sdk/1.66.0 gl-python/3.13.3 + method: POST + uri: https://generativelanguage.googleapis.com/v1beta/models/gemini-embedding-001:batchEmbedContents + response: + body: + string: "{\n \"embeddings\": [\n {\n \"values\": [\n -0.023325698,\n + \ 0.0046664835,\n 0.011547477,\n -0.09579112,\n -0.0014762759,\n + \ 0.0008815472,\n -0.0021552797,\n -0.010228449,\n 0.0051200837,\n + \ -0.00017234083,\n -0.004328001,\n -0.011912019,\n 0.00035554593,\n + \ -0.0041076173,\n 0.16096918,\n 0.012422918,\n -0.0063045956,\n + \ 0.007986352,\n -0.002453504,\n -0.0076586856,\n -0.0046673263,\n + \ -0.011785407,\n 0.019633682,\n -0.0028250674,\n 0.003508845,\n + \ -0.008396229,\n 0.023694735,\n 0.013479813,\n 0.019790472,\n + \ -0.0024608728,\n -0.009812026,\n 0.013141339\n ]\n + \ },\n {\n \"values\": [\n -0.020953175,\n 0.00159121,\n + \ 0.01681236,\n -0.09719086,\n -0.0057125897,\n -0.0091514345,\n + \ 0.0065565477,\n -0.010889619,\n 0.005243009,\n -0.0102227805,\n + \ 0.0024382372,\n -0.008563973,\n 0.0069260946,\n -0.010109229,\n + \ 0.15993133,\n 0.014212301,\n -0.013013395,\n 0.0043994756,\n + \ -0.015382343,\n -0.009089905,\n 6.0088325e-05,\n -0.006942369,\n + \ 0.020732542,\n -0.0068101394,\n 0.0039503737,\n -0.004525233,\n + \ 0.032600258,\n 0.009397907,\n 0.022764705,\n -0.006015099,\n + \ -0.012276714,\n 0.013333517\n ]\n }\n ]\n}\n" + headers: + Alt-Svc: + - h3=":443"; ma=2592000,h3-29=":443"; ma=2592000 + Content-Type: + - application/json; charset=UTF-8 + Date: + - Fri, 20 Mar 2026 17:28:26 GMT + Server: + - scaffolding on HTTPServer2 + Server-Timing: + - gfet4t7; dur=142 + Transfer-Encoding: + - chunked + Vary: + - Origin + - X-Origin + - Referer + X-Content-Type-Options: + - nosniff + X-Frame-Options: + - SAMEORIGIN + X-XSS-Protection: + - '0' + status: + code: 200 + message: OK +version: 1 diff --git a/py/src/braintrust/wrappers/google_genai/__init__.py b/py/src/braintrust/wrappers/google_genai/__init__.py index 61df30ab..3bdae565 100644 --- a/py/src/braintrust/wrappers/google_genai/__init__.py +++ b/py/src/braintrust/wrappers/google_genai/__init__.py @@ -1,9 +1,17 @@ import logging import time -from collections.abc import Iterable -from typing import Any +from collections.abc import Awaitable, Callable, Iterable +from typing import TYPE_CHECKING, Any from braintrust.bt_json import bt_safe_deep_copy + + +if TYPE_CHECKING: + from google.genai.types import ( + EmbedContentResponse, + GenerateContentResponse, + GenerateContentResponseUsageMetadata, + ) from braintrust.logger import NOOP_SPAN, Attachment, current_span, init_logger, start_span from braintrust.span_types import SpanTypeAttribute from wrapt import wrap_function_wrapper @@ -54,47 +62,40 @@ def wrap_models(Models: Any): return Models def wrap_generate_content(wrapped: Any, instance: Any, args: Any, kwargs: Any): - input, clean_kwargs = get_args_kwargs(args, kwargs, ["model", "contents", "config"]) - - input = _serialize_input(instance._api_client, input) - - clean_kwargs["model"] = input["model"] - - start = time.time() - with start_span( - name="generate_content", type=SpanTypeAttribute.LLM, input=input, metadata=clean_kwargs - ) as span: - result = wrapped(*args, **kwargs) - metrics = _extract_generate_content_metrics(result, start) - span.log(output=result, metrics=metrics) - return result + return _run_traced_call( + instance._api_client, + args, + kwargs, + name="generate_content", + invoke=lambda: wrapped(*args, **kwargs), + process_result=_gc_process_result, + ) wrap_function_wrapper(Models, "_generate_content", wrap_generate_content) def wrap_generate_content_stream(wrapped: Any, instance: Any, args: Any, kwargs: Any): - input, clean_kwargs = get_args_kwargs(args, kwargs, ["model", "contents", "config"]) - - input = _serialize_input(instance._api_client, input) - - clean_kwargs["model"] = input["model"] + return _run_stream_traced_call( + instance._api_client, + args, + kwargs, + name="generate_content_stream", + invoke=lambda: wrapped(*args, **kwargs), + aggregate=_aggregate_generate_content_chunks, + ) - start = time.time() - first_token_time = None - with start_span( - name="generate_content_stream", type=SpanTypeAttribute.LLM, input=input, metadata=clean_kwargs - ) as span: - chunks = [] - for chunk in wrapped(*args, **kwargs): - if first_token_time is None: - first_token_time = time.time() - chunks.append(chunk) - yield chunk + wrap_function_wrapper(Models, "generate_content_stream", wrap_generate_content_stream) - aggregated, metrics = _aggregate_generate_content_chunks(chunks, start, first_token_time) - span.log(output=aggregated, metrics=metrics) - return aggregated + def wrap_embed_content(wrapped: Any, instance: Any, args: Any, kwargs: Any): + return _run_traced_call( + instance._api_client, + args, + kwargs, + name="embed_content", + invoke=lambda: wrapped(*args, **kwargs), + process_result=_embed_process_result, + ) - wrap_function_wrapper(Models, "generate_content_stream", wrap_generate_content_stream) + wrap_function_wrapper(Models, "embed_content", wrap_embed_content) mark_patched(Models) return Models @@ -105,49 +106,40 @@ def wrap_async_models(AsyncModels: Any): return AsyncModels async def wrap_generate_content(wrapped: Any, instance: Any, args: Any, kwargs: Any): - input, clean_kwargs = get_args_kwargs(args, kwargs, ["model", "contents", "config"]) - - input = _serialize_input(instance._api_client, input) - - clean_kwargs["model"] = input["model"] - - start = time.time() - with start_span( - name="generate_content", type=SpanTypeAttribute.LLM, input=input, metadata=clean_kwargs - ) as span: - result = await wrapped(*args, **kwargs) - metrics = _extract_generate_content_metrics(result, start) - span.log(output=result, metrics=metrics) - return result + return await _run_async_traced_call( + instance._api_client, + args, + kwargs, + name="generate_content", + invoke=lambda: wrapped(*args, **kwargs), + process_result=_gc_process_result, + ) wrap_function_wrapper(AsyncModels, "generate_content", wrap_generate_content) async def wrap_generate_content_stream(wrapped: Any, instance: Any, args: Any, kwargs: Any): - input, clean_kwargs = get_args_kwargs(args, kwargs, ["model", "contents", "config"]) - - input = _serialize_input(instance._api_client, input) - - clean_kwargs["model"] = input["model"] - - async def stream_generator(): - start = time.time() - first_token_time = None - with start_span( - name="generate_content_stream", type=SpanTypeAttribute.LLM, input=input, metadata=clean_kwargs - ) as span: - chunks = [] - async for chunk in await wrapped(*args, **kwargs): - if first_token_time is None: - first_token_time = time.time() - chunks.append(chunk) - yield chunk + return _run_async_stream_traced_call( + instance._api_client, + args, + kwargs, + name="generate_content_stream", + invoke=lambda: wrapped(*args, **kwargs), + aggregate=_aggregate_generate_content_chunks, + ) - aggregated, metrics = _aggregate_generate_content_chunks(chunks, start, first_token_time) - span.log(output=aggregated, metrics=metrics) + wrap_function_wrapper(AsyncModels, "generate_content_stream", wrap_generate_content_stream) - return stream_generator() + async def wrap_embed_content(wrapped: Any, instance: Any, args: Any, kwargs: Any): + return await _run_async_traced_call( + instance._api_client, + args, + kwargs, + name="embed_content", + invoke=lambda: wrapped(*args, **kwargs), + process_result=_embed_process_result, + ) - wrap_function_wrapper(AsyncModels, "generate_content_stream", wrap_generate_content_stream) + wrap_function_wrapper(AsyncModels, "embed_content", wrap_embed_content) mark_patched(AsyncModels) return AsyncModels @@ -171,6 +163,113 @@ def _serialize_input(api_client: Any, input: dict[str, Any]): return input +def _gc_process_result(result: "GenerateContentResponse", start: float) -> tuple[Any, dict[str, Any]]: + return result, _extract_generate_content_metrics(result, start) + + +def _embed_process_result(result: "EmbedContentResponse", start: float) -> tuple[Any, dict[str, Any]]: + return _extract_embed_content_output(result), _extract_embed_content_metrics(result, start) + + +def _prepare_traced_call( + api_client: Any, args: list[Any], kwargs: dict[str, Any] +) -> tuple[dict[str, Any], dict[str, Any]]: + input, clean_kwargs = get_args_kwargs(args, kwargs, ["model", "contents", "config"], ["contents", "config"]) + return _serialize_input(api_client, input), clean_kwargs + + +def _run_traced_call( + api_client: Any, + args: list[Any], + kwargs: dict[str, Any], + *, + name: str, + invoke: Callable[[], Any], + process_result: Callable[[Any, float], tuple[Any, dict[str, Any]]], +): + input, clean_kwargs = _prepare_traced_call(api_client, args, kwargs) + + start = time.time() + with start_span(name=name, type=SpanTypeAttribute.LLM, input=input, metadata=clean_kwargs) as span: + result = invoke() + output, metrics = process_result(result, start) + span.log(output=output, metrics=metrics) + return result + + +async def _run_async_traced_call( + api_client: Any, + args: list[Any], + kwargs: dict[str, Any], + *, + name: str, + invoke: Callable[[], Awaitable[Any]], + process_result: Callable[[Any, float], tuple[Any, dict[str, Any]]], +): + input, clean_kwargs = _prepare_traced_call(api_client, args, kwargs) + + start = time.time() + with start_span(name=name, type=SpanTypeAttribute.LLM, input=input, metadata=clean_kwargs) as span: + result = await invoke() + output, metrics = process_result(result, start) + span.log(output=output, metrics=metrics) + return result + + +def _run_stream_traced_call( + api_client: Any, + args: list[Any], + kwargs: dict[str, Any], + *, + name: str, + invoke: Callable[[], Any], + aggregate: Callable[[list[Any], float, float | None], tuple[Any, dict[str, Any]]], +): + input, clean_kwargs = _prepare_traced_call(api_client, args, kwargs) + + start = time.time() + first_token_time = None + with start_span(name=name, type=SpanTypeAttribute.LLM, input=input, metadata=clean_kwargs) as span: + chunks = [] + for chunk in invoke(): + if first_token_time is None: + first_token_time = time.time() + chunks.append(chunk) + yield chunk + + output, metrics = aggregate(chunks, start, first_token_time) + span.log(output=output, metrics=metrics) + return output + + +def _run_async_stream_traced_call( + api_client: Any, + args: list[Any], + kwargs: dict[str, Any], + *, + name: str, + invoke: Callable[[], Awaitable[Any]], + aggregate: Callable[[list[Any], float, float | None], tuple[Any, dict[str, Any]]], +): + input, clean_kwargs = _prepare_traced_call(api_client, args, kwargs) + + async def stream_generator(): + start = time.time() + first_token_time = None + with start_span(name=name, type=SpanTypeAttribute.LLM, input=input, metadata=clean_kwargs) as span: + chunks = [] + async for chunk in await invoke(): + if first_token_time is None: + first_token_time = time.time() + chunks.append(chunk) + yield chunk + + output, metrics = aggregate(chunks, start, first_token_time) + span.log(output=output, metrics=metrics) + + return stream_generator() + + def _serialize_contents(contents: Any) -> Any: """Serialize contents, converting binary data to base64-encoded data URLs.""" if contents is None: @@ -259,11 +358,29 @@ def mark_patched(obj: Any): return setattr(obj, "_braintrust_patched", True) -def get_args_kwargs(args: list[str], kwargs: dict[str, Any], keys: Iterable[str]): - return {k: args[i] if args else kwargs.get(k) for i, k in enumerate(keys)}, omit(kwargs, keys) +def get_args_kwargs( + args: list[str], kwargs: dict[str, Any], keys: Iterable[str], omit_keys: Iterable[str] | None = None +): + return {k: args[i] if args else kwargs.get(k) for i, k in enumerate(keys)}, omit(kwargs, omit_keys or keys) + + +def _extract_usage_metadata_metrics( + usage_metadata: "GenerateContentResponseUsageMetadata", metrics: dict[str, Any] +) -> None: + """Mutate metrics in-place with token counts from a usage_metadata object.""" + if hasattr(usage_metadata, "prompt_token_count"): + metrics["prompt_tokens"] = usage_metadata.prompt_token_count + if hasattr(usage_metadata, "candidates_token_count"): + metrics["completion_tokens"] = usage_metadata.candidates_token_count + if hasattr(usage_metadata, "total_token_count"): + metrics["tokens"] = usage_metadata.total_token_count + if hasattr(usage_metadata, "cached_content_token_count"): + metrics["prompt_cached_tokens"] = usage_metadata.cached_content_token_count + if hasattr(usage_metadata, "thoughts_token_count"): + metrics["completion_reasoning_tokens"] = usage_metadata.thoughts_token_count -def _extract_generate_content_metrics(response: Any, start: float) -> dict[str, Any]: +def _extract_generate_content_metrics(response: "GenerateContentResponse", start: float) -> dict[str, Any]: """Extract metrics from a non-streaming generate_content response.""" end_time = time.time() metrics = dict( @@ -272,37 +389,55 @@ def _extract_generate_content_metrics(response: Any, start: float) -> dict[str, duration=end_time - start, ) - # Extract usage metadata if available if hasattr(response, "usage_metadata") and response.usage_metadata: - usage_metadata = response.usage_metadata - - # Extract token metrics - if hasattr(usage_metadata, "prompt_token_count"): - metrics["prompt_tokens"] = usage_metadata.prompt_token_count - if hasattr(usage_metadata, "candidates_token_count"): - metrics["completion_tokens"] = usage_metadata.candidates_token_count - if hasattr(usage_metadata, "total_token_count"): - metrics["tokens"] = usage_metadata.total_token_count - if hasattr(usage_metadata, "cached_content_token_count"): - metrics["prompt_cached_tokens"] = usage_metadata.cached_content_token_count - - # Extract additional metrics for thinking/reasoning tokens - if hasattr(usage_metadata, "thoughts_token_count"): - metrics["completion_reasoning_tokens"] = usage_metadata.thoughts_token_count - - # Extract tool use prompt tokens if available - if hasattr(usage_metadata, "tool_use_prompt_token_count"): - # Add to prompt_tokens if not already counted - tool_tokens = usage_metadata.tool_use_prompt_token_count - if tool_tokens and "prompt_tokens" in metrics: - # Tool tokens are typically part of prompt tokens, but track separately if needed - pass + _extract_usage_metadata_metrics(response.usage_metadata, metrics) return clean(dict(metrics)) +def _extract_embed_content_output(response: "EmbedContentResponse") -> dict[str, Any]: + embeddings = getattr(response, "embeddings", None) or [] + first_embedding = embeddings[0] if embeddings else None + first_values = getattr(first_embedding, "values", None) or [] + + return clean( + { + "embedding_length": len(first_values) if first_values else None, + "embeddings_count": len(embeddings) if embeddings else None, + } + ) + + +def _extract_embed_content_metrics(response: "EmbedContentResponse", start: float) -> dict[str, Any]: + end_time = time.time() + metrics = dict( + start=start, + end=end_time, + duration=end_time - start, + ) + + embeddings = getattr(response, "embeddings", None) or [] + token_counts = [] + for embedding in embeddings: + statistics = getattr(embedding, "statistics", None) + token_count = getattr(statistics, "token_count", None) + if token_count is not None: + token_counts.append(token_count) + + if token_counts: + metrics["prompt_tokens"] = sum(token_counts) + metrics["tokens"] = metrics["prompt_tokens"] + + metadata = getattr(response, "metadata", None) + billable_character_count = getattr(metadata, "billable_character_count", None) + if billable_character_count is not None: + metrics["billable_characters"] = billable_character_count + + return clean(metrics) + + def _aggregate_generate_content_chunks( - chunks: list[Any], start: float, first_token_time: float | None = None + chunks: "list[GenerateContentResponse]", start: float, first_token_time: float | None = None ) -> tuple[dict[str, Any], dict[str, Any]]: """Aggregate streaming chunks into a single response with metrics.""" end_time = time.time() @@ -383,28 +518,7 @@ def _aggregate_generate_content_chunks( # Add usage metadata if usage_metadata: aggregated["usage_metadata"] = usage_metadata - - # Extract token metrics - if hasattr(usage_metadata, "prompt_token_count"): - metrics["prompt_tokens"] = usage_metadata.prompt_token_count - if hasattr(usage_metadata, "candidates_token_count"): - metrics["completion_tokens"] = usage_metadata.candidates_token_count - if hasattr(usage_metadata, "total_token_count"): - metrics["tokens"] = usage_metadata.total_token_count - if hasattr(usage_metadata, "cached_content_token_count"): - metrics["prompt_cached_tokens"] = usage_metadata.cached_content_token_count - - # Extract additional metrics for thinking/reasoning tokens - if hasattr(usage_metadata, "thoughts_token_count"): - metrics["completion_reasoning_tokens"] = usage_metadata.thoughts_token_count - - # Extract tool use prompt tokens if available - if hasattr(usage_metadata, "tool_use_prompt_token_count"): - # Add to prompt_tokens if not already counted - tool_tokens = usage_metadata.tool_use_prompt_token_count - if tool_tokens and "prompt_tokens" in metrics: - # Tool tokens are typically part of prompt tokens, but track separately if needed - pass + _extract_usage_metadata_metrics(usage_metadata, metrics) # Add convenience text property if text: diff --git a/py/src/braintrust/wrappers/test_google_genai.py b/py/src/braintrust/wrappers/test_google_genai.py index 51b3b090..73a31e71 100644 --- a/py/src/braintrust/wrappers/test_google_genai.py +++ b/py/src/braintrust/wrappers/test_google_genai.py @@ -13,6 +13,7 @@ PROJECT_NAME = "test-genai-app" MODEL = "gemini-2.0-flash-001" +EMBEDDING_MODEL = "gemini-embedding-001" FIXTURES_DIR = Path(__file__).parent.parent.parent.parent.parent / "internal/golden/fixtures" @@ -61,6 +62,14 @@ def _assert_metrics_are_valid(metrics, start=None, end=None): assert metrics["start"] <= metrics["end"] +def _assert_timing_metrics_are_valid(metrics, start=None, end=None): + assert metrics["duration"] >= 0 + if start and end: + assert start <= metrics["start"] <= metrics["end"] <= end + else: + assert metrics["start"] <= metrics["end"] + + # Test 1: Basic Completion (Sync) @pytest.mark.vcr @pytest.mark.parametrize( @@ -164,6 +173,71 @@ async def test_basic_completion_async(memory_logger, mode): _assert_metrics_are_valid(span["metrics"], start, end) +@pytest.mark.vcr +def test_embed_content(memory_logger): + assert not memory_logger.pop() + + client = Client() + start = time.time() + response = client.models.embed_content( + model=EMBEDDING_MODEL, + contents=["This is a test", "This is another test"], + config=types.EmbedContentConfig( + task_type="RETRIEVAL_DOCUMENT", + output_dimensionality=32, + ), + ) + end = time.time() + + assert response.embeddings + assert len(response.embeddings) == 2 + assert response.embeddings[0].values + assert len(response.embeddings[0].values) == 32 + + spans = memory_logger.pop() + assert len(spans) == 1 + span = spans[0] + assert span["metadata"]["model"] == EMBEDDING_MODEL + assert "RETRIEVAL_DOCUMENT" in str(span["input"]) + assert "This is a test" in str(span["input"]) + assert span["output"]["embedding_length"] == 32 + assert span["output"]["embeddings_count"] == 2 + _assert_timing_metrics_are_valid(span["metrics"], start, end) + + +@pytest.mark.vcr +@pytest.mark.asyncio +async def test_embed_content_async(memory_logger): + assert not memory_logger.pop() + + client = Client() + start = time.time() + response = await client.aio.models.embed_content( + model=EMBEDDING_MODEL, + contents=["This is a test", "This is another test"], + config=types.EmbedContentConfig( + task_type="RETRIEVAL_DOCUMENT", + output_dimensionality=32, + ), + ) + end = time.time() + + assert response.embeddings + assert len(response.embeddings) == 2 + assert response.embeddings[0].values + assert len(response.embeddings[0].values) == 32 + + spans = memory_logger.pop() + assert len(spans) == 1 + span = spans[0] + assert span["metadata"]["model"] == EMBEDDING_MODEL + assert "RETRIEVAL_DOCUMENT" in str(span["input"]) + assert "This is a test" in str(span["input"]) + assert span["output"]["embedding_length"] == 32 + assert span["output"]["embeddings_count"] == 2 + _assert_timing_metrics_are_valid(span["metrics"], start, end) + + # Test 2: Mixed Content (Sync) @pytest.mark.skip @pytest.mark.vcr