Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 41 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
- [Authentication](#authentication)
- [API Keys](#api-keys)
- [Bearer Token](#bearer-token)
- [Lifecycle Management](#lifecycle-management)
- [Deployments](#deployments)
- [List Deployments](#list-deployments)
- [Get Deployment by Id](#get-deployment-by-id)
Expand Down Expand Up @@ -121,6 +122,45 @@ async_client = AsyncDial(
)
```

### Lifecycle Management

For deterministic shutdown of underlying HTTP clients, both client types and
client pools expose lifecycle APIs.

```python
from aidial_client import AsyncDial, AsyncDialClientPool, Dial, DialClientPool

# Sync client
with Dial(api_key="your_api_key", base_url="https://your-dial-instance.com") as client:
...

client = Dial(api_key="your_api_key", base_url="https://your-dial-instance.com")
client.close()

# Async client
async with AsyncDial(
api_key="your_api_key", base_url="https://your-dial-instance.com"
) as async_client:
...

async_client = AsyncDial(
api_key="your_api_key", base_url="https://your-dial-instance.com"
)
await async_client.aclose()

# Sync pool
with DialClientPool() as pool:
pooled_client = pool.create_client(
base_url="https://your-dial-instance.com", api_key="your-api-key"
)

# Async pool
async with AsyncDialClientPool() as async_pool:
pooled_async_client = async_pool.create_client(
base_url="https://your-dial-instance.com", api_key="your-api-key"
)
```

You can also pass `bearer_token` as a function without parameters, that returns a `string`:

```python
Expand Down Expand Up @@ -928,7 +968,7 @@ second_client = client_pool.create_client(
#### Asynchronous Client Pool

```python
from dial_client import (
from aidial_client import (
AsyncDialClientPool,
)

Expand Down
31 changes: 30 additions & 1 deletion aidial_client/_client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from abc import ABC, abstractmethod
from pathlib import PurePosixPath
from typing import Dict, Generic, Optional, TypeVar, Union
from types import TracebackType
from typing import Dict, Generic, Optional, Type, TypeVar, Union
from urllib.parse import urljoin

import openai
Expand Down Expand Up @@ -167,6 +168,20 @@ def my_appdata_home(self) -> Optional[PurePosixPath]:
def auth_headers(self) -> Dict[str, str]:
return self._http_client.auth_headers()

def close(self) -> None:
self._http_client.internal_http_client.close()

def __enter__(self) -> "Dial":
return self

def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_value: Optional[BaseException],
traceback: Optional[TracebackType],
) -> None:
self.close()


class AsyncDial(BaseDialClient[AsyncHTTPClient, AsyncAuthValue]):

Expand Down Expand Up @@ -257,3 +272,17 @@ async def my_appdata_home(self) -> Optional[PurePosixPath]:

async def auth_headers(self) -> Dict[str, str]:
return await self._http_client.auth_headers()

async def aclose(self) -> None:
await self._http_client.internal_http_client.aclose()

async def __aenter__(self) -> "AsyncDial":
return self

async def __aexit__(
self,
exc_type: Optional[Type[BaseException]],
exc_value: Optional[BaseException],
traceback: Optional[TracebackType],
) -> None:
await self.aclose()
31 changes: 30 additions & 1 deletion aidial_client/_client_pool.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Optional, Union
from types import TracebackType
from typing import Optional, Type, Union

import httpx

Expand Down Expand Up @@ -46,6 +47,20 @@ def create_client(
),
)

def close(self) -> None:
self._internal_http_client.close()

def __enter__(self) -> "DialClientPool":
return self

def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_value: Optional[BaseException],
traceback: Optional[TracebackType],
) -> None:
self.close()


class AsyncDialClientPool:
def __init__(
Expand Down Expand Up @@ -80,3 +95,17 @@ def create_client(
internal_http_client=self._internal_http_client,
),
)

async def aclose(self) -> None:
await self._internal_http_client.aclose()

async def __aenter__(self) -> "AsyncDialClientPool":
return self

async def __aexit__(
self,
exc_type: Optional[Type[BaseException]],
exc_value: Optional[BaseException],
traceback: Optional[TracebackType],
) -> None:
await self.aclose()
94 changes: 94 additions & 0 deletions tests/test_lifecycle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
from unittest.mock import patch

import pytest

from aidial_client import AsyncDialClientPool, Dial, DialClientPool
from aidial_client._client import AsyncDial


def test_dial_close():
client = Dial(api_key="dummy", base_url="http://dial.core")

with patch.object(
client._http_client.internal_http_client, "close"
) as close_mock:
client.close()

close_mock.assert_called_once()


def test_dial_context_manager():
client = Dial(api_key="dummy", base_url="http://dial.core")

with patch.object(
client._http_client.internal_http_client, "close"
) as close_mock:
with client as managed_client:
assert managed_client is client

close_mock.assert_called_once()


@pytest.mark.asyncio
async def test_async_dial_aclose():
client = AsyncDial(api_key="dummy", base_url="http://dial.core")

with patch.object(
client._http_client.internal_http_client, "aclose"
) as aclose_mock:
await client.aclose()

aclose_mock.assert_awaited_once()


@pytest.mark.asyncio
async def test_async_dial_context_manager():
client = AsyncDial(api_key="dummy", base_url="http://dial.core")

with patch.object(
client._http_client.internal_http_client, "aclose"
) as aclose_mock:
async with client as managed_client:
assert managed_client is client

aclose_mock.assert_awaited_once()


def test_dial_client_pool_close():
pool = DialClientPool()

with patch.object(pool._internal_http_client, "close") as close_mock:
pool.close()

close_mock.assert_called_once()


def test_dial_client_pool_context_manager():
pool = DialClientPool()

with patch.object(pool._internal_http_client, "close") as close_mock:
with pool as managed_pool:
assert managed_pool is pool

close_mock.assert_called_once()


@pytest.mark.asyncio
async def test_async_dial_client_pool_aclose():
pool = AsyncDialClientPool()

with patch.object(pool._internal_http_client, "aclose") as aclose_mock:
await pool.aclose()

aclose_mock.assert_awaited_once()


@pytest.mark.asyncio
async def test_async_dial_client_pool_context_manager():
pool = AsyncDialClientPool()

with patch.object(pool._internal_http_client, "aclose") as aclose_mock:
async with pool as managed_pool:
assert managed_pool is pool

aclose_mock.assert_awaited_once()