From 11d3bd7b62645f9aba023a303e40777c45d8ad50 Mon Sep 17 00:00:00 2001 From: Abhijeet Prasad Date: Mon, 23 Mar 2026 13:49:46 -0700 Subject: [PATCH] feat: Introduce new integrations API and convert anthropic --- .agents/skills/sdk-integrations/SKILL.md | 204 ++++++ AGENTS.md | 1 + py/noxfile.py | 10 +- py/setup.py | 1 + py/src/braintrust/__init__.py | 7 +- py/src/braintrust/auto.py | 66 +- py/src/braintrust/integrations/__init__.py | 5 + .../integrations/anthropic/__init__.py | 23 + .../anthropic/_utils.py} | 15 +- ...cSpans.test_setup_creates_async_spans.yaml | 88 +++ ...onSetupSpans.test_setup_creates_spans.yaml | 88 +++ ...t_patch_anthropic_async_creates_spans.yaml | 0 ...ns.test_patch_anthropic_creates_spans.yaml | 0 ..._anthropic_beta_messages_create_async.yaml | 0 ...t_anthropic_beta_messages_stream_sync.yaml | 0 ...thropic_beta_messages_streaming_async.yaml | 0 .../test_anthropic_beta_messages_sync.yaml | 0 .../test_anthropic_client_error.yaml | 0 .../test_anthropic_messages_create_async.yaml | 0 ...pic_messages_create_async_stream_true.yaml | 0 ...anthropic_messages_create_stream_true.yaml | 0 ...nthropic_messages_model_params_inputs.yaml | 0 ...test_anthropic_messages_stream_errors.yaml | 0 ...st_anthropic_messages_streaming_async.yaml | 0 ...est_anthropic_messages_streaming_sync.yaml | 0 .../test_anthropic_messages_sync.yaml | 0 ...thropic_messages_system_prompt_inputs.yaml | 0 .../integrations/anthropic/integration.py | 12 + .../integrations/anthropic/patchers.py | 15 + .../integrations/anthropic/test_anthropic.py | 630 +++++++++++++++++ .../integrations/anthropic/tracing.py | 367 ++++++++++ .../auto_test_scripts/test_auto_agno.py | 0 .../auto_test_scripts/test_auto_anthropic.py | 11 +- .../test_auto_anthropic_patch_config.py | 20 + .../test_auto_claude_agent_sdk.py | 0 .../auto_test_scripts/test_auto_dspy.py | 0 .../test_auto_google_genai.py | 0 .../auto_test_scripts/test_auto_litellm.py | 0 .../auto_test_scripts/test_auto_openai.py | 0 .../test_auto_pydantic_ai.py | 0 .../test_patch_litellm_aresponses.py | 0 .../test_patch_litellm_responses.py | 0 py/src/braintrust/integrations/base.py | 211 ++++++ .../integrations/test_versioning.py | 34 + py/src/braintrust/integrations/versioning.py | 55 ++ py/src/braintrust/wrappers/anthropic.py | 429 +----------- .../wrappers/claude_agent_sdk/_wrapper.py | 2 +- py/src/braintrust/wrappers/test_anthropic.py | 654 ++---------------- py/src/braintrust/wrappers/test_utils.py | 4 +- 49 files changed, 1892 insertions(+), 1060 deletions(-) create mode 100644 .agents/skills/sdk-integrations/SKILL.md create mode 100644 py/src/braintrust/integrations/__init__.py create mode 100644 py/src/braintrust/integrations/anthropic/__init__.py rename py/src/braintrust/{wrappers/_anthropic_utils.py => integrations/anthropic/_utils.py} (84%) create mode 100644 py/src/braintrust/integrations/anthropic/cassettes/TestAnthropicIntegrationSetupAsyncSpans.test_setup_creates_async_spans.yaml create mode 100644 py/src/braintrust/integrations/anthropic/cassettes/TestAnthropicIntegrationSetupSpans.test_setup_creates_spans.yaml rename py/src/braintrust/{wrappers => integrations/anthropic}/cassettes/TestPatchAnthropicAsyncSpans.test_patch_anthropic_async_creates_spans.yaml (100%) rename py/src/braintrust/{wrappers => integrations/anthropic}/cassettes/TestPatchAnthropicSpans.test_patch_anthropic_creates_spans.yaml (100%) rename py/src/braintrust/{wrappers => integrations/anthropic}/cassettes/test_anthropic_beta_messages_create_async.yaml (100%) rename py/src/braintrust/{wrappers => integrations/anthropic}/cassettes/test_anthropic_beta_messages_stream_sync.yaml (100%) rename py/src/braintrust/{wrappers => integrations/anthropic}/cassettes/test_anthropic_beta_messages_streaming_async.yaml (100%) rename py/src/braintrust/{wrappers => integrations/anthropic}/cassettes/test_anthropic_beta_messages_sync.yaml (100%) rename py/src/braintrust/{wrappers => integrations/anthropic}/cassettes/test_anthropic_client_error.yaml (100%) rename py/src/braintrust/{wrappers => integrations/anthropic}/cassettes/test_anthropic_messages_create_async.yaml (100%) rename py/src/braintrust/{wrappers => integrations/anthropic}/cassettes/test_anthropic_messages_create_async_stream_true.yaml (100%) rename py/src/braintrust/{wrappers => integrations/anthropic}/cassettes/test_anthropic_messages_create_stream_true.yaml (100%) rename py/src/braintrust/{wrappers => integrations/anthropic}/cassettes/test_anthropic_messages_model_params_inputs.yaml (100%) rename py/src/braintrust/{wrappers => integrations/anthropic}/cassettes/test_anthropic_messages_stream_errors.yaml (100%) rename py/src/braintrust/{wrappers => integrations/anthropic}/cassettes/test_anthropic_messages_streaming_async.yaml (100%) rename py/src/braintrust/{wrappers => integrations/anthropic}/cassettes/test_anthropic_messages_streaming_sync.yaml (100%) rename py/src/braintrust/{wrappers => integrations/anthropic}/cassettes/test_anthropic_messages_sync.yaml (100%) rename py/src/braintrust/{wrappers => integrations/anthropic}/cassettes/test_anthropic_messages_system_prompt_inputs.yaml (100%) create mode 100644 py/src/braintrust/integrations/anthropic/integration.py create mode 100644 py/src/braintrust/integrations/anthropic/patchers.py create mode 100644 py/src/braintrust/integrations/anthropic/test_anthropic.py create mode 100644 py/src/braintrust/integrations/anthropic/tracing.py rename py/src/braintrust/{wrappers => integrations}/auto_test_scripts/test_auto_agno.py (100%) rename py/src/braintrust/{wrappers => integrations}/auto_test_scripts/test_auto_anthropic.py (60%) create mode 100644 py/src/braintrust/integrations/auto_test_scripts/test_auto_anthropic_patch_config.py rename py/src/braintrust/{wrappers => integrations}/auto_test_scripts/test_auto_claude_agent_sdk.py (100%) rename py/src/braintrust/{wrappers => integrations}/auto_test_scripts/test_auto_dspy.py (100%) rename py/src/braintrust/{wrappers => integrations}/auto_test_scripts/test_auto_google_genai.py (100%) rename py/src/braintrust/{wrappers => integrations}/auto_test_scripts/test_auto_litellm.py (100%) rename py/src/braintrust/{wrappers => integrations}/auto_test_scripts/test_auto_openai.py (100%) rename py/src/braintrust/{wrappers => integrations}/auto_test_scripts/test_auto_pydantic_ai.py (100%) rename py/src/braintrust/{wrappers => integrations}/auto_test_scripts/test_patch_litellm_aresponses.py (100%) rename py/src/braintrust/{wrappers => integrations}/auto_test_scripts/test_patch_litellm_responses.py (100%) create mode 100644 py/src/braintrust/integrations/base.py create mode 100644 py/src/braintrust/integrations/test_versioning.py create mode 100644 py/src/braintrust/integrations/versioning.py diff --git a/.agents/skills/sdk-integrations/SKILL.md b/.agents/skills/sdk-integrations/SKILL.md new file mode 100644 index 00000000..b7b3aaa4 --- /dev/null +++ b/.agents/skills/sdk-integrations/SKILL.md @@ -0,0 +1,204 @@ +--- +name: sdk-integrations +description: Create or update a Braintrust Python SDK integration using the integrations API. Use when asked to add an integration, update an existing integration, add or update patchers, update auto_instrument, add integration tests, or work in py/src/braintrust/integrations/. +--- + +# SDK Integrations + +SDK integrations define how Braintrust discovers a provider, patches it safely, and keeps provider-specific tracing local to that integration. Read the existing integration closest to your task before writing a new one. If there is no closer example, `py/src/braintrust/integrations/anthropic/` is a useful reference implementation. + +## Workflow + +1. Read the shared integration primitives and the closest provider example. +2. Choose the task shape: new provider, existing provider update, or `auto_instrument()` update. +3. Implement the smallest integration, patcher, tracing, and export changes needed. +4. Add or update VCR-backed integration tests and only re-record cassettes when behavior changed intentionally. +5. Run the narrowest provider session first, then expand to shared validation only if the change touched shared code. + +## Commands + +```bash +cd py && nox -s "test_(latest)" +cd py && nox -s "test_(latest)" -- -k "test_name" +cd py && nox -s "test_(latest)" -- --vcr-record=all -k "test_name" +cd py && make test-core +cd py && make lint +``` + +## Creating or Updating an Integration + +### 1. Read the nearest existing implementation + +Always inspect these first: + +- `py/src/braintrust/integrations/base.py` +- `py/src/braintrust/integrations/runtime.py` +- `py/src/braintrust/integrations/versioning.py` +- `py/src/braintrust/integrations/config.py` + +Relevant example implementation: + +- `py/src/braintrust/integrations/anthropic/` + +Read these additional files only when the task needs them: + +- changing `auto_instrument()`: `py/src/braintrust/auto.py` and `py/src/braintrust/auto_test_scripts/test_auto_anthropic_patch_config.py` +- adding or updating VCR tests: `py/src/braintrust/conftest.py` and `py/src/braintrust/integrations/anthropic/test_anthropic.py` + +Then choose the path that matches the task: + +- new provider: create `py/src/braintrust/integrations//` +- existing provider: read the provider package first and change only the affected patchers, tracing, tests, or exports +- `auto_instrument()` only: keep the integration package unchanged unless the option shape or patcher surface also changed + +### 2. Create or extend the integration module + +For a new provider, create a package under `py/src/braintrust/integrations//`. + +For an existing provider, keep the module layout unless the current structure is actively causing problems. + +Typical files: + +- `__init__.py`: public exports for the integration type and any public helpers +- `integration.py`: the `BaseIntegration` subclass, patcher registration, and high-level orchestration +- `patchers.py`: one patcher per patch target, with version gating and existence checks close to the patch +- `tracing.py`: provider-specific span creation, metadata extraction, stream handling, and output normalization +- `test_.py`: integration tests for `wrap(...)`, `setup()`, sync/async behavior, streaming, and error handling +- `cassettes/`: recorded provider traffic for VCR-backed integration tests when the provider uses HTTP + +### 3. Define the integration class + +Implement a `BaseIntegration` subclass in `integration.py`. + +Set: + +- `name` +- `import_names` +- `min_version` and `max_version` only when needed +- `patchers` + +Keep the class focused on orchestration. Provider-specific tracing logic should stay in `tracing.py`. + +### 4. Add one patcher per coherent patch target + +Put patchers in `patchers.py`. + +Use `FunctionWrapperPatcher` when patching a single import path with `wrapt.wrap_function_wrapper`. Good examples: + +- constructor patchers like `ProviderClient.__init__` +- single API surfaces like `client.responses.create` +- one sync and one async constructor patcher instead of one patcher doing both + +Keep patchers narrow. If you need to patch multiple unrelated targets, create multiple patchers rather than one large patcher. + +Patchers are responsible for: + +- stable patcher ids via `name` +- optional version gating +- existence checks +- idempotence through the base patcher marker + +### 5. Keep tracing provider-local + +Put span creation, metadata extraction, stream aggregation, error logging, and output normalization in `tracing.py`. + +This layer should: + +- preserve provider behavior +- support sync, async, and streaming paths as needed +- avoid raising from tracing-only code when that would break the provider call + +If the provider has complex streaming internals, keep that logic local instead of forcing it into shared abstractions. + +### 6. Wire public exports + +Update public exports only as needed: + +- `py/src/braintrust/integrations/__init__.py` +- `py/src/braintrust/__init__.py` + +### 7. Update auto_instrument only if this integration should be auto-patched + +If the provider belongs in `braintrust.auto.auto_instrument()`, add a branch in `py/src/braintrust/auto.py`. + +Match the current pattern: + +- plain `bool` options for simple on/off integrations +- `IntegrationPatchConfig` only when users need patcher-level selection + +## Tests + +Keep integration tests with the integration package. + +Provider behavior tests should use `@pytest.mark.vcr` whenever the provider uses network calls. Avoid mocks and fakes. + +Cover: + +- direct `wrap(...)` behavior +- `setup()` patching new clients +- sync behavior +- async behavior +- streaming behavior +- idempotence +- failure/error logging +- patcher selection if using `IntegrationPatchConfig` + +Preferred locations: + +- provider behavior tests: `py/src/braintrust/integrations//test_.py` +- version helper tests: `py/src/braintrust/integrations/test_versioning.py` +- auto-instrument subprocess tests: `py/src/braintrust/auto_test_scripts/` + +If the provider uses VCR, keep cassettes next to the integration test file under `py/src/braintrust/integrations//cassettes/`. + +Only re-record cassettes when the behavior change is intentional. + +Use mocks or fakes only for cases that are hard to drive through recorded provider traffic, such as narrowly scoped error injection, local version-routing logic, or patcher existence checks. + +## Patterns + +### Constructor patching + +If instrumenting future clients created by the SDK is the goal, patch constructors and attach traced surfaces after the real constructor runs. Anthropic is an example of this pattern. + +### Patcher selection + +Use `IntegrationPatchConfig` only when users benefit from enabling or disabling specific patchers. Validate unknown patcher ids through `BaseIntegration.resolve_patchers()` instead of silently ignoring them. + +### Versioning + +Prefer feature detection first and version checks second. + +Use: + +- `detect_module_version(...)` +- `version_in_range(...)` +- `version_matches_spec(...)` + +Do not add `packaging` just for integration routing. + +## Validation + +- Run the narrowest provider session first. +- Run `cd py && make test-core` if you changed shared integration code. +- Run `cd py && make lint` before handing off broader integration changes. +- If you changed `auto_instrument()`, run the relevant subprocess auto-instrument tests. + +## Done When + +- the provider package contains only the integration, patcher, tracing, export, and test changes required by the task +- provider behavior tests use VCR unless recorded traffic cannot cover the behavior +- cassette changes are present only when provider behavior changed intentionally +- the narrowest affected provider session passes +- `cd py && make test-core` has been run if shared integration code changed +- `cd py && make lint` has been run before handoff + +## Common Pitfalls + +- Leaving provider behavior in `BaseIntegration` instead of the provider package. +- Combining multiple unrelated patch targets into one patcher. +- Forgetting async or streaming coverage. +- Defaulting to mocks or fakes when the provider flow can be covered with VCR. +- Moving tests but not moving their cassettes. +- Adding patcher selection without tests for enabled and disabled cases. +- Editing `auto_instrument()` in a way that implies a registry exists when it does not. diff --git a/AGENTS.md b/AGENTS.md index 2744c5ae..8acef533 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -144,3 +144,4 @@ Avoid editing `py/src/braintrust/version.py` while also running build commands. - Reuse existing fixtures and cassette patterns. - If a change affects examples or integrations, update the nearest example or focused test. - For CLI/devserver changes, consider whether wheel-mode behavior also needs coverage. +- Do **not** add `from __future__ import annotations` unless it is absolutely required (e.g., a genuine forward-reference that cannot be resolved any other way). This import changes annotation evaluation semantics at runtime and can silently break `get_type_hints()`, Pydantic models, and other runtime introspection. Prefer quoted string literals (`"MyClass"`) or `TYPE_CHECKING` guards for forward references instead. diff --git a/py/noxfile.py b/py/noxfile.py index 4c25af1f..bff911db 100644 --- a/py/noxfile.py +++ b/py/noxfile.py @@ -40,6 +40,9 @@ def _pinned_python_version(): SRC_DIR = "braintrust" WRAPPER_DIR = "braintrust/wrappers" +INTEGRATION_DIR = "braintrust/integrations" +INTEGRATION_AUTO_TEST_DIR = "braintrust/integrations/auto_test_scripts" +ANTHROPIC_INTEGRATION_DIR = "braintrust/integrations/anthropic" CONTRIB_DIR = "braintrust/contrib" DEVSERVER_DIR = "braintrust/devserver" @@ -176,6 +179,7 @@ def test_anthropic(session, version): _install_test_deps(session) _install(session, "anthropic", version) _run_tests(session, f"{WRAPPER_DIR}/test_anthropic.py") + _run_tests(session, f"{INTEGRATION_DIR}/anthropic/test_anthropic.py") _run_core_tests(session) @@ -400,7 +404,11 @@ def _get_braintrust_wheel(): def _run_core_tests(session): """Run all tests which don't require optional dependencies.""" - _run_tests(session, SRC_DIR, ignore_paths=[WRAPPER_DIR, CONTRIB_DIR, DEVSERVER_DIR]) + _run_tests( + session, + SRC_DIR, + ignore_paths=[WRAPPER_DIR, INTEGRATION_AUTO_TEST_DIR, ANTHROPIC_INTEGRATION_DIR, CONTRIB_DIR, DEVSERVER_DIR], + ) def _run_tests(session, test_path, ignore_path="", ignore_paths=None, env=None): diff --git a/py/setup.py b/py/setup.py index c8b87d94..ee387304 100644 --- a/py/setup.py +++ b/py/setup.py @@ -19,6 +19,7 @@ "tqdm", "exceptiongroup>=1.2.0", "jsonschema", + "packaging", "python-dotenv", "sseclient-py", "python-slugify", diff --git a/py/src/braintrust/__init__.py b/py/src/braintrust/__init__.py index 32ef4999..7a0115ae 100644 --- a/py/src/braintrust/__init__.py +++ b/py/src/braintrust/__init__.py @@ -63,6 +63,7 @@ def is_equal(expected, output): from .audit import * from .auto import ( + IntegrationPatchConfig, # noqa: F401 # type: ignore[reportUnusedImport] auto_instrument, # noqa: F401 # type: ignore[reportUnusedImport] ) from .framework import * @@ -70,6 +71,9 @@ def is_equal(expected, output): from .functions.invoke import * from .functions.stream import * from .generated_types import * +from .integrations.anthropic import ( + wrap_anthropic, # noqa: F401 # type: ignore[reportUnusedImport] +) from .logger import * from .logger import ( _internal_get_global_state, # noqa: F401 # type: ignore[reportUnusedImport] @@ -89,9 +93,6 @@ def is_equal(expected, output): BT_IS_ASYNC_ATTRIBUTE, # noqa: F401 # type: ignore[reportUnusedImport] MarkAsyncWrapper, # noqa: F401 # type: ignore[reportUnusedImport] ) -from .wrappers.anthropic import ( - wrap_anthropic, # noqa: F401 # type: ignore[reportUnusedImport] -) from .wrappers.litellm import ( wrap_litellm, # noqa: F401 # type: ignore[reportUnusedImport] ) diff --git a/py/src/braintrust/auto.py b/py/src/braintrust/auto.py index 30dcc2b2..6c15b653 100644 --- a/py/src/braintrust/auto.py +++ b/py/src/braintrust/auto.py @@ -9,10 +9,13 @@ import logging from contextlib import contextmanager +from braintrust.integrations import AnthropicIntegration, IntegrationPatchConfig + __all__ = ["auto_instrument"] logger = logging.getLogger(__name__) +InstrumentOption = bool | IntegrationPatchConfig @contextmanager @@ -29,7 +32,7 @@ def _try_patch(): def auto_instrument( *, openai: bool = True, - anthropic: bool = True, + anthropic: InstrumentOption = True, litellm: bool = True, pydantic_ai: bool = True, google_genai: bool = True, @@ -49,7 +52,8 @@ def auto_instrument( Args: openai: Enable OpenAI instrumentation (default: True) - anthropic: Enable Anthropic instrumentation (default: True) + anthropic: Enable Anthropic instrumentation (default: True), or pass an + IntegrationPatchConfig to select Anthropic patchers explicitly. litellm: Enable LiteLLM instrumentation (default: True) pydantic_ai: Enable Pydantic AI instrumentation (default: True) google_genai: Enable Google GenAI instrumentation (default: True) @@ -104,23 +108,33 @@ def auto_instrument( """ results = {} - if openai: + openai_enabled = _normalize_bool_option("openai", openai) + anthropic_enabled, anthropic_config = _normalize_anthropic_option(anthropic) + litellm_enabled = _normalize_bool_option("litellm", litellm) + pydantic_ai_enabled = _normalize_bool_option("pydantic_ai", pydantic_ai) + google_genai_enabled = _normalize_bool_option("google_genai", google_genai) + agno_enabled = _normalize_bool_option("agno", agno) + claude_agent_sdk_enabled = _normalize_bool_option("claude_agent_sdk", claude_agent_sdk) + dspy_enabled = _normalize_bool_option("dspy", dspy) + adk_enabled = _normalize_bool_option("adk", adk) + + if openai_enabled: results["openai"] = _instrument_openai() - if anthropic: - results["anthropic"] = _instrument_anthropic() - if litellm: + if anthropic_enabled: + results["anthropic"] = _instrument_integration(AnthropicIntegration, patch_config=anthropic_config) + if litellm_enabled: results["litellm"] = _instrument_litellm() - if pydantic_ai: + if pydantic_ai_enabled: results["pydantic_ai"] = _instrument_pydantic_ai() - if google_genai: + if google_genai_enabled: results["google_genai"] = _instrument_google_genai() - if agno: + if agno_enabled: results["agno"] = _instrument_agno() - if claude_agent_sdk: + if claude_agent_sdk_enabled: results["claude_agent_sdk"] = _instrument_claude_agent_sdk() - if dspy: + if dspy_enabled: results["dspy"] = _instrument_dspy() - if adk: + if adk_enabled: results["adk"] = _instrument_adk() return results @@ -134,14 +148,34 @@ def _instrument_openai() -> bool: return False -def _instrument_anthropic() -> bool: +def _instrument_integration(integration, *, patch_config: IntegrationPatchConfig | None = None) -> bool: with _try_patch(): - from braintrust.wrappers.anthropic import patch_anthropic - - return patch_anthropic() + return integration.setup( + enabled_patchers=patch_config.enabled_patchers if patch_config is not None else None, + disabled_patchers=patch_config.disabled_patchers if patch_config is not None else None, + ) return False +def _normalize_bool_option(name: str, option: bool) -> bool: + if isinstance(option, bool): + return option + + raise TypeError(f"auto_instrument option {name!r} must be a bool, got {type(option).__name__}") + + +def _normalize_anthropic_option(option: InstrumentOption) -> tuple[bool, IntegrationPatchConfig | None]: + if isinstance(option, bool): + return option, None + + if isinstance(option, IntegrationPatchConfig): + return True, option + + raise TypeError( + f"auto_instrument option 'anthropic' must be a bool or IntegrationPatchConfig, got {type(option).__name__}" + ) + + def _instrument_litellm() -> bool: with _try_patch(): from braintrust.wrappers.litellm import patch_litellm diff --git a/py/src/braintrust/integrations/__init__.py b/py/src/braintrust/integrations/__init__.py new file mode 100644 index 00000000..1dddbd91 --- /dev/null +++ b/py/src/braintrust/integrations/__init__.py @@ -0,0 +1,5 @@ +from .anthropic import AnthropicIntegration +from .base import IntegrationPatchConfig + + +__all__ = ["AnthropicIntegration", "IntegrationPatchConfig"] diff --git a/py/src/braintrust/integrations/anthropic/__init__.py b/py/src/braintrust/integrations/anthropic/__init__.py new file mode 100644 index 00000000..26b36ba8 --- /dev/null +++ b/py/src/braintrust/integrations/anthropic/__init__.py @@ -0,0 +1,23 @@ +import warnings + +from .integration import AnthropicIntegration +from .tracing import _wrap_anthropic + + +wrap_anthropic = _wrap_anthropic + + +def wrap_anthropic_client(client): + warnings.warn( + "wrap_anthropic_client() is deprecated. Use wrap_anthropic() instead.", + DeprecationWarning, + stacklevel=2, + ) + return _wrap_anthropic(client) + + +__all__ = [ + "AnthropicIntegration", + "wrap_anthropic", + "wrap_anthropic_client", +] diff --git a/py/src/braintrust/wrappers/_anthropic_utils.py b/py/src/braintrust/integrations/anthropic/_utils.py similarity index 84% rename from py/src/braintrust/wrappers/_anthropic_utils.py rename to py/src/braintrust/integrations/anthropic/_utils.py index 12f72a4d..e117451c 100644 --- a/py/src/braintrust/wrappers/_anthropic_utils.py +++ b/py/src/braintrust/integrations/anthropic/_utils.py @@ -34,13 +34,11 @@ def extract_anthropic_usage(usage: Any) -> dict[str, float]: if not usage: return metrics - # Handle both dict and object with attributes def get_value(key: str) -> Any: if isinstance(usage, dict): return usage.get(key) return getattr(usage, key, None) - # Standard token counts input_tokens = get_value("input_tokens") if input_tokens is not None: try: @@ -55,7 +53,6 @@ def get_value(key: str) -> Any: except (ValueError, TypeError): pass - # Anthropic cache tokens cache_read_tokens = get_value("cache_read_input_tokens") if cache_read_tokens is not None: try: @@ -74,17 +71,7 @@ def get_value(key: str) -> Any: def finalize_anthropic_tokens(metrics: dict[str, float]) -> dict[str, float]: - """Finalize Anthropic token calculations. - - Anthropic doesn't include cache tokens in the total, so we need to sum them. - Updates 'prompt_tokens' to include cache tokens and adds 'tokens' field with the total. - - Args: - metrics: Dictionary with token metrics - - Returns: - Updated metrics with total prompt tokens and total tokens fields - """ + """Finalize Anthropic token calculations.""" total_prompt_tokens = ( metrics.get("prompt_tokens", 0) + metrics.get("prompt_cached_tokens", 0) diff --git a/py/src/braintrust/integrations/anthropic/cassettes/TestAnthropicIntegrationSetupAsyncSpans.test_setup_creates_async_spans.yaml b/py/src/braintrust/integrations/anthropic/cassettes/TestAnthropicIntegrationSetupAsyncSpans.test_setup_creates_async_spans.yaml new file mode 100644 index 00000000..3cda147b --- /dev/null +++ b/py/src/braintrust/integrations/anthropic/cassettes/TestAnthropicIntegrationSetupAsyncSpans.test_setup_creates_async_spans.yaml @@ -0,0 +1,88 @@ +interactions: +- request: + body: '{"max_tokens":100,"messages":[{"role":"user","content":"Say hi async"}],"model":"claude-3-5-haiku-latest"}' + headers: + Accept: + - application/json + Accept-Encoding: + - gzip, deflate, zstd + Connection: + - keep-alive + Content-Length: + - '106' + Content-Type: + - application/json + Host: + - api.anthropic.com + User-Agent: + - AsyncAnthropic/Python 0.84.0 + X-Stainless-Arch: + - arm64 + X-Stainless-Async: + - async:asyncio + X-Stainless-Lang: + - python + X-Stainless-OS: + - MacOS + X-Stainless-Package-Version: + - 0.84.0 + X-Stainless-Runtime: + - CPython + X-Stainless-Runtime-Version: + - 3.13.3 + anthropic-version: + - '2023-06-01' + x-stainless-read-timeout: + - '600' + x-stainless-retry-count: + - '0' + x-stainless-timeout: + - '600' + method: POST + uri: https://api.anthropic.com/v1/messages + response: + body: + string: '{"type":"error","error":{"type":"not_found_error","message":"model: + claude-3-5-haiku-latest"},"request_id":"req_011CZLdMZnqva1tBGWuemWN9"}' + headers: + CF-RAY: + - 9e101d82cf341fdb-SJC + Connection: + - keep-alive + Content-Security-Policy: + - default-src 'none'; frame-ancestors 'none' + Content-Type: + - application/json + Date: + - Mon, 23 Mar 2026 20:16:41 GMT + Server: + - cloudflare + Transfer-Encoding: + - chunked + X-Robots-Tag: + - none + anthropic-organization-id: + - 27796668-7351-40ac-acc4-024aee8995a5 + cf-cache-status: + - DYNAMIC + content-length: + - '138' + request-id: + - req_011CZLdMZnqva1tBGWuemWN9 + server-timing: + - x-originResponse;dur=25 + set-cookie: + - _cfuvid=O35oR1Sx8_qfQk2ct3bUq6P_LfxnelQS3AuhGTtAo3Q-1774297001.4099903-1.0.1.1-l__Wn1ND363h2RHwrTFmIqpFtWzPoIUNjQXrH1hjWyY; + HttpOnly; SameSite=None; Secure; Path=/; Domain=api.anthropic.com + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + vary: + - Accept-Encoding + x-envoy-upstream-service-time: + - '23' + x-should-retry: + - 'false' + status: + code: 404 + message: Not Found +version: 1 diff --git a/py/src/braintrust/integrations/anthropic/cassettes/TestAnthropicIntegrationSetupSpans.test_setup_creates_spans.yaml b/py/src/braintrust/integrations/anthropic/cassettes/TestAnthropicIntegrationSetupSpans.test_setup_creates_spans.yaml new file mode 100644 index 00000000..8b6d8e58 --- /dev/null +++ b/py/src/braintrust/integrations/anthropic/cassettes/TestAnthropicIntegrationSetupSpans.test_setup_creates_spans.yaml @@ -0,0 +1,88 @@ +interactions: +- request: + body: '{"max_tokens":100,"messages":[{"role":"user","content":"Say hi"}],"model":"claude-3-5-haiku-latest"}' + headers: + Accept: + - application/json + Accept-Encoding: + - gzip, deflate, zstd + Connection: + - keep-alive + Content-Length: + - '100' + Content-Type: + - application/json + Host: + - api.anthropic.com + User-Agent: + - Anthropic/Python 0.84.0 + X-Stainless-Arch: + - arm64 + X-Stainless-Async: + - 'false' + X-Stainless-Lang: + - python + X-Stainless-OS: + - MacOS + X-Stainless-Package-Version: + - 0.84.0 + X-Stainless-Runtime: + - CPython + X-Stainless-Runtime-Version: + - 3.13.3 + anthropic-version: + - '2023-06-01' + x-stainless-read-timeout: + - '600' + x-stainless-retry-count: + - '0' + x-stainless-timeout: + - '600' + method: POST + uri: https://api.anthropic.com/v1/messages + response: + body: + string: '{"type":"error","error":{"type":"not_found_error","message":"model: + claude-3-5-haiku-latest"},"request_id":"req_011CZLdMYnZkfSiZUydAPCJh"}' + headers: + CF-RAY: + - 9e101d81485f9e5c-SJC + Connection: + - keep-alive + Content-Security-Policy: + - default-src 'none'; frame-ancestors 'none' + Content-Type: + - application/json + Date: + - Mon, 23 Mar 2026 20:16:41 GMT + Server: + - cloudflare + Transfer-Encoding: + - chunked + X-Robots-Tag: + - none + anthropic-organization-id: + - 27796668-7351-40ac-acc4-024aee8995a5 + cf-cache-status: + - DYNAMIC + content-length: + - '138' + request-id: + - req_011CZLdMYnZkfSiZUydAPCJh + server-timing: + - x-originResponse;dur=33 + set-cookie: + - _cfuvid=qjfwtdMQuaH4s2MvMr53WG49MOjmHdRt83XWnDUSnZs-1774297001.1711438-1.0.1.1-6h3vTb4keZZALUiWWbTnUXgtrTuyB4XqC2sFuQc9fQk; + HttpOnly; SameSite=None; Secure; Path=/; Domain=api.anthropic.com + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + vary: + - Accept-Encoding + x-envoy-upstream-service-time: + - '31' + x-should-retry: + - 'false' + status: + code: 404 + message: Not Found +version: 1 diff --git a/py/src/braintrust/wrappers/cassettes/TestPatchAnthropicAsyncSpans.test_patch_anthropic_async_creates_spans.yaml b/py/src/braintrust/integrations/anthropic/cassettes/TestPatchAnthropicAsyncSpans.test_patch_anthropic_async_creates_spans.yaml similarity index 100% rename from py/src/braintrust/wrappers/cassettes/TestPatchAnthropicAsyncSpans.test_patch_anthropic_async_creates_spans.yaml rename to py/src/braintrust/integrations/anthropic/cassettes/TestPatchAnthropicAsyncSpans.test_patch_anthropic_async_creates_spans.yaml diff --git a/py/src/braintrust/wrappers/cassettes/TestPatchAnthropicSpans.test_patch_anthropic_creates_spans.yaml b/py/src/braintrust/integrations/anthropic/cassettes/TestPatchAnthropicSpans.test_patch_anthropic_creates_spans.yaml similarity index 100% rename from py/src/braintrust/wrappers/cassettes/TestPatchAnthropicSpans.test_patch_anthropic_creates_spans.yaml rename to py/src/braintrust/integrations/anthropic/cassettes/TestPatchAnthropicSpans.test_patch_anthropic_creates_spans.yaml diff --git a/py/src/braintrust/wrappers/cassettes/test_anthropic_beta_messages_create_async.yaml b/py/src/braintrust/integrations/anthropic/cassettes/test_anthropic_beta_messages_create_async.yaml similarity index 100% rename from py/src/braintrust/wrappers/cassettes/test_anthropic_beta_messages_create_async.yaml rename to py/src/braintrust/integrations/anthropic/cassettes/test_anthropic_beta_messages_create_async.yaml diff --git a/py/src/braintrust/wrappers/cassettes/test_anthropic_beta_messages_stream_sync.yaml b/py/src/braintrust/integrations/anthropic/cassettes/test_anthropic_beta_messages_stream_sync.yaml similarity index 100% rename from py/src/braintrust/wrappers/cassettes/test_anthropic_beta_messages_stream_sync.yaml rename to py/src/braintrust/integrations/anthropic/cassettes/test_anthropic_beta_messages_stream_sync.yaml diff --git a/py/src/braintrust/wrappers/cassettes/test_anthropic_beta_messages_streaming_async.yaml b/py/src/braintrust/integrations/anthropic/cassettes/test_anthropic_beta_messages_streaming_async.yaml similarity index 100% rename from py/src/braintrust/wrappers/cassettes/test_anthropic_beta_messages_streaming_async.yaml rename to py/src/braintrust/integrations/anthropic/cassettes/test_anthropic_beta_messages_streaming_async.yaml diff --git a/py/src/braintrust/wrappers/cassettes/test_anthropic_beta_messages_sync.yaml b/py/src/braintrust/integrations/anthropic/cassettes/test_anthropic_beta_messages_sync.yaml similarity index 100% rename from py/src/braintrust/wrappers/cassettes/test_anthropic_beta_messages_sync.yaml rename to py/src/braintrust/integrations/anthropic/cassettes/test_anthropic_beta_messages_sync.yaml diff --git a/py/src/braintrust/wrappers/cassettes/test_anthropic_client_error.yaml b/py/src/braintrust/integrations/anthropic/cassettes/test_anthropic_client_error.yaml similarity index 100% rename from py/src/braintrust/wrappers/cassettes/test_anthropic_client_error.yaml rename to py/src/braintrust/integrations/anthropic/cassettes/test_anthropic_client_error.yaml diff --git a/py/src/braintrust/wrappers/cassettes/test_anthropic_messages_create_async.yaml b/py/src/braintrust/integrations/anthropic/cassettes/test_anthropic_messages_create_async.yaml similarity index 100% rename from py/src/braintrust/wrappers/cassettes/test_anthropic_messages_create_async.yaml rename to py/src/braintrust/integrations/anthropic/cassettes/test_anthropic_messages_create_async.yaml diff --git a/py/src/braintrust/wrappers/cassettes/test_anthropic_messages_create_async_stream_true.yaml b/py/src/braintrust/integrations/anthropic/cassettes/test_anthropic_messages_create_async_stream_true.yaml similarity index 100% rename from py/src/braintrust/wrappers/cassettes/test_anthropic_messages_create_async_stream_true.yaml rename to py/src/braintrust/integrations/anthropic/cassettes/test_anthropic_messages_create_async_stream_true.yaml diff --git a/py/src/braintrust/wrappers/cassettes/test_anthropic_messages_create_stream_true.yaml b/py/src/braintrust/integrations/anthropic/cassettes/test_anthropic_messages_create_stream_true.yaml similarity index 100% rename from py/src/braintrust/wrappers/cassettes/test_anthropic_messages_create_stream_true.yaml rename to py/src/braintrust/integrations/anthropic/cassettes/test_anthropic_messages_create_stream_true.yaml diff --git a/py/src/braintrust/wrappers/cassettes/test_anthropic_messages_model_params_inputs.yaml b/py/src/braintrust/integrations/anthropic/cassettes/test_anthropic_messages_model_params_inputs.yaml similarity index 100% rename from py/src/braintrust/wrappers/cassettes/test_anthropic_messages_model_params_inputs.yaml rename to py/src/braintrust/integrations/anthropic/cassettes/test_anthropic_messages_model_params_inputs.yaml diff --git a/py/src/braintrust/wrappers/cassettes/test_anthropic_messages_stream_errors.yaml b/py/src/braintrust/integrations/anthropic/cassettes/test_anthropic_messages_stream_errors.yaml similarity index 100% rename from py/src/braintrust/wrappers/cassettes/test_anthropic_messages_stream_errors.yaml rename to py/src/braintrust/integrations/anthropic/cassettes/test_anthropic_messages_stream_errors.yaml diff --git a/py/src/braintrust/wrappers/cassettes/test_anthropic_messages_streaming_async.yaml b/py/src/braintrust/integrations/anthropic/cassettes/test_anthropic_messages_streaming_async.yaml similarity index 100% rename from py/src/braintrust/wrappers/cassettes/test_anthropic_messages_streaming_async.yaml rename to py/src/braintrust/integrations/anthropic/cassettes/test_anthropic_messages_streaming_async.yaml diff --git a/py/src/braintrust/wrappers/cassettes/test_anthropic_messages_streaming_sync.yaml b/py/src/braintrust/integrations/anthropic/cassettes/test_anthropic_messages_streaming_sync.yaml similarity index 100% rename from py/src/braintrust/wrappers/cassettes/test_anthropic_messages_streaming_sync.yaml rename to py/src/braintrust/integrations/anthropic/cassettes/test_anthropic_messages_streaming_sync.yaml diff --git a/py/src/braintrust/wrappers/cassettes/test_anthropic_messages_sync.yaml b/py/src/braintrust/integrations/anthropic/cassettes/test_anthropic_messages_sync.yaml similarity index 100% rename from py/src/braintrust/wrappers/cassettes/test_anthropic_messages_sync.yaml rename to py/src/braintrust/integrations/anthropic/cassettes/test_anthropic_messages_sync.yaml diff --git a/py/src/braintrust/wrappers/cassettes/test_anthropic_messages_system_prompt_inputs.yaml b/py/src/braintrust/integrations/anthropic/cassettes/test_anthropic_messages_system_prompt_inputs.yaml similarity index 100% rename from py/src/braintrust/wrappers/cassettes/test_anthropic_messages_system_prompt_inputs.yaml rename to py/src/braintrust/integrations/anthropic/cassettes/test_anthropic_messages_system_prompt_inputs.yaml diff --git a/py/src/braintrust/integrations/anthropic/integration.py b/py/src/braintrust/integrations/anthropic/integration.py new file mode 100644 index 00000000..78f9fe0a --- /dev/null +++ b/py/src/braintrust/integrations/anthropic/integration.py @@ -0,0 +1,12 @@ +from braintrust.integrations.base import BaseIntegration + +from .patchers import AnthropicAsyncInitPatcher, AnthropicSyncInitPatcher + + +class AnthropicIntegration(BaseIntegration): + """Braintrust instrumentation for the Anthropic Python SDK on anthropic>=0.48.0.""" + + name = "anthropic" + import_names = ("anthropic",) + min_version = "0.48.0" + patchers = (AnthropicSyncInitPatcher, AnthropicAsyncInitPatcher) diff --git a/py/src/braintrust/integrations/anthropic/patchers.py b/py/src/braintrust/integrations/anthropic/patchers.py new file mode 100644 index 00000000..efe9a3b8 --- /dev/null +++ b/py/src/braintrust/integrations/anthropic/patchers.py @@ -0,0 +1,15 @@ +from braintrust.integrations.base import FunctionWrapperPatcher + +from .tracing import _anthropic_init_wrapper, _async_anthropic_init_wrapper + + +class AnthropicSyncInitPatcher(FunctionWrapperPatcher): + name = "anthropic.init.sync" + target_path = "Anthropic.__init__" + wrapper = _anthropic_init_wrapper + + +class AnthropicAsyncInitPatcher(FunctionWrapperPatcher): + name = "anthropic.init.async" + target_path = "AsyncAnthropic.__init__" + wrapper = _async_anthropic_init_wrapper diff --git a/py/src/braintrust/integrations/anthropic/test_anthropic.py b/py/src/braintrust/integrations/anthropic/test_anthropic.py new file mode 100644 index 00000000..570de9b1 --- /dev/null +++ b/py/src/braintrust/integrations/anthropic/test_anthropic.py @@ -0,0 +1,630 @@ +""" +Tests to ensure we reliably wrap the Anthropic API. +""" + +import inspect +import time +from pathlib import Path + +import anthropic +import pytest +from braintrust import logger +from braintrust.integrations.anthropic import AnthropicIntegration, wrap_anthropic, wrap_anthropic_client +from braintrust.integrations.versioning import make_specifier, version_satisfies +from braintrust.test_helpers import init_test_logger + + +TEST_ORG_ID = "test-org-123" +PROJECT_NAME = "test-anthropic-app" +MODEL = "claude-3-haiku-20240307" # use the cheapest model since answers dont matter + + +@pytest.fixture(scope="module") +def vcr_cassette_dir(): + return str(Path(__file__).resolve().parent / "cassettes") + + +def _get_client(): + return anthropic.Anthropic() + + +def _get_async_client(): + return anthropic.AsyncAnthropic() + + +@pytest.fixture +def memory_logger(): + init_test_logger(PROJECT_NAME) + with logger._internal_with_memory_background_logger() as bgl: + yield bgl + + +@pytest.mark.vcr +def test_anthropic_messages_create_stream_true(memory_logger): + assert not memory_logger.pop() + + client = wrap_anthropic(_get_client()) + kws = { + "model": MODEL, + "max_tokens": 300, + "messages": [{"role": "user", "content": "What is 3*4?"}], + "stream": True, + } + + start = time.time() + with client.messages.create(**kws) as out: + msgs = [m for m in out] + end = time.time() + + assert msgs # a very coarse grained check that this works + + spans = memory_logger.pop() + assert len(spans) == 1 + span = spans[0] + assert span["metadata"]["model"] == MODEL + assert span["metadata"]["provider"] == "anthropic" + assert span["metadata"]["max_tokens"] == 300 + assert span["metadata"]["stream"] == True + metrics = span["metrics"] + _assert_metrics_are_valid(metrics, start, end) + assert span["input"] == kws["messages"] + assert span["output"] + assert span["output"]["role"] == "assistant" + assert "12" in span["output"]["content"][0]["text"] + + +def test_wrap_anthropic_client_alias_wraps_client(): + client = _get_client() + + with pytest.deprecated_call(match="wrap_anthropic_client\\(\\) is deprecated"): + wrapped = wrap_anthropic_client(client) + + assert type(wrapped.messages).__module__ == "braintrust.integrations.anthropic.tracing" + + +@pytest.mark.vcr +def test_anthropic_messages_model_params_inputs(memory_logger): + assert not memory_logger.pop() + client = wrap_anthropic(_get_client()) + + kw = { + "model": MODEL, + "max_tokens": 300, + "system": "just return the number", + "messages": [{"role": "user", "content": "what is 1+1?"}], + "temperature": 0.5, + "top_p": 0.5, + } + + def _with_messages_create(): + return client.messages.create(**kw) + + def _with_messages_stream(): + with client.messages.stream(**kw) as stream: + for msg in stream: + pass + return stream.get_final_message() + + for f in [_with_messages_create, _with_messages_stream]: + msg = f() + assert msg.content[0].text == "2" + + logs = memory_logger.pop() + assert len(logs) == 1 + log = logs[0] + assert log["output"]["role"] == "assistant" + assert "2" in log["output"]["content"][0]["text"] + assert log["metadata"]["model"] == MODEL + assert log["metadata"]["max_tokens"] == 300 + assert log["metadata"]["temperature"] == 0.5 + assert log["metadata"]["top_p"] == 0.5 + + +@pytest.mark.vcr +def test_anthropic_messages_system_prompt_inputs(memory_logger): + assert not memory_logger.pop() + + client = wrap_anthropic(_get_client()) + system = "Today's date is 2024-03-26. Only return the date" + q = [{"role": "user", "content": "what is tomorrow's date? only return the date"}] + + args = { + "messages": q, + "temperature": 0, + "max_tokens": 300, + "system": system, + "model": MODEL, + } + + def _with_messages_create(): + return client.messages.create(**args) + + def _with_messages_stream(): + with client.messages.stream(**args) as stream: + for msg in stream: + pass + return stream.get_final_message() + + for f in [_with_messages_create, _with_messages_stream]: + msg = f() + assert "2024-03-27" in msg.content[0].text + + logs = memory_logger.pop() + assert len(logs) == 1 + log = logs[0] + inputs = log["input"] + assert len(inputs) == 2 + inputs_by_role = {m["role"]: m["content"] for m in inputs} + assert inputs_by_role["system"] == system + assert inputs_by_role["user"] == q[0]["content"] + + +@pytest.mark.vcr +@pytest.mark.asyncio +async def test_anthropic_messages_create_async(memory_logger): + assert not memory_logger.pop() + + params = { + "model": MODEL, + "max_tokens": 100, + "messages": [{"role": "user", "content": "what is 6+1?, just return the number"}], + } + + client = wrap_anthropic(anthropic.AsyncAnthropic()) + msg = await client.messages.create(**params) + assert "7" in msg.content[0].text + + spans = memory_logger.pop() + assert len(spans) == 1 + span = spans[0] + assert span["metadata"]["model"] == MODEL + assert span["metadata"]["max_tokens"] == 100 + assert span["input"] == params["messages"] + assert span["output"]["role"] == "assistant" + assert "7" in span["output"]["content"][0]["text"] + + +@pytest.mark.vcr +@pytest.mark.asyncio +async def test_anthropic_messages_create_async_stream_true(memory_logger): + assert not memory_logger.pop() + + params = { + "model": MODEL, + "max_tokens": 100, + "messages": [{"role": "user", "content": "what is 6+1?, just return the number"}], + "stream": True, + } + + client = wrap_anthropic(anthropic.AsyncAnthropic()) + stream = await client.messages.create(**params) + async for event in stream: + pass + + spans = memory_logger.pop() + assert len(spans) == 1 + span = spans[0] + assert span["metadata"]["model"] == MODEL + assert span["metadata"]["max_tokens"] == 100 + assert span["input"] == params["messages"] + assert span["output"]["role"] == "assistant" + assert "7" in span["output"]["content"][0]["text"] + + +@pytest.mark.vcr +@pytest.mark.asyncio +async def test_anthropic_messages_streaming_async(memory_logger): + assert not memory_logger.pop() + + client = wrap_anthropic(_get_async_client()) + msgs_in = [{"role": "user", "content": "what is 1+1?, just return the number"}] + + start = time.time() + msg_out = None + + async with client.messages.stream(max_tokens=1024, messages=msgs_in, model=MODEL) as stream: + async for event in stream: + pass + msg_out = await stream.get_final_message() + assert msg_out.content[0].text == "2" + usage = msg_out.usage + end = time.time() + + logs = memory_logger.pop() + assert len(logs) == 1 + log = logs[0] + assert "user" in str(log["input"]) + assert "1+1" in str(log["input"]) + assert "2" in str(log["output"]) + assert log["project_id"] == PROJECT_NAME + assert log["span_attributes"]["type"] == "llm" + assert log["metadata"]["model"] == MODEL + assert log["metadata"]["max_tokens"] == 1024 + _assert_metrics_are_valid(log["metrics"], start, end) + metrics = log["metrics"] + assert metrics["prompt_tokens"] == usage.input_tokens + assert metrics["completion_tokens"] == usage.output_tokens + assert metrics["tokens"] == usage.input_tokens + usage.output_tokens + assert metrics["prompt_cached_tokens"] == usage.cache_read_input_tokens + assert metrics["prompt_cache_creation_tokens"] == usage.cache_creation_input_tokens + assert log["metadata"]["model"] == MODEL + assert log["metadata"]["max_tokens"] == 1024 + + +@pytest.mark.vcr +def test_anthropic_client_error(memory_logger): + assert not memory_logger.pop() + + client = wrap_anthropic(_get_client()) + + fake_model = "there-is-no-such-model" + msg_in = {"role": "user", "content": "who are you?"} + + try: + client.messages.create(model=fake_model, max_tokens=999, messages=[msg_in]) + except Exception: + pass + else: + raise Exception("should have raised an exception") + + logs = memory_logger.pop() + assert len(logs) == 1 + log = logs[0] + assert log["project_id"] == PROJECT_NAME + assert "404" in log["error"] + + +@pytest.mark.vcr +def test_anthropic_messages_stream_errors(memory_logger): + assert not memory_logger.pop() + + client = wrap_anthropic(_get_client()) + msg_in = {"role": "user", "content": "what is 2+2? (just the number)"} + + try: + with client.messages.stream(model=MODEL, max_tokens=300, messages=[msg_in]) as stream: + raise Exception("fake-error") + except Exception: + pass + else: + raise Exception("should have raised an exception") + + spans = memory_logger.pop() + assert len(spans) == 1 + span = spans[0] + assert "Exception: fake-error" in span["error"] + assert span["metrics"]["end"] > 0 + + +@pytest.mark.vcr +def test_anthropic_messages_streaming_sync(memory_logger): + assert not memory_logger.pop() + + client = wrap_anthropic(_get_client()) + msg_in = {"role": "user", "content": "what is 2+2? (just the number)"} + + start = time.time() + with client.messages.stream(model=MODEL, max_tokens=300, messages=[msg_in]) as stream: + msgs_out = [m for m in stream] + end = time.time() + msg_out = stream.get_final_message() + usage = msg_out.usage + # crudely check that the stream is valid + assert len(msgs_out) > 3 + assert 1 <= len([m for m in msgs_out if m.type == "text"]) + assert msgs_out[0].type == "message_start" + assert msgs_out[-1].type == "message_stop" + + logs = memory_logger.pop() + assert len(logs) == 1 + log = logs[0] + assert "user" in str(log["input"]) + assert "2+2" in str(log["input"]) + assert "4" in str(log["output"]) + assert log["project_id"] == PROJECT_NAME + assert log["span_attributes"]["type"] == "llm" + _assert_metrics_are_valid(log["metrics"], start, end) + assert log["metrics"]["prompt_tokens"] == usage.input_tokens + assert log["metrics"]["completion_tokens"] == usage.output_tokens + assert log["metrics"]["tokens"] == usage.input_tokens + usage.output_tokens + assert log["metrics"]["prompt_cached_tokens"] == usage.cache_read_input_tokens + assert log["metrics"]["prompt_cache_creation_tokens"] == usage.cache_creation_input_tokens + + +@pytest.mark.vcr +def test_anthropic_messages_sync(memory_logger): + assert not memory_logger.pop() + + client = wrap_anthropic(_get_client()) + + msg_in = {"role": "user", "content": "what's 2+2?"} + + start = time.time() + msg = client.messages.create(model=MODEL, max_tokens=300, messages=[msg_in]) + end = time.time() + + text = msg.content[0].text + assert text + + # verify we generated the right spans. + logs = memory_logger.pop() + + assert len(logs) == 1 + log = logs[0] + assert "2+2" in str(log["input"]) + assert "4" in str(log["output"]) + assert log["project_id"] == PROJECT_NAME + assert log["span_id"] + assert log["root_span_id"] + attrs = log["span_attributes"] + assert attrs["type"] == "llm" + assert "anthropic" in attrs["name"] + metrics = log["metrics"] + _assert_metrics_are_valid(metrics, start, end) + assert log["metadata"]["model"] == MODEL + + +def _assert_metrics_are_valid(metrics, start, end): + assert metrics["tokens"] > 0 + assert metrics["prompt_tokens"] > 0 + assert metrics["completion_tokens"] > 0 + assert "time_to_first_token" in metrics + assert metrics["time_to_first_token"] >= 0 + if start and end: + assert start <= metrics["start"] <= metrics["end"] <= end + else: + assert metrics["start"] <= metrics["end"] + + +@pytest.mark.vcr +def test_anthropic_beta_messages_sync(memory_logger): + assert not memory_logger.pop() + + client = wrap_anthropic(_get_client()) + msg_in = {"role": "user", "content": "what's 3+3?"} + + start = time.time() + msg = client.beta.messages.create(model=MODEL, max_tokens=300, messages=[msg_in]) + end = time.time() + + text = msg.content[0].text + assert text + assert "6" in text + + logs = memory_logger.pop() + assert len(logs) == 1 + log = logs[0] + assert "3+3" in str(log["input"]) + assert "6" in str(log["output"]) + assert log["project_id"] == PROJECT_NAME + assert log["span_id"] + assert log["root_span_id"] + attrs = log["span_attributes"] + assert attrs["type"] == "llm" + assert "anthropic" in attrs["name"] + metrics = log["metrics"] + _assert_metrics_are_valid(metrics, start, end) + assert log["metadata"]["model"] == MODEL + + +@pytest.mark.vcr +def test_anthropic_beta_messages_stream_sync(memory_logger): + assert not memory_logger.pop() + + client = wrap_anthropic(_get_client()) + msg_in = {"role": "user", "content": "what is 5+5? (just the number)"} + + start = time.time() + with client.beta.messages.stream(model=MODEL, max_tokens=300, messages=[msg_in]) as stream: + msgs_out = [m for m in stream] + end = time.time() + msg_out = stream.get_final_message() + usage = msg_out.usage + + assert len(msgs_out) > 3 + assert msgs_out[0].type == "message_start" + assert msgs_out[-1].type == "message_stop" + assert "10" in msg_out.content[0].text + + logs = memory_logger.pop() + assert len(logs) == 1 + log = logs[0] + assert "user" in str(log["input"]) + assert "5+5" in str(log["input"]) + assert "10" in str(log["output"]) + assert log["project_id"] == PROJECT_NAME + assert log["span_attributes"]["type"] == "llm" + _assert_metrics_are_valid(log["metrics"], start, end) + assert log["metrics"]["prompt_tokens"] == usage.input_tokens + assert log["metrics"]["completion_tokens"] == usage.output_tokens + assert log["metrics"]["tokens"] == usage.input_tokens + usage.output_tokens + + +@pytest.mark.vcr +@pytest.mark.asyncio +async def test_anthropic_beta_messages_create_async(memory_logger): + assert not memory_logger.pop() + + params = { + "model": MODEL, + "max_tokens": 100, + "messages": [{"role": "user", "content": "what is 8+2?, just return the number"}], + } + + client = wrap_anthropic(anthropic.AsyncAnthropic()) + msg = await client.beta.messages.create(**params) + assert "10" in msg.content[0].text + + spans = memory_logger.pop() + assert len(spans) == 1 + span = spans[0] + assert span["metadata"]["model"] == MODEL + assert span["metadata"]["max_tokens"] == 100 + assert span["input"] == params["messages"] + assert span["output"]["role"] == "assistant" + assert "10" in span["output"]["content"][0]["text"] + + +@pytest.mark.vcr( + match_on=["method", "scheme", "host", "port", "path", "body"] +) # exclude query - varies by SDK version +@pytest.mark.asyncio +async def test_anthropic_beta_messages_streaming_async(memory_logger): + assert not memory_logger.pop() + + client = wrap_anthropic(_get_async_client()) + msgs_in = [{"role": "user", "content": "what is 9+1?, just return the number"}] + + start = time.time() + msg_out = None + + async with client.beta.messages.stream(max_tokens=1024, messages=msgs_in, model=MODEL) as stream: + async for event in stream: + pass + msg_out = await stream.get_final_message() + assert "10" in msg_out.content[0].text + usage = msg_out.usage + end = time.time() + + logs = memory_logger.pop() + assert len(logs) == 1 + log = logs[0] + assert "user" in str(log["input"]) + assert "9+1" in str(log["input"]) + assert "10" in str(log["output"]) + assert log["project_id"] == PROJECT_NAME + assert log["span_attributes"]["type"] == "llm" + assert log["metadata"]["model"] == MODEL + assert log["metadata"]["max_tokens"] == 1024 + _assert_metrics_are_valid(log["metrics"], start, end) + metrics = log["metrics"] + assert metrics["prompt_tokens"] == usage.input_tokens + assert metrics["completion_tokens"] == usage.output_tokens + assert metrics["tokens"] == usage.input_tokens + usage.output_tokens + + +class TestAnthropicIntegrationSetup: + """Tests for `AnthropicIntegration.setup()`.""" + + def test_available_patchers(self): + assert AnthropicIntegration.available_patchers() == ( + "anthropic.init.sync", + "anthropic.init.async", + ) + + def test_resolve_patchers_honors_enable_disable_filters(self): + selected = AnthropicIntegration.resolve_patchers( + enabled_patchers={"anthropic.init.sync", "anthropic.init.async"}, + disabled_patchers={"anthropic.init.async"}, + ) + + assert tuple(patcher.identifier() for patcher in selected) == ("anthropic.init.sync",) + + def test_resolve_patchers_rejects_unknown_patchers(self): + with pytest.raises(ValueError, match="Unknown patchers"): + AnthropicIntegration.resolve_patchers(enabled_patchers={"anthropic.init.unknown"}) + + def test_setup_rejects_unsupported_versions(self): + spec = make_specifier( + min_version=AnthropicIntegration.min_version, max_version=AnthropicIntegration.max_version + ) + assert version_satisfies("0.47.9", spec) is False + + def test_setup_wraps_supported_clients(self): + """`AnthropicIntegration.setup()` should wrap both sync and async client constructors.""" + unpatched_sync = anthropic.Anthropic(api_key="test-key") + unpatched_async = anthropic.AsyncAnthropic(api_key="test-key") + assert type(unpatched_sync.messages).__module__.startswith("anthropic.") + assert type(unpatched_async.messages).__module__.startswith("anthropic.") + + AnthropicIntegration.setup() + patched_sync = anthropic.Anthropic(api_key="test-key") + patched_async = anthropic.AsyncAnthropic(api_key="test-key") + assert type(patched_sync.messages).__module__ == "braintrust.integrations.anthropic.tracing" + assert type(patched_async.messages).__module__ == "braintrust.integrations.anthropic.tracing" + + def test_setup_is_idempotent(self): + """Multiple `AnthropicIntegration.setup()` calls should be safe.""" + AnthropicIntegration.setup() + first_sync_init = inspect.getattr_static(anthropic.Anthropic, "__init__") + first_async_init = inspect.getattr_static(anthropic.AsyncAnthropic, "__init__") + + AnthropicIntegration.setup() + assert first_sync_init is inspect.getattr_static(anthropic.Anthropic, "__init__") + assert first_async_init is inspect.getattr_static(anthropic.AsyncAnthropic, "__init__") + + def test_setup_creates_spans(self): + """`AnthropicIntegration.setup()` should create spans when making API calls.""" + init_test_logger("test-auto") + with logger._internal_with_memory_background_logger() as memory_logger: + AnthropicIntegration.setup() + + client = anthropic.Anthropic() + + import braintrust + + with braintrust.start_span(name="test"): + try: + client.messages.create( + model="claude-3-5-haiku-latest", + max_tokens=100, + messages=[{"role": "user", "content": "hi"}], + ) + except Exception: + pass + + spans = memory_logger.pop() + assert len(spans) >= 1, f"Expected spans, got {spans}" + + +class TestPatchAnthropicSpans: + """VCR-based tests verifying that `AnthropicIntegration.setup()` produces spans.""" + + @pytest.mark.vcr + def test_patch_anthropic_creates_spans(self, memory_logger): + """`AnthropicIntegration.setup()` should create spans when making API calls.""" + assert not memory_logger.pop() + + AnthropicIntegration.setup() + client = anthropic.Anthropic() + response = client.messages.create( + model="claude-3-5-haiku-latest", + max_tokens=100, + messages=[{"role": "user", "content": "Say hi"}], + ) + assert response.content[0].text + + # Verify span was created + spans = memory_logger.pop() + assert len(spans) == 1 + span = spans[0] + assert span["metadata"]["provider"] == "anthropic" + assert "claude" in span["metadata"]["model"] + assert span["input"] + + +class TestPatchAnthropicAsyncSpans: + """VCR-based tests verifying that `AnthropicIntegration.setup()` produces spans for async clients.""" + + @pytest.mark.vcr + @pytest.mark.asyncio + async def test_patch_anthropic_async_creates_spans(self, memory_logger): + """`AnthropicIntegration.setup()` should create spans for async API calls.""" + assert not memory_logger.pop() + + AnthropicIntegration.setup() + client = anthropic.AsyncAnthropic() + response = await client.messages.create( + model="claude-3-5-haiku-latest", + max_tokens=100, + messages=[{"role": "user", "content": "Say hi async"}], + ) + assert response.content[0].text + + # Verify span was created + spans = memory_logger.pop() + assert len(spans) == 1 + span = spans[0] + assert span["metadata"]["provider"] == "anthropic" + assert "claude" in span["metadata"]["model"] + assert span["input"] diff --git a/py/src/braintrust/integrations/anthropic/tracing.py b/py/src/braintrust/integrations/anthropic/tracing.py new file mode 100644 index 00000000..de352018 --- /dev/null +++ b/py/src/braintrust/integrations/anthropic/tracing.py @@ -0,0 +1,367 @@ +import logging +import time +import warnings +from contextlib import contextmanager + +from braintrust.integrations.anthropic._utils import Wrapper, extract_anthropic_usage, finalize_anthropic_tokens +from braintrust.logger import NOOP_SPAN, log_exc_info_to_span, start_span + + +log = logging.getLogger(__name__) + + +# This tracer depends on an internal anthropic method used to merge +# streamed messages together. It's a bit tricky so I'm opting to use it +# here. If it goes away, this polyfill will make it a no-op and the only +# result will be missing `output` and metrics in our spans. Our tests always +# run against the latest version of anthropic's SDK, so we'll know. +# anthropic-sdk-python/blob/main/src/anthropic/lib/streaming/_messages.py#L392 +try: + from anthropic.lib.streaming._messages import accumulate_event +except ImportError: + + def accumulate_event(event=None, current_snapshot=None, **kwargs): + warnings.warn("braintrust: missing method: anthropic.lib.streaming._messages.accumulate_event") + return current_snapshot + + +# Anthropic model parameters that we want to track as span metadata. +METADATA_PARAMS = ( + "model", + "max_tokens", + "temperature", + "top_k", + "top_p", + "stop_sequences", + "tool_choice", + "tools", + "stream", + "thinking", +) + + +class TracedAsyncAnthropic(Wrapper): + def __init__(self, client): + super().__init__(client) + self.__client = client + + @property + def messages(self): + return AsyncMessages(self.__client.messages) + + @property + def beta(self): + return AsyncBeta(self.__client.beta) + + +class AsyncMessages(Wrapper): + def __init__(self, messages): + super().__init__(messages) + self.__messages = messages + + async def create(self, *args, **kwargs): + if kwargs.get("stream", False): + return await self.__create_with_stream_true(*args, **kwargs) + else: + return await self.__create_with_stream_false(*args, **kwargs) + + async def __create_with_stream_false(self, *args, **kwargs): + span = _start_span("anthropic.messages.create", kwargs) + request_start_time = time.time() + try: + result = await self.__messages.create(*args, **kwargs) + ttft = time.time() - request_start_time + with _catch_exceptions(): + _log_message_to_span(result, span, time_to_first_token=ttft) + return result + except Exception as e: + with _catch_exceptions(): + span.log(error=e) + raise + finally: + span.end() + + async def __create_with_stream_true(self, *args, **kwargs): + span = _start_span("anthropic.messages.stream", kwargs) + request_start_time = time.time() + try: + stream = await self.__messages.create(*args, **kwargs) + except Exception as e: + with _catch_exceptions(): + span.log(error=e) + span.end() + raise + + traced_stream = TracedMessageStream(stream, span, request_start_time) + + async def async_stream(): + try: + async for msg in traced_stream: + yield msg + except Exception as e: + with _catch_exceptions(): + span.log(error=e) + raise + finally: + with _catch_exceptions(): + msg = traced_stream._get_final_traced_message() + if msg: + ttft = traced_stream._get_time_to_first_token() + _log_message_to_span(msg, span, time_to_first_token=ttft) + span.end() + + return async_stream() + + def stream(self, *args, **kwargs): + span = _start_span("anthropic.messages.stream", kwargs) + request_start_time = time.time() + stream = self.__messages.stream(*args, **kwargs) + return TracedMessageStreamManager(stream, span, request_start_time) + + +class AsyncBeta(Wrapper): + def __init__(self, beta): + super().__init__(beta) + self.__beta = beta + + @property + def messages(self): + return AsyncMessages(self.__beta.messages) + + +class TracedAnthropic(Wrapper): + def __init__(self, client): + super().__init__(client) + self.__client = client + + @property + def messages(self): + return Messages(self.__client.messages) + + @property + def beta(self): + return Beta(self.__client.beta) + + +class Messages(Wrapper): + def __init__(self, messages): + super().__init__(messages) + self.__messages = messages + + def stream(self, *args, **kwargs): + return self.__trace_stream(self.__messages.stream, *args, **kwargs) + + def create(self, *args, **kwargs): + if kwargs.get("stream"): + return self.__trace_stream(self.__messages.create, *args, **kwargs) + + span = _start_span("anthropic.messages.create", kwargs) + request_start_time = time.time() + try: + msg = self.__messages.create(*args, **kwargs) + ttft = time.time() - request_start_time + _log_message_to_span(msg, span, time_to_first_token=ttft) + return msg + except Exception as e: + span.log(error=e) + raise + finally: + span.end() + + def __trace_stream(self, stream_func, *args, **kwargs): + span = _start_span("anthropic.messages.stream", kwargs) + request_start_time = time.time() + s = stream_func(*args, **kwargs) + return TracedMessageStreamManager(s, span, request_start_time) + + +class Beta(Wrapper): + def __init__(self, beta): + super().__init__(beta) + self.__beta = beta + + @property + def messages(self): + return Messages(self.__beta.messages) + + +class TracedMessageStreamManager(Wrapper): + def __init__(self, msg_stream_mgr, span, request_start_time: float): + super().__init__(msg_stream_mgr) + self.__msg_stream_mgr = msg_stream_mgr + self.__traced_message_stream = None + self.__span = span + self.__request_start_time = request_start_time + + async def __aenter__(self): + ms = await self.__msg_stream_mgr.__aenter__() + self.__traced_message_stream = TracedMessageStream(ms, self.__span, self.__request_start_time) + return self.__traced_message_stream + + def __enter__(self): + ms = self.__msg_stream_mgr.__enter__() + self.__traced_message_stream = TracedMessageStream(ms, self.__span, self.__request_start_time) + return self.__traced_message_stream + + def __aexit__(self, exc_type, exc_value, traceback): + try: + return self.__msg_stream_mgr.__aexit__(exc_type, exc_value, traceback) + finally: + with _catch_exceptions(): + self.__close(exc_type, exc_value, traceback) + + def __exit__(self, exc_type, exc_value, traceback): + try: + return self.__msg_stream_mgr.__exit__(exc_type, exc_value, traceback) + finally: + with _catch_exceptions(): + self.__close(exc_type, exc_value, traceback) + + def __close(self, exc_type, exc_value, traceback): + with _catch_exceptions(): + tms = self.__traced_message_stream + msg = tms._get_final_traced_message() + if msg: + ttft = tms._get_time_to_first_token() + _log_message_to_span(msg, self.__span, time_to_first_token=ttft) + if exc_type: + log_exc_info_to_span(self.__span, exc_type, exc_value, traceback) + self.__span.end() + + +class TracedMessageStream(Wrapper): + """TracedMessageStream wraps both sync and async message streams.""" + + def __init__(self, msg_stream, span, request_start_time: float): + super().__init__(msg_stream) + self.__msg_stream = msg_stream + self.__span = span + self.__metrics = {} + self.__snapshot = None + self.__request_start_time = request_start_time + self.__time_to_first_token: float | None = None + + def _get_final_traced_message(self): + return self.__snapshot + + def _get_time_to_first_token(self): + return self.__time_to_first_token + + def __await__(self): + return self.__msg_stream.__await__() + + def __aiter__(self): + return self + + def __iter__(self): + return self + + async def __anext__(self): + m = await self.__msg_stream.__anext__() + with _catch_exceptions(): + self.__process_message(m) + return m + + def __next__(self): + m = next(self.__msg_stream) + with _catch_exceptions(): + self.__process_message(m) + return m + + def __process_message(self, m): + if self.__time_to_first_token is None: + self.__time_to_first_token = time.time() - self.__request_start_time + + with _catch_exceptions(): + self.__snapshot = accumulate_event(event=m, current_snapshot=self.__snapshot) + + +def _get_input_from_kwargs(kwargs): + msgs = list(kwargs.get("messages", [])) + kwargs["messages"] = msgs.copy() + + system = kwargs.get("system", None) + if system: + msgs.append({"role": "system", "content": system}) + return msgs + + +def _get_metadata_from_kwargs(kwargs): + metadata = {"provider": "anthropic"} + for k in METADATA_PARAMS: + v = kwargs.get(k, None) + if v is not None: + metadata[k] = v + return metadata + + +def _start_span(name, kwargs): + with _catch_exceptions(): + _input = _get_input_from_kwargs(kwargs) + metadata = _get_metadata_from_kwargs(kwargs) + return start_span(name=name, type="llm", metadata=metadata, input=_input) + + return NOOP_SPAN + + +def _log_message_to_span(message, span, time_to_first_token: float | None = None): + with _catch_exceptions(): + usage = getattr(message, "usage", {}) + metrics = finalize_anthropic_tokens(extract_anthropic_usage(usage)) + + if time_to_first_token is not None: + metrics["time_to_first_token"] = time_to_first_token + + output = { + k: v + for k, v in {"role": getattr(message, "role", None), "content": getattr(message, "content", None)}.items() + if v + } or None + + span.log(output=output, metrics=metrics) + + +@contextmanager +def _catch_exceptions(): + try: + yield + except Exception as e: + log.warning("swallowing exception in tracing code", exc_info=e) + + +def _wrap_anthropic(client): + """Wrap an Anthropic object (or AsyncAnthropic) to add tracing.""" + type_name = getattr(type(client), "__name__") + if "AsyncAnthropic" in type_name: + return TracedAsyncAnthropic(client) + elif "Anthropic" in type_name: + return TracedAnthropic(client) + else: + return client + + +wrap_anthropic = _wrap_anthropic + + +def _apply_anthropic_wrapper(client): + wrapped = _wrap_anthropic(client) + client.messages = wrapped.messages + if hasattr(wrapped, "beta"): + client.beta = wrapped.beta + + +def _apply_async_anthropic_wrapper(client): + wrapped = _wrap_anthropic(client) + client.messages = wrapped.messages + if hasattr(wrapped, "beta"): + client.beta = wrapped.beta + + +def _anthropic_init_wrapper(wrapped, instance, args, kwargs): + wrapped(*args, **kwargs) + _apply_anthropic_wrapper(instance) + + +def _async_anthropic_init_wrapper(wrapped, instance, args, kwargs): + wrapped(*args, **kwargs) + _apply_async_anthropic_wrapper(instance) diff --git a/py/src/braintrust/wrappers/auto_test_scripts/test_auto_agno.py b/py/src/braintrust/integrations/auto_test_scripts/test_auto_agno.py similarity index 100% rename from py/src/braintrust/wrappers/auto_test_scripts/test_auto_agno.py rename to py/src/braintrust/integrations/auto_test_scripts/test_auto_agno.py diff --git a/py/src/braintrust/wrappers/auto_test_scripts/test_auto_anthropic.py b/py/src/braintrust/integrations/auto_test_scripts/test_auto_anthropic.py similarity index 60% rename from py/src/braintrust/wrappers/auto_test_scripts/test_auto_anthropic.py rename to py/src/braintrust/integrations/auto_test_scripts/test_auto_anthropic.py index 2c51d911..9833df9f 100644 --- a/py/src/braintrust/wrappers/auto_test_scripts/test_auto_anthropic.py +++ b/py/src/braintrust/integrations/auto_test_scripts/test_auto_anthropic.py @@ -6,12 +6,19 @@ # 1. Verify not patched initially -assert not getattr(anthropic, "__braintrust_wrapped__", False) +original_sync_module = type(anthropic.Anthropic(api_key="test-key").messages).__module__ +original_async_module = type(anthropic.AsyncAnthropic(api_key="test-key").messages).__module__ # 2. Instrument results = auto_instrument() assert results.get("anthropic") == True -assert getattr(anthropic, "__braintrust_wrapped__", False) + +patched_sync = anthropic.Anthropic(api_key="test-key") +patched_async = anthropic.AsyncAnthropic(api_key="test-key") +assert type(patched_sync.messages).__module__ == "braintrust.integrations.anthropic.tracing" +assert type(patched_async.messages).__module__ == "braintrust.integrations.anthropic.tracing" +assert type(patched_sync.messages).__module__ != original_sync_module +assert type(patched_async.messages).__module__ != original_async_module # 3. Idempotent results2 = auto_instrument() diff --git a/py/src/braintrust/integrations/auto_test_scripts/test_auto_anthropic_patch_config.py b/py/src/braintrust/integrations/auto_test_scripts/test_auto_anthropic_patch_config.py new file mode 100644 index 00000000..e15f9361 --- /dev/null +++ b/py/src/braintrust/integrations/auto_test_scripts/test_auto_anthropic_patch_config.py @@ -0,0 +1,20 @@ +"""Test auto_instrument patch selection for Anthropic.""" + +import anthropic +from braintrust.auto import auto_instrument +from braintrust.integrations import IntegrationPatchConfig + + +results = auto_instrument( + anthropic=IntegrationPatchConfig( + enabled_patchers={"anthropic.init.sync"}, + ) +) +assert results.get("anthropic") == True + +patched_sync = anthropic.Anthropic(api_key="test-key") +unpatched_async = anthropic.AsyncAnthropic(api_key="test-key") +assert type(patched_sync.messages).__module__ == "braintrust.integrations.anthropic.tracing" +assert type(unpatched_async.messages).__module__.startswith("anthropic.") + +print("SUCCESS") diff --git a/py/src/braintrust/wrappers/auto_test_scripts/test_auto_claude_agent_sdk.py b/py/src/braintrust/integrations/auto_test_scripts/test_auto_claude_agent_sdk.py similarity index 100% rename from py/src/braintrust/wrappers/auto_test_scripts/test_auto_claude_agent_sdk.py rename to py/src/braintrust/integrations/auto_test_scripts/test_auto_claude_agent_sdk.py diff --git a/py/src/braintrust/wrappers/auto_test_scripts/test_auto_dspy.py b/py/src/braintrust/integrations/auto_test_scripts/test_auto_dspy.py similarity index 100% rename from py/src/braintrust/wrappers/auto_test_scripts/test_auto_dspy.py rename to py/src/braintrust/integrations/auto_test_scripts/test_auto_dspy.py diff --git a/py/src/braintrust/wrappers/auto_test_scripts/test_auto_google_genai.py b/py/src/braintrust/integrations/auto_test_scripts/test_auto_google_genai.py similarity index 100% rename from py/src/braintrust/wrappers/auto_test_scripts/test_auto_google_genai.py rename to py/src/braintrust/integrations/auto_test_scripts/test_auto_google_genai.py diff --git a/py/src/braintrust/wrappers/auto_test_scripts/test_auto_litellm.py b/py/src/braintrust/integrations/auto_test_scripts/test_auto_litellm.py similarity index 100% rename from py/src/braintrust/wrappers/auto_test_scripts/test_auto_litellm.py rename to py/src/braintrust/integrations/auto_test_scripts/test_auto_litellm.py diff --git a/py/src/braintrust/wrappers/auto_test_scripts/test_auto_openai.py b/py/src/braintrust/integrations/auto_test_scripts/test_auto_openai.py similarity index 100% rename from py/src/braintrust/wrappers/auto_test_scripts/test_auto_openai.py rename to py/src/braintrust/integrations/auto_test_scripts/test_auto_openai.py diff --git a/py/src/braintrust/wrappers/auto_test_scripts/test_auto_pydantic_ai.py b/py/src/braintrust/integrations/auto_test_scripts/test_auto_pydantic_ai.py similarity index 100% rename from py/src/braintrust/wrappers/auto_test_scripts/test_auto_pydantic_ai.py rename to py/src/braintrust/integrations/auto_test_scripts/test_auto_pydantic_ai.py diff --git a/py/src/braintrust/wrappers/auto_test_scripts/test_patch_litellm_aresponses.py b/py/src/braintrust/integrations/auto_test_scripts/test_patch_litellm_aresponses.py similarity index 100% rename from py/src/braintrust/wrappers/auto_test_scripts/test_patch_litellm_aresponses.py rename to py/src/braintrust/integrations/auto_test_scripts/test_patch_litellm_aresponses.py diff --git a/py/src/braintrust/wrappers/auto_test_scripts/test_patch_litellm_responses.py b/py/src/braintrust/integrations/auto_test_scripts/test_patch_litellm_responses.py similarity index 100% rename from py/src/braintrust/wrappers/auto_test_scripts/test_patch_litellm_responses.py rename to py/src/braintrust/integrations/auto_test_scripts/test_patch_litellm_responses.py diff --git a/py/src/braintrust/integrations/base.py b/py/src/braintrust/integrations/base.py new file mode 100644 index 00000000..debe282b --- /dev/null +++ b/py/src/braintrust/integrations/base.py @@ -0,0 +1,211 @@ +"""Shared integration orchestration primitives.""" + +import importlib +import inspect +import re +from abc import ABC, abstractmethod +from collections.abc import Collection, Iterable +from dataclasses import dataclass +from typing import Any, ClassVar + +from wrapt import wrap_function_wrapper + +from .versioning import detect_module_version, make_specifier, version_satisfies + + +@dataclass(frozen=True) +class IntegrationPatchConfig: + """Per-integration patch selection for instrumentation setup.""" + + enabled_patchers: Collection[str] | None = None + disabled_patchers: Collection[str] | None = None + + +class BasePatcher(ABC): + """Base class for one concrete integration patch strategy.""" + + name: ClassVar[str] + patch_id: ClassVar[str | None] = None + version_spec: ClassVar[str | None] = None + priority: ClassVar[int] = 100 + + @classmethod + def identifier(cls) -> str: + """Return the public identifier for selecting this patcher.""" + return cls.patch_id or cls.name + + @classmethod + def applies(cls, module: Any | None, version: str | None, *, target: Any | None = None) -> bool: + """Return whether this patcher should run for the given module/version.""" + return version_satisfies(version, cls.version_spec) + + @classmethod + @abstractmethod + def is_patched(cls, module: Any | None, version: str | None, *, target: Any | None = None) -> bool: + """Return whether this patcher's target has already been instrumented.""" + raise NotImplementedError + + @classmethod + @abstractmethod + def patch(cls, module: Any | None, version: str | None, *, target: Any | None = None) -> bool: + """Apply instrumentation for this patcher.""" + raise NotImplementedError + + +class FunctionWrapperPatcher(BasePatcher): + """Base patcher for single-target `wrap_function_wrapper` instrumentation.""" + + target_path: ClassVar[str] + wrapper: ClassVar[Any] + + @classmethod + def resolve_root(cls, module: Any | None, version: str | None, *, target: Any | None = None) -> Any | None: + """Return the root object from which this patcher resolves its target.""" + return target or module + + @classmethod + def resolve_target(cls, module: Any | None, version: str | None, *, target: Any | None = None) -> Any | None: + """Return the concrete callable or descriptor that this patcher instruments.""" + root = cls.resolve_root(module, version, target=target) + if root is None: + return None + return _resolve_attr_path(root, cls.target_path) + + @classmethod + def applies(cls, module: Any | None, version: str | None, *, target: Any | None = None) -> bool: + """Return whether the target exists and this patcher's version gate passes.""" + return ( + super().applies(module, version, target=target) + and cls.resolve_target(module, version, target=target) is not None + ) + + @classmethod + def patch_marker_attr(cls) -> str: + """Return the sentinel attribute used to mark this target as patched.""" + suffix = re.sub(r"\W+", "_", cls.name).strip("_") + return f"__braintrust_patched_{suffix}__" + + @classmethod + def mark_patched(cls, obj: Any) -> None: + """Mark a wrapped target so future patch attempts are idempotent.""" + setattr(obj, cls.patch_marker_attr(), True) + + @classmethod + def is_patched(cls, module: Any | None, version: str | None, *, target: Any | None = None) -> bool: + """Return whether this patcher's target has already been instrumented.""" + resolved_target = cls.resolve_target(module, version, target=target) + return bool(resolved_target is not None and getattr(resolved_target, cls.patch_marker_attr(), False)) + + @classmethod + def patch(cls, module: Any | None, version: str | None, *, target: Any | None = None) -> bool: + """Apply instrumentation for this patcher.""" + root = cls.resolve_root(module, version, target=target) + if root is None or not cls.applies(module, version, target=target): + return False + + wrap_function_wrapper(root, cls.target_path, cls.wrapper) + resolved_target = cls.resolve_target(module, version, target=target) + if resolved_target is None: + return False + + cls.mark_patched(resolved_target) + return True + + +class BaseIntegration(ABC): + """Base class for an instrumentable third-party integration.""" + + name: ClassVar[str] + import_names: ClassVar[tuple[str, ...]] + patchers: ClassVar[tuple[type[BasePatcher], ...]] = () + min_version: ClassVar[str | None] = None + max_version: ClassVar[str | None] = None + + @classmethod + def available_patchers(cls) -> tuple[str, ...]: + """Return patcher identifiers in declaration order.""" + return tuple(patcher.identifier() for patcher in cls.patchers) + + @classmethod + def resolve_patchers( + cls, + *, + enabled_patchers: Collection[str] | None = None, + disabled_patchers: Collection[str] | None = None, + ) -> tuple[type[BasePatcher], ...]: + """Return the selected patchers after validating explicit selectors.""" + patchers_by_id: dict[str, type[BasePatcher]] = {} + for patcher in cls.patchers: + patcher_id = patcher.identifier() + existing = patchers_by_id.get(patcher_id) + if existing is not None and existing is not patcher: + raise ValueError(f"Duplicate patcher identifier {patcher_id!r} for integration {cls.name!r}") + patchers_by_id[patcher_id] = patcher + + enabled = set(enabled_patchers) if enabled_patchers is not None else None + disabled = set(disabled_patchers or ()) + requested = disabled if enabled is None else enabled | disabled + unknown = requested - set(patchers_by_id) + if unknown: + available = ", ".join(sorted(patchers_by_id)) + unknown_display = ", ".join(sorted(unknown)) + raise ValueError( + f"Unknown patchers for integration {cls.name!r}: {unknown_display}. Available patchers: {available}" + ) + + return tuple( + patcher + for patcher in cls.patchers + if (enabled is None or patcher.identifier() in enabled) and patcher.identifier() not in disabled + ) + + @classmethod + def setup( + cls, + *, + target: Any | None = None, + enabled_patchers: Collection[str] | None = None, + disabled_patchers: Collection[str] | None = None, + ) -> bool: + """Apply all applicable patchers for this integration.""" + module = _import_first_available(cls.import_names) + if module is None: + return False + version = detect_module_version(module, cls.import_names) + if not version_satisfies(version, make_specifier(min_version=cls.min_version, max_version=cls.max_version)): + return False + + success = False + selected_patchers = cls.resolve_patchers( + enabled_patchers=enabled_patchers, + disabled_patchers=disabled_patchers, + ) + for patcher in sorted(selected_patchers, key=lambda patcher: patcher.priority): + if not patcher.applies(module, version, target=target): + continue + if patcher.is_patched(module, version, target=target): + success = True + continue + success = patcher.patch(module, version, target=target) or success + + return success + + +def _import_first_available(import_names: Iterable[str]) -> Any | None: + """Import and return the first available module from the given names.""" + for import_name in import_names: + try: + return importlib.import_module(import_name) + except ImportError: + continue + return None + + +def _resolve_attr_path(root: Any, path: str) -> Any | None: + current = root + for part in path.split("."): + try: + current = inspect.getattr_static(current, part) + except AttributeError: + return None + return current diff --git a/py/src/braintrust/integrations/test_versioning.py b/py/src/braintrust/integrations/test_versioning.py new file mode 100644 index 00000000..226ba86f --- /dev/null +++ b/py/src/braintrust/integrations/test_versioning.py @@ -0,0 +1,34 @@ +from braintrust.integrations.versioning import version_satisfies + + +def test_version_satisfies_handles_prereleases(): + # 1.0rc1 is a pre-release of 1.0 — it sorts before the final 1.0 release. + assert not version_satisfies("1.0rc1", "<1.0") + assert not version_satisfies("1.0rc1", ">=1.0") + + # Pre-releases of a *different* version work as plain comparisons. + assert version_satisfies("1.0rc1", "<1.1") + assert not version_satisfies("1.0rc1", ">=1.1") + + # Explicit pre-release bounds. + assert version_satisfies("1.0rc1", ">=1.0rc1") + assert not version_satisfies("1.0rc1", ">1.0rc1") + assert version_satisfies("1.0rc1", "<1.0rc2") + + +def test_version_satisfies_ignores_trailing_zeroes(): + assert version_satisfies("1.0.0", "==1.0") + assert version_satisfies("1.2.0", ">=1.2") + + +def test_version_satisfies_none_handling(): + # No spec means anything is compatible. + assert version_satisfies("1.0", None) + assert version_satisfies(None, None) + + # No version with a spec means incompatible. + assert not version_satisfies(None, ">=1.0") + + +def test_version_satisfies_invalid_version(): + assert not version_satisfies("not-a-version", ">=1.0") diff --git a/py/src/braintrust/integrations/versioning.py b/py/src/braintrust/integrations/versioning.py new file mode 100644 index 00000000..384a4aa8 --- /dev/null +++ b/py/src/braintrust/integrations/versioning.py @@ -0,0 +1,55 @@ +import importlib.metadata +from typing import Any + +from packaging.specifiers import SpecifierSet +from packaging.version import InvalidVersion, Version + + +def detect_module_version(module: Any, import_names: tuple[str, ...]) -> str | None: + candidates: list[str] = [] + + module_name = getattr(module, "__name__", None) + if isinstance(module_name, str) and module_name: + candidates.append(module_name.split(".")[0]) + + module_package = getattr(module, "__package__", None) + if isinstance(module_package, str) and module_package: + candidates.append(module_package.split(".")[0]) + + candidates.extend(name.split(".")[0] for name in import_names) + + seen: set[str] = set() + for candidate in candidates: + if candidate in seen: + continue + seen.add(candidate) + try: + return importlib.metadata.version(candidate) + except importlib.metadata.PackageNotFoundError: + continue + + version = getattr(module, "__version__", None) + return version if isinstance(version, str) else None + + +def make_specifier(*, min_version: str | None = None, max_version: str | None = None) -> SpecifierSet: + """Build a :class:`SpecifierSet` from optional min/max bounds.""" + spec = SpecifierSet(prereleases=True) + if min_version is not None: + spec &= SpecifierSet(f">={min_version}", prereleases=True) + if max_version is not None: + spec &= SpecifierSet(f"<={max_version}", prereleases=True) + return spec + + +def version_satisfies(version: str | None, spec: str | SpecifierSet | None) -> bool: + """Return True if *version* satisfies the PEP 440 *spec*.""" + if spec is None: + return True + if version is None: + return False + try: + ss = spec if isinstance(spec, SpecifierSet) else SpecifierSet(spec, prereleases=True) + return Version(version) in ss + except InvalidVersion: + return False diff --git a/py/src/braintrust/wrappers/anthropic.py b/py/src/braintrust/wrappers/anthropic.py index e14e971a..b89422dc 100644 --- a/py/src/braintrust/wrappers/anthropic.py +++ b/py/src/braintrust/wrappers/anthropic.py @@ -1,426 +1,7 @@ -import logging -import time -import warnings -from contextlib import contextmanager +from braintrust.integrations.anthropic import wrap_anthropic, wrap_anthropic_client -from braintrust.logger import NOOP_SPAN, log_exc_info_to_span, start_span -from braintrust.wrappers._anthropic_utils import Wrapper, extract_anthropic_usage, finalize_anthropic_tokens -from wrapt import wrap_function_wrapper - -log = logging.getLogger(__name__) - - -# This tracer depends on an internal anthropic method used to merge -# streamed messages together. It's a bit tricky so I'm opting to use it -# here. If it goes away, this polyfill will make it a no-op and the only -# result will be missing `output` and metrics in our spans. Our tests always -# run against the latest version of anthropic's SDK, so we'll know. -# anthropic-sdk-python/blob/main/src/anthropic/lib/streaming/_messages.py#L392 -try: - from anthropic.lib.streaming._messages import accumulate_event -except ImportError: - - def accumulate_event(event=None, current_snapshot=None, **kwargs): - warnings.warn("braintrust: missing method: anthropic.lib.streaming._messages.accumulate_event") - return current_snapshot - - -# Anthropic model parameters that we want to track as span metadata. -METADATA_PARAMS = ( - "model", - "max_tokens", - "temperature", - "top_k", - "top_p", - "stop_sequences", - "tool_choice", - "tools", - "stream", - "thinking", -) - - -class TracedAsyncAnthropic(Wrapper): - def __init__(self, client): - super().__init__(client) - self.__client = client - - @property - def messages(self): - return AsyncMessages(self.__client.messages) - - @property - def beta(self): - return AsyncBeta(self.__client.beta) - - -class AsyncMessages(Wrapper): - def __init__(self, messages): - super().__init__(messages) - self.__messages = messages - - async def create(self, *args, **kwargs): - if kwargs.get("stream", False): - return await self.__create_with_stream_true(*args, **kwargs) - else: - return await self.__create_with_stream_false(*args, **kwargs) - - async def __create_with_stream_false(self, *args, **kwargs): - span = _start_span("anthropic.messages.create", kwargs) - request_start_time = time.time() - try: - result = await self.__messages.create(*args, **kwargs) - ttft = time.time() - request_start_time - with _catch_exceptions(): - _log_message_to_span(result, span, time_to_first_token=ttft) - return result - except Exception as e: - with _catch_exceptions(): - span.log(error=e) - raise - finally: - span.end() - - async def __create_with_stream_true(self, *args, **kwargs): - span = _start_span("anthropic.messages.stream", kwargs) - request_start_time = time.time() - try: - stream = await self.__messages.create(*args, **kwargs) - except Exception as e: - with _catch_exceptions(): - span.log(error=e) - span.end() - raise - - traced_stream = TracedMessageStream(stream, span, request_start_time) - - async def async_stream(): - try: - async for msg in traced_stream: - yield msg - except Exception as e: - with _catch_exceptions(): - span.log(error=e) - raise - finally: - with _catch_exceptions(): - msg = traced_stream._get_final_traced_message() - if msg: - ttft = traced_stream._get_time_to_first_token() - _log_message_to_span(msg, span, time_to_first_token=ttft) - span.end() - - return async_stream() - - def stream(self, *args, **kwargs): - span = _start_span("anthropic.messages.stream", kwargs) - request_start_time = time.time() - stream = self.__messages.stream(*args, **kwargs) - return TracedMessageStreamManager(stream, span, request_start_time) - - -class AsyncBeta(Wrapper): - def __init__(self, beta): - super().__init__(beta) - self.__beta = beta - - @property - def messages(self): - return AsyncMessages(self.__beta.messages) - - -class TracedAnthropic(Wrapper): - def __init__(self, client): - super().__init__(client) - self.__client = client - - @property - def messages(self): - return Messages(self.__client.messages) - - @property - def beta(self): - return Beta(self.__client.beta) - - -class Messages(Wrapper): - def __init__(self, messages): - super().__init__(messages) - self.__messages = messages - - def stream(self, *args, **kwargs): - return self.__trace_stream(self.__messages.stream, *args, **kwargs) - - def create(self, *args, **kwargs): - # If stream is True, we need to trace the stream function - if kwargs.get("stream"): - return self.__trace_stream(self.__messages.create, *args, **kwargs) - - span = _start_span("anthropic.messages.create", kwargs) - request_start_time = time.time() - try: - msg = self.__messages.create(*args, **kwargs) - ttft = time.time() - request_start_time - _log_message_to_span(msg, span, time_to_first_token=ttft) - return msg - except Exception as e: - span.log(error=e) - raise - finally: - span.end() - - def __trace_stream(self, stream_func, *args, **kwargs): - span = _start_span("anthropic.messages.stream", kwargs) - request_start_time = time.time() - s = stream_func(*args, **kwargs) - return TracedMessageStreamManager(s, span, request_start_time) - - -class Beta(Wrapper): - def __init__(self, beta): - super().__init__(beta) - self.__beta = beta - - @property - def messages(self): - return Messages(self.__beta.messages) - - -class TracedMessageStreamManager(Wrapper): - def __init__(self, msg_stream_mgr, span, request_start_time: float): - super().__init__(msg_stream_mgr) - self.__msg_stream_mgr = msg_stream_mgr - self.__traced_message_stream = None - self.__span = span - self.__request_start_time = request_start_time - - async def __aenter__(self): - ms = await self.__msg_stream_mgr.__aenter__() - self.__traced_message_stream = TracedMessageStream(ms, self.__span, self.__request_start_time) - return self.__traced_message_stream - - def __enter__(self): - ms = self.__msg_stream_mgr.__enter__() - self.__traced_message_stream = TracedMessageStream(ms, self.__span, self.__request_start_time) - return self.__traced_message_stream - - def __aexit__(self, exc_type, exc_value, traceback): - try: - return self.__msg_stream_mgr.__aexit__(exc_type, exc_value, traceback) - finally: - with _catch_exceptions(): - self.__close(exc_type, exc_value, traceback) - - def __exit__(self, exc_type, exc_value, traceback): - try: - return self.__msg_stream_mgr.__exit__(exc_type, exc_value, traceback) - finally: - with _catch_exceptions(): - self.__close(exc_type, exc_value, traceback) - - def __close(self, exc_type, exc_value, traceback): - with _catch_exceptions(): - tms = self.__traced_message_stream - msg = tms._get_final_traced_message() - if msg: - ttft = tms._get_time_to_first_token() - _log_message_to_span(msg, self.__span, time_to_first_token=ttft) - if exc_type: - log_exc_info_to_span(self.__span, exc_type, exc_value, traceback) - self.__span.end() - - -class TracedMessageStream(Wrapper): - """TracedMessageStream wraps both sync and async message streams. Obviously only one - makes sense at a time - """ - - def __init__(self, msg_stream, span, request_start_time: float): - super().__init__(msg_stream) - self.__msg_stream = msg_stream - self.__span = span - self.__metrics = {} - self.__snapshot = None - self.__request_start_time = request_start_time - self.__time_to_first_token: float | None = None - - def _get_final_traced_message(self): - return self.__snapshot - - def _get_time_to_first_token(self): - return self.__time_to_first_token - - def __await__(self): - return self.__msg_stream.__await__() - - def __aiter__(self): - return self - - def __iter__(self): - return self - - async def __anext__(self): - m = await self.__msg_stream.__anext__() - with _catch_exceptions(): - self.__process_message(m) - return m - - def __next__(self): - m = next(self.__msg_stream) - with _catch_exceptions(): - self.__process_message(m) - return m - - def __process_message(self, m): - # Track time to first token on the first message - if self.__time_to_first_token is None: - self.__time_to_first_token = time.time() - self.__request_start_time - - with _catch_exceptions(): - self.__snapshot = accumulate_event(event=m, current_snapshot=self.__snapshot) - - -def _get_input_from_kwargs(kwargs): - msgs = list(kwargs.get("messages", [])) - # save a copy of the messages because it might be a generator - # and we may mutate it. - kwargs["messages"] = msgs.copy() - - system = kwargs.get("system", None) - if system: - msgs.append({"role": "system", "content": system}) - return msgs - - -def _get_metadata_from_kwargs(kwargs): - metadata = {"provider": "anthropic"} - for k in METADATA_PARAMS: - v = kwargs.get(k, None) - if v is not None: - metadata[k] = v - return metadata - - -def _start_span(name, kwargs): - """Start a span with the given name, tagged with all of the relevant data from kwargs. kwargs is the dictionary of options - passed into anthropic.messages.create or anthropic.messages.stream. - """ - with _catch_exceptions(): - _input = _get_input_from_kwargs(kwargs) - metadata = _get_metadata_from_kwargs(kwargs) - return start_span(name=name, type="llm", metadata=metadata, input=_input) - - # if this failed, maintain the API. - return NOOP_SPAN - - -def _log_message_to_span(message, span, time_to_first_token: float | None = None): - """Log telemetry from the given anthropic.Message to the given span.""" - with _catch_exceptions(): - usage = getattr(message, "usage", {}) - metrics = finalize_anthropic_tokens(extract_anthropic_usage(usage)) - - # Add time_to_first_token if provided - if time_to_first_token is not None: - metrics["time_to_first_token"] = time_to_first_token - - # Create output dict with only truthy values for role and content - output = { - k: v - for k, v in {"role": getattr(message, "role", None), "content": getattr(message, "content", None)}.items() - if v - } or None - - span.log(output=output, metrics=metrics) - - -@contextmanager -def _catch_exceptions(): - try: - yield - except Exception as e: - log.warning("swallowing exception in tracing code", exc_info=e) - - -def wrap_anthropic(client): - """Wrap an `Anthropic` object (or AsyncAnthropic) to add tracing. If Braintrust - is not configured, this is a no-op. If this is not an `Anthropic` object, this - function is a no-op. - """ - type_name = getattr(type(client), "__name__") - # We use 'in' because it could be AsyncAnthropicBedrock - if "AsyncAnthropic" in type_name: - return TracedAsyncAnthropic(client) - elif "Anthropic" in type_name: - return TracedAnthropic(client) - else: - # Unexpected. - return client - - -def wrap_anthropic_client(client): - return wrap_anthropic(client) - - -def _apply_anthropic_wrapper(client): - """Apply tracing wrapper to an Anthropic client instance in-place.""" - wrapped = wrap_anthropic(client) - client.messages = wrapped.messages - if hasattr(wrapped, "beta"): - client.beta = wrapped.beta - - -def _apply_async_anthropic_wrapper(client): - """Apply tracing wrapper to an AsyncAnthropic client instance in-place.""" - wrapped = wrap_anthropic(client) - client.messages = wrapped.messages - if hasattr(wrapped, "beta"): - client.beta = wrapped.beta - - -def _anthropic_init_wrapper(wrapped, instance, args, kwargs): - """Wrapper for Anthropic.__init__ that applies tracing after initialization.""" - wrapped(*args, **kwargs) - _apply_anthropic_wrapper(instance) - - -def _async_anthropic_init_wrapper(wrapped, instance, args, kwargs): - """Wrapper for AsyncAnthropic.__init__ that applies tracing after initialization.""" - wrapped(*args, **kwargs) - _apply_async_anthropic_wrapper(instance) - - -def patch_anthropic() -> bool: - """ - Patch Anthropic to add Braintrust tracing globally. - - After calling this, all new Anthropic() and AsyncAnthropic() clients - will automatically have tracing enabled. - - Returns: - True if Anthropic was patched (or already patched), False if Anthropic is not installed. - - Example: - ```python - import braintrust - braintrust.patch_anthropic() - - import anthropic - client = anthropic.Anthropic() - # All calls are now traced! - ``` - """ - try: - import anthropic - - if getattr(anthropic, "__braintrust_wrapped__", False): - return True # Already patched - - wrap_function_wrapper("anthropic", "Anthropic.__init__", _anthropic_init_wrapper) - wrap_function_wrapper("anthropic", "AsyncAnthropic.__init__", _async_anthropic_init_wrapper) - anthropic.__braintrust_wrapped__ = True - return True - - except ImportError: - return False +__all__ = [ + "wrap_anthropic", + "wrap_anthropic_client", +] diff --git a/py/src/braintrust/wrappers/claude_agent_sdk/_wrapper.py b/py/src/braintrust/wrappers/claude_agent_sdk/_wrapper.py index 1cefc6d4..c78a10ac 100644 --- a/py/src/braintrust/wrappers/claude_agent_sdk/_wrapper.py +++ b/py/src/braintrust/wrappers/claude_agent_sdk/_wrapper.py @@ -7,9 +7,9 @@ from collections.abc import AsyncGenerator, AsyncIterable from typing import Any +from braintrust.integrations.anthropic._utils import Wrapper, extract_anthropic_usage, finalize_anthropic_tokens from braintrust.logger import start_span from braintrust.span_types import SpanTypeAttribute -from braintrust.wrappers._anthropic_utils import Wrapper, extract_anthropic_usage, finalize_anthropic_tokens from braintrust.wrappers.claude_agent_sdk._constants import ( ANTHROPIC_MESSAGES_CREATE_SPAN_NAME, CLAUDE_AGENT_TASK_SPAN_NAME, diff --git a/py/src/braintrust/wrappers/test_anthropic.py b/py/src/braintrust/wrappers/test_anthropic.py index 54182597..e7884464 100644 --- a/py/src/braintrust/wrappers/test_anthropic.py +++ b/py/src/braintrust/wrappers/test_anthropic.py @@ -1,642 +1,102 @@ """ -Tests to ensure we reliably wrap the Anthropic API. +Compatibility tests for the Anthropic wrapper import path. """ -import time - -import anthropic -import pytest -from braintrust import logger -from braintrust.test_helpers import init_test_logger -from braintrust.wrappers.anthropic import wrap_anthropic from braintrust.wrappers.test_utils import run_in_subprocess, verify_autoinstrument_script -TEST_ORG_ID = "test-org-123" -PROJECT_NAME = "test-anthropic-app" -MODEL = "claude-3-haiku-20240307" # use the cheapest model since answers dont matter - - -def _get_client(): - return anthropic.Anthropic() - - -def _get_async_client(): - return anthropic.AsyncAnthropic() - - -@pytest.fixture -def memory_logger(): - init_test_logger(PROJECT_NAME) - with logger._internal_with_memory_background_logger() as bgl: - yield bgl - - -@pytest.mark.vcr -def test_anthropic_messages_create_stream_true(memory_logger): - assert not memory_logger.pop() - - client = wrap_anthropic(_get_client()) - kws = { - "model": MODEL, - "max_tokens": 300, - "messages": [{"role": "user", "content": "What is 3*4?"}], - "stream": True, - } - - start = time.time() - with client.messages.create(**kws) as out: - msgs = [m for m in out] - end = time.time() - - assert msgs # a very coarse grained check that this works - - spans = memory_logger.pop() - assert len(spans) == 1 - span = spans[0] - assert span["metadata"]["model"] == MODEL - assert span["metadata"]["provider"] == "anthropic" - assert span["metadata"]["max_tokens"] == 300 - assert span["metadata"]["stream"] == True - metrics = span["metrics"] - _assert_metrics_are_valid(metrics, start, end) - assert span["input"] == kws["messages"] - assert span["output"] - assert span["output"]["role"] == "assistant" - assert "12" in span["output"]["content"][0]["text"] - - -@pytest.mark.vcr -def test_anthropic_messages_model_params_inputs(memory_logger): - assert not memory_logger.pop() - client = wrap_anthropic(_get_client()) - - kw = { - "model": MODEL, - "max_tokens": 300, - "system": "just return the number", - "messages": [{"role": "user", "content": "what is 1+1?"}], - "temperature": 0.5, - "top_p": 0.5, - } - - def _with_messages_create(): - return client.messages.create(**kw) - - def _with_messages_stream(): - with client.messages.stream(**kw) as stream: - for msg in stream: - pass - return stream.get_final_message() - - for f in [_with_messages_create, _with_messages_stream]: - msg = f() - assert msg.content[0].text == "2" - - logs = memory_logger.pop() - assert len(logs) == 1 - log = logs[0] - assert log["output"]["role"] == "assistant" - assert "2" in log["output"]["content"][0]["text"] - assert log["metadata"]["model"] == MODEL - assert log["metadata"]["max_tokens"] == 300 - assert log["metadata"]["temperature"] == 0.5 - assert log["metadata"]["top_p"] == 0.5 - - -@pytest.mark.vcr -def test_anthropic_messages_system_prompt_inputs(memory_logger): - assert not memory_logger.pop() - - client = wrap_anthropic(_get_client()) - system = "Today's date is 2024-03-26. Only return the date" - q = [{"role": "user", "content": "what is tomorrow's date? only return the date"}] - - args = { - "messages": q, - "temperature": 0, - "max_tokens": 300, - "system": system, - "model": MODEL, - } - - def _with_messages_create(): - return client.messages.create(**args) - - def _with_messages_stream(): - with client.messages.stream(**args) as stream: - for msg in stream: - pass - return stream.get_final_message() - - for f in [_with_messages_create, _with_messages_stream]: - msg = f() - assert "2024-03-27" in msg.content[0].text - - logs = memory_logger.pop() - assert len(logs) == 1 - log = logs[0] - inputs = log["input"] - assert len(inputs) == 2 - inputs_by_role = {m["role"]: m["content"] for m in inputs} - assert inputs_by_role["system"] == system - assert inputs_by_role["user"] == q[0]["content"] - - -@pytest.mark.vcr -@pytest.mark.asyncio -async def test_anthropic_messages_create_async(memory_logger): - assert not memory_logger.pop() - - params = { - "model": MODEL, - "max_tokens": 100, - "messages": [{"role": "user", "content": "what is 6+1?, just return the number"}], - } - - client = wrap_anthropic(anthropic.AsyncAnthropic()) - msg = await client.messages.create(**params) - assert "7" in msg.content[0].text - - spans = memory_logger.pop() - assert len(spans) == 1 - span = spans[0] - assert span["metadata"]["model"] == MODEL - assert span["metadata"]["max_tokens"] == 100 - assert span["input"] == params["messages"] - assert span["output"]["role"] == "assistant" - assert "7" in span["output"]["content"][0]["text"] - - -@pytest.mark.vcr -@pytest.mark.asyncio -async def test_anthropic_messages_create_async_stream_true(memory_logger): - assert not memory_logger.pop() - - params = { - "model": MODEL, - "max_tokens": 100, - "messages": [{"role": "user", "content": "what is 6+1?, just return the number"}], - "stream": True, - } - - client = wrap_anthropic(anthropic.AsyncAnthropic()) - stream = await client.messages.create(**params) - async for event in stream: - pass - - spans = memory_logger.pop() - assert len(spans) == 1 - span = spans[0] - assert span["metadata"]["model"] == MODEL - assert span["metadata"]["max_tokens"] == 100 - assert span["input"] == params["messages"] - assert span["output"]["role"] == "assistant" - assert "7" in span["output"]["content"][0]["text"] - - -@pytest.mark.vcr -@pytest.mark.asyncio -async def test_anthropic_messages_streaming_async(memory_logger): - assert not memory_logger.pop() - - client = wrap_anthropic(_get_async_client()) - msgs_in = [{"role": "user", "content": "what is 1+1?, just return the number"}] - - start = time.time() - msg_out = None - - async with client.messages.stream(max_tokens=1024, messages=msgs_in, model=MODEL) as stream: - async for event in stream: - pass - msg_out = await stream.get_final_message() - assert msg_out.content[0].text == "2" - usage = msg_out.usage - end = time.time() - - logs = memory_logger.pop() - assert len(logs) == 1 - log = logs[0] - assert "user" in str(log["input"]) - assert "1+1" in str(log["input"]) - assert "2" in str(log["output"]) - assert log["project_id"] == PROJECT_NAME - assert log["span_attributes"]["type"] == "llm" - assert log["metadata"]["model"] == MODEL - assert log["metadata"]["max_tokens"] == 1024 - _assert_metrics_are_valid(log["metrics"], start, end) - metrics = log["metrics"] - assert metrics["prompt_tokens"] == usage.input_tokens - assert metrics["completion_tokens"] == usage.output_tokens - assert metrics["tokens"] == usage.input_tokens + usage.output_tokens - assert metrics["prompt_cached_tokens"] == usage.cache_read_input_tokens - assert metrics["prompt_cache_creation_tokens"] == usage.cache_creation_input_tokens - assert log["metadata"]["model"] == MODEL - assert log["metadata"]["max_tokens"] == 1024 - - -@pytest.mark.vcr -def test_anthropic_client_error(memory_logger): - assert not memory_logger.pop() - - client = wrap_anthropic(_get_client()) - - fake_model = "there-is-no-such-model" - msg_in = {"role": "user", "content": "who are you?"} - - try: - client.messages.create(model=fake_model, max_tokens=999, messages=[msg_in]) - except Exception: - pass - else: - raise Exception("should have raised an exception") - - logs = memory_logger.pop() - assert len(logs) == 1 - log = logs[0] - assert log["project_id"] == PROJECT_NAME - assert "404" in log["error"] - - -@pytest.mark.vcr -def test_anthropic_messages_stream_errors(memory_logger): - assert not memory_logger.pop() - - client = wrap_anthropic(_get_client()) - msg_in = {"role": "user", "content": "what is 2+2? (just the number)"} - - try: - with client.messages.stream(model=MODEL, max_tokens=300, messages=[msg_in]) as stream: - raise Exception("fake-error") - except Exception: - pass - else: - raise Exception("should have raised an exception") - - spans = memory_logger.pop() - assert len(spans) == 1 - span = spans[0] - assert "Exception: fake-error" in span["error"] - assert span["metrics"]["end"] > 0 - - -@pytest.mark.vcr -def test_anthropic_messages_streaming_sync(memory_logger): - assert not memory_logger.pop() - - client = wrap_anthropic(_get_client()) - msg_in = {"role": "user", "content": "what is 2+2? (just the number)"} - - start = time.time() - with client.messages.stream(model=MODEL, max_tokens=300, messages=[msg_in]) as stream: - msgs_out = [m for m in stream] - end = time.time() - msg_out = stream.get_final_message() - usage = msg_out.usage - # crudely check that the stream is valid - assert len(msgs_out) > 3 - assert 1 <= len([m for m in msgs_out if m.type == "text"]) - assert msgs_out[0].type == "message_start" - assert msgs_out[-1].type == "message_stop" - - logs = memory_logger.pop() - assert len(logs) == 1 - log = logs[0] - assert "user" in str(log["input"]) - assert "2+2" in str(log["input"]) - assert "4" in str(log["output"]) - assert log["project_id"] == PROJECT_NAME - assert log["span_attributes"]["type"] == "llm" - _assert_metrics_are_valid(log["metrics"], start, end) - assert log["metrics"]["prompt_tokens"] == usage.input_tokens - assert log["metrics"]["completion_tokens"] == usage.output_tokens - assert log["metrics"]["tokens"] == usage.input_tokens + usage.output_tokens - assert log["metrics"]["prompt_cached_tokens"] == usage.cache_read_input_tokens - assert log["metrics"]["prompt_cache_creation_tokens"] == usage.cache_creation_input_tokens - - -@pytest.mark.vcr -def test_anthropic_messages_sync(memory_logger): - assert not memory_logger.pop() - - client = wrap_anthropic(_get_client()) - - msg_in = {"role": "user", "content": "what's 2+2?"} - - start = time.time() - msg = client.messages.create(model=MODEL, max_tokens=300, messages=[msg_in]) - end = time.time() - - text = msg.content[0].text - assert text - - # verify we generated the right spans. - logs = memory_logger.pop() - - assert len(logs) == 1 - log = logs[0] - assert "2+2" in str(log["input"]) - assert "4" in str(log["output"]) - assert log["project_id"] == PROJECT_NAME - assert log["span_id"] - assert log["root_span_id"] - attrs = log["span_attributes"] - assert attrs["type"] == "llm" - assert "anthropic" in attrs["name"] - metrics = log["metrics"] - _assert_metrics_are_valid(metrics, start, end) - assert log["metadata"]["model"] == MODEL - - -def _assert_metrics_are_valid(metrics, start, end): - assert metrics["tokens"] > 0 - assert metrics["prompt_tokens"] > 0 - assert metrics["completion_tokens"] > 0 - assert "time_to_first_token" in metrics - assert metrics["time_to_first_token"] >= 0 - if start and end: - assert start <= metrics["start"] <= metrics["end"] <= end - else: - assert metrics["start"] <= metrics["end"] - - -@pytest.mark.vcr -def test_anthropic_beta_messages_sync(memory_logger): - assert not memory_logger.pop() - - client = wrap_anthropic(_get_client()) - msg_in = {"role": "user", "content": "what's 3+3?"} - - start = time.time() - msg = client.beta.messages.create(model=MODEL, max_tokens=300, messages=[msg_in]) - end = time.time() - - text = msg.content[0].text - assert text - assert "6" in text - - logs = memory_logger.pop() - assert len(logs) == 1 - log = logs[0] - assert "3+3" in str(log["input"]) - assert "6" in str(log["output"]) - assert log["project_id"] == PROJECT_NAME - assert log["span_id"] - assert log["root_span_id"] - attrs = log["span_attributes"] - assert attrs["type"] == "llm" - assert "anthropic" in attrs["name"] - metrics = log["metrics"] - _assert_metrics_are_valid(metrics, start, end) - assert log["metadata"]["model"] == MODEL - - -@pytest.mark.vcr -def test_anthropic_beta_messages_stream_sync(memory_logger): - assert not memory_logger.pop() - - client = wrap_anthropic(_get_client()) - msg_in = {"role": "user", "content": "what is 5+5? (just the number)"} - - start = time.time() - with client.beta.messages.stream(model=MODEL, max_tokens=300, messages=[msg_in]) as stream: - msgs_out = [m for m in stream] - end = time.time() - msg_out = stream.get_final_message() - usage = msg_out.usage - - assert len(msgs_out) > 3 - assert msgs_out[0].type == "message_start" - assert msgs_out[-1].type == "message_stop" - assert "10" in msg_out.content[0].text - - logs = memory_logger.pop() - assert len(logs) == 1 - log = logs[0] - assert "user" in str(log["input"]) - assert "5+5" in str(log["input"]) - assert "10" in str(log["output"]) - assert log["project_id"] == PROJECT_NAME - assert log["span_attributes"]["type"] == "llm" - _assert_metrics_are_valid(log["metrics"], start, end) - assert log["metrics"]["prompt_tokens"] == usage.input_tokens - assert log["metrics"]["completion_tokens"] == usage.output_tokens - assert log["metrics"]["tokens"] == usage.input_tokens + usage.output_tokens - - -@pytest.mark.vcr -@pytest.mark.asyncio -async def test_anthropic_beta_messages_create_async(memory_logger): - assert not memory_logger.pop() - - params = { - "model": MODEL, - "max_tokens": 100, - "messages": [{"role": "user", "content": "what is 8+2?, just return the number"}], - } - - client = wrap_anthropic(anthropic.AsyncAnthropic()) - msg = await client.beta.messages.create(**params) - assert "10" in msg.content[0].text - - spans = memory_logger.pop() - assert len(spans) == 1 - span = spans[0] - assert span["metadata"]["model"] == MODEL - assert span["metadata"]["max_tokens"] == 100 - assert span["input"] == params["messages"] - assert span["output"]["role"] == "assistant" - assert "10" in span["output"]["content"][0]["text"] - - -@pytest.mark.vcr( - match_on=["method", "scheme", "host", "port", "path", "body"] -) # exclude query - varies by SDK version -@pytest.mark.asyncio -async def test_anthropic_beta_messages_streaming_async(memory_logger): - assert not memory_logger.pop() - - client = wrap_anthropic(_get_async_client()) - msgs_in = [{"role": "user", "content": "what is 9+1?, just return the number"}] - - start = time.time() - msg_out = None - - async with client.beta.messages.stream(max_tokens=1024, messages=msgs_in, model=MODEL) as stream: - async for event in stream: - pass - msg_out = await stream.get_final_message() - assert "10" in msg_out.content[0].text - usage = msg_out.usage - end = time.time() - - logs = memory_logger.pop() - assert len(logs) == 1 - log = logs[0] - assert "user" in str(log["input"]) - assert "9+1" in str(log["input"]) - assert "10" in str(log["output"]) - assert log["project_id"] == PROJECT_NAME - assert log["span_attributes"]["type"] == "llm" - assert log["metadata"]["model"] == MODEL - assert log["metadata"]["max_tokens"] == 1024 - _assert_metrics_are_valid(log["metrics"], start, end) - metrics = log["metrics"] - assert metrics["prompt_tokens"] == usage.input_tokens - assert metrics["completion_tokens"] == usage.output_tokens - assert metrics["tokens"] == usage.input_tokens + usage.output_tokens - - -class TestPatchAnthropic: - """Tests for patch_anthropic() / unpatch_anthropic().""" - - def test_patch_anthropic_sets_wrapped_flag(self): - """patch_anthropic() should set __braintrust_wrapped__ on anthropic module.""" +class TestAnthropicWrapperCompat: + def test_anthropic_wrapper_compat_exports(self): result = run_in_subprocess(""" - from braintrust.wrappers.anthropic import patch_anthropic - import anthropic + from braintrust.wrappers.anthropic import wrap_anthropic as compat_wrap + from braintrust.integrations.anthropic import wrap_anthropic as new_wrap + from braintrust.integrations.anthropic import wrap_anthropic_client - assert not hasattr(anthropic, "__braintrust_wrapped__") - patch_anthropic() - assert hasattr(anthropic, "__braintrust_wrapped__") + assert compat_wrap is new_wrap + assert callable(wrap_anthropic_client) print("SUCCESS") """) assert result.returncode == 0, f"Failed: {result.stderr}" assert "SUCCESS" in result.stdout - def test_patch_anthropic_wraps_new_clients(self): - """After patch_anthropic(), new Anthropic() clients should be wrapped.""" + def test_anthropic_integration_setup_wraps_supported_clients(self): result = run_in_subprocess(""" - from braintrust.wrappers.anthropic import patch_anthropic - patch_anthropic() - + from braintrust.integrations.anthropic import AnthropicIntegration import anthropic - client = anthropic.Anthropic(api_key="test-key") - # Check that messages is wrapped - messages_type = type(client.messages).__name__ - print(f"messages_type={messages_type}") + original_sync_module = type(anthropic.Anthropic(api_key="test-key").messages).__module__ + original_async_module = type(anthropic.AsyncAnthropic(api_key="test-key").messages).__module__ + AnthropicIntegration.setup() + patched_sync = anthropic.Anthropic(api_key="test-key") + patched_async = anthropic.AsyncAnthropic(api_key="test-key") + + assert type(patched_sync.messages).__module__ == "braintrust.integrations.anthropic.tracing" + assert type(patched_async.messages).__module__ == "braintrust.integrations.anthropic.tracing" + assert type(patched_sync.messages).__module__ != original_sync_module + assert type(patched_async.messages).__module__ != original_async_module print("SUCCESS") """) assert result.returncode == 0, f"Failed: {result.stderr}" assert "SUCCESS" in result.stdout - def test_patch_anthropic_idempotent(self): - """Multiple patch_anthropic() calls should be safe.""" + def test_anthropic_integration_setup_is_idempotent(self): result = run_in_subprocess(""" - from braintrust.wrappers.anthropic import patch_anthropic + import inspect + from braintrust.integrations.anthropic import AnthropicIntegration import anthropic - patch_anthropic() - first_class = anthropic.Anthropic + AnthropicIntegration.setup() + first_sync_init = inspect.getattr_static(anthropic.Anthropic, "__init__") + first_async_init = inspect.getattr_static(anthropic.AsyncAnthropic, "__init__") - patch_anthropic() # Second call - second_class = anthropic.Anthropic + AnthropicIntegration.setup() - assert first_class is second_class + assert inspect.getattr_static(anthropic.Anthropic, "__init__") is first_sync_init + assert inspect.getattr_static(anthropic.AsyncAnthropic, "__init__") is first_async_init print("SUCCESS") """) assert result.returncode == 0, f"Failed: {result.stderr}" assert "SUCCESS" in result.stdout - def test_patch_anthropic_creates_spans(self): - """patch_anthropic() should create spans when making API calls.""" + def test_anthropic_integration_setup_can_disable_specific_patchers(self): result = run_in_subprocess(""" - from braintrust.wrappers.anthropic import patch_anthropic - from braintrust.test_helpers import init_test_logger - from braintrust import logger - - # Set up memory logger - init_test_logger("test-auto") - with logger._internal_with_memory_background_logger() as memory_logger: - patch_anthropic() - - import anthropic - client = anthropic.Anthropic() + from braintrust.integrations.anthropic import AnthropicIntegration + import anthropic - # Make a call within a span context - import braintrust - with braintrust.start_span(name="test") as span: - try: - # This will fail without API key, but span should still be created - client.messages.create( - model="claude-3-5-haiku-latest", - max_tokens=100, - messages=[{"role": "user", "content": "hi"}], - ) - except Exception: - pass # Expected without API key + AnthropicIntegration.setup(disabled_patchers={"anthropic.init.async"}) + patched_sync = anthropic.Anthropic(api_key="test-key") + unpatched_async = anthropic.AsyncAnthropic(api_key="test-key") - # Check that spans were logged - spans = memory_logger.pop() - # Should have at least the parent span - assert len(spans) >= 1, f"Expected spans, got {spans}" - print("SUCCESS") + assert type(patched_sync.messages).__module__ == "braintrust.integrations.anthropic.tracing" + assert type(unpatched_async.messages).__module__.startswith("anthropic.") + print("SUCCESS") """) assert result.returncode == 0, f"Failed: {result.stderr}" assert "SUCCESS" in result.stdout -class TestPatchAnthropicSpans: - """VCR-based tests verifying that patch_anthropic() produces spans.""" - - @pytest.mark.vcr - def test_patch_anthropic_creates_spans(self, memory_logger): - """patch_anthropic() should create spans when making API calls.""" - from braintrust.wrappers.anthropic import patch_anthropic - - assert not memory_logger.pop() - - patch_anthropic() - client = anthropic.Anthropic() - response = client.messages.create( - model="claude-3-5-haiku-latest", - max_tokens=100, - messages=[{"role": "user", "content": "Say hi"}], - ) - assert response.content[0].text - - # Verify span was created - spans = memory_logger.pop() - assert len(spans) == 1 - span = spans[0] - assert span["metadata"]["provider"] == "anthropic" - assert "claude" in span["metadata"]["model"] - assert span["input"] - - -class TestPatchAnthropicAsyncSpans: - """VCR-based tests verifying that patch_anthropic() produces spans for async clients.""" - - @pytest.mark.vcr - @pytest.mark.asyncio - async def test_patch_anthropic_async_creates_spans(self, memory_logger): - """patch_anthropic() should create spans for async API calls.""" - from braintrust.wrappers.anthropic import patch_anthropic - - assert not memory_logger.pop() - - patch_anthropic() - client = anthropic.AsyncAnthropic() - response = await client.messages.create( - model="claude-3-5-haiku-latest", - max_tokens=100, - messages=[{"role": "user", "content": "Say hi async"}], - ) - assert response.content[0].text - - # Verify span was created - spans = memory_logger.pop() - assert len(spans) == 1 - span = spans[0] - assert span["metadata"]["provider"] == "anthropic" - assert "claude" in span["metadata"]["model"] - assert span["input"] - - class TestAutoInstrumentAnthropic: """Tests for auto_instrument() with Anthropic.""" def test_auto_instrument_anthropic(self): """Test auto_instrument patches Anthropic, creates spans, and uninstrument works.""" verify_autoinstrument_script("test_auto_anthropic.py") + + def test_auto_instrument_anthropic_patch_config(self): + verify_autoinstrument_script("test_auto_anthropic_patch_config.py") + + def test_auto_instrument_rejects_non_bool_option_for_openai(self): + result = run_in_subprocess(""" + from braintrust.auto import auto_instrument + from braintrust.integrations import IntegrationPatchConfig + + try: + auto_instrument(openai=IntegrationPatchConfig()) + except TypeError as exc: + assert "must be a bool" in str(exc) + print("SUCCESS") + else: + raise AssertionError("Expected TypeError") + """) + assert result.returncode == 0, f"Failed: {result.stderr}" + assert "SUCCESS" in result.stdout diff --git a/py/src/braintrust/wrappers/test_utils.py b/py/src/braintrust/wrappers/test_utils.py index 80d9d661..f91a3502 100644 --- a/py/src/braintrust/wrappers/test_utils.py +++ b/py/src/braintrust/wrappers/test_utils.py @@ -13,7 +13,7 @@ # Source directory paths (resolved to handle installed vs source locations) _SOURCE_DIR = Path(__file__).resolve().parent -AUTO_TEST_SCRIPTS_DIR = _SOURCE_DIR / "auto_test_scripts" +AUTO_TEST_SCRIPTS_DIR = _SOURCE_DIR.parent / "integrations" / "auto_test_scripts" # Cassettes dir can be overridden via env var for subprocess tests CASSETTES_DIR = Path(os.environ.get("BRAINTRUST_CASSETTES_DIR", _SOURCE_DIR / "cassettes")) @@ -34,7 +34,7 @@ def run_in_subprocess(code: str, timeout: int = 30, env: dict[str, str] | None = def verify_autoinstrument_script(script_name: str, timeout: int = 30) -> subprocess.CompletedProcess: - """Run a test script from the auto_test_scripts directory. + """Run a test script from the integrations auto_test_scripts directory. Raises AssertionError if the script exits with non-zero code. """