From 59f93b8fb8850a43f914e5d2786ecb989bd42174 Mon Sep 17 00:00:00 2001 From: Danylo_Kriachkov Date: Fri, 12 Jun 2026 13:22:26 +0300 Subject: [PATCH] feat: add lifecycle management support for synchronous and asynchronous clients --- README.md | 42 +++++++++++++++- aidial_client/_client.py | 31 +++++++++++- aidial_client/_client_pool.py | 31 +++++++++++- tests/test_lifecycle.py | 94 +++++++++++++++++++++++++++++++++++ 4 files changed, 195 insertions(+), 3 deletions(-) create mode 100644 tests/test_lifecycle.py diff --git a/README.md b/README.md index 98ffb61..d172b99 100644 --- a/README.md +++ b/README.md @@ -17,6 +17,7 @@ - [Authentication](#authentication) - [API Keys](#api-keys) - [Bearer Token](#bearer-token) + - [Lifecycle Management](#lifecycle-management) - [Deployments](#deployments) - [List Deployments](#list-deployments) - [Get Deployment by Id](#get-deployment-by-id) @@ -121,6 +122,45 @@ async_client = AsyncDial( ) ``` +### Lifecycle Management + +For deterministic shutdown of underlying HTTP clients, both client types and +client pools expose lifecycle APIs. + +```python +from aidial_client import AsyncDial, AsyncDialClientPool, Dial, DialClientPool + +# Sync client +with Dial(api_key="your_api_key", base_url="https://your-dial-instance.com") as client: + ... + +client = Dial(api_key="your_api_key", base_url="https://your-dial-instance.com") +client.close() + +# Async client +async with AsyncDial( + api_key="your_api_key", base_url="https://your-dial-instance.com" +) as async_client: + ... + +async_client = AsyncDial( + api_key="your_api_key", base_url="https://your-dial-instance.com" +) +await async_client.aclose() + +# Sync pool +with DialClientPool() as pool: + pooled_client = pool.create_client( + base_url="https://your-dial-instance.com", api_key="your-api-key" + ) + +# Async pool +async with AsyncDialClientPool() as async_pool: + pooled_async_client = async_pool.create_client( + base_url="https://your-dial-instance.com", api_key="your-api-key" + ) +``` + You can also pass `bearer_token` as a function without parameters, that returns a `string`: ```python @@ -928,7 +968,7 @@ second_client = client_pool.create_client( #### Asynchronous Client Pool ```python -from dial_client import ( +from aidial_client import ( AsyncDialClientPool, ) diff --git a/aidial_client/_client.py b/aidial_client/_client.py index 5c4907d..9617f79 100644 --- a/aidial_client/_client.py +++ b/aidial_client/_client.py @@ -1,6 +1,7 @@ from abc import ABC, abstractmethod from pathlib import PurePosixPath -from typing import Dict, Generic, Optional, TypeVar, Union +from types import TracebackType +from typing import Dict, Generic, Optional, Type, TypeVar, Union from urllib.parse import urljoin import openai @@ -167,6 +168,20 @@ def my_appdata_home(self) -> Optional[PurePosixPath]: def auth_headers(self) -> Dict[str, str]: return self._http_client.auth_headers() + def close(self) -> None: + self._http_client.internal_http_client.close() + + def __enter__(self) -> "Dial": + return self + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_value: Optional[BaseException], + traceback: Optional[TracebackType], + ) -> None: + self.close() + class AsyncDial(BaseDialClient[AsyncHTTPClient, AsyncAuthValue]): @@ -257,3 +272,17 @@ async def my_appdata_home(self) -> Optional[PurePosixPath]: async def auth_headers(self) -> Dict[str, str]: return await self._http_client.auth_headers() + + async def aclose(self) -> None: + await self._http_client.internal_http_client.aclose() + + async def __aenter__(self) -> "AsyncDial": + return self + + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]], + exc_value: Optional[BaseException], + traceback: Optional[TracebackType], + ) -> None: + await self.aclose() diff --git a/aidial_client/_client_pool.py b/aidial_client/_client_pool.py index c2b9837..6420d92 100644 --- a/aidial_client/_client_pool.py +++ b/aidial_client/_client_pool.py @@ -1,4 +1,5 @@ -from typing import Optional, Union +from types import TracebackType +from typing import Optional, Type, Union import httpx @@ -46,6 +47,20 @@ def create_client( ), ) + def close(self) -> None: + self._internal_http_client.close() + + def __enter__(self) -> "DialClientPool": + return self + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_value: Optional[BaseException], + traceback: Optional[TracebackType], + ) -> None: + self.close() + class AsyncDialClientPool: def __init__( @@ -80,3 +95,17 @@ def create_client( internal_http_client=self._internal_http_client, ), ) + + async def aclose(self) -> None: + await self._internal_http_client.aclose() + + async def __aenter__(self) -> "AsyncDialClientPool": + return self + + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]], + exc_value: Optional[BaseException], + traceback: Optional[TracebackType], + ) -> None: + await self.aclose() diff --git a/tests/test_lifecycle.py b/tests/test_lifecycle.py new file mode 100644 index 0000000..059ac42 --- /dev/null +++ b/tests/test_lifecycle.py @@ -0,0 +1,94 @@ +from unittest.mock import patch + +import pytest + +from aidial_client import AsyncDialClientPool, Dial, DialClientPool +from aidial_client._client import AsyncDial + + +def test_dial_close(): + client = Dial(api_key="dummy", base_url="http://dial.core") + + with patch.object( + client._http_client.internal_http_client, "close" + ) as close_mock: + client.close() + + close_mock.assert_called_once() + + +def test_dial_context_manager(): + client = Dial(api_key="dummy", base_url="http://dial.core") + + with patch.object( + client._http_client.internal_http_client, "close" + ) as close_mock: + with client as managed_client: + assert managed_client is client + + close_mock.assert_called_once() + + +@pytest.mark.asyncio +async def test_async_dial_aclose(): + client = AsyncDial(api_key="dummy", base_url="http://dial.core") + + with patch.object( + client._http_client.internal_http_client, "aclose" + ) as aclose_mock: + await client.aclose() + + aclose_mock.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_async_dial_context_manager(): + client = AsyncDial(api_key="dummy", base_url="http://dial.core") + + with patch.object( + client._http_client.internal_http_client, "aclose" + ) as aclose_mock: + async with client as managed_client: + assert managed_client is client + + aclose_mock.assert_awaited_once() + + +def test_dial_client_pool_close(): + pool = DialClientPool() + + with patch.object(pool._internal_http_client, "close") as close_mock: + pool.close() + + close_mock.assert_called_once() + + +def test_dial_client_pool_context_manager(): + pool = DialClientPool() + + with patch.object(pool._internal_http_client, "close") as close_mock: + with pool as managed_pool: + assert managed_pool is pool + + close_mock.assert_called_once() + + +@pytest.mark.asyncio +async def test_async_dial_client_pool_aclose(): + pool = AsyncDialClientPool() + + with patch.object(pool._internal_http_client, "aclose") as aclose_mock: + await pool.aclose() + + aclose_mock.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_async_dial_client_pool_context_manager(): + pool = AsyncDialClientPool() + + with patch.object(pool._internal_http_client, "aclose") as aclose_mock: + async with pool as managed_pool: + assert managed_pool is pool + + aclose_mock.assert_awaited_once()