Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -492,6 +492,16 @@ result = await async_client.files.download(
)
```

For large async downloads, use `stream_download()` to process bytes as they arrive without buffering the full response in memory:

```python
async with async_client.files.stream_download(
url=await async_client.my_files_home() / "relative_folder/my-file.txt"
) as result:
async for bytes_chunk in result:
...
```

As a result, you will receive an object of type `FileDownloadResponse`, that you can iterate by byte chunks:

```python
Expand Down
36 changes: 36 additions & 0 deletions aidial_client/_http_client/_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,42 @@ async def request(

return process_block_response(cast_to=cast_to, response=response)

@asynccontextmanager
async def stream(
self,
*,
options: FinalRequestOptions,
on_http_error: Optional[
Callable[[httpx.HTTPStatusError], Optional[DialException]]
] = None,
) -> AsyncIterator[httpx.Response]:
auth_headers = await self.auth_headers()
request = self._build_request(options, auth_headers)
try:
response = await self._internal_http_client.send(
request, stream=True
)
except httpx.TimeoutException as err:
raise DialException(
message="Request timed out",
status_code=HTTPStatus.REQUEST_TIMEOUT,
) from err
except httpx.HTTPError as err:
raise DialException(message=f"Request failed: {err}") from err

try:
try:
response.raise_for_status()
except httpx.HTTPStatusError as err:
custom_error = on_http_error(err) if on_http_error else None
raised_error = custom_error or self._make_dial_error_from_response(
err.response
)
raise raised_error from err
yield response
finally:
await response.aclose()

@asynccontextmanager
async def stream_sse(
self,
Expand Down
28 changes: 27 additions & 1 deletion aidial_client/resources/files.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from pathlib import PurePosixPath
from typing import Literal, Optional, Union
from contextlib import asynccontextmanager
from typing import AsyncIterator, Literal, Optional, Union
from urllib.parse import urljoin

import httpx
Expand Down Expand Up @@ -209,6 +210,31 @@ async def download(
response=response, filename=storage_resource.filename
)

@asynccontextmanager
async def stream_download(
self,
url: Union[str, PurePosixPath],
etag_if_match: Optional[str] = None,
) -> AsyncIterator[FileDownloadResponse]:
storage_resource = self.get_storage_resource(str(url))
if storage_resource.filename is None:
raise InvalidDialURLError("URL points to a directory, not a file")
async with self.http_client.stream(
options=FinalRequestOptions(
method="GET",
url=urljoin(API_PREFIX, storage_resource.api_path),
headers=remove_none(
{
"If-Match": etag_if_match,
}
),
),
on_http_error=_files_error_processor,
) as response:
yield FileDownloadResponse(
response=response, filename=storage_resource.filename
)

async def delete(
self,
url: Union[str, PurePosixPath],
Expand Down
61 changes: 61 additions & 0 deletions tests/resources/files/test_download.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from typing import Any, Dict, List, cast
from unittest.mock import AsyncMock

import httpx
import pytest

from aidial_client._client import AsyncDial
from aidial_client._exception import InvalidDialURLError
from tests.client_mock import MockStreamIterator


@pytest.mark.asyncio
async def test_stream_download_async_streams_and_closes_response():
captured_requests: List[httpx.Request] = []
captured_kwargs: List[Dict[str, Any]] = []
captured_responses: List[httpx.Response] = []
client = AsyncDial(api_key="dummy", base_url="http://dial.core")
client._get_my_bucket = cast(Any, AsyncMock(return_value="test-bucket"))

async def send_mock(
request: httpx.Request, *, stream: bool = False, **kwargs: Any
) -> httpx.Response:
captured_requests.append(request)
captured_kwargs.append({"stream": stream, **kwargs})
response = httpx.Response(
status_code=200,
request=request,
stream=MockStreamIterator(mock_chunks=[b"hello ", b"world"]),
)
captured_responses.append(response)
return response

client._http_client._internal_http_client.send = cast(Any, send_mock)

async with client.files.stream_download(
url=await client.my_files_home() / "folder/file.txt"
) as response:
assert response.filename == "file.txt"
assert b"".join([chunk async for chunk in response]) == b"hello world"

assert captured_requests[0].url.path == "/v1/files/test-bucket/folder/file.txt"
assert captured_kwargs == [{"stream": True}]
assert captured_responses[0].is_closed is True


@pytest.mark.asyncio
async def test_stream_download_async_rejects_directory_url():
client = AsyncDial(api_key="dummy", base_url="http://dial.core")
client._get_my_bucket = cast(Any, AsyncMock(return_value="test-bucket"))
send_mock = AsyncMock()
client._http_client._internal_http_client.send = cast(Any, send_mock)

with pytest.raises(
InvalidDialURLError, match="URL points to a directory, not a file"
):
async with client.files.stream_download(
url="files/test-bucket/folder/"
):
pass

send_mock.assert_not_called()