diff --git a/src/cohere/client_v2.py b/src/cohere/client_v2.py index 9f4f7f22c..fa28ea3a1 100644 --- a/src/cohere/client_v2.py +++ b/src/cohere/client_v2.py @@ -1,11 +1,32 @@ -from .client import Client, AsyncClient -from .v2.client import V2Client, AsyncV2Client -import typing -from .environment import ClientEnvironment import os -import httpx +import typing from concurrent.futures import ThreadPoolExecutor +import httpx +from .client import AsyncClient, Client +from .environment import ClientEnvironment +from .v2.client import AsyncRawV2Client, AsyncV2Client, RawV2Client, V2Client + + +class _CombinedRawClient: + """Proxy that combines v1 and v2 raw clients. + + V2Client and Client both assign to self._raw_client in __init__, + causing a collision when combined in ClientV2/AsyncClientV2. + This proxy delegates to v2 first, falling back to v1 for + legacy methods like generate_stream. + """ + + def __init__(self, v1_raw_client: typing.Any, v2_raw_client: typing.Any): + self._v1 = v1_raw_client + self._v2 = v2_raw_client + + def __getattr__(self, name: str) -> typing.Any: + try: + return getattr(self._v2, name) + except AttributeError: + return getattr(self._v1, name) + class ClientV2(V2Client, Client): # type: ignore def __init__( @@ -32,10 +53,12 @@ def __init__( thread_pool_executor=thread_pool_executor, log_warning_experimental_features=log_warning_experimental_features, ) + v1_raw = self._raw_client V2Client.__init__( self, client_wrapper=self._client_wrapper ) + self._raw_client = typing.cast(RawV2Client, _CombinedRawClient(v1_raw, self._raw_client)) class AsyncClientV2(AsyncV2Client, AsyncClient): # type: ignore @@ -63,7 +86,9 @@ def __init__( thread_pool_executor=thread_pool_executor, log_warning_experimental_features=log_warning_experimental_features, ) + v1_raw = self._raw_client AsyncV2Client.__init__( self, client_wrapper=self._client_wrapper ) + self._raw_client = typing.cast(AsyncRawV2Client, _CombinedRawClient(v1_raw, self._raw_client)) diff --git a/tests/test_client_v2.py b/tests/test_client_v2.py index e670b1360..fadd0a5eb 100644 --- a/tests/test_client_v2.py +++ b/tests/test_client_v2.py @@ -36,6 +36,12 @@ def test_chat_stream(self) -> None: self.assertTrue("content-delta" in events) self.assertTrue("content-end" in events) self.assertTrue("message-end" in events) + + def test_legacy_methods_available(self) -> None: + self.assertTrue(hasattr(co, "generate")) + self.assertTrue(callable(getattr(co, "generate"))) + self.assertTrue(hasattr(co, "generate_stream")) + self.assertTrue(callable(getattr(co, "generate_stream"))) @unittest.skip("Skip v2 test for now") def test_chat_documents(self) -> None: