From e73119effa86a62a682e04dfc4bc86720424c3f8 Mon Sep 17 00:00:00 2001 From: Shudipto Trafder Date: Sat, 30 May 2026 15:41:07 +0600 Subject: [PATCH] fix(agent): ensure fallback clients honor Vertex AI selection --- agentflow/core/graph/agent.py | 3 +++ .../core/graph/agent_internal/execution.py | 18 ++++++++++++++---- 2 files changed, 17 insertions(+), 4 deletions(-) diff --git a/agentflow/core/graph/agent.py b/agentflow/core/graph/agent.py index 5e27351..8a6ab09 100644 --- a/agentflow/core/graph/agent.py +++ b/agentflow/core/graph/agent.py @@ -252,6 +252,9 @@ class MyState(AgentState): ).lower() == "true", ) # legacy alias for provider="google" + # Persist so fallback clients (created lazily at call time) honour the + # same Vertex AI selection as the primary client. + self.use_vertex_ai = use_vertex_ai # Call parent constructor super().__init__( model=model, diff --git a/agentflow/core/graph/agent_internal/execution.py b/agentflow/core/graph/agent_internal/execution.py index 33f15ff..b48a751 100644 --- a/agentflow/core/graph/agent_internal/execution.py +++ b/agentflow/core/graph/agent_internal/execution.py @@ -158,6 +158,10 @@ def _extract_response_text(response: Any) -> str: class AgentExecutionMixin: """Execution flow, tool resolution, and provider dispatch helpers.""" + # Set by ``Agent.__init__``; declared here so the mixin can read it when + # lazily building fallback clients (the mixin never assigns it itself). + use_vertex_ai: bool + def _setup_tools(self) -> ToolNode | None: """Normalize the tool_node input and wire internal state. @@ -276,10 +280,10 @@ async def _call_llm_with_retry( # noqa: PLR0912 # Build the ordered attempt list: primary + fallbacks attempts: list[tuple[str, str, Any, str | None]] = [ - (self.model, self.provider, self.client, getattr(self, "base_url", None)), + (self.model, self.provider, self.client, self.base_url), ] for fb_model, fb_provider in fallback_models: - attempts.append((fb_model, fb_provider or self.provider, None, None)) + attempts.append((fb_model, fb_provider or self.provider, None, self.base_url)) last_exc: Exception | None = None @@ -301,14 +305,20 @@ async def _call_llm_with_retry( # noqa: PLR0912 self.model, self.provider, self.client, - getattr(self, "base_url", None), + self.base_url, ) self.model = model self.provider = provider self.base_url = base_url active_client = fallback_client if active_client is None: - active_client = self._create_client(provider, base_url) + # Lazily build the fallback client, honouring the + # agent's Vertex AI selection (only affects google). + active_client = self._create_client( + provider, + base_url, + self.use_vertex_ai, + ) self.client = active_client try: result = await self._call_llm(messages, tools, stream, **kwargs)