From 65a3be9225b5744bbaac187dfc0b438b87efb406 Mon Sep 17 00:00:00 2001 From: Abhijeet Prasad Date: Mon, 23 Mar 2026 17:22:17 -0700 Subject: [PATCH 1/4] ref(adk): migrate ADK wrapper to integrations API Convert the Google ADK (Agent Development Kit) instrumentation from the legacy wrappers pattern to the new integrations API introduced in #118. - Extract tracing helpers into integrations/adk/tracing.py - Create ADK patchers (agent, runner, flow, mcp_tool) in patchers.py - Add AdkIntegration class with proper registration - Move tests and cassettes from wrappers/adk/ to integrations/adk/ - Move auto_instrument test script to integrations/auto_test_scripts/ - Slim down wrappers/adk/__init__.py to delegate to integrations - Update noxfile.py with ADK integration test sessions - Update integrations base and registry for ADK support --- .agents/skills/sdk-integrations/SKILL.md | 235 +++--- .../skills/sdk-wrapper-migrations/SKILL.md | 178 +++++ py/noxfile.py | 13 +- py/src/braintrust/auto.py | 22 +- py/src/braintrust/integrations/__init__.py | 3 +- .../braintrust/integrations/adk/__init__.py | 63 ++ ...st_adk_agent_metadata_with_attachment.yaml | 0 ...adk_binary_data_attachment_conversion.yaml | 0 .../test_adk_braintrust_integration.yaml | 0 .../cassettes/test_adk_captures_metrics.yaml | 0 .../test_adk_complex_nested_schema.yaml | 0 .../test_adk_input_schema_serialization.yaml | 0 .../test_adk_max_tokens_captures_content.yaml | 0 .../test_adk_response_json_schema_dict.yaml | 0 .../test_adk_structured_output_pydantic.yaml | 0 .../integrations/adk/integration.py | 30 + .../braintrust/integrations/adk/patchers.py | 167 +++++ .../adk/test_adk.py | 55 +- .../adk/test_adk_mcp_tool.py | 56 +- py/src/braintrust/integrations/adk/tracing.py | 569 +++++++++++++++ .../auto_test_scripts/test_auto_adk.py | 16 + py/src/braintrust/integrations/base.py | 108 ++- py/src/braintrust/wrappers/adk/__init__.py | 684 +----------------- .../braintrust/wrappers/adk/test_auto_adk.py | 23 - 24 files changed, 1318 insertions(+), 904 deletions(-) create mode 100644 .agents/skills/sdk-wrapper-migrations/SKILL.md create mode 100644 py/src/braintrust/integrations/adk/__init__.py rename py/src/braintrust/{wrappers => integrations/adk}/cassettes/test_adk_agent_metadata_with_attachment.yaml (100%) rename py/src/braintrust/{wrappers => integrations/adk}/cassettes/test_adk_binary_data_attachment_conversion.yaml (100%) rename py/src/braintrust/{wrappers => integrations/adk}/cassettes/test_adk_braintrust_integration.yaml (100%) rename py/src/braintrust/{wrappers => integrations/adk}/cassettes/test_adk_captures_metrics.yaml (100%) rename py/src/braintrust/{wrappers => integrations/adk}/cassettes/test_adk_complex_nested_schema.yaml (100%) rename py/src/braintrust/{wrappers => integrations/adk}/cassettes/test_adk_input_schema_serialization.yaml (100%) rename py/src/braintrust/{wrappers => integrations/adk}/cassettes/test_adk_max_tokens_captures_content.yaml (100%) rename py/src/braintrust/{wrappers => integrations/adk}/cassettes/test_adk_response_json_schema_dict.yaml (100%) rename py/src/braintrust/{wrappers => integrations/adk}/cassettes/test_adk_structured_output_pydantic.yaml (100%) create mode 100644 py/src/braintrust/integrations/adk/integration.py create mode 100644 py/src/braintrust/integrations/adk/patchers.py rename py/src/braintrust/{wrappers => integrations}/adk/test_adk.py (96%) rename py/src/braintrust/{wrappers => integrations}/adk/test_adk_mcp_tool.py (85%) create mode 100644 py/src/braintrust/integrations/adk/tracing.py create mode 100644 py/src/braintrust/integrations/auto_test_scripts/test_auto_adk.py delete mode 100644 py/src/braintrust/wrappers/adk/test_auto_adk.py diff --git a/.agents/skills/sdk-integrations/SKILL.md b/.agents/skills/sdk-integrations/SKILL.md index b7b3aaa4..7e2fa498 100644 --- a/.agents/skills/sdk-integrations/SKILL.md +++ b/.agents/skills/sdk-integrations/SKILL.md @@ -1,204 +1,185 @@ --- name: sdk-integrations -description: Create or update a Braintrust Python SDK integration using the integrations API. Use when asked to add an integration, update an existing integration, add or update patchers, update auto_instrument, add integration tests, or work in py/src/braintrust/integrations/. +description: Create or update Braintrust Python SDK integrations built on the integrations API. Use for work in `py/src/braintrust/integrations/`, including new providers, patchers, tracing, `auto_instrument()` updates, integration exports, and integration tests. --- # SDK Integrations -SDK integrations define how Braintrust discovers a provider, patches it safely, and keeps provider-specific tracing local to that integration. Read the existing integration closest to your task before writing a new one. If there is no closer example, `py/src/braintrust/integrations/anthropic/` is a useful reference implementation. +Use this skill for integrations API work under `py/src/braintrust/integrations/`. -## Workflow - -1. Read the shared integration primitives and the closest provider example. -2. Choose the task shape: new provider, existing provider update, or `auto_instrument()` update. -3. Implement the smallest integration, patcher, tracing, and export changes needed. -4. Add or update VCR-backed integration tests and only re-record cassettes when behavior changed intentionally. -5. Run the narrowest provider session first, then expand to shared validation only if the change touched shared code. +Start from the nearest existing provider instead of designing from scratch: -## Commands +- ADK (`py/src/braintrust/integrations/adk/`) is the best reference for direct method patching, `target_module`, `CompositeFunctionWrapperPatcher`, and public `wrap_*()` helpers. +- Anthropic (`py/src/braintrust/integrations/anthropic/`) is the best reference for constructor patching with `FunctionWrapperPatcher`. -```bash -cd py && nox -s "test_(latest)" -cd py && nox -s "test_(latest)" -- -k "test_name" -cd py && nox -s "test_(latest)" -- --vcr-record=all -k "test_name" -cd py && make test-core -cd py && make lint -``` +## Workflow -## Creating or Updating an Integration +1. Read the shared primitives and the nearest provider example. +2. Decide whether the task is a new provider, an existing provider update, or an `auto_instrument()` change. +3. Change only the affected integration, patchers, tracing, exports, and tests. +4. Update tests and cassettes only where behavior changed intentionally. +5. Run the narrowest provider session first, then expand only if shared code changed. -### 1. Read the nearest existing implementation +## Read First -Always inspect these first: +Always read: - `py/src/braintrust/integrations/base.py` -- `py/src/braintrust/integrations/runtime.py` - `py/src/braintrust/integrations/versioning.py` -- `py/src/braintrust/integrations/config.py` - -Relevant example implementation: - -- `py/src/braintrust/integrations/anthropic/` - -Read these additional files only when the task needs them: - -- changing `auto_instrument()`: `py/src/braintrust/auto.py` and `py/src/braintrust/auto_test_scripts/test_auto_anthropic_patch_config.py` -- adding or updating VCR tests: `py/src/braintrust/conftest.py` and `py/src/braintrust/integrations/anthropic/test_anthropic.py` - -Then choose the path that matches the task: -- new provider: create `py/src/braintrust/integrations//` -- existing provider: read the provider package first and change only the affected patchers, tracing, tests, or exports -- `auto_instrument()` only: keep the integration package unchanged unless the option shape or patcher surface also changed +Read when relevant: -### 2. Create or extend the integration module +- `py/src/braintrust/auto.py` for `auto_instrument()` work +- `py/src/braintrust/conftest.py` for VCR behavior +- `py/src/braintrust/integrations/adk/test_adk.py` for integration test patterns +- `py/src/braintrust/integrations/auto_test_scripts/` for subprocess auto-instrument tests -For a new provider, create a package under `py/src/braintrust/integrations//`. +## Package Layout -For an existing provider, keep the module layout unless the current structure is actively causing problems. +Create new providers under `py/src/braintrust/integrations//`. Keep the existing layout for provider updates unless the current structure is the problem. Typical files: -- `__init__.py`: public exports for the integration type and any public helpers -- `integration.py`: the `BaseIntegration` subclass, patcher registration, and high-level orchestration -- `patchers.py`: one patcher per patch target, with version gating and existence checks close to the patch -- `tracing.py`: provider-specific span creation, metadata extraction, stream handling, and output normalization -- `test_.py`: integration tests for `wrap(...)`, `setup()`, sync/async behavior, streaming, and error handling -- `cassettes/`: recorded provider traffic for VCR-backed integration tests when the provider uses HTTP +- `__init__.py`: export the integration class, `setup_()`, and public `wrap_*()` helpers +- `integration.py`: define the `BaseIntegration` subclass and register patchers +- `patchers.py`: define patchers and `wrap_*()` helpers +- `tracing.py`: keep provider-specific tracing, stream handling, and normalization +- `test_.py`: keep provider behavior tests next to the integration +- `cassettes/`: keep VCR recordings next to the integration tests when the provider uses HTTP -### 3. Define the integration class +## Integration Rules -Implement a `BaseIntegration` subclass in `integration.py`. - -Set: +Keep `integration.py` thin. Set: - `name` - `import_names` -- `min_version` and `max_version` only when needed - `patchers` +- `min_version` and `max_version` only when needed -Keep the class focused on orchestration. Provider-specific tracing logic should stay in `tracing.py`. +Keep provider behavior in the provider package, not in shared integration code. Put span creation, metadata extraction, stream aggregation, error logging, and output normalization in `tracing.py`. -### 4. Add one patcher per coherent patch target +Preserve provider behavior. Do not let tracing-only code break the provider call. -Put patchers in `patchers.py`. +## Patcher Rules -Use `FunctionWrapperPatcher` when patching a single import path with `wrapt.wrap_function_wrapper`. Good examples: +Create one patcher per coherent patch target. If targets are unrelated, split them. -- constructor patchers like `ProviderClient.__init__` -- single API surfaces like `client.responses.create` -- one sync and one async constructor patcher instead of one patcher doing both +Use `FunctionWrapperPatcher` for one import path or one constructor/method surface, for example: -Keep patchers narrow. If you need to patch multiple unrelated targets, create multiple patchers rather than one large patcher. +- `ProviderClient.__init__` +- `client.responses.create` -Patchers are responsible for: +Use `CompositeFunctionWrapperPatcher` when several closely related targets should appear as one patcher, for example: -- stable patcher ids via `name` -- optional version gating -- existence checks -- idempotence through the base patcher marker +- sync and async variants of the same method +- the same function patched across multiple modules -### 5. Keep tracing provider-local +Set `target_module` when the patch target lives outside the module named by `import_names`, especially for optional or deep submodules. Failed `target_module` imports should cause the patcher to skip cleanly through `applies()`. -Put span creation, metadata extraction, stream aggregation, error logging, and output normalization in `tracing.py`. +Expose manual wrapping helpers through `wrap_target()`: -This layer should: +```python +def wrap_agent(Agent: Any) -> Any: + return AgentRunAsyncPatcher.wrap_target(Agent) +``` -- preserve provider behavior -- support sync, async, and streaming paths as needed -- avoid raising from tracing-only code when that would break the provider call +Use lower `priority` values only when ordering matters, such as context propagation before tracing. -If the provider has complex streaming internals, keep that logic local instead of forcing it into shared abstractions. +Patchers must provide: -### 6. Wire public exports +- stable `name` values +- version gating only when needed +- existence checks +- idempotence through the base patcher marker -Update public exports only as needed: +Use `IntegrationPatchConfig` only when users need patcher-level selection. Let `BaseIntegration.resolve_patchers()` reject unknown patcher ids instead of silently ignoring them. -- `py/src/braintrust/integrations/__init__.py` -- `py/src/braintrust/__init__.py` +## Patching Patterns -### 7. Update auto_instrument only if this integration should be auto-patched +Use constructor patching when the goal is to instrument future clients created by the provider SDK. Patch the constructor, then attach traced surfaces after the real constructor runs. -If the provider belongs in `braintrust.auto.auto_instrument()`, add a branch in `py/src/braintrust/auto.py`. +Use direct method patching with `target_module` when the provider exposes a flatter API and there is no useful constructor patch point. -Match the current pattern: +Keep public `wrap_*()` helpers in `patchers.py` and export them from the integration package. -- plain `bool` options for simple on/off integrations -- `IntegrationPatchConfig` only when users need patcher-level selection +## Versioning -## Tests +Prefer feature detection first and version checks second. -Keep integration tests with the integration package. +Use: -Provider behavior tests should use `@pytest.mark.vcr` whenever the provider uses network calls. Avoid mocks and fakes. +- `detect_module_version(...)` +- `version_in_range(...)` +- `version_matches_spec(...)` -Cover: +Do not add `packaging` just for integration routing. -- direct `wrap(...)` behavior -- `setup()` patching new clients -- sync behavior -- async behavior -- streaming behavior -- idempotence -- failure/error logging -- patcher selection if using `IntegrationPatchConfig` +## `auto_instrument()` -Preferred locations: +Update `py/src/braintrust/auto.py` only if the integration should be auto-patched. -- provider behavior tests: `py/src/braintrust/integrations//test_.py` -- version helper tests: `py/src/braintrust/integrations/test_versioning.py` -- auto-instrument subprocess tests: `py/src/braintrust/auto_test_scripts/` +Match the existing option shape: -If the provider uses VCR, keep cassettes next to the integration test file under `py/src/braintrust/integrations//cassettes/`. +- use plain `bool` for simple on/off integrations that do not use the integrations API +- use `InstrumentOption` for integrations API providers that support `IntegrationPatchConfig` -Only re-record cassettes when the behavior change is intentional. +For integrations API providers, use `_normalize_instrument_option()` and `_instrument_integration(...)` instead of adding a custom `_instrument_*` function: -Use mocks or fakes only for cases that are hard to drive through recorded provider traffic, such as narrowly scoped error injection, local version-routing logic, or patcher existence checks. +```python +enabled, config = _normalize_instrument_option("provider", provider) +if enabled: + results["provider"] = _instrument_integration(ProviderIntegration, patch_config=config) +``` -## Patterns +Add the integration import near the other integration imports in `auto.py`. -### Constructor patching +## Tests -If instrumenting future clients created by the SDK is the goal, patch constructors and attach traced surfaces after the real constructor runs. Anthropic is an example of this pattern. +Keep integration tests in the provider package. -### Patcher selection +Use `@pytest.mark.vcr` for real provider network behavior. Prefer recorded provider traffic over mocks or fakes. Use mocks or fakes only for cases that are hard to drive through recordings, such as: -Use `IntegrationPatchConfig` only when users benefit from enabling or disabling specific patchers. Validate unknown patcher ids through `BaseIntegration.resolve_patchers()` instead of silently ignoring them. +- narrow error injection +- local version-routing logic +- patcher existence checks -### Versioning +Cover the surfaces that changed: -Prefer feature detection first and version checks second. +- direct `wrap(...)` behavior +- `setup()` patching new clients +- sync behavior +- async behavior +- streaming behavior +- idempotence +- failure and error logging +- patcher selection when using `IntegrationPatchConfig` -Use: +Keep VCR cassettes in `py/src/braintrust/integrations//cassettes/`. Re-record them only for intentional behavior changes. -- `detect_module_version(...)` -- `version_in_range(...)` -- `version_matches_spec(...)` +## Commands -Do not add `packaging` just for integration routing. +```bash +cd py && nox -s "test_(latest)" +cd py && nox -s "test_(latest)" -- -k "test_name" +cd py && nox -s "test_(latest)" -- --vcr-record=all -k "test_name" +cd py && make test-core +cd py && make lint +``` ## Validation - Run the narrowest provider session first. -- Run `cd py && make test-core` if you changed shared integration code. +- Run `cd py && make test-core` if shared integration code changed. - Run `cd py && make lint` before handing off broader integration changes. -- If you changed `auto_instrument()`, run the relevant subprocess auto-instrument tests. - -## Done When - -- the provider package contains only the integration, patcher, tracing, export, and test changes required by the task -- provider behavior tests use VCR unless recorded traffic cannot cover the behavior -- cassette changes are present only when provider behavior changed intentionally -- the narrowest affected provider session passes -- `cd py && make test-core` has been run if shared integration code changed -- `cd py && make lint` has been run before handoff +- Run the relevant auto-instrument subprocess tests if `auto_instrument()` changed. -## Common Pitfalls +## Pitfalls -- Leaving provider behavior in `BaseIntegration` instead of the provider package. -- Combining multiple unrelated patch targets into one patcher. +- Moving provider-specific behavior into shared integration code. +- Combining unrelated targets into one patcher. - Forgetting async or streaming coverage. -- Defaulting to mocks or fakes when the provider flow can be covered with VCR. -- Moving tests but not moving their cassettes. - Adding patcher selection without tests for enabled and disabled cases. -- Editing `auto_instrument()` in a way that implies a registry exists when it does not. +- Re-recording cassettes when behavior did not intentionally change. +- Using `_normalize_bool_option()` for an integrations API provider. +- Adding a custom `_instrument_*` helper where `_instrument_integration()` already fits. +- Forgetting `target_module` for deep or optional submodule patch targets. diff --git a/.agents/skills/sdk-wrapper-migrations/SKILL.md b/.agents/skills/sdk-wrapper-migrations/SKILL.md new file mode 100644 index 00000000..f7c1852a --- /dev/null +++ b/.agents/skills/sdk-wrapper-migrations/SKILL.md @@ -0,0 +1,178 @@ +--- +name: sdk-wrapper-migrations +description: Migrate Braintrust Python SDK legacy wrapper implementations to the integrations API. Use when moving a provider from `py/src/braintrust/wrappers/` into `py/src/braintrust/integrations/`, preserving backward compatibility while relocating tracing, patchers, tests, cassettes, auto-instrument hooks, and test sessions. +--- + +# SDK Wrapper Migrations + +Use this skill when a provider already exists under `py/src/braintrust/wrappers/` and needs to be migrated to the integrations API. + +Use current repo examples, not old commit history: + +- `py/src/braintrust/integrations/adk/` for full integration package structure, test placement, auto-instrument coverage, and wrapper delegation +- `py/src/braintrust/integrations/anthropic/` for constructor patching and a minimal compatibility wrapper + +The target end state is: + +- provider logic lives in `py/src/braintrust/integrations//` +- tests and cassettes live with the integration +- `auto_instrument()` uses the integration +- the legacy wrapper becomes a thin compatibility layer + +## Read First + +Always read: + +- the existing legacy wrapper under `py/src/braintrust/wrappers//` +- `py/src/braintrust/integrations/anthropic/__init__.py` +- `py/src/braintrust/integrations/anthropic/integration.py` +- `py/src/braintrust/integrations/anthropic/patchers.py` +- `py/src/braintrust/integrations/anthropic/tracing.py` +- `py/src/braintrust/integrations/base.py` +- `py/src/braintrust/auto.py` +- `py/noxfile.py` + +Read when relevant: + +- `py/src/braintrust/integrations/auto_test_scripts/` +- the provider's existing wrapper tests and cassettes + +## Workflow + +1. Inventory the wrapper's public API, patch targets, tests, and cassettes. +2. Create an integration package that preserves the wrapper's behavior and public helper surface. +3. Move provider-specific tracing and patching into the integration package. +4. Move tests, cassettes, and auto-instrument subprocess coverage next to the integration. +5. Wire the integration into exports, `auto.py`, and `py/noxfile.py`. +6. Replace the wrapper with a thin re-export layer. +7. Run the narrowest provider session first, then expand if shared code changed. + +## Migration Checklist + +### 1. Preserve the public surface + +Before moving code, list the public names exposed by the wrapper: + +- setup functions +- `wrap_*()` helpers +- deprecated aliases that still need to work +- `__all__` + +The integration package should own that public surface after the migration. The wrapper should only delegate to it. + +### 2. Create the integration package + +Create `py/src/braintrust/integrations//` with the same split used by ADK: + +- `__init__.py`: public API, setup entry point, deprecated aliases if needed +- `integration.py`: `BaseIntegration` subclass and patcher registration +- `patchers.py`: one patcher per coherent patch target, plus public `wrap_*()` helpers +- `tracing.py`: provider-specific tracing, stream handling, normalization, and helper code +- `test_.py`: provider behavior tests +- `cassettes/`: VCR recordings when the provider uses HTTP + +Keep provider-specific behavior out of shared modules unless the provider truly needs a shared change. + +### 3. Move tracing and patching out of the wrapper + +Extract wrapper internals into: + +- `tracing.py` for spans, metadata extraction, stream aggregation, and output normalization +- `patchers.py` for patcher classes and `wrap_*()` helpers +- `integration.py` for the orchestration layer only + +Prefer one patcher per coherent patch target. Use composite patchers only when several related targets should be user-visible as one patcher. + +### 4. Preserve setup behavior + +The new integration package should preserve the wrapper's setup semantics: + +- keep the same setup function names where possible +- keep deprecated aliases that users may still import +- keep logger initialization or other setup-time side effects aligned with prior behavior + +The integration package is the new source of truth. Do not leave setup logic duplicated in the wrapper. + +### 5. Move tests and cassettes + +Move provider tests from `py/src/braintrust/wrappers/` into the integration package. + +Move or rename: + +- provider behavior tests to `py/src/braintrust/integrations//` +- cassettes to `py/src/braintrust/integrations//cassettes/` +- auto-instrument subprocess tests to `py/src/braintrust/integrations/auto_test_scripts/` + +Update imports and cassette paths during the move. Preserve coverage for: + +- direct `wrap_*()` behavior +- setup-time patching +- sync paths +- async paths +- streaming paths +- idempotence +- failure and logging behavior + +### 6. Wire repo-level integration points + +Update the minimum shared surfaces required by the migration: + +- `py/src/braintrust/integrations/__init__.py` +- `py/src/braintrust/auto.py` if the provider participates in `auto_instrument()` +- `py/noxfile.py` so provider sessions run against the integration tests + +Only change shared integration primitives when the provider actually needs it. + +### 7. Reduce the wrapper to compatibility imports + +After the integration package is working, replace the legacy wrapper implementation with a thin `__init__.py` that re-exports the migrated surface from `braintrust.integrations.`. + +Keep `__all__` aligned with the pre-migration public API. Do not leave business logic, tracing helpers, or patchers behind in the wrapper package. + +## Current Examples + +Use ADK as the main structural reference: + +- tracing moved into `py/src/braintrust/integrations/adk/tracing.py` +- patchers moved into `py/src/braintrust/integrations/adk/patchers.py` +- orchestration moved into `py/src/braintrust/integrations/adk/integration.py` +- public exports live in `py/src/braintrust/integrations/adk/__init__.py` +- wrapper tests and cassettes moved under `py/src/braintrust/integrations/adk/` +- auto-instrument subprocess coverage moved to `py/src/braintrust/integrations/auto_test_scripts/test_auto_adk.py` +- `py/src/braintrust/wrappers/adk/__init__.py` became a thin compatibility layer + +Use Anthropic as the compact constructor-patching reference: + +- `py/src/braintrust/integrations/anthropic/integration.py` registers sync and async constructor patchers +- `py/src/braintrust/integrations/anthropic/patchers.py` keeps one patcher per constructor target +- `py/src/braintrust/wrappers/anthropic.py` is a minimal compatibility re-export + +Match those patterns unless the provider has a clear reason to differ. + +## Commands + +```bash +cd py && nox -s "test_(latest)" +cd py && nox -s "test_(latest)" -- -k "test_name" +cd py && nox -s "test_(latest)" -- --vcr-record=all -k "test_name" +cd py && make test-core +cd py && make lint +``` + +## Validation + +- Run the narrowest provider session first. +- Run `cd py && make test-core` if shared integration code changed. +- Run `cd py && make lint` before handoff when the migration touches shared files. +- Run the relevant auto-instrument subprocess tests if `auto.py` changed. +- Verify the old wrapper import path still works through compatibility re-exports. + +## Pitfalls + +- Copying wrapper code into the integration package without restructuring it around `integration.py`, `patchers.py`, and `tracing.py`. +- Leaving real logic behind in the wrapper after the migration. +- Breaking deprecated aliases or `__all__` exports that users still import. +- Moving tests without moving their cassettes or auto-instrument scripts. +- Forgetting to update `py/noxfile.py` to point at the new integration test paths. +- Changing shared integration code more broadly than the provider requires. +- Re-recording cassettes when behavior did not intentionally change. diff --git a/py/noxfile.py b/py/noxfile.py index bff911db..4c99cb00 100644 --- a/py/noxfile.py +++ b/py/noxfile.py @@ -41,8 +41,6 @@ def _pinned_python_version(): SRC_DIR = "braintrust" WRAPPER_DIR = "braintrust/wrappers" INTEGRATION_DIR = "braintrust/integrations" -INTEGRATION_AUTO_TEST_DIR = "braintrust/integrations/auto_test_scripts" -ANTHROPIC_INTEGRATION_DIR = "braintrust/integrations/anthropic" CONTRIB_DIR = "braintrust/contrib" DEVSERVER_DIR = "braintrust/devserver" @@ -198,8 +196,8 @@ def test_google_adk(session, version): """Test Google ADK integration.""" _install_test_deps(session) _install(session, "google-adk", version) - _run_tests(session, f"{WRAPPER_DIR}/adk/test_adk.py") - _run_tests(session, f"{WRAPPER_DIR}/adk/test_adk_mcp_tool.py") + _run_tests(session, f"{INTEGRATION_DIR}/adk/test_adk.py") + _run_tests(session, f"{INTEGRATION_DIR}/adk/test_adk_mcp_tool.py") _run_core_tests(session) @@ -407,7 +405,12 @@ def _run_core_tests(session): _run_tests( session, SRC_DIR, - ignore_paths=[WRAPPER_DIR, INTEGRATION_AUTO_TEST_DIR, ANTHROPIC_INTEGRATION_DIR, CONTRIB_DIR, DEVSERVER_DIR], + ignore_paths=[ + WRAPPER_DIR, + INTEGRATION_DIR, + CONTRIB_DIR, + DEVSERVER_DIR, + ], ) diff --git a/py/src/braintrust/auto.py b/py/src/braintrust/auto.py index 6c15b653..13ea636d 100644 --- a/py/src/braintrust/auto.py +++ b/py/src/braintrust/auto.py @@ -9,7 +9,7 @@ import logging from contextlib import contextmanager -from braintrust.integrations import AnthropicIntegration, IntegrationPatchConfig +from braintrust.integrations import ADKIntegration, AnthropicIntegration, IntegrationPatchConfig __all__ = ["auto_instrument"] @@ -39,7 +39,7 @@ def auto_instrument( agno: bool = True, claude_agent_sdk: bool = True, dspy: bool = True, - adk: bool = True, + adk: InstrumentOption = True, ) -> dict[str, bool]: """ Auto-instrument supported AI/ML libraries for Braintrust tracing. @@ -109,14 +109,14 @@ def auto_instrument( results = {} openai_enabled = _normalize_bool_option("openai", openai) - anthropic_enabled, anthropic_config = _normalize_anthropic_option(anthropic) + anthropic_enabled, anthropic_config = _normalize_instrument_option("anthropic", anthropic) litellm_enabled = _normalize_bool_option("litellm", litellm) pydantic_ai_enabled = _normalize_bool_option("pydantic_ai", pydantic_ai) google_genai_enabled = _normalize_bool_option("google_genai", google_genai) agno_enabled = _normalize_bool_option("agno", agno) claude_agent_sdk_enabled = _normalize_bool_option("claude_agent_sdk", claude_agent_sdk) dspy_enabled = _normalize_bool_option("dspy", dspy) - adk_enabled = _normalize_bool_option("adk", adk) + adk_enabled, adk_config = _normalize_instrument_option("adk", adk) if openai_enabled: results["openai"] = _instrument_openai() @@ -135,7 +135,7 @@ def auto_instrument( if dspy_enabled: results["dspy"] = _instrument_dspy() if adk_enabled: - results["adk"] = _instrument_adk() + results["adk"] = _instrument_integration(ADKIntegration, patch_config=adk_config) return results @@ -164,7 +164,7 @@ def _normalize_bool_option(name: str, option: bool) -> bool: raise TypeError(f"auto_instrument option {name!r} must be a bool, got {type(option).__name__}") -def _normalize_anthropic_option(option: InstrumentOption) -> tuple[bool, IntegrationPatchConfig | None]: +def _normalize_instrument_option(name: str, option: InstrumentOption) -> tuple[bool, IntegrationPatchConfig | None]: if isinstance(option, bool): return option, None @@ -172,7 +172,7 @@ def _normalize_anthropic_option(option: InstrumentOption) -> tuple[bool, Integra return True, option raise TypeError( - f"auto_instrument option 'anthropic' must be a bool or IntegrationPatchConfig, got {type(option).__name__}" + f"auto_instrument option {name} must be a bool or IntegrationPatchConfig, got {type(option).__name__}" ) @@ -222,11 +222,3 @@ def _instrument_dspy() -> bool: return patch_dspy() return False - - -def _instrument_adk() -> bool: - with _try_patch(): - from braintrust.wrappers.adk import setup_adk - - return setup_adk() - return False diff --git a/py/src/braintrust/integrations/__init__.py b/py/src/braintrust/integrations/__init__.py index 1dddbd91..72aab3da 100644 --- a/py/src/braintrust/integrations/__init__.py +++ b/py/src/braintrust/integrations/__init__.py @@ -1,5 +1,6 @@ +from .adk import ADKIntegration from .anthropic import AnthropicIntegration from .base import IntegrationPatchConfig -__all__ = ["AnthropicIntegration", "IntegrationPatchConfig"] +__all__ = ["ADKIntegration", "AnthropicIntegration", "IntegrationPatchConfig"] diff --git a/py/src/braintrust/integrations/adk/__init__.py b/py/src/braintrust/integrations/adk/__init__.py new file mode 100644 index 00000000..cddb0f7f --- /dev/null +++ b/py/src/braintrust/integrations/adk/__init__.py @@ -0,0 +1,63 @@ +"""Braintrust integration for Google ADK.""" + +import logging +import warnings + +from braintrust.logger import NOOP_SPAN, current_span, init_logger + +from .integration import ADKIntegration +from .patchers import ( + wrap_agent, + wrap_flow, + wrap_mcp_tool, + wrap_runner, +) + + +logger = logging.getLogger(__name__) + +__all__ = [ + "ADKIntegration", + "_create_thread_wrapper", + "setup_adk", + "setup_braintrust", + "wrap_agent", + "wrap_runner", + "wrap_flow", + "wrap_mcp_tool", +] + + +def setup_braintrust(*args, **kwargs): + warnings.warn("setup_braintrust is deprecated, use setup_adk instead", DeprecationWarning, stacklevel=2) + return setup_adk(*args, **kwargs) + + +def setup_adk( + api_key: str | None = None, + project_id: str | None = None, + project_name: str | None = None, + SpanProcessor: type | None = None, +) -> bool: + """ + Setup Braintrust integration with Google ADK. Will automatically patch Google ADK agents, runners, flows, and MCP tools for automatic tracing. + + If you prefer manual patching take a look at `wrap_agent`, `wrap_runner`, `wrap_flow`, and `wrap_mcp_tool`. + + Args: + api_key (Optional[str]): Braintrust API key. + project_id (Optional[str]): Braintrust project ID. + project_name (Optional[str]): Braintrust project name. + SpanProcessor (Optional[type]): Deprecated parameter. + + Returns: + bool: True if setup was successful, False otherwise. + """ + if SpanProcessor is not None: + warnings.warn("SpanProcessor parameter is deprecated and will be ignored", DeprecationWarning, stacklevel=2) + + span = current_span() + if span == NOOP_SPAN: + init_logger(project=project_name, api_key=api_key, project_id=project_id) + + return ADKIntegration.setup() diff --git a/py/src/braintrust/wrappers/cassettes/test_adk_agent_metadata_with_attachment.yaml b/py/src/braintrust/integrations/adk/cassettes/test_adk_agent_metadata_with_attachment.yaml similarity index 100% rename from py/src/braintrust/wrappers/cassettes/test_adk_agent_metadata_with_attachment.yaml rename to py/src/braintrust/integrations/adk/cassettes/test_adk_agent_metadata_with_attachment.yaml diff --git a/py/src/braintrust/wrappers/cassettes/test_adk_binary_data_attachment_conversion.yaml b/py/src/braintrust/integrations/adk/cassettes/test_adk_binary_data_attachment_conversion.yaml similarity index 100% rename from py/src/braintrust/wrappers/cassettes/test_adk_binary_data_attachment_conversion.yaml rename to py/src/braintrust/integrations/adk/cassettes/test_adk_binary_data_attachment_conversion.yaml diff --git a/py/src/braintrust/wrappers/cassettes/test_adk_braintrust_integration.yaml b/py/src/braintrust/integrations/adk/cassettes/test_adk_braintrust_integration.yaml similarity index 100% rename from py/src/braintrust/wrappers/cassettes/test_adk_braintrust_integration.yaml rename to py/src/braintrust/integrations/adk/cassettes/test_adk_braintrust_integration.yaml diff --git a/py/src/braintrust/wrappers/cassettes/test_adk_captures_metrics.yaml b/py/src/braintrust/integrations/adk/cassettes/test_adk_captures_metrics.yaml similarity index 100% rename from py/src/braintrust/wrappers/cassettes/test_adk_captures_metrics.yaml rename to py/src/braintrust/integrations/adk/cassettes/test_adk_captures_metrics.yaml diff --git a/py/src/braintrust/wrappers/cassettes/test_adk_complex_nested_schema.yaml b/py/src/braintrust/integrations/adk/cassettes/test_adk_complex_nested_schema.yaml similarity index 100% rename from py/src/braintrust/wrappers/cassettes/test_adk_complex_nested_schema.yaml rename to py/src/braintrust/integrations/adk/cassettes/test_adk_complex_nested_schema.yaml diff --git a/py/src/braintrust/wrappers/cassettes/test_adk_input_schema_serialization.yaml b/py/src/braintrust/integrations/adk/cassettes/test_adk_input_schema_serialization.yaml similarity index 100% rename from py/src/braintrust/wrappers/cassettes/test_adk_input_schema_serialization.yaml rename to py/src/braintrust/integrations/adk/cassettes/test_adk_input_schema_serialization.yaml diff --git a/py/src/braintrust/wrappers/cassettes/test_adk_max_tokens_captures_content.yaml b/py/src/braintrust/integrations/adk/cassettes/test_adk_max_tokens_captures_content.yaml similarity index 100% rename from py/src/braintrust/wrappers/cassettes/test_adk_max_tokens_captures_content.yaml rename to py/src/braintrust/integrations/adk/cassettes/test_adk_max_tokens_captures_content.yaml diff --git a/py/src/braintrust/wrappers/cassettes/test_adk_response_json_schema_dict.yaml b/py/src/braintrust/integrations/adk/cassettes/test_adk_response_json_schema_dict.yaml similarity index 100% rename from py/src/braintrust/wrappers/cassettes/test_adk_response_json_schema_dict.yaml rename to py/src/braintrust/integrations/adk/cassettes/test_adk_response_json_schema_dict.yaml diff --git a/py/src/braintrust/wrappers/cassettes/test_adk_structured_output_pydantic.yaml b/py/src/braintrust/integrations/adk/cassettes/test_adk_structured_output_pydantic.yaml similarity index 100% rename from py/src/braintrust/wrappers/cassettes/test_adk_structured_output_pydantic.yaml rename to py/src/braintrust/integrations/adk/cassettes/test_adk_structured_output_pydantic.yaml diff --git a/py/src/braintrust/integrations/adk/integration.py b/py/src/braintrust/integrations/adk/integration.py new file mode 100644 index 00000000..78bcb4c1 --- /dev/null +++ b/py/src/braintrust/integrations/adk/integration.py @@ -0,0 +1,30 @@ +"""ADK integration — orchestration class and setup entry-point.""" + +import logging + +from braintrust.integrations.base import BaseIntegration + +from .patchers import ( + AgentRunAsyncPatcher, + FlowRunAsyncPatcher, + McpToolPatcher, + RunnerRunSyncPatcher, + ThreadBridgePatcher, +) + + +logger = logging.getLogger(__name__) + + +class ADKIntegration(BaseIntegration): + """Braintrust instrumentation for Google ADK (Agent Development Kit).""" + + name = "adk" + import_names = ("google.adk",) + patchers = ( + ThreadBridgePatcher, + AgentRunAsyncPatcher, + RunnerRunSyncPatcher, + FlowRunAsyncPatcher, + McpToolPatcher, + ) diff --git a/py/src/braintrust/integrations/adk/patchers.py b/py/src/braintrust/integrations/adk/patchers.py new file mode 100644 index 00000000..4c140ca3 --- /dev/null +++ b/py/src/braintrust/integrations/adk/patchers.py @@ -0,0 +1,167 @@ +"""ADK patchers — one patcher per coherent patch target.""" + +from typing import Any, ClassVar + +from braintrust.integrations.base import CompositeFunctionWrapperPatcher, FunctionWrapperPatcher + +from .tracing import ( + _agent_run_async_wrapper, + _create_thread_wrapper, + _flow_call_llm_async_wrapper, + _flow_run_async_wrapper, + _mcp_tool_run_async_wrapper_async, + _runner_run_async_wrapper, + _runner_run_wrapper, +) + + +# --------------------------------------------------------------------------- +# Agent patcher +# --------------------------------------------------------------------------- + + +class AgentRunAsyncPatcher(FunctionWrapperPatcher): + """Patch ``BaseAgent.run_async`` for tracing.""" + + name = "adk.agent.run_async" + target_module = "google.adk.agents" + target_path = "BaseAgent.run_async" + wrapper = _agent_run_async_wrapper + + +# --------------------------------------------------------------------------- +# Runner patchers (sync + async) +# --------------------------------------------------------------------------- + + +class _RunnerRunSyncSubPatcher(FunctionWrapperPatcher): + """Patch ``Runner.run`` (sync generator).""" + + name = "adk.runner.run.sync" + target_module = "google.adk.runners" + target_path = "Runner.run" + wrapper = _runner_run_wrapper + + +class _RunnerRunAsyncSubPatcher(FunctionWrapperPatcher): + """Patch ``Runner.run_async`` (async generator).""" + + name = "adk.runner.run.async" + target_module = "google.adk.runners" + target_path = "Runner.run_async" + wrapper = _runner_run_async_wrapper + + +class RunnerRunSyncPatcher(CompositeFunctionWrapperPatcher): + """Patch ``Runner.run`` (sync) and ``Runner.run_async`` for tracing.""" + + name = "adk.runner.run" + sub_patchers = (_RunnerRunSyncSubPatcher, _RunnerRunAsyncSubPatcher) + + +# --------------------------------------------------------------------------- +# Flow patchers +# --------------------------------------------------------------------------- + + +class _FlowRunAsyncSubPatcher(FunctionWrapperPatcher): + """Patch ``BaseLlmFlow.run_async``.""" + + name = "adk.flow.run_async.run" + target_module = "google.adk.flows.llm_flows.base_llm_flow" + target_path = "BaseLlmFlow.run_async" + wrapper = _flow_run_async_wrapper + + +class _FlowCallLlmAsyncSubPatcher(FunctionWrapperPatcher): + """Patch ``BaseLlmFlow._call_llm_async``.""" + + name = "adk.flow.run_async.call_llm" + target_module = "google.adk.flows.llm_flows.base_llm_flow" + target_path = "BaseLlmFlow._call_llm_async" + wrapper = _flow_call_llm_async_wrapper + + +class FlowRunAsyncPatcher(CompositeFunctionWrapperPatcher): + """Patch ``BaseLlmFlow.run_async`` and ``_call_llm_async`` for tracing.""" + + name = "adk.flow.run_async" + sub_patchers = (_FlowRunAsyncSubPatcher, _FlowCallLlmAsyncSubPatcher) + + +# --------------------------------------------------------------------------- +# Thread-bridge patchers +# --------------------------------------------------------------------------- + + +class _ThreadBridgePlatformSubPatcher(FunctionWrapperPatcher): + """Patch ``google.adk.platform.thread.create_thread`` for context propagation.""" + + name = "adk.thread_bridge.platform" + target_module = "google.adk.platform.thread" + target_path = "create_thread" + wrapper = _create_thread_wrapper + + +class _ThreadBridgeRunnersSubPatcher(FunctionWrapperPatcher): + """Patch ``google.adk.runners.create_thread`` for context propagation.""" + + name = "adk.thread_bridge.runners" + target_module = "google.adk.runners" + target_path = "create_thread" + wrapper = _create_thread_wrapper + + +class ThreadBridgePatcher(CompositeFunctionWrapperPatcher): + """Patch ``create_thread`` in ADK platform and runners for context propagation.""" + + name = "adk.thread_bridge" + priority: ClassVar[int] = 50 # run before other patchers so context propagates + sub_patchers = (_ThreadBridgePlatformSubPatcher, _ThreadBridgeRunnersSubPatcher) + + +# --------------------------------------------------------------------------- +# MCP tool patcher +# --------------------------------------------------------------------------- + + +class McpToolPatcher(FunctionWrapperPatcher): + """Patch ``McpTool.run_async`` for tracing (optional – MCP may not be installed).""" + + name = "adk.mcp_tool" + target_module = "google.adk.tools.mcp_tool.mcp_tool" + target_path = "McpTool.run_async" + wrapper = _mcp_tool_run_async_wrapper_async + + +# --------------------------------------------------------------------------- +# Public wrap_*() helpers — thin wrappers around patcher.wrap_target() +# --------------------------------------------------------------------------- + + +def wrap_agent(Agent: Any) -> Any: + """Manually patch an agent class for tracing.""" + return AgentRunAsyncPatcher.wrap_target(Agent) + + +def wrap_runner(Runner: Any) -> Any: + """Manually patch a runner class for tracing.""" + return RunnerRunSyncPatcher.wrap_target(Runner) + + +def wrap_flow(Flow: Any) -> Any: + """Manually patch a flow class for tracing.""" + return FlowRunAsyncPatcher.wrap_target(Flow) + + +def wrap_mcp_tool(McpTool: Any) -> Any: + """Manually patch an MCP tool class for tracing. + + Creates Braintrust spans for each MCP tool call, capturing: + - Tool name + - Input arguments + - Output results + - Execution time + - Errors if they occur + """ + return McpToolPatcher.wrap_target(McpTool) diff --git a/py/src/braintrust/wrappers/adk/test_adk.py b/py/src/braintrust/integrations/adk/test_adk.py similarity index 96% rename from py/src/braintrust/wrappers/adk/test_adk.py rename to py/src/braintrust/integrations/adk/test_adk.py index 4462d89d..be6f97e6 100644 --- a/py/src/braintrust/wrappers/adk/test_adk.py +++ b/py/src/braintrust/integrations/adk/test_adk.py @@ -5,9 +5,10 @@ import pytest from braintrust import logger from braintrust.bt_json import bt_safe_deep_copy +from braintrust.integrations.adk import setup_adk +from braintrust.integrations.adk.tracing import _create_thread_wrapper from braintrust.logger import Attachment from braintrust.test_helpers import init_test_logger -from braintrust.wrappers.adk import _wrap_create_thread, setup_adk from google.adk import Agent @@ -36,7 +37,7 @@ def before_record_request(request): return { "record_mode": record_mode, - "cassette_library_dir": str(Path(__file__).parent.parent / "cassettes"), + "cassette_library_dir": str(Path(__file__).parent / "cassettes"), "filter_headers": [ "authorization", "x-goog-api-key", @@ -110,22 +111,20 @@ async def generate_content_async(self, llm_request: LlmRequest, stream: bool = F assert thread_root == parent_span.root_span_id -def test_wrap_create_thread_exception_does_not_double_invoke_target(): +def test_create_thread_wrapper_exception_does_not_double_invoke_target(): """Regression test: target exceptions must not cause a second invocation.""" call_count = 0 def create_thread(target, *args, **kwargs): return target(*args, **kwargs) - wrapped_create_thread = _wrap_create_thread(create_thread) - def target(): nonlocal call_count call_count += 1 raise RuntimeError("boom") with pytest.raises(RuntimeError, match="boom"): - wrapped_create_thread(target) + _create_thread_wrapper(create_thread, None, (target,), {}) assert call_count == 1 @@ -290,8 +289,8 @@ async def test_adk_max_tokens_captures_content(memory_logger): def test_serialize_content_with_binary_data(): """Test that _serialize_content converts binary data to Attachment references.""" + from braintrust.integrations.adk.tracing import _serialize_content, _serialize_part from braintrust.logger import Attachment - from braintrust.wrappers.adk import _serialize_content, _serialize_part # Create a minimal PNG image (1x1 red pixel) minimal_png = ( @@ -364,7 +363,7 @@ def __init__(self, parts, role): def test_serialize_part_with_file_data(): """Test that _serialize_part handles file_data (file references) correctly.""" - from braintrust.wrappers.adk import _serialize_part + from braintrust.integrations.adk.tracing import _serialize_part class MockFileData: def __init__(self, file_uri, mime_type): @@ -387,7 +386,7 @@ def __init__(self, file_data=None, text=None): def test_serialize_part_with_dict(): """Test that _serialize_part handles dict input correctly.""" - from braintrust.wrappers.adk import _serialize_part + from braintrust.integrations.adk.tracing import _serialize_part # Test that dicts pass through unchanged dict_part = {"text": "Hello", "custom": "field"} @@ -397,7 +396,7 @@ def test_serialize_part_with_dict(): def test_serialize_content_with_none(): """Test that _serialize_content handles None correctly.""" - from braintrust.wrappers.adk import _serialize_content + from braintrust.integrations.adk.tracing import _serialize_content result = _serialize_content(None) assert result is None, "None should serialize to None" @@ -574,7 +573,7 @@ async def test_adk_captures_metrics(memory_logger): def test_determine_llm_call_type_direct_response(): """Test that _determine_llm_call_type returns 'direct_response' when tools are available but not used.""" - from braintrust.wrappers.adk import _determine_llm_call_type + from braintrust.integrations.adk.tracing import _determine_llm_call_type # Request with tools available llm_request = { @@ -603,7 +602,7 @@ def test_determine_llm_call_type_direct_response(): def test_determine_llm_call_type_tool_selection(): """Test that _determine_llm_call_type returns 'tool_selection' when LLM calls a tool.""" - from braintrust.wrappers.adk import _determine_llm_call_type + from braintrust.integrations.adk.tracing import _determine_llm_call_type # Request with tools available llm_request = { @@ -633,7 +632,7 @@ def test_determine_llm_call_type_tool_selection(): def test_determine_llm_call_type_tool_selection_snake_case(): """Test that _determine_llm_call_type handles snake_case function_call.""" - from braintrust.wrappers.adk import _determine_llm_call_type + from braintrust.integrations.adk.tracing import _determine_llm_call_type llm_request = { "config": {"tools": [{"function_declarations": [{"name": "search"}]}]}, @@ -654,7 +653,7 @@ def test_determine_llm_call_type_tool_selection_snake_case(): def test_determine_llm_call_type_response_generation(): """Test that _determine_llm_call_type returns 'response_generation' after tool execution.""" - from braintrust.wrappers.adk import _determine_llm_call_type + from braintrust.integrations.adk.tracing import _determine_llm_call_type # Request with function_response in history llm_request = { @@ -680,7 +679,7 @@ def test_determine_llm_call_type_response_generation(): def test_determine_llm_call_type_no_tools(): """Test that _determine_llm_call_type returns 'direct_response' when no tools configured.""" - from braintrust.wrappers.adk import _determine_llm_call_type + from braintrust.integrations.adk.tracing import _determine_llm_call_type llm_request = { "config": {}, @@ -697,7 +696,7 @@ def test_determine_llm_call_type_no_tools(): def test_determine_llm_call_type_no_response(): """Test that _determine_llm_call_type handles missing model_response gracefully.""" - from braintrust.wrappers.adk import _determine_llm_call_type + from braintrust.integrations.adk.tracing import _determine_llm_call_type llm_request = { "config": {"tools": [{"function_declarations": [{"name": "tool1"}]}]}, @@ -725,7 +724,7 @@ async def test_llm_call_span_wraps_child_spans(memory_logger): from unittest.mock import MagicMock from braintrust import current_span, start_span - from braintrust.wrappers.adk import wrap_flow + from braintrust.integrations.adk import wrap_flow # Clear any existing logs memory_logger.pop() @@ -817,9 +816,9 @@ async def test_async_context_preservation_across_yields(): that occur when async generators yield control and resume in different contexts. """ import asyncio + from contextlib import aclosing from braintrust import start_span - from braintrust.wrappers.adk import aclosing # Initialize logger init_test_logger("test-context") @@ -1136,7 +1135,7 @@ class Person(BaseModel): @pytest.mark.asyncio async def test_serialize_config_handles_all_schema_fields(): """Test that _serialize_config handles all 4 schema fields.""" - from braintrust.wrappers.adk import _serialize_config + from braintrust.integrations.adk.tracing import _serialize_config class TestSchema(BaseModel): value: str = Field(description="Test value") @@ -1170,7 +1169,7 @@ class TestSchema(BaseModel): @pytest.mark.asyncio async def test_serialize_config_handles_non_pydantic(): """Test that _serialize_config handles non-Pydantic values gracefully.""" - from braintrust.wrappers.adk import _serialize_config + from braintrust.integrations.adk.tracing import _serialize_config # Test with non-Pydantic values config = {"response_schema": "not a pydantic model", "other_field": {"key": "value"}} @@ -1186,7 +1185,7 @@ async def test_serialize_config_handles_non_pydantic(): @pytest.mark.asyncio async def test_serialize_pydantic_schema_direct(): """Test _serialize_pydantic_schema directly with various inputs.""" - from braintrust.wrappers.adk import _serialize_pydantic_schema + from braintrust.integrations.adk.tracing import _serialize_pydantic_schema class SimpleSchema(BaseModel): name: str = Field(description="A name") @@ -1219,7 +1218,7 @@ class NotPydantic: @pytest.mark.asyncio async def test_bt_safe_deep_copy_never_raises(): """Test that bt_safe_deep_copy never raises exceptions.""" - from braintrust.wrappers.adk import bt_safe_deep_copy + from braintrust.bt_json import bt_safe_deep_copy class BrokenModel: def model_dump(self): @@ -1377,7 +1376,7 @@ async def test_adk_response_json_schema_dict(memory_logger): @pytest.mark.asyncio async def test_serialize_config_preserves_none(): """Test that _serialize_config returns None when config is None (not empty dict).""" - from braintrust.wrappers.adk import _serialize_config + from braintrust.integrations.adk.tracing import _serialize_config # None should be preserved as None, not converted to {} result = _serialize_config(None) @@ -1494,3 +1493,13 @@ async def test_adk_bytes_and_attachment_in_structure(): assert "binary_data" in result assert "nested" in result assert "more_bytes" in result["nested"] + + +class TestAutoInstrumentADK: + """Tests for auto_instrument() with Google ADK.""" + + def test_auto_instrument_adk(self): + """Test auto_instrument patches ADK classes and is idempotent.""" + from braintrust.wrappers.test_utils import verify_autoinstrument_script + + verify_autoinstrument_script("test_auto_adk.py") diff --git a/py/src/braintrust/wrappers/adk/test_adk_mcp_tool.py b/py/src/braintrust/integrations/adk/test_adk_mcp_tool.py similarity index 85% rename from py/src/braintrust/wrappers/adk/test_adk_mcp_tool.py rename to py/src/braintrust/integrations/adk/test_adk_mcp_tool.py index 5894c5b6..c58ec190 100644 --- a/py/src/braintrust/wrappers/adk/test_adk_mcp_tool.py +++ b/py/src/braintrust/integrations/adk/test_adk_mcp_tool.py @@ -3,7 +3,7 @@ from unittest.mock import AsyncMock, MagicMock, patch import pytest -from braintrust.wrappers.adk import setup_adk, wrap_mcp_tool +from braintrust.integrations.adk import setup_adk, wrap_mcp_tool @pytest.mark.asyncio @@ -18,16 +18,17 @@ async def run_async(self, *, args, tool_context): # Wrap the class wrapped_class = wrap_mcp_tool(MockMcpTool) - # Verify it's marked as patched - assert hasattr(wrapped_class, "_braintrust_patched") - assert wrapped_class._braintrust_patched is True # pylint: disable=no-member + # Verify it's marked as patched via the patcher marker + from braintrust.integrations.adk.patchers import McpToolPatcher + + assert getattr(wrapped_class, McpToolPatcher.patch_marker_attr(), False) @pytest.mark.asyncio async def test_mcp_tool_execution_creates_span(): """Test that MCP tool execution creates proper trace spans.""" - with patch("braintrust.wrappers.adk.start_span") as mock_start_span: + with patch("braintrust.integrations.adk.tracing.start_span") as mock_start_span: # Setup mock span mock_span = MagicMock() mock_span.__enter__ = MagicMock(return_value=mock_span) @@ -69,7 +70,7 @@ async def test_mcp_tool_span_captures_tool_info(): """Test that MCP tool spans capture tool name, args, and results.""" from braintrust.span_types import SpanTypeAttribute - with patch("braintrust.wrappers.adk.start_span") as mock_start_span: + with patch("braintrust.integrations.adk.tracing.start_span") as mock_start_span: mock_span = MagicMock() mock_span.__enter__ = MagicMock(return_value=mock_span) mock_span.__exit__ = MagicMock(return_value=False) @@ -120,7 +121,7 @@ async def run_async(self, *, args, tool_context): @pytest.mark.asyncio async def test_mcp_tool_error_handling(): """Test that MCP tool errors are captured in spans.""" - with patch("braintrust.wrappers.adk.start_span") as mock_start_span: + with patch("braintrust.integrations.adk.tracing.start_span") as mock_start_span: mock_span = MagicMock() mock_span.__enter__ = MagicMock(return_value=mock_span) mock_span.__exit__ = MagicMock(return_value=False) @@ -152,35 +153,28 @@ async def run_async(self, *, args, tool_context): @pytest.mark.asyncio async def test_setup_adk_patches_mcp_tool(): - """Test that setup_adk automatically patches McpTool.""" - MockMcpTool = MagicMock() + """Test that setup_adk automatically patches McpTool via ADKIntegration.""" + result = setup_adk(project_name="test") + assert result is True - with patch("braintrust.wrappers.adk.init_logger"): - with patch("braintrust.wrappers.adk.wrap_mcp_tool") as mock_wrap: - with patch("google.adk.tools.mcp_tool.mcp_tool") as mock_mcp_module: - mock_mcp_module.McpTool = MockMcpTool - result = setup_adk(project_name="test") + # Verify McpTool got patched (if available) + try: + from braintrust.integrations.adk.patchers import McpToolPatcher - assert result is True - mock_wrap.assert_called_once_with(MockMcpTool) + assert McpToolPatcher.is_patched(None, None), "McpTool should be patched" + except ImportError: + pass # MCP is optional @pytest.mark.asyncio async def test_setup_adk_graceful_fallback_when_mcp_unavailable(): """Test that setup_adk gracefully handles MCP not being installed.""" - with patch("braintrust.wrappers.adk.init_logger"): - # This test is tricky - we need MCP import to fail but not break other imports - # The actual behavior is tested in integration: when MCP is not available, - # it gets ImportError from the google.adk.tools.mcp_tool module itself - # For this test, we just verify setup_adk succeeds even when MCP module raises ImportError - - result = setup_adk(project_name="test") - - # Should succeed - MCP is optional - assert result is True + # setup_adk delegates to ADKIntegration.setup() which handles ImportError + # in the McpToolPatcher gracefully + result = setup_adk(project_name="test") - # When MCP is not available, MCP import fails but setup_adk continues - # This is the actual graceful fallback in action + # Should succeed - MCP is optional + assert result is True @pytest.mark.asyncio @@ -194,7 +188,7 @@ async def test_mcp_tool_async_context_preservation(): """ import contextvars - from braintrust.wrappers.adk import wrap_mcp_tool + from braintrust.integrations.adk import wrap_mcp_tool # Track context switches context_var = contextvars.ContextVar("test_context", default=None) @@ -252,7 +246,7 @@ async def test_mcp_tool_nested_async_generators(): 3. MCP tool execution happens deep in the stack 4. All generators yield and resume, potentially in different contexts """ - from braintrust.wrappers.adk import wrap_mcp_tool + from braintrust.integrations.adk import wrap_mcp_tool class MockMcpTool: def __init__(self): @@ -309,9 +303,9 @@ async def test_real_context_loss_with_braintrust_spans(): suppressing in the aclosing.__aexit__ method. """ import asyncio + from contextlib import aclosing from braintrust import init_logger - from braintrust.wrappers.adk import aclosing # Initialize a test logger logger = init_logger(project="test-context-loss") diff --git a/py/src/braintrust/integrations/adk/tracing.py b/py/src/braintrust/integrations/adk/tracing.py new file mode 100644 index 00000000..5f02e6f1 --- /dev/null +++ b/py/src/braintrust/integrations/adk/tracing.py @@ -0,0 +1,569 @@ +"""ADK-specific span creation, metadata extraction, stream handling, and output normalization.""" + +import contextvars +import inspect +import logging +import time +from collections.abc import Iterable +from contextlib import aclosing +from typing import Any, cast + +from braintrust.bt_json import bt_safe_deep_copy +from braintrust.logger import Attachment, start_span +from braintrust.span_types import SpanTypeAttribute + + +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# Serialization helpers +# --------------------------------------------------------------------------- + + +def _serialize_content(content: Any) -> Any: + """Serialize Google ADK Content/Part objects, converting binary data to Attachments.""" + if content is None: + return None + + # Handle Content objects with parts + if hasattr(content, "parts") and content.parts: + serialized_parts = [] + for part in content.parts: + serialized_parts.append(_serialize_part(part)) + + result = {"parts": serialized_parts} + if hasattr(content, "role"): + result["role"] = content.role + return result + + # Handle single Part + return _serialize_part(content) + + +def _serialize_part(part: Any) -> Any: + """Serialize a single Part object, handling binary data.""" + if part is None: + return None + + # If it's already a dict, return as-is + if isinstance(part, dict): + return part + + # Handle Part objects with inline_data (binary data like images) + if hasattr(part, "inline_data") and part.inline_data: + inline_data = part.inline_data + if hasattr(inline_data, "data") and hasattr(inline_data, "mime_type"): + data = inline_data.data + mime_type = inline_data.mime_type + + # Convert bytes to Attachment + if isinstance(data, bytes): + extension = mime_type.split("/")[1] if "/" in mime_type else "bin" + filename = f"file.{extension}" + attachment = Attachment(data=data, filename=filename, content_type=mime_type) + + # Return in image_url format - SDK will replace with AttachmentReference + return {"image_url": {"url": attachment}} + + # Handle Part objects with file_data (file references) + if hasattr(part, "file_data") and part.file_data: + file_data = part.file_data + result = {"file_data": {}} + if hasattr(file_data, "file_uri"): + result["file_data"]["file_uri"] = file_data.file_uri + if hasattr(file_data, "mime_type"): + result["file_data"]["mime_type"] = file_data.mime_type + return result + + # Handle text parts + if hasattr(part, "text") and part.text is not None: + result = {"text": part.text} + if hasattr(part, "thought") and part.thought: + result["thought"] = part.thought + return result + + # Try standard serialization methods + return bt_safe_deep_copy(part) + + +def _serialize_pydantic_schema(schema_class: Any) -> dict[str, Any]: + """ + Serialize a Pydantic model class to its full JSON schema. + + Returns the complete schema including descriptions, constraints, and nested definitions + so engineers can see exactly what structured output schema was used. + """ + try: + from pydantic import BaseModel + + if inspect.isclass(schema_class) and issubclass(schema_class, BaseModel): + # Return the full JSON schema - includes all field info, descriptions, constraints, etc. + return schema_class.model_json_schema() + except (ImportError, AttributeError, TypeError): + pass + # If not a Pydantic model, return class name + return {"__class__": schema_class.__name__ if inspect.isclass(schema_class) else str(type(schema_class).__name__)} + + +def _serialize_config(config: Any) -> dict[str, Any] | Any: + """ + Serialize a config object, specifically handling schema fields that may contain Pydantic classes. + + Google ADK uses these fields for schemas: + - response_schema, response_json_schema (in GenerateContentConfig for LLM requests) + - input_schema, output_schema (in agent config) + """ + if config is None: + return None + if not config: + return config + + # Extract schema fields BEFORE calling bt_safe_deep_copy (which converts Pydantic classes to dicts) + schema_fields = ["response_schema", "response_json_schema", "input_schema", "output_schema"] + serialized_schemas: dict[str, Any] = {} + + for field in schema_fields: + schema_value = None + + # Try to get the field value + if hasattr(config, field): + schema_value = getattr(config, field) + elif isinstance(config, dict) and field in config: + schema_value = config[field] + + # If it's a Pydantic class, serialize it + if schema_value is not None and inspect.isclass(schema_value): + try: + from pydantic import BaseModel + + if issubclass(schema_value, BaseModel): + serialized_schemas[field] = _serialize_pydantic_schema(schema_value) + except (TypeError, ImportError): + pass + + # Serialize the config + config_dict = bt_safe_deep_copy(config) + if not isinstance(config_dict, dict): + return config_dict # type: ignore + + # Replace schema fields with serialized versions + config_dict.update(serialized_schemas) + + return config_dict + + +def _omit(obj: Any, keys: Iterable[str]): + return {k: v for k, v in obj.items() if k not in keys} + + +def _extract_metrics(response: Any) -> dict[str, float] | None: + """Extract token usage metrics from Google GenAI response.""" + if not response: + return None + + usage_metadata = getattr(response, "usage_metadata", None) + if not usage_metadata: + return None + + metrics: dict[str, float] = {} + + # Core token counts + if hasattr(usage_metadata, "prompt_token_count") and usage_metadata.prompt_token_count is not None: + metrics["prompt_tokens"] = float(usage_metadata.prompt_token_count) + + if hasattr(usage_metadata, "candidates_token_count") and usage_metadata.candidates_token_count is not None: + metrics["completion_tokens"] = float(usage_metadata.candidates_token_count) + + if hasattr(usage_metadata, "total_token_count") and usage_metadata.total_token_count is not None: + metrics["tokens"] = float(usage_metadata.total_token_count) + + # Cached token metrics + if hasattr(usage_metadata, "cached_content_token_count") and usage_metadata.cached_content_token_count is not None: + metrics["prompt_cached_tokens"] = float(usage_metadata.cached_content_token_count) + + # Reasoning token metrics (thoughts_token_count) + if hasattr(usage_metadata, "thoughts_token_count") and usage_metadata.thoughts_token_count is not None: + metrics["completion_reasoning_tokens"] = float(usage_metadata.thoughts_token_count) + + return metrics if metrics else None + + +def _extract_model_name(response: Any, llm_request: Any, instance: Any) -> str | None: + """Extract model name from Google GenAI response, request, or flow instance.""" + # Try to get from response first + if response: + model_version = getattr(response, "model_version", None) + if model_version: + return model_version + + # Try to get from llm_request + if llm_request: + if hasattr(llm_request, "model") and llm_request.model: + return str(llm_request.model) + + # Try to get from instance (flow's llm) + if instance: + if hasattr(instance, "llm"): + llm = instance.llm + if hasattr(llm, "model") and llm.model: + return str(llm.model) + + # Try to get model from instance directly + if hasattr(instance, "model") and instance.model: + return str(instance.model) + + return None + + +def _determine_llm_call_type(llm_request: Any, model_response: Any = None) -> str: + """ + Determine the type of LLM call based on the request and response content. + + Returns: + - "tool_selection" if the LLM selected a tool to call in its response + - "response_generation" if the LLM is generating a response after tool execution + - "direct_response" if there are no tools involved or tools available but not used + """ + try: + # Convert to dict if it's a model object + request_dict = cast(dict[str, Any], bt_safe_deep_copy(llm_request)) + + # Check the conversation history for function responses + contents = request_dict.get("contents", []) + has_function_response = False + + for content in contents: + if isinstance(content, dict): + parts = content.get("parts", []) + for part in parts: + if isinstance(part, dict): + if "function_response" in part and part["function_response"] is not None: + has_function_response = True + + # Check if the response contains function calls + response_has_function_call = False + if model_response: + # Check if it's an Event object with get_function_calls method (ADK Event) + if hasattr(model_response, "get_function_calls"): + try: + function_calls = model_response.get_function_calls() + if function_calls and len(function_calls) > 0: + response_has_function_call = True + except Exception: + pass + + # Fallback: Check the response dict structure + if not response_has_function_call: + response_dict = bt_safe_deep_copy(model_response) + if isinstance(response_dict, dict): + # Try multiple possible response structures + # 1. Standard: response.content.parts + content = response_dict.get("content", {}) + if isinstance(content, dict): + parts = content.get("parts", []) + if isinstance(parts, list): + for part in parts: + if isinstance(part, dict): + if ("function_call" in part and part["function_call"] is not None) or ( + "functionCall" in part and part["functionCall"] is not None + ): + response_has_function_call = True + break + + # 2. Alternative: response has parts directly (for some event types) + if not response_has_function_call and "parts" in response_dict: + parts = response_dict.get("parts", []) + if isinstance(parts, list): + for part in parts: + if isinstance(part, dict): + if ("function_call" in part and part["function_call"] is not None) or ( + "functionCall" in part and part["functionCall"] is not None + ): + response_has_function_call = True + break + + # Determine the call type + if has_function_response: + return "response_generation" + elif response_has_function_call: + return "tool_selection" + else: + return "direct_response" + + except Exception: + return "unknown" + + +# --------------------------------------------------------------------------- +# Thread-bridge helper (wrapt-style wrapper) +# --------------------------------------------------------------------------- + + +def _create_thread_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any) -> Any: + """wrapt wrapper for ``create_thread`` that copies context into new threads.""" + ctx = contextvars.copy_context() + + # ``create_thread(target, ...)`` — target may be positional or keyword. + if args: + target = args[0] + rest_args = args[1:] + else: + target = kwargs.pop("target") + rest_args = args + + def _run_in_context(*target_args: Any, **target_kwargs: Any) -> Any: + return ctx.run(target, *target_args, **target_kwargs) + + return wrapped(_run_in_context, *rest_args, **kwargs) + + +# --------------------------------------------------------------------------- +# wrapt wrapper functions (used by patchers) +# --------------------------------------------------------------------------- + + +async def _agent_run_async_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any): + parent_context = args[0] if len(args) > 0 else kwargs.get("parent_context") + + async def _trace(): + with start_span( + name=f"agent_run [{instance.name}]", + type=SpanTypeAttribute.TASK, + metadata=bt_safe_deep_copy({"parent_context": parent_context, **_omit(kwargs, ["parent_context"])}), + ) as agent_span: + last_event = None + async with aclosing(wrapped(*args, **kwargs)) as agen: + async for event in agen: + if event.is_final_response(): + last_event = event + yield event + if last_event: + agent_span.log(output=last_event) + + async with aclosing(_trace()) as agen: + async for event in agen: + yield event + + +async def _flow_run_async_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any): + invocation_context = args[0] if len(args) > 0 else kwargs.get("invocation_context") + + async def _trace(): + with start_span( + name="call_llm", + type=SpanTypeAttribute.TASK, + metadata=bt_safe_deep_copy( + { + "invocation_context": invocation_context, + **_omit(kwargs, ["invocation_context"]), + } + ), + ) as llm_span: + last_event = None + async with aclosing(wrapped(*args, **kwargs)) as agen: + async for event in agen: + last_event = event + yield event + if last_event: + llm_span.log(output=last_event) + + async with aclosing(_trace()) as agen: + async for event in agen: + yield event + + +async def _flow_call_llm_async_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any): + invocation_context = args[0] if len(args) > 0 else kwargs.get("invocation_context") + llm_request = args[1] if len(args) > 1 else kwargs.get("llm_request") + model_response_event = args[2] if len(args) > 2 else kwargs.get("model_response_event") + + async def _trace(): + # Extract and serialize contents BEFORE converting to dict + # This is critical because bt_safe_deep_copy converts bytes to string representations + serialized_contents = None + if llm_request and hasattr(llm_request, "contents"): + contents = llm_request.contents + if contents: + serialized_contents = ( + [_serialize_content(c) for c in contents] + if isinstance(contents, list) + else _serialize_content(contents) + ) + + # Now convert the whole request to dict + serialized_request = bt_safe_deep_copy(llm_request) + + # Replace contents with our serialized version that has Attachments + if serialized_contents is not None and isinstance(serialized_request, dict): + serialized_request["contents"] = serialized_contents + + # Handle config specifically to serialize Pydantic schema classes + if isinstance(serialized_request, dict) and "config" in serialized_request: + serialized_request["config"] = _serialize_config(serialized_request["config"]) + + # Extract model name from request or instance + model_name = _extract_model_name(None, llm_request, instance) + + # Create span BEFORE execution so child spans (like mcp_tool) have proper parent + # Start with generic name - we'll update it after we see the response + with start_span( + name="llm_call", + type=SpanTypeAttribute.LLM, + input=serialized_request, + metadata=bt_safe_deep_copy( + { + "invocation_context": invocation_context, + "model_response_event": model_response_event, + "flow_class": instance.__class__.__name__, + "model": model_name, + **_omit(kwargs, ["invocation_context", "model_response_event", "flow_class", "llm_call_type"]), + } + ), + ) as llm_span: + # Execute the LLM call and yield events while span is active + last_event = None + event_with_content = None + start_time = time.time() + first_token_time = None + + async with aclosing(wrapped(*args, **kwargs)) as agen: + async for event in agen: + # Record time to first token + if first_token_time is None: + first_token_time = time.time() + + last_event = event + if hasattr(event, "content") and event.content is not None: + event_with_content = event + yield event + + # After execution, update span with correct call type and output + if last_event: + # We need to check if we should merge content from an earlier event + # Convert to dict to inspect/modify, but let span.log() handle final serialization + output_dict = bt_safe_deep_copy(last_event) + if event_with_content and isinstance(output_dict, dict): + if "content" not in output_dict or output_dict.get("content") is None: + content = ( + bt_safe_deep_copy(event_with_content.content) + if hasattr(event_with_content, "content") + else None + ) + if content: + output_dict["content"] = content + + # Extract metrics from response + metrics = _extract_metrics(last_event) + + # Add time to first token if we captured it + if first_token_time is not None: + if metrics is None: + metrics = {} + metrics["time_to_first_token"] = first_token_time - start_time + + # Determine the actual call type based on the response + call_type = _determine_llm_call_type(llm_request, last_event) + + # Update span name with the specific call type now that we know it + llm_span.set_attributes( + name=f"llm_call [{call_type}]", + span_attributes={"llm_call_type": call_type}, + ) + + # Log output and metrics (span.log will handle serialization) + llm_span.log(output=output_dict, metrics=metrics) + + async with aclosing(_trace()) as agen: + async for event in agen: + yield event + + +def _runner_run_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any): + user_id = kwargs.get("user_id") + session_id = kwargs.get("session_id") + new_message = kwargs.get("new_message") + + # Serialize new_message before any dict conversion to handle binary data + serialized_message = _serialize_content(new_message) if new_message else None + + def _trace(): + with start_span( + name=f"invocation [{instance.app_name}]", + type=SpanTypeAttribute.TASK, + input={"new_message": serialized_message}, + metadata=bt_safe_deep_copy( + { + "user_id": user_id, + "session_id": session_id, + **_omit(kwargs, ["user_id", "session_id", "new_message"]), + } + ), + ) as runner_span: + last_event = None + for event in wrapped(*args, **kwargs): + if event.is_final_response(): + last_event = event + yield event + if last_event: + runner_span.log(output=last_event) + + yield from _trace() + + +async def _runner_run_async_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any): + user_id = kwargs.get("user_id") + session_id = kwargs.get("session_id") + new_message = kwargs.get("new_message") + state_delta = kwargs.get("state_delta") + + # Serialize new_message before any dict conversion to handle binary data + serialized_message = _serialize_content(new_message) if new_message else None + + async def _trace(): + with start_span( + name=f"invocation [{instance.app_name}]", + type=SpanTypeAttribute.TASK, + input={"new_message": serialized_message}, + metadata=bt_safe_deep_copy( + { + "user_id": user_id, + "session_id": session_id, + "state_delta": state_delta, + **_omit(kwargs, ["user_id", "session_id", "new_message", "state_delta"]), + } + ), + ) as runner_span: + last_event = None + async with aclosing(wrapped(*args, **kwargs)) as agen: + async for event in agen: + if event.is_final_response(): + last_event = event + yield event + if last_event: + runner_span.log(output=last_event) + + async with aclosing(_trace()) as agen: + async for event in agen: + yield event + + +async def _mcp_tool_run_async_wrapper_async(wrapped: Any, instance: Any, args: Any, kwargs: Any): + # Extract tool information + tool_name = instance.name + tool_args = kwargs.get("args", {}) + + with start_span( + name=f"mcp_tool [{tool_name}]", + type=SpanTypeAttribute.TOOL, + input={"tool_name": tool_name, "arguments": tool_args}, + metadata=_omit(kwargs, ["args"]), + ) as tool_span: + try: + result = await wrapped(*args, **kwargs) + tool_span.log(output=result) + return result + except Exception as e: + # Log error to span but re-raise for ADK to handle + tool_span.log(error=str(e)) + raise diff --git a/py/src/braintrust/integrations/auto_test_scripts/test_auto_adk.py b/py/src/braintrust/integrations/auto_test_scripts/test_auto_adk.py new file mode 100644 index 00000000..6a4d3650 --- /dev/null +++ b/py/src/braintrust/integrations/auto_test_scripts/test_auto_adk.py @@ -0,0 +1,16 @@ +"""Test auto_instrument for Google ADK.""" + +from braintrust.auto import auto_instrument + + +# 1. Instrument +results = auto_instrument() +assert results.get("adk") == True, "auto_instrument should return True for adk" + +# 2. Idempotent +results2 = auto_instrument() +assert results2.get("adk") == True, "auto_instrument should still return True on second call" + +# 3. Verify classes are patched using patcher markers + +print("SUCCESS") diff --git a/py/src/braintrust/integrations/base.py b/py/src/braintrust/integrations/base.py index debe282b..12588036 100644 --- a/py/src/braintrust/integrations/base.py +++ b/py/src/braintrust/integrations/base.py @@ -53,15 +53,34 @@ def patch(cls, module: Any | None, version: str | None, *, target: Any | None = class FunctionWrapperPatcher(BasePatcher): - """Base patcher for single-target `wrap_function_wrapper` instrumentation.""" + """Base patcher for single-target `wrap_function_wrapper` instrumentation. + + Set ``target_module`` to an import path when the patch target lives in a + different module than the one provided by the integration (e.g. a deep + submodule that may or may not be installed). The module is imported lazily + when the patcher is evaluated. + """ target_path: ClassVar[str] wrapper: ClassVar[Any] + target_module: ClassVar[str | None] = None @classmethod def resolve_root(cls, module: Any | None, version: str | None, *, target: Any | None = None) -> Any | None: - """Return the root object from which this patcher resolves its target.""" - return target or module + """Return the root object from which this patcher resolves its target. + + When ``target_module`` is set, the patcher imports that module and uses + it as root instead of the integration-level module. If the import + fails, ``None`` is returned so that ``applies()`` returns ``False``. + """ + if target is not None: + return target + if cls.target_module is not None: + try: + return importlib.import_module(cls.target_module) + except ImportError: + return None + return module @classmethod def resolve_target(cls, module: Any | None, version: str | None, *, target: Any | None = None) -> Any | None: @@ -111,6 +130,89 @@ def patch(cls, module: Any | None, version: str | None, *, target: Any | None = cls.mark_patched(resolved_target) return True + @classmethod + def wrap_target(cls, target: Any) -> Any: + """Patch *target* directly for tracing (idempotent). + + Unlike ``patch()``, which resolves the full ``target_path`` from a + module root, this method wraps the **leaf** attribute of + ``target_path`` directly on *target*. This is useful for manual + wrapping of a specific class or object (e.g. ``wrap_agent(MyAgent)``). + + The patch marker is set on *target* itself so that callers can check + ``getattr(target, patcher.patch_marker_attr(), False)`` to detect + whether the patch has already been applied. + + Returns *target* unchanged if the leaf attribute does not exist on + *target* or the patch has already been applied. Returns *target* + for convenient chaining. + """ + marker = cls.patch_marker_attr() + if getattr(target, marker, False): + return target + attr = cls.target_path.rsplit(".", 1)[-1] + if _resolve_attr_path(target, attr) is None: + return target + wrap_function_wrapper(target, attr, cls.wrapper) + cls.mark_patched(target) + return target + + +class CompositeFunctionWrapperPatcher(BasePatcher): + """Patcher that applies multiple ``FunctionWrapperPatcher`` sub-patchers as one unit. + + Use this when several closely related targets should be patched together + under a single patcher name — for example, patching both the sync and async + variants of the same method on one class. + + Subclasses declare ``sub_patchers`` as a tuple of ``FunctionWrapperPatcher`` + classes. The composite delegates ``applies``, ``is_patched``, and ``patch`` + to the sub-patchers, and the composite is considered patched when **all** + applicable sub-patchers have been applied. + """ + + sub_patchers: ClassVar[tuple[type[FunctionWrapperPatcher], ...]] + + @classmethod + def applies(cls, module: Any | None, version: str | None, *, target: Any | None = None) -> bool: + """Return ``True`` if the version gate passes and at least one sub-patcher applies.""" + if not super().applies(module, version, target=target): + return False + return any(sub.applies(module, version, target=target) for sub in cls.sub_patchers) + + @classmethod + def is_patched(cls, module: Any | None, version: str | None, *, target: Any | None = None) -> bool: + """Return ``True`` when every applicable sub-patcher has been applied.""" + applicable = [sub for sub in cls.sub_patchers if sub.applies(module, version, target=target)] + if not applicable: + return False + return all(sub.is_patched(module, version, target=target) for sub in applicable) + + @classmethod + def patch(cls, module: Any | None, version: str | None, *, target: Any | None = None) -> bool: + """Apply all applicable sub-patchers.""" + success = False + for sub in cls.sub_patchers: + if not sub.applies(module, version, target=target): + continue + if sub.is_patched(module, version, target=target): + success = True + continue + success = sub.patch(module, version, target=target) or success + return success + + @classmethod + def wrap_target(cls, target: Any) -> Any: + """Patch *target* directly for tracing (idempotent). + + Delegates to each sub-patcher's ``wrap_target``, which individually + skips sub-patchers whose leaf attribute does not exist on *target*. + Returns *target* for convenient chaining. + """ + for sub in cls.sub_patchers: + sub.wrap_target(target) + return target + class BaseIntegration(ABC): """Base class for an instrumentable third-party integration.""" diff --git a/py/src/braintrust/wrappers/adk/__init__.py b/py/src/braintrust/wrappers/adk/__init__.py index 6c6b8a14..e625e19c 100644 --- a/py/src/braintrust/wrappers/adk/__init__.py +++ b/py/src/braintrust/wrappers/adk/__init__.py @@ -1,679 +1,11 @@ -import contextvars -import inspect -import logging -import time -from collections.abc import Iterable -from contextlib import aclosing -from typing import Any, cast +from braintrust.integrations.adk import ( # noqa: F401 + setup_adk, + setup_braintrust, + wrap_agent, + wrap_flow, + wrap_mcp_tool, + wrap_runner, +) -from braintrust.bt_json import bt_safe_deep_copy -from braintrust.logger import NOOP_SPAN, Attachment, current_span, init_logger, start_span -from braintrust.span_types import SpanTypeAttribute -from wrapt import wrap_function_wrapper - - -logger = logging.getLogger(__name__) __all__ = ["setup_braintrust", "setup_adk", "wrap_agent", "wrap_runner", "wrap_flow", "wrap_mcp_tool"] - - -def setup_braintrust(*args, **kwargs): - logger.warning("setup_braintrust is deprecated, use setup_adk instead") - return setup_adk(*args, **kwargs) - - -def setup_adk( - api_key: str | None = None, - project_id: str | None = None, - project_name: str | None = None, - SpanProcessor: type | None = None, -) -> bool: - """ - Setup Braintrust integration with Google ADK. Will automatically patch Google ADK agents, runners, flows, and MCP tools for automatic tracing. - - If you prefer manual patching take a look at `wrap_agent`, `wrap_runner`, `wrap_flow`, and `wrap_mcp_tool`. - - Args: - api_key (Optional[str]): Braintrust API key. - project_id (Optional[str]): Braintrust project ID. - project_name (Optional[str]): Braintrust project name. - SpanProcessor (Optional[type]): Deprecated parameter. - - Returns: - bool: True if setup was successful, False otherwise. - """ - if SpanProcessor is not None: - logging.warning("SpanProcessor parameter is deprecated and will be ignored") - - span = current_span() - if span == NOOP_SPAN: - init_logger(project=project_name, api_key=api_key, project_id=project_id) - - try: - from google.adk import agents, runners - from google.adk.flows.llm_flows import base_llm_flow - - agents.BaseAgent = wrap_agent(agents.BaseAgent) - runners.Runner = wrap_runner(runners.Runner) - base_llm_flow.BaseLlmFlow = wrap_flow(base_llm_flow.BaseLlmFlow) - - try: - from google.adk.platform import thread as adk_thread - - adk_thread.create_thread = _wrap_create_thread(adk_thread.create_thread) - runners.create_thread = _wrap_create_thread(runners.create_thread) - logger.debug("ADK thread bridge patching successful") - except Exception as e: - logger.warning(f"Failed to patch ADK thread bridge: {e}") - - # Try to patch McpTool if available (MCP is optional) - try: - from google.adk.tools.mcp_tool import mcp_tool - - mcp_tool.McpTool = wrap_mcp_tool(mcp_tool.McpTool) - logger.debug("McpTool patching successful") - except ImportError: - # MCP is optional - gracefully skip if not installed - logger.debug("McpTool not available, skipping MCP instrumentation") - except Exception as e: - # Log but don't fail - MCP patching is optional - logger.warning(f"Failed to patch McpTool: {e}") - - return True - except ImportError as e: - logger.error(f"Failed to import Google ADK agents: {e}") - logger.error("Google ADK is not installed. Please install it with: pip install google-adk") - return False - - -def _wrap_create_thread(create_thread): - if _is_patched(create_thread): - return create_thread - - def _wrapped_create_thread(target: Any, *args: Any, **kwargs: Any): - ctx = contextvars.copy_context() - - def _run_in_context(*target_args: Any, **target_kwargs: Any): - return ctx.run(target, *target_args, **target_kwargs) - - return create_thread(_run_in_context, *args, **kwargs) - - _wrapped_create_thread._braintrust_patched = True - return _wrapped_create_thread - - -def wrap_agent(Agent: Any) -> Any: - if _is_patched(Agent): - return Agent - - async def agent_run_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any): - parent_context = args[0] if len(args) > 0 else kwargs.get("parent_context") - - async def _trace(): - with start_span( - name=f"agent_run [{instance.name}]", - type=SpanTypeAttribute.TASK, - metadata=bt_safe_deep_copy({"parent_context": parent_context, **_omit(kwargs, ["parent_context"])}), - ) as agent_span: - last_event = None - async with aclosing(wrapped(*args, **kwargs)) as agen: - async for event in agen: - if event.is_final_response(): - last_event = event - yield event - if last_event: - agent_span.log(output=last_event) - - async with aclosing(_trace()) as agen: - async for event in agen: - yield event - - wrap_function_wrapper(Agent, "run_async", agent_run_wrapper) - Agent._braintrust_patched = True - return Agent - - -def wrap_flow(Flow: Any): - if _is_patched(Flow): - return Flow - - async def trace_flow(wrapped: Any, instance: Any, args: Any, kwargs: Any): - invocation_context = args[0] if len(args) > 0 else kwargs.get("invocation_context") - - async def _trace(): - with start_span( - name=f"call_llm", - type=SpanTypeAttribute.TASK, - metadata=bt_safe_deep_copy( - { - "invocation_context": invocation_context, - **_omit(kwargs, ["invocation_context"]), - } - ), - ) as llm_span: - last_event = None - async with aclosing(wrapped(*args, **kwargs)) as agen: - async for event in agen: - last_event = event - yield event - if last_event: - llm_span.log(output=last_event) - - async with aclosing(_trace()) as agen: - async for event in agen: - yield event - - wrap_function_wrapper(Flow, "run_async", trace_flow) - - async def trace_run_sync_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any): - invocation_context = args[0] if len(args) > 0 else kwargs.get("invocation_context") - llm_request = args[1] if len(args) > 1 else kwargs.get("llm_request") - model_response_event = args[2] if len(args) > 2 else kwargs.get("model_response_event") - - async def _trace(): - # Extract and serialize contents BEFORE converting to dict - # This is critical because bt_safe_deep_copy converts bytes to string representations - serialized_contents = None - if llm_request and hasattr(llm_request, "contents"): - contents = llm_request.contents - if contents: - serialized_contents = ( - [_serialize_content(c) for c in contents] - if isinstance(contents, list) - else _serialize_content(contents) - ) - - # Now convert the whole request to dict - serialized_request = bt_safe_deep_copy(llm_request) - - # Replace contents with our serialized version that has Attachments - if serialized_contents is not None and isinstance(serialized_request, dict): - serialized_request["contents"] = serialized_contents - - # Handle config specifically to serialize Pydantic schema classes - if isinstance(serialized_request, dict) and "config" in serialized_request: - serialized_request["config"] = _serialize_config(serialized_request["config"]) - - # Extract model name from request or instance - model_name = _extract_model_name(None, llm_request, instance) - - # Create span BEFORE execution so child spans (like mcp_tool) have proper parent - # Start with generic name - we'll update it after we see the response - with start_span( - name="llm_call", - type=SpanTypeAttribute.LLM, - input=serialized_request, - metadata=bt_safe_deep_copy( - { - "invocation_context": invocation_context, - "model_response_event": model_response_event, - "flow_class": instance.__class__.__name__, - "model": model_name, - **_omit(kwargs, ["invocation_context", "model_response_event", "flow_class", "llm_call_type"]), - } - ), - ) as llm_span: - # Execute the LLM call and yield events while span is active - last_event = None - event_with_content = None - start_time = time.time() - first_token_time = None - - async with aclosing(wrapped(*args, **kwargs)) as agen: - async for event in agen: - # Record time to first token - if first_token_time is None: - first_token_time = time.time() - - last_event = event - if hasattr(event, "content") and event.content is not None: - event_with_content = event - yield event - - # After execution, update span with correct call type and output - if last_event: - # We need to check if we should merge content from an earlier event - # Convert to dict to inspect/modify, but let span.log() handle final serialization - output_dict = bt_safe_deep_copy(last_event) - if event_with_content and isinstance(output_dict, dict): - if "content" not in output_dict or output_dict.get("content") is None: - content = ( - bt_safe_deep_copy(event_with_content.content) - if hasattr(event_with_content, "content") - else None - ) - if content: - output_dict["content"] = content - - # Extract metrics from response - metrics = _extract_metrics(last_event) - - # Add time to first token if we captured it - if first_token_time is not None: - if metrics is None: - metrics = {} - metrics["time_to_first_token"] = first_token_time - start_time - - # Determine the actual call type based on the response - call_type = _determine_llm_call_type(llm_request, last_event) - - # Update span name with the specific call type now that we know it - llm_span.set_attributes( - name=f"llm_call [{call_type}]", - span_attributes={"llm_call_type": call_type}, - ) - - # Log output and metrics (span.log will handle serialization) - llm_span.log(output=output_dict, metrics=metrics) - - async with aclosing(_trace()) as agen: - async for event in agen: - yield event - - wrap_function_wrapper(Flow, "_call_llm_async", trace_run_sync_wrapper) - Flow._braintrust_patched = True - return Flow - - -def wrap_runner(Runner: Any): - if _is_patched(Runner): - return Runner - - def trace_run_sync_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any): - user_id = kwargs.get("user_id") - session_id = kwargs.get("session_id") - new_message = kwargs.get("new_message") - - # Serialize new_message before any dict conversion to handle binary data - serialized_message = _serialize_content(new_message) if new_message else None - - def _trace(): - with start_span( - name=f"invocation [{instance.app_name}]", - type=SpanTypeAttribute.TASK, - input={"new_message": serialized_message}, - metadata=bt_safe_deep_copy( - { - "user_id": user_id, - "session_id": session_id, - **_omit(kwargs, ["user_id", "session_id", "new_message"]), - } - ), - ) as runner_span: - last_event = None - for event in wrapped(*args, **kwargs): - if event.is_final_response(): - last_event = event - yield event - if last_event: - runner_span.log(output=last_event) - - yield from _trace() - - wrap_function_wrapper(Runner, "run", trace_run_sync_wrapper) - - async def trace_run_async_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any): - user_id = kwargs.get("user_id") - session_id = kwargs.get("session_id") - new_message = kwargs.get("new_message") - state_delta = kwargs.get("state_delta") - - # Serialize new_message before any dict conversion to handle binary data - serialized_message = _serialize_content(new_message) if new_message else None - - async def _trace(): - with start_span( - name=f"invocation [{instance.app_name}]", - type=SpanTypeAttribute.TASK, - input={"new_message": serialized_message}, - metadata=bt_safe_deep_copy( - { - "user_id": user_id, - "session_id": session_id, - "state_delta": state_delta, - **_omit(kwargs, ["user_id", "session_id", "new_message", "state_delta"]), - } - ), - ) as runner_span: - last_event = None - async with aclosing(wrapped(*args, **kwargs)) as agen: - async for event in agen: - if event.is_final_response(): - last_event = event - yield event - if last_event: - runner_span.log(output=last_event) - - async with aclosing(_trace()) as agen: - async for event in agen: - yield event - - wrap_function_wrapper(Runner, "run_async", trace_run_async_wrapper) - Runner._braintrust_patched = True - return Runner - - -def wrap_mcp_tool(McpTool: Any) -> Any: - """ - Wrap McpTool to trace MCP tool invocations. - - Creates Braintrust spans for each MCP tool call, capturing: - - Tool name - - Input arguments - - Output results - - Execution time - - Errors if they occur - - Args: - McpTool: The McpTool class to wrap - - Returns: - The wrapped McpTool class - """ - if _is_patched(McpTool): - return McpTool - - async def tool_run_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any): - # Extract tool information - tool_name = instance.name - tool_args = kwargs.get("args", {}) - - with start_span( - name=f"mcp_tool [{tool_name}]", - type=SpanTypeAttribute.TOOL, - input={"tool_name": tool_name, "arguments": tool_args}, - metadata=_omit(kwargs, ["args"]), - ) as tool_span: - try: - result = await wrapped(*args, **kwargs) - tool_span.log(output=result) - return result - except Exception as e: - # Log error to span but re-raise for ADK to handle - tool_span.log(error=str(e)) - raise - - wrap_function_wrapper(McpTool, "run_async", tool_run_wrapper) - McpTool._braintrust_patched = True - return McpTool - - -def _determine_llm_call_type(llm_request: Any, model_response: Any = None) -> str: - """ - Determine the type of LLM call based on the request and response content. - - Returns: - - "tool_selection" if the LLM selected a tool to call in its response - - "response_generation" if the LLM is generating a response after tool execution - - "direct_response" if there are no tools involved or tools available but not used - """ - try: - # Convert to dict if it's a model object - request_dict = cast(dict[str, Any], bt_safe_deep_copy(llm_request)) - - # Check if there are tools in the config - has_tools = bool(request_dict.get("config", {}).get("tools")) - - # Check the conversation history for function responses - contents = request_dict.get("contents", []) - has_function_response = False - - for content in contents: - if isinstance(content, dict): - parts = content.get("parts", []) - for part in parts: - if isinstance(part, dict): - if "function_response" in part and part["function_response"] is not None: - has_function_response = True - - # Check if the response contains function calls - response_has_function_call = False - if model_response: - # Check if it's an Event object with get_function_calls method (ADK Event) - if hasattr(model_response, "get_function_calls"): - try: - function_calls = model_response.get_function_calls() - if function_calls and len(function_calls) > 0: - response_has_function_call = True - except Exception: - pass - - # Fallback: Check the response dict structure - if not response_has_function_call: - response_dict = bt_safe_deep_copy(model_response) - if isinstance(response_dict, dict): - # Try multiple possible response structures - # 1. Standard: response.content.parts - content = response_dict.get("content", {}) - if isinstance(content, dict): - parts = content.get("parts", []) - if isinstance(parts, list): - for part in parts: - if isinstance(part, dict): - if ("function_call" in part and part["function_call"] is not None) or ( - "functionCall" in part and part["functionCall"] is not None - ): - response_has_function_call = True - break - - # 2. Alternative: response has parts directly (for some event types) - if not response_has_function_call and "parts" in response_dict: - parts = response_dict.get("parts", []) - if isinstance(parts, list): - for part in parts: - if isinstance(part, dict): - if ("function_call" in part and part["function_call"] is not None) or ( - "functionCall" in part and part["functionCall"] is not None - ): - response_has_function_call = True - break - - # Determine the call type - if has_function_response: - return "response_generation" - elif response_has_function_call: - return "tool_selection" - else: - return "direct_response" - - except Exception: - return "unknown" - - -def _is_patched(obj: Any): - return getattr(obj, "_braintrust_patched", False) - - -def _serialize_content(content: Any) -> Any: - """Serialize Google ADK Content/Part objects, converting binary data to Attachments.""" - if content is None: - return None - - # Handle Content objects with parts - if hasattr(content, "parts") and content.parts: - serialized_parts = [] - for part in content.parts: - serialized_parts.append(_serialize_part(part)) - - result = {"parts": serialized_parts} - if hasattr(content, "role"): - result["role"] = content.role - return result - - # Handle single Part - return _serialize_part(content) - - -def _serialize_part(part: Any) -> Any: - """Serialize a single Part object, handling binary data.""" - if part is None: - return None - - # If it's already a dict, return as-is - if isinstance(part, dict): - return part - - # Handle Part objects with inline_data (binary data like images) - if hasattr(part, "inline_data") and part.inline_data: - inline_data = part.inline_data - if hasattr(inline_data, "data") and hasattr(inline_data, "mime_type"): - data = inline_data.data - mime_type = inline_data.mime_type - - # Convert bytes to Attachment - if isinstance(data, bytes): - extension = mime_type.split("/")[1] if "/" in mime_type else "bin" - filename = f"file.{extension}" - attachment = Attachment(data=data, filename=filename, content_type=mime_type) - - # Return in image_url format - SDK will replace with AttachmentReference - return {"image_url": {"url": attachment}} - - # Handle Part objects with file_data (file references) - if hasattr(part, "file_data") and part.file_data: - file_data = part.file_data - result = {"file_data": {}} - if hasattr(file_data, "file_uri"): - result["file_data"]["file_uri"] = file_data.file_uri - if hasattr(file_data, "mime_type"): - result["file_data"]["mime_type"] = file_data.mime_type - return result - - # Handle text parts - if hasattr(part, "text") and part.text is not None: - result = {"text": part.text} - if hasattr(part, "thought") and part.thought: - result["thought"] = part.thought - return result - - # Try standard serialization methods - return bt_safe_deep_copy(part) - - -def _serialize_pydantic_schema(schema_class: Any) -> dict[str, Any]: - """ - Serialize a Pydantic model class to its full JSON schema. - - Returns the complete schema including descriptions, constraints, and nested definitions - so engineers can see exactly what structured output schema was used. - """ - try: - from pydantic import BaseModel - - if inspect.isclass(schema_class) and issubclass(schema_class, BaseModel): - # Return the full JSON schema - includes all field info, descriptions, constraints, etc. - return schema_class.model_json_schema() - except (ImportError, AttributeError, TypeError): - pass - # If not a Pydantic model, return class name - return {"__class__": schema_class.__name__ if inspect.isclass(schema_class) else str(type(schema_class).__name__)} - - -def _serialize_config(config: Any) -> dict[str, Any] | Any: - """ - Serialize a config object, specifically handling schema fields that may contain Pydantic classes. - - Google ADK uses these fields for schemas: - - response_schema, response_json_schema (in GenerateContentConfig for LLM requests) - - input_schema, output_schema (in agent config) - """ - if config is None: - return None - if not config: - return config - - # Extract schema fields BEFORE calling bt_safe_deep_copy (which converts Pydantic classes to dicts) - schema_fields = ["response_schema", "response_json_schema", "input_schema", "output_schema"] - serialized_schemas: dict[str, Any] = {} - - for field in schema_fields: - schema_value = None - - # Try to get the field value - if hasattr(config, field): - schema_value = getattr(config, field) - elif isinstance(config, dict) and field in config: - schema_value = config[field] - - # If it's a Pydantic class, serialize it - if schema_value is not None and inspect.isclass(schema_value): - try: - from pydantic import BaseModel - - if issubclass(schema_value, BaseModel): - serialized_schemas[field] = _serialize_pydantic_schema(schema_value) - except (TypeError, ImportError): - pass - - # Serialize the config - config_dict = bt_safe_deep_copy(config) - if not isinstance(config_dict, dict): - return config_dict # type: ignore - - # Replace schema fields with serialized versions - config_dict.update(serialized_schemas) - - return config_dict - - -def _omit(obj: Any, keys: Iterable[str]): - return {k: v for k, v in obj.items() if k not in keys} - - -def _extract_metrics(response: Any) -> dict[str, float] | None: - """Extract token usage metrics from Google GenAI response.""" - if not response: - return None - - usage_metadata = getattr(response, "usage_metadata", None) - if not usage_metadata: - return None - - metrics: dict[str, float] = {} - - # Core token counts - if hasattr(usage_metadata, "prompt_token_count") and usage_metadata.prompt_token_count is not None: - metrics["prompt_tokens"] = float(usage_metadata.prompt_token_count) - - if hasattr(usage_metadata, "candidates_token_count") and usage_metadata.candidates_token_count is not None: - metrics["completion_tokens"] = float(usage_metadata.candidates_token_count) - - if hasattr(usage_metadata, "total_token_count") and usage_metadata.total_token_count is not None: - metrics["tokens"] = float(usage_metadata.total_token_count) - - # Cached token metrics - if hasattr(usage_metadata, "cached_content_token_count") and usage_metadata.cached_content_token_count is not None: - metrics["prompt_cached_tokens"] = float(usage_metadata.cached_content_token_count) - - # Reasoning token metrics (thoughts_token_count) - if hasattr(usage_metadata, "thoughts_token_count") and usage_metadata.thoughts_token_count is not None: - metrics["completion_reasoning_tokens"] = float(usage_metadata.thoughts_token_count) - - return metrics if metrics else None - - -def _extract_model_name(response: Any, llm_request: Any, instance: Any) -> str | None: - """Extract model name from Google GenAI response, request, or flow instance.""" - # Try to get from response first - if response: - model_version = getattr(response, "model_version", None) - if model_version: - return model_version - - # Try to get from llm_request - if llm_request: - if hasattr(llm_request, "model") and llm_request.model: - return str(llm_request.model) - - # Try to get from instance (flow's llm) - if instance: - if hasattr(instance, "llm"): - llm = instance.llm - if hasattr(llm, "model") and llm.model: - return str(llm.model) - - # Try to get model from instance directly - if hasattr(instance, "model") and instance.model: - return str(instance.model) - - return None diff --git a/py/src/braintrust/wrappers/adk/test_auto_adk.py b/py/src/braintrust/wrappers/adk/test_auto_adk.py deleted file mode 100644 index 493bd1fb..00000000 --- a/py/src/braintrust/wrappers/adk/test_auto_adk.py +++ /dev/null @@ -1,23 +0,0 @@ -"""Test auto_instrument for Google ADK.""" - -from braintrust.auto import auto_instrument - - -# 1. Instrument -results = auto_instrument() -assert results.get("adk") == True, "auto_instrument should return True for adk" - -# 2. Idempotent -results2 = auto_instrument() -assert results2.get("adk") == True, "auto_instrument should still return True on second call" - -# 3. Verify classes are patched -from google.adk import agents, runners -from google.adk.flows.llm_flows import base_llm_flow - - -assert getattr(agents.BaseAgent, "_braintrust_patched", False), "BaseAgent should be patched" -assert getattr(runners.Runner, "_braintrust_patched", False), "Runner should be patched" -assert getattr(base_llm_flow.BaseLlmFlow, "_braintrust_patched", False), "BaseLlmFlow should be patched" - -print("SUCCESS") From e0fb5cf2661e4bef43b3085bfbbb5b2a9643f12e Mon Sep 17 00:00:00 2001 From: Abhijeet Prasad Date: Tue, 24 Mar 2026 09:16:09 -0700 Subject: [PATCH 2/4] change versioning logic --- .agents/skills/sdk-integrations/SKILL.md | 15 +++++---------- py/src/braintrust/integrations/adk/__init__.py | 1 - py/src/braintrust/integrations/adk/integration.py | 1 + py/src/braintrust/integrations/test_versioning.py | 5 +++-- py/src/braintrust/integrations/versioning.py | 10 ++++++++-- 5 files changed, 17 insertions(+), 15 deletions(-) diff --git a/.agents/skills/sdk-integrations/SKILL.md b/.agents/skills/sdk-integrations/SKILL.md index 7e2fa498..c463aff6 100644 --- a/.agents/skills/sdk-integrations/SKILL.md +++ b/.agents/skills/sdk-integrations/SKILL.md @@ -109,21 +109,16 @@ Prefer feature detection first and version checks second. Use: - `detect_module_version(...)` -- `version_in_range(...)` -- `version_matches_spec(...)` - -Do not add `packaging` just for integration routing. +- `version_satisfies(...)` +- `make_specifier(...)` ## `auto_instrument()` Update `py/src/braintrust/auto.py` only if the integration should be auto-patched. -Match the existing option shape: - -- use plain `bool` for simple on/off integrations that do not use the integrations API -- use `InstrumentOption` for integrations API providers that support `IntegrationPatchConfig` +Use `InstrumentOption` (i.e. `bool | IntegrationPatchConfig`) for all integrations, including those that do not yet use the integrations API. This keeps the signature uniform and avoids a breaking change when the integration is later migrated. -For integrations API providers, use `_normalize_instrument_option()` and `_instrument_integration(...)` instead of adding a custom `_instrument_*` function: +Use `_normalize_instrument_option()` and `_instrument_integration(...)` instead of adding a custom `_instrument_*` function: ```python enabled, config = _normalize_instrument_option("provider", provider) @@ -180,6 +175,6 @@ cd py && make lint - Forgetting async or streaming coverage. - Adding patcher selection without tests for enabled and disabled cases. - Re-recording cassettes when behavior did not intentionally change. -- Using `_normalize_bool_option()` for an integrations API provider. +- Using `_normalize_bool_option()` instead of `_normalize_instrument_option()` — all integrations should accept `InstrumentOption`. - Adding a custom `_instrument_*` helper where `_instrument_integration()` already fits. - Forgetting `target_module` for deep or optional submodule patch targets. diff --git a/py/src/braintrust/integrations/adk/__init__.py b/py/src/braintrust/integrations/adk/__init__.py index cddb0f7f..bec4d711 100644 --- a/py/src/braintrust/integrations/adk/__init__.py +++ b/py/src/braintrust/integrations/adk/__init__.py @@ -18,7 +18,6 @@ __all__ = [ "ADKIntegration", - "_create_thread_wrapper", "setup_adk", "setup_braintrust", "wrap_agent", diff --git a/py/src/braintrust/integrations/adk/integration.py b/py/src/braintrust/integrations/adk/integration.py index 78bcb4c1..a775cc42 100644 --- a/py/src/braintrust/integrations/adk/integration.py +++ b/py/src/braintrust/integrations/adk/integration.py @@ -21,6 +21,7 @@ class ADKIntegration(BaseIntegration): name = "adk" import_names = ("google.adk",) + min_version = "1.14.1" patchers = ( ThreadBridgePatcher, AgentRunAsyncPatcher, diff --git a/py/src/braintrust/integrations/test_versioning.py b/py/src/braintrust/integrations/test_versioning.py index 226ba86f..39584ad2 100644 --- a/py/src/braintrust/integrations/test_versioning.py +++ b/py/src/braintrust/integrations/test_versioning.py @@ -26,8 +26,9 @@ def test_version_satisfies_none_handling(): assert version_satisfies("1.0", None) assert version_satisfies(None, None) - # No version with a spec means incompatible. - assert not version_satisfies(None, ">=1.0") + # No version with a spec — optimistically allow so patching still proceeds + # when version detection fails. + assert version_satisfies(None, ">=1.0") def test_version_satisfies_invalid_version(): diff --git a/py/src/braintrust/integrations/versioning.py b/py/src/braintrust/integrations/versioning.py index 384a4aa8..dc7f7dca 100644 --- a/py/src/braintrust/integrations/versioning.py +++ b/py/src/braintrust/integrations/versioning.py @@ -43,11 +43,17 @@ def make_specifier(*, min_version: str | None = None, max_version: str | None = def version_satisfies(version: str | None, spec: str | SpecifierSet | None) -> bool: - """Return True if *version* satisfies the PEP 440 *spec*.""" + """Return True if *version* satisfies the PEP 440 *spec*. + + When *version* is ``None`` (i.e. we could not detect the installed + version), we optimistically return ``True`` so that patching still + proceeds. A failed detection should not silently disable + instrumentation. + """ if spec is None: return True if version is None: - return False + return True try: ss = spec if isinstance(spec, SpecifierSet) else SpecifierSet(spec, prereleases=True) return Version(version) in ss From d735951497db1044ed2c72ff7c261ebeb6dec85d Mon Sep 17 00:00:00 2001 From: Abhijeet Prasad Date: Tue, 24 Mar 2026 10:01:07 -0700 Subject: [PATCH 3/4] remove patcher config for now --- .agents/skills/sdk-integrations/SKILL.md | 15 ++-- py/src/braintrust/__init__.py | 1 - py/src/braintrust/auto.py | 70 +++++-------------- py/src/braintrust/integrations/__init__.py | 3 +- .../integrations/anthropic/test_anthropic.py | 12 ---- .../test_auto_anthropic_patch_config.py | 20 ------ py/src/braintrust/integrations/base.py | 44 ++---------- py/src/braintrust/wrappers/test_anthropic.py | 35 ---------- 8 files changed, 28 insertions(+), 172 deletions(-) delete mode 100644 py/src/braintrust/integrations/auto_test_scripts/test_auto_anthropic_patch_config.py diff --git a/.agents/skills/sdk-integrations/SKILL.md b/.agents/skills/sdk-integrations/SKILL.md index c463aff6..e487fc1b 100644 --- a/.agents/skills/sdk-integrations/SKILL.md +++ b/.agents/skills/sdk-integrations/SKILL.md @@ -92,7 +92,7 @@ Patchers must provide: - existence checks - idempotence through the base patcher marker -Use `IntegrationPatchConfig` only when users need patcher-level selection. Let `BaseIntegration.resolve_patchers()` reject unknown patcher ids instead of silently ignoring them. +Let `BaseIntegration.resolve_patchers()` reject duplicate patcher ids instead of silently ignoring them. ## Patching Patterns @@ -116,14 +116,11 @@ Use: Update `py/src/braintrust/auto.py` only if the integration should be auto-patched. -Use `InstrumentOption` (i.e. `bool | IntegrationPatchConfig`) for all integrations, including those that do not yet use the integrations API. This keeps the signature uniform and avoids a breaking change when the integration is later migrated. - -Use `_normalize_instrument_option()` and `_instrument_integration(...)` instead of adding a custom `_instrument_*` function: +All `auto_instrument()` parameters are plain `bool` flags. Use `_instrument_integration(...)` instead of adding a custom `_instrument_*` function: ```python -enabled, config = _normalize_instrument_option("provider", provider) -if enabled: - results["provider"] = _instrument_integration(ProviderIntegration, patch_config=config) +if provider: + results["provider"] = _instrument_integration(ProviderIntegration) ``` Add the integration import near the other integration imports in `auto.py`. @@ -147,7 +144,7 @@ Cover the surfaces that changed: - streaming behavior - idempotence - failure and error logging -- patcher selection when using `IntegrationPatchConfig` +- patcher resolution and duplicate detection Keep VCR cassettes in `py/src/braintrust/integrations//cassettes/`. Re-record them only for intentional behavior changes. @@ -173,8 +170,6 @@ cd py && make lint - Moving provider-specific behavior into shared integration code. - Combining unrelated targets into one patcher. - Forgetting async or streaming coverage. -- Adding patcher selection without tests for enabled and disabled cases. - Re-recording cassettes when behavior did not intentionally change. -- Using `_normalize_bool_option()` instead of `_normalize_instrument_option()` — all integrations should accept `InstrumentOption`. - Adding a custom `_instrument_*` helper where `_instrument_integration()` already fits. - Forgetting `target_module` for deep or optional submodule patch targets. diff --git a/py/src/braintrust/__init__.py b/py/src/braintrust/__init__.py index 7a0115ae..c961ac72 100644 --- a/py/src/braintrust/__init__.py +++ b/py/src/braintrust/__init__.py @@ -63,7 +63,6 @@ def is_equal(expected, output): from .audit import * from .auto import ( - IntegrationPatchConfig, # noqa: F401 # type: ignore[reportUnusedImport] auto_instrument, # noqa: F401 # type: ignore[reportUnusedImport] ) from .framework import * diff --git a/py/src/braintrust/auto.py b/py/src/braintrust/auto.py index 13ea636d..7cd51870 100644 --- a/py/src/braintrust/auto.py +++ b/py/src/braintrust/auto.py @@ -4,18 +4,15 @@ Provides one-line instrumentation for supported libraries. """ -from __future__ import annotations - import logging from contextlib import contextmanager -from braintrust.integrations import ADKIntegration, AnthropicIntegration, IntegrationPatchConfig +from braintrust.integrations import ADKIntegration, AnthropicIntegration __all__ = ["auto_instrument"] logger = logging.getLogger(__name__) -InstrumentOption = bool | IntegrationPatchConfig @contextmanager @@ -32,14 +29,14 @@ def _try_patch(): def auto_instrument( *, openai: bool = True, - anthropic: InstrumentOption = True, + anthropic: bool = True, litellm: bool = True, pydantic_ai: bool = True, google_genai: bool = True, agno: bool = True, claude_agent_sdk: bool = True, dspy: bool = True, - adk: InstrumentOption = True, + adk: bool = True, ) -> dict[str, bool]: """ Auto-instrument supported AI/ML libraries for Braintrust tracing. @@ -52,8 +49,7 @@ def auto_instrument( Args: openai: Enable OpenAI instrumentation (default: True) - anthropic: Enable Anthropic instrumentation (default: True), or pass an - IntegrationPatchConfig to select Anthropic patchers explicitly. + anthropic: Enable Anthropic instrumentation (default: True) litellm: Enable LiteLLM instrumentation (default: True) pydantic_ai: Enable Pydantic AI instrumentation (default: True) google_genai: Enable Google GenAI instrumentation (default: True) @@ -108,34 +104,24 @@ def auto_instrument( """ results = {} - openai_enabled = _normalize_bool_option("openai", openai) - anthropic_enabled, anthropic_config = _normalize_instrument_option("anthropic", anthropic) - litellm_enabled = _normalize_bool_option("litellm", litellm) - pydantic_ai_enabled = _normalize_bool_option("pydantic_ai", pydantic_ai) - google_genai_enabled = _normalize_bool_option("google_genai", google_genai) - agno_enabled = _normalize_bool_option("agno", agno) - claude_agent_sdk_enabled = _normalize_bool_option("claude_agent_sdk", claude_agent_sdk) - dspy_enabled = _normalize_bool_option("dspy", dspy) - adk_enabled, adk_config = _normalize_instrument_option("adk", adk) - - if openai_enabled: + if openai: results["openai"] = _instrument_openai() - if anthropic_enabled: - results["anthropic"] = _instrument_integration(AnthropicIntegration, patch_config=anthropic_config) - if litellm_enabled: + if anthropic: + results["anthropic"] = _instrument_integration(AnthropicIntegration) + if litellm: results["litellm"] = _instrument_litellm() - if pydantic_ai_enabled: + if pydantic_ai: results["pydantic_ai"] = _instrument_pydantic_ai() - if google_genai_enabled: + if google_genai: results["google_genai"] = _instrument_google_genai() - if agno_enabled: + if agno: results["agno"] = _instrument_agno() - if claude_agent_sdk_enabled: + if claude_agent_sdk: results["claude_agent_sdk"] = _instrument_claude_agent_sdk() - if dspy_enabled: + if dspy: results["dspy"] = _instrument_dspy() - if adk_enabled: - results["adk"] = _instrument_integration(ADKIntegration, patch_config=adk_config) + if adk: + results["adk"] = _instrument_integration(ADKIntegration) return results @@ -148,34 +134,12 @@ def _instrument_openai() -> bool: return False -def _instrument_integration(integration, *, patch_config: IntegrationPatchConfig | None = None) -> bool: +def _instrument_integration(integration) -> bool: with _try_patch(): - return integration.setup( - enabled_patchers=patch_config.enabled_patchers if patch_config is not None else None, - disabled_patchers=patch_config.disabled_patchers if patch_config is not None else None, - ) + return integration.setup() return False -def _normalize_bool_option(name: str, option: bool) -> bool: - if isinstance(option, bool): - return option - - raise TypeError(f"auto_instrument option {name!r} must be a bool, got {type(option).__name__}") - - -def _normalize_instrument_option(name: str, option: InstrumentOption) -> tuple[bool, IntegrationPatchConfig | None]: - if isinstance(option, bool): - return option, None - - if isinstance(option, IntegrationPatchConfig): - return True, option - - raise TypeError( - f"auto_instrument option {name} must be a bool or IntegrationPatchConfig, got {type(option).__name__}" - ) - - def _instrument_litellm() -> bool: with _try_patch(): from braintrust.wrappers.litellm import patch_litellm diff --git a/py/src/braintrust/integrations/__init__.py b/py/src/braintrust/integrations/__init__.py index 72aab3da..d8c51617 100644 --- a/py/src/braintrust/integrations/__init__.py +++ b/py/src/braintrust/integrations/__init__.py @@ -1,6 +1,5 @@ from .adk import ADKIntegration from .anthropic import AnthropicIntegration -from .base import IntegrationPatchConfig -__all__ = ["ADKIntegration", "AnthropicIntegration", "IntegrationPatchConfig"] +__all__ = ["ADKIntegration", "AnthropicIntegration"] diff --git a/py/src/braintrust/integrations/anthropic/test_anthropic.py b/py/src/braintrust/integrations/anthropic/test_anthropic.py index 570de9b1..7e395f55 100644 --- a/py/src/braintrust/integrations/anthropic/test_anthropic.py +++ b/py/src/braintrust/integrations/anthropic/test_anthropic.py @@ -512,18 +512,6 @@ def test_available_patchers(self): "anthropic.init.async", ) - def test_resolve_patchers_honors_enable_disable_filters(self): - selected = AnthropicIntegration.resolve_patchers( - enabled_patchers={"anthropic.init.sync", "anthropic.init.async"}, - disabled_patchers={"anthropic.init.async"}, - ) - - assert tuple(patcher.identifier() for patcher in selected) == ("anthropic.init.sync",) - - def test_resolve_patchers_rejects_unknown_patchers(self): - with pytest.raises(ValueError, match="Unknown patchers"): - AnthropicIntegration.resolve_patchers(enabled_patchers={"anthropic.init.unknown"}) - def test_setup_rejects_unsupported_versions(self): spec = make_specifier( min_version=AnthropicIntegration.min_version, max_version=AnthropicIntegration.max_version diff --git a/py/src/braintrust/integrations/auto_test_scripts/test_auto_anthropic_patch_config.py b/py/src/braintrust/integrations/auto_test_scripts/test_auto_anthropic_patch_config.py deleted file mode 100644 index e15f9361..00000000 --- a/py/src/braintrust/integrations/auto_test_scripts/test_auto_anthropic_patch_config.py +++ /dev/null @@ -1,20 +0,0 @@ -"""Test auto_instrument patch selection for Anthropic.""" - -import anthropic -from braintrust.auto import auto_instrument -from braintrust.integrations import IntegrationPatchConfig - - -results = auto_instrument( - anthropic=IntegrationPatchConfig( - enabled_patchers={"anthropic.init.sync"}, - ) -) -assert results.get("anthropic") == True - -patched_sync = anthropic.Anthropic(api_key="test-key") -unpatched_async = anthropic.AsyncAnthropic(api_key="test-key") -assert type(patched_sync.messages).__module__ == "braintrust.integrations.anthropic.tracing" -assert type(unpatched_async.messages).__module__.startswith("anthropic.") - -print("SUCCESS") diff --git a/py/src/braintrust/integrations/base.py b/py/src/braintrust/integrations/base.py index 12588036..e3deaabb 100644 --- a/py/src/braintrust/integrations/base.py +++ b/py/src/braintrust/integrations/base.py @@ -4,8 +4,7 @@ import inspect import re from abc import ABC, abstractmethod -from collections.abc import Collection, Iterable -from dataclasses import dataclass +from collections.abc import Iterable from typing import Any, ClassVar from wrapt import wrap_function_wrapper @@ -13,14 +12,6 @@ from .versioning import detect_module_version, make_specifier, version_satisfies -@dataclass(frozen=True) -class IntegrationPatchConfig: - """Per-integration patch selection for instrumentation setup.""" - - enabled_patchers: Collection[str] | None = None - disabled_patchers: Collection[str] | None = None - - class BasePatcher(ABC): """Base class for one concrete integration patch strategy.""" @@ -229,13 +220,8 @@ def available_patchers(cls) -> tuple[str, ...]: return tuple(patcher.identifier() for patcher in cls.patchers) @classmethod - def resolve_patchers( - cls, - *, - enabled_patchers: Collection[str] | None = None, - disabled_patchers: Collection[str] | None = None, - ) -> tuple[type[BasePatcher], ...]: - """Return the selected patchers after validating explicit selectors.""" + def resolve_patchers(cls) -> tuple[type[BasePatcher], ...]: + """Return all patchers after validating there are no duplicate identifiers.""" patchers_by_id: dict[str, type[BasePatcher]] = {} for patcher in cls.patchers: patcher_id = patcher.identifier() @@ -244,30 +230,13 @@ def resolve_patchers( raise ValueError(f"Duplicate patcher identifier {patcher_id!r} for integration {cls.name!r}") patchers_by_id[patcher_id] = patcher - enabled = set(enabled_patchers) if enabled_patchers is not None else None - disabled = set(disabled_patchers or ()) - requested = disabled if enabled is None else enabled | disabled - unknown = requested - set(patchers_by_id) - if unknown: - available = ", ".join(sorted(patchers_by_id)) - unknown_display = ", ".join(sorted(unknown)) - raise ValueError( - f"Unknown patchers for integration {cls.name!r}: {unknown_display}. Available patchers: {available}" - ) - - return tuple( - patcher - for patcher in cls.patchers - if (enabled is None or patcher.identifier() in enabled) and patcher.identifier() not in disabled - ) + return cls.patchers @classmethod def setup( cls, *, target: Any | None = None, - enabled_patchers: Collection[str] | None = None, - disabled_patchers: Collection[str] | None = None, ) -> bool: """Apply all applicable patchers for this integration.""" module = _import_first_available(cls.import_names) @@ -278,10 +247,7 @@ def setup( return False success = False - selected_patchers = cls.resolve_patchers( - enabled_patchers=enabled_patchers, - disabled_patchers=disabled_patchers, - ) + selected_patchers = cls.resolve_patchers() for patcher in sorted(selected_patchers, key=lambda patcher: patcher.priority): if not patcher.applies(module, version, target=target): continue diff --git a/py/src/braintrust/wrappers/test_anthropic.py b/py/src/braintrust/wrappers/test_anthropic.py index e7884464..ea4af183 100644 --- a/py/src/braintrust/wrappers/test_anthropic.py +++ b/py/src/braintrust/wrappers/test_anthropic.py @@ -58,22 +58,6 @@ def test_anthropic_integration_setup_is_idempotent(self): assert result.returncode == 0, f"Failed: {result.stderr}" assert "SUCCESS" in result.stdout - def test_anthropic_integration_setup_can_disable_specific_patchers(self): - result = run_in_subprocess(""" - from braintrust.integrations.anthropic import AnthropicIntegration - import anthropic - - AnthropicIntegration.setup(disabled_patchers={"anthropic.init.async"}) - patched_sync = anthropic.Anthropic(api_key="test-key") - unpatched_async = anthropic.AsyncAnthropic(api_key="test-key") - - assert type(patched_sync.messages).__module__ == "braintrust.integrations.anthropic.tracing" - assert type(unpatched_async.messages).__module__.startswith("anthropic.") - print("SUCCESS") - """) - assert result.returncode == 0, f"Failed: {result.stderr}" - assert "SUCCESS" in result.stdout - class TestAutoInstrumentAnthropic: """Tests for auto_instrument() with Anthropic.""" @@ -81,22 +65,3 @@ class TestAutoInstrumentAnthropic: def test_auto_instrument_anthropic(self): """Test auto_instrument patches Anthropic, creates spans, and uninstrument works.""" verify_autoinstrument_script("test_auto_anthropic.py") - - def test_auto_instrument_anthropic_patch_config(self): - verify_autoinstrument_script("test_auto_anthropic_patch_config.py") - - def test_auto_instrument_rejects_non_bool_option_for_openai(self): - result = run_in_subprocess(""" - from braintrust.auto import auto_instrument - from braintrust.integrations import IntegrationPatchConfig - - try: - auto_instrument(openai=IntegrationPatchConfig()) - except TypeError as exc: - assert "must be a bool" in str(exc) - print("SUCCESS") - else: - raise AssertionError("Expected TypeError") - """) - assert result.returncode == 0, f"Failed: {result.stderr}" - assert "SUCCESS" in result.stdout From d4b4d65372dfd26b4337905982be31e81be7c088 Mon Sep 17 00:00:00 2001 From: Abhijeet Prasad Date: Tue, 24 Mar 2026 10:18:17 -0700 Subject: [PATCH 4/4] fix more merge conflicts --- .../cassettes/test_setup_creates_spans.yaml | 110 ++++++++++++++++++ .../integrations/anthropic/test_anthropic.py | 79 +++---------- .../integrations/anthropic/tracing.py | 18 ++- 3 files changed, 144 insertions(+), 63 deletions(-) create mode 100644 py/src/braintrust/integrations/anthropic/cassettes/test_setup_creates_spans.yaml diff --git a/py/src/braintrust/integrations/anthropic/cassettes/test_setup_creates_spans.yaml b/py/src/braintrust/integrations/anthropic/cassettes/test_setup_creates_spans.yaml new file mode 100644 index 00000000..ed30637c --- /dev/null +++ b/py/src/braintrust/integrations/anthropic/cassettes/test_setup_creates_spans.yaml @@ -0,0 +1,110 @@ +interactions: +- request: + body: '{"max_tokens":100,"messages":[{"role":"user","content":"hi"}],"model":"claude-3-haiku-20240307"}' + headers: + Accept: + - application/json + Accept-Encoding: + - gzip, deflate + Connection: + - keep-alive + Content-Length: + - '96' + Content-Type: + - application/json + Host: + - api.anthropic.com + User-Agent: + - Anthropic/Python 0.86.0 + X-Stainless-Arch: + - arm64 + X-Stainless-Async: + - 'false' + X-Stainless-Lang: + - python + X-Stainless-OS: + - MacOS + X-Stainless-Package-Version: + - 0.86.0 + X-Stainless-Runtime: + - CPython + X-Stainless-Runtime-Version: + - 3.13.3 + anthropic-version: + - '2023-06-01' + x-stainless-read-timeout: + - '600' + x-stainless-retry-count: + - '0' + x-stainless-timeout: + - '600' + method: POST + uri: https://api.anthropic.com/v1/messages + response: + body: + string: '{"model":"claude-3-haiku-20240307","id":"msg_0117iU2tMYP3e6LP1NXdHcpt","type":"message","role":"assistant","content":[{"type":"text","text":"Hello! + How can I assist you today?"}],"stop_reason":"end_turn","stop_sequence":null,"usage":{"input_tokens":8,"cache_creation_input_tokens":0,"cache_read_input_tokens":0,"cache_creation":{"ephemeral_5m_input_tokens":0,"ephemeral_1h_input_tokens":0},"output_tokens":12,"service_tier":"standard","inference_geo":"not_available"}}' + headers: + CF-RAY: + - 9e17520e5d2226b0-SJC + Connection: + - keep-alive + Content-Security-Policy: + - default-src 'none'; frame-ancestors 'none' + Content-Type: + - application/json + Date: + - Tue, 24 Mar 2026 17:15:54 GMT + Server: + - cloudflare + Transfer-Encoding: + - chunked + X-Robots-Tag: + - none + anthropic-organization-id: + - 27796668-7351-40ac-acc4-024aee8995a5 + anthropic-ratelimit-input-tokens-limit: + - '8000000' + anthropic-ratelimit-input-tokens-remaining: + - '8000000' + anthropic-ratelimit-input-tokens-reset: + - '2026-03-24T17:15:54Z' + anthropic-ratelimit-output-tokens-limit: + - '1500000' + anthropic-ratelimit-output-tokens-remaining: + - '1500000' + anthropic-ratelimit-output-tokens-reset: + - '2026-03-24T17:15:54Z' + anthropic-ratelimit-requests-limit: + - '10000' + anthropic-ratelimit-requests-remaining: + - '9999' + anthropic-ratelimit-requests-reset: + - '2026-03-24T17:15:54Z' + anthropic-ratelimit-tokens-limit: + - '9500000' + anthropic-ratelimit-tokens-remaining: + - '9500000' + anthropic-ratelimit-tokens-reset: + - '2026-03-24T17:15:54Z' + cf-cache-status: + - DYNAMIC + content-length: + - '468' + request-id: + - req_011CZNHNsVf8oVrJ25Aaqocy + server-timing: + - x-originResponse;dur=379 + set-cookie: + - _cfuvid=wNnIJEFrdvZ9PvwuqaVIcC.f38mcpIT580b2SVCmGeg-1774372553.9806187-1.0.1.1-xSLK0QdT6_irKBkykjkQKy4uFCjzt9RoIb5SrFAfiJk; + HttpOnly; SameSite=None; Secure; Path=/; Domain=api.anthropic.com + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + vary: + - Accept-Encoding + x-envoy-upstream-service-time: + - '378' + status: + code: 200 + message: OK +version: 1 diff --git a/py/src/braintrust/integrations/anthropic/test_anthropic.py b/py/src/braintrust/integrations/anthropic/test_anthropic.py index d698cecb..bcbee1bd 100644 --- a/py/src/braintrust/integrations/anthropic/test_anthropic.py +++ b/py/src/braintrust/integrations/anthropic/test_anthropic.py @@ -2,7 +2,6 @@ Tests to ensure we reliably wrap the Anthropic API. """ -import inspect import time import unittest.mock from pathlib import Path @@ -11,7 +10,6 @@ import pytest from braintrust import logger from braintrust.integrations.anthropic import AnthropicIntegration, wrap_anthropic -from braintrust.integrations.versioning import make_specifier, version_satisfies from braintrust.test_helpers import init_test_logger @@ -494,66 +492,23 @@ async def test_anthropic_beta_messages_streaming_async(memory_logger): assert metrics["tokens"] == usage.input_tokens + usage.output_tokens -class TestAnthropicIntegrationSetup: - """Tests for `AnthropicIntegration.setup()`.""" - - def test_available_patchers(self): - assert AnthropicIntegration.available_patchers() == ( - "anthropic.init.sync", - "anthropic.init.async", - ) - - def test_setup_rejects_unsupported_versions(self): - spec = make_specifier( - min_version=AnthropicIntegration.min_version, max_version=AnthropicIntegration.max_version - ) - assert version_satisfies("0.47.9", spec) is False - - def test_setup_wraps_supported_clients(self): - """`AnthropicIntegration.setup()` should wrap both sync and async client constructors.""" - unpatched_sync = anthropic.Anthropic(api_key="test-key") - unpatched_async = anthropic.AsyncAnthropic(api_key="test-key") - assert type(unpatched_sync.messages).__module__.startswith("anthropic.") - assert type(unpatched_async.messages).__module__.startswith("anthropic.") - - AnthropicIntegration.setup() - patched_sync = anthropic.Anthropic(api_key="test-key") - patched_async = anthropic.AsyncAnthropic(api_key="test-key") - assert type(patched_sync.messages).__module__ == "braintrust.integrations.anthropic.tracing" - assert type(patched_async.messages).__module__ == "braintrust.integrations.anthropic.tracing" - - def test_setup_is_idempotent(self): - """Multiple `AnthropicIntegration.setup()` calls should be safe.""" - AnthropicIntegration.setup() - first_sync_init = inspect.getattr_static(anthropic.Anthropic, "__init__") - first_async_init = inspect.getattr_static(anthropic.AsyncAnthropic, "__init__") - - AnthropicIntegration.setup() - assert first_sync_init is inspect.getattr_static(anthropic.Anthropic, "__init__") - assert first_async_init is inspect.getattr_static(anthropic.AsyncAnthropic, "__init__") - - def test_setup_creates_spans(self): - """`AnthropicIntegration.setup()` should create spans when making API calls.""" - init_test_logger("test-auto") - with logger._internal_with_memory_background_logger() as memory_logger: - AnthropicIntegration.setup() - - client = anthropic.Anthropic() - - import braintrust - - with braintrust.start_span(name="test"): - try: - client.messages.create( - model="claude-3-5-haiku-latest", - max_tokens=100, - messages=[{"role": "user", "content": "hi"}], - ) - except Exception: - pass - - spans = memory_logger.pop() - assert len(spans) >= 1, f"Expected spans, got {spans}" +@pytest.mark.vcr +def test_setup_creates_spans(memory_logger): + """`AnthropicIntegration.setup()` should create spans when making API calls.""" + AnthropicIntegration.setup() + + client = anthropic.Anthropic() + client.messages.create( + model=MODEL, + max_tokens=100, + messages=[{"role": "user", "content": "hi"}], + ) + + spans = memory_logger.pop() + assert len(spans) == 1 + span = spans[0] + assert span["metadata"]["model"] == MODEL + assert span["metadata"]["provider"] == "anthropic" def _make_batch_requests(): diff --git a/py/src/braintrust/integrations/anthropic/tracing.py b/py/src/braintrust/integrations/anthropic/tracing.py index 8f0e09b7..9f5e737e 100644 --- a/py/src/braintrust/integrations/anthropic/tracing.py +++ b/py/src/braintrust/integrations/anthropic/tracing.py @@ -467,8 +467,18 @@ def _catch_exceptions(): log.warning("swallowing exception in tracing code", exc_info=e) +_BRAINTRUST_TRACED = "__braintrust_traced__" + + def _wrap_anthropic(client): - """Wrap an Anthropic object (or AsyncAnthropic) to add tracing.""" + """Wrap an Anthropic object (or AsyncAnthropic) to add tracing. + + If the client is already traced (e.g. via ``AnthropicIntegration.setup()``), + it is returned unchanged to avoid double-wrapping. + """ + if getattr(client, _BRAINTRUST_TRACED, False): + return client + type_name = getattr(type(client), "__name__") if "AsyncAnthropic" in type_name: return TracedAsyncAnthropic(client) @@ -482,17 +492,23 @@ def _wrap_anthropic(client): def _apply_anthropic_wrapper(client): + if getattr(client, _BRAINTRUST_TRACED, False): + return wrapped = _wrap_anthropic(client) client.messages = wrapped.messages if hasattr(wrapped, "beta"): client.beta = wrapped.beta + setattr(client, _BRAINTRUST_TRACED, True) def _apply_async_anthropic_wrapper(client): + if getattr(client, _BRAINTRUST_TRACED, False): + return wrapped = _wrap_anthropic(client) client.messages = wrapped.messages if hasattr(wrapped, "beta"): client.beta = wrapped.beta + setattr(client, _BRAINTRUST_TRACED, True) def _anthropic_init_wrapper(wrapped, instance, args, kwargs):