diff --git a/README.md b/README.md index 98ffb61..5fdd5ad 100644 --- a/README.md +++ b/README.md @@ -492,6 +492,16 @@ result = await async_client.files.download( ) ``` +For large async downloads, use `stream_download()` to process bytes as they arrive without buffering the full response in memory: + +```python +async with async_client.files.stream_download( + url=await async_client.my_files_home() / "relative_folder/my-file.txt" +) as result: + async for bytes_chunk in result: + ... +``` + As a result, you will receive an object of type `FileDownloadResponse`, that you can iterate by byte chunks: ```python diff --git a/aidial_client/_http_client/_async.py b/aidial_client/_http_client/_async.py index 5dcf37c..f82c82f 100644 --- a/aidial_client/_http_client/_async.py +++ b/aidial_client/_http_client/_async.py @@ -120,6 +120,42 @@ async def request( return process_block_response(cast_to=cast_to, response=response) + @asynccontextmanager + async def stream( + self, + *, + options: FinalRequestOptions, + on_http_error: Optional[ + Callable[[httpx.HTTPStatusError], Optional[DialException]] + ] = None, + ) -> AsyncIterator[httpx.Response]: + auth_headers = await self.auth_headers() + request = self._build_request(options, auth_headers) + try: + response = await self._internal_http_client.send( + request, stream=True + ) + except httpx.TimeoutException as err: + raise DialException( + message="Request timed out", + status_code=HTTPStatus.REQUEST_TIMEOUT, + ) from err + except httpx.HTTPError as err: + raise DialException(message=f"Request failed: {err}") from err + + try: + try: + response.raise_for_status() + except httpx.HTTPStatusError as err: + custom_error = on_http_error(err) if on_http_error else None + raised_error = custom_error or self._make_dial_error_from_response( + err.response + ) + raise raised_error from err + yield response + finally: + await response.aclose() + @asynccontextmanager async def stream_sse( self, diff --git a/aidial_client/resources/files.py b/aidial_client/resources/files.py index 47421aa..befc5ba 100644 --- a/aidial_client/resources/files.py +++ b/aidial_client/resources/files.py @@ -1,5 +1,6 @@ from pathlib import PurePosixPath -from typing import Literal, Optional, Union +from contextlib import asynccontextmanager +from typing import AsyncIterator, Literal, Optional, Union from urllib.parse import urljoin import httpx @@ -209,6 +210,31 @@ async def download( response=response, filename=storage_resource.filename ) + @asynccontextmanager + async def stream_download( + self, + url: Union[str, PurePosixPath], + etag_if_match: Optional[str] = None, + ) -> AsyncIterator[FileDownloadResponse]: + storage_resource = self.get_storage_resource(str(url)) + if storage_resource.filename is None: + raise InvalidDialURLError("URL points to a directory, not a file") + async with self.http_client.stream( + options=FinalRequestOptions( + method="GET", + url=urljoin(API_PREFIX, storage_resource.api_path), + headers=remove_none( + { + "If-Match": etag_if_match, + } + ), + ), + on_http_error=_files_error_processor, + ) as response: + yield FileDownloadResponse( + response=response, filename=storage_resource.filename + ) + async def delete( self, url: Union[str, PurePosixPath], diff --git a/tests/resources/files/test_download.py b/tests/resources/files/test_download.py new file mode 100644 index 0000000..34d1e8d --- /dev/null +++ b/tests/resources/files/test_download.py @@ -0,0 +1,61 @@ +from typing import Any, Dict, List, cast +from unittest.mock import AsyncMock + +import httpx +import pytest + +from aidial_client._client import AsyncDial +from aidial_client._exception import InvalidDialURLError +from tests.client_mock import MockStreamIterator + + +@pytest.mark.asyncio +async def test_stream_download_async_streams_and_closes_response(): + captured_requests: List[httpx.Request] = [] + captured_kwargs: List[Dict[str, Any]] = [] + captured_responses: List[httpx.Response] = [] + client = AsyncDial(api_key="dummy", base_url="http://dial.core") + client._get_my_bucket = cast(Any, AsyncMock(return_value="test-bucket")) + + async def send_mock( + request: httpx.Request, *, stream: bool = False, **kwargs: Any + ) -> httpx.Response: + captured_requests.append(request) + captured_kwargs.append({"stream": stream, **kwargs}) + response = httpx.Response( + status_code=200, + request=request, + stream=MockStreamIterator(mock_chunks=[b"hello ", b"world"]), + ) + captured_responses.append(response) + return response + + client._http_client._internal_http_client.send = cast(Any, send_mock) + + async with client.files.stream_download( + url=await client.my_files_home() / "folder/file.txt" + ) as response: + assert response.filename == "file.txt" + assert b"".join([chunk async for chunk in response]) == b"hello world" + + assert captured_requests[0].url.path == "/v1/files/test-bucket/folder/file.txt" + assert captured_kwargs == [{"stream": True}] + assert captured_responses[0].is_closed is True + + +@pytest.mark.asyncio +async def test_stream_download_async_rejects_directory_url(): + client = AsyncDial(api_key="dummy", base_url="http://dial.core") + client._get_my_bucket = cast(Any, AsyncMock(return_value="test-bucket")) + send_mock = AsyncMock() + client._http_client._internal_http_client.send = cast(Any, send_mock) + + with pytest.raises( + InvalidDialURLError, match="URL points to a directory, not a file" + ): + async with client.files.stream_download( + url="files/test-bucket/folder/" + ): + pass + + send_mock.assert_not_called()