Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 30 additions & 5 deletions src/cohere/client_v2.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,32 @@
from .client import Client, AsyncClient
from .v2.client import V2Client, AsyncV2Client
import typing
from .environment import ClientEnvironment
import os
import httpx
import typing
from concurrent.futures import ThreadPoolExecutor

import httpx
from .client import AsyncClient, Client
from .environment import ClientEnvironment
from .v2.client import AsyncRawV2Client, AsyncV2Client, RawV2Client, V2Client


class _CombinedRawClient:
"""Proxy that combines v1 and v2 raw clients.
V2Client and Client both assign to self._raw_client in __init__,
causing a collision when combined in ClientV2/AsyncClientV2.
This proxy delegates to v2 first, falling back to v1 for
legacy methods like generate_stream.
"""

def __init__(self, v1_raw_client: typing.Any, v2_raw_client: typing.Any):
self._v1 = v1_raw_client
self._v2 = v2_raw_client

def __getattr__(self, name: str) -> typing.Any:
try:
return getattr(self._v2, name)
except AttributeError:
return getattr(self._v1, name)


class ClientV2(V2Client, Client): # type: ignore
def __init__(
Expand All @@ -32,10 +53,12 @@ def __init__(
thread_pool_executor=thread_pool_executor,
log_warning_experimental_features=log_warning_experimental_features,
)
v1_raw = self._raw_client
V2Client.__init__(
self,
client_wrapper=self._client_wrapper
)
self._raw_client = typing.cast(RawV2Client, _CombinedRawClient(v1_raw, self._raw_client))


class AsyncClientV2(AsyncV2Client, AsyncClient): # type: ignore
Expand Down Expand Up @@ -63,7 +86,9 @@ def __init__(
thread_pool_executor=thread_pool_executor,
log_warning_experimental_features=log_warning_experimental_features,
)
v1_raw = self._raw_client
AsyncV2Client.__init__(
self,
client_wrapper=self._client_wrapper
)
self._raw_client = typing.cast(AsyncRawV2Client, _CombinedRawClient(v1_raw, self._raw_client))
6 changes: 6 additions & 0 deletions tests/test_client_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,12 @@ def test_chat_stream(self) -> None:
self.assertTrue("content-delta" in events)
self.assertTrue("content-end" in events)
self.assertTrue("message-end" in events)

def test_legacy_methods_available(self) -> None:
self.assertTrue(hasattr(co, "generate"))
self.assertTrue(callable(getattr(co, "generate")))
self.assertTrue(hasattr(co, "generate_stream"))
self.assertTrue(callable(getattr(co, "generate_stream")))

@unittest.skip("Skip v2 test for now")
def test_chat_documents(self) -> None:
Expand Down