diff --git a/src/upbeat/__init__.py b/src/upbeat/__init__.py index d6e42d4..09a6baa 100644 --- a/src/upbeat/__init__.py +++ b/src/upbeat/__init__.py @@ -42,6 +42,7 @@ RemainingRequest, UnprocessableEntityError, UpbeatError, + ValidationError, WebSocketClosedError, WebSocketConnectionError, WebSocketError, @@ -137,6 +138,7 @@ "RemainingRequest", "UnprocessableEntityError", "UpbeatError", + "ValidationError", "WebSocketClosedError", "WebSocketConnectionError", "WebSocketError", diff --git a/src/upbeat/_client.py b/src/upbeat/_client.py index 59d7489..6a1d704 100644 --- a/src/upbeat/_client.py +++ b/src/upbeat/_client.py @@ -42,6 +42,7 @@ def __init__( logger: Logger | None = None, http_client: httpx.Client | None = None, event_hooks: dict[str, list[Any]] | None = None, + validate_min_order: bool = False, ) -> None: if (access_key is None) != (secret_key is None): raise ValueError( @@ -59,6 +60,7 @@ def __init__( self._max_retries = max_retries self._auto_throttle = auto_throttle self._logger = logger + self._validate_min_order = validate_min_order self._owns_http_client = http_client is None self._closed = False @@ -84,7 +86,11 @@ def accounts(self) -> AccountsAPI: @cached_property def orders(self) -> OrdersAPI: - return OrdersAPI(self._transport, self._credentials) + return OrdersAPI( + self._transport, + self._credentials, + validate_min_order=self._validate_min_order, + ) @cached_property def deposits(self) -> DepositsAPI: @@ -131,6 +137,7 @@ def with_options( max_retries: int | None = None, auto_throttle: bool | None = None, logger: Logger | None = None, + validate_min_order: bool | None = None, ) -> Upbeat: new = Upbeat.__new__(Upbeat) new._credentials = self._credentials @@ -141,6 +148,11 @@ def with_options( auto_throttle if auto_throttle is not None else self._auto_throttle ) new._logger = logger if logger is not None else self._logger + new._validate_min_order = ( + validate_min_order + if validate_min_order is not None + else self._validate_min_order + ) new._owns_http_client = False new._closed = False @@ -176,6 +188,7 @@ def __init__( logger: Logger | None = None, http_client: httpx.AsyncClient | None = None, event_hooks: dict[str, list[Any]] | None = None, + validate_min_order: bool = False, ) -> None: if (access_key is None) != (secret_key is None): raise ValueError( @@ -193,6 +206,7 @@ def __init__( self._max_retries = max_retries self._auto_throttle = auto_throttle self._logger = logger + self._validate_min_order = validate_min_order self._owns_http_client = http_client is None self._closed = False @@ -218,7 +232,11 @@ def accounts(self) -> AsyncAccountsAPI: @cached_property def orders(self) -> AsyncOrdersAPI: - return AsyncOrdersAPI(self._transport, self._credentials) + return AsyncOrdersAPI( + self._transport, + self._credentials, + validate_min_order=self._validate_min_order, + ) @cached_property def deposits(self) -> AsyncDepositsAPI: @@ -269,6 +287,7 @@ def with_options( max_retries: int | None = None, auto_throttle: bool | None = None, logger: Logger | None = None, + validate_min_order: bool | None = None, ) -> AsyncUpbeat: new = AsyncUpbeat.__new__(AsyncUpbeat) new._credentials = self._credentials @@ -279,6 +298,11 @@ def with_options( auto_throttle if auto_throttle is not None else self._auto_throttle ) new._logger = logger if logger is not None else self._logger + new._validate_min_order = ( + validate_min_order + if validate_min_order is not None + else self._validate_min_order + ) new._owns_http_client = False new._closed = False diff --git a/src/upbeat/_errors.py b/src/upbeat/_errors.py index 75d9591..c85e600 100644 --- a/src/upbeat/_errors.py +++ b/src/upbeat/_errors.py @@ -39,6 +39,30 @@ def __init__(self, message: str) -> None: self.message = message +# ── Validation errors ───────────────────────────────────────────────── + + +class ValidationError(UpbeatError): + """Raised when client-side validation catches an invalid order before sending.""" + + market: str + total: str + min_total: str + + def __init__( + self, + message: str, + *, + market: str, + total: str, + min_total: str, + ) -> None: + super().__init__(message) + self.market = market + self.total = total + self.min_total = min_total + + # ── API errors ─────────────────────────────────────────────────────────── diff --git a/src/upbeat/api/orders.py b/src/upbeat/api/orders.py index a700bfe..77ddd4c 100644 --- a/src/upbeat/api/orders.py +++ b/src/upbeat/api/orders.py @@ -1,8 +1,12 @@ from __future__ import annotations +from decimal import Decimal from typing import Any +from upbeat._auth import Credentials from upbeat._base import _AsyncAPIResource, _SyncAPIResource +from upbeat._errors import ValidationError +from upbeat._http import AsyncTransport, SyncTransport from upbeat.types.order import ( CancelAndNewOrderResponse, CancelResult, @@ -20,7 +24,54 @@ def _filter_params(**kwargs: Any) -> dict[str, Any]: return {k: v for k, v in kwargs.items() if v is not None} +def _compute_bid_total( + price: str | None, volume: str | None, ord_type: str +) -> Decimal | None: + """Return the total KRW value of a bid order, or None if indeterminate.""" + if price is None: + return None + if ord_type == "limit": + return Decimal(price) * Decimal(volume) if volume is not None else None + return Decimal(price) + + class OrdersAPI(_SyncAPIResource): + _validate_min_order: bool + + def __init__( + self, + transport: SyncTransport, + credentials: Credentials | None, + *, + validate_min_order: bool = False, + ) -> None: + super().__init__(transport, credentials) + self._validate_min_order = validate_min_order + + def _check_min_order( + self, + market: str, + side: str, + price: str | None, + volume: str | None, + ord_type: str, + ) -> None: + if not self._validate_min_order or side != "bid": + return + total = _compute_bid_total(price, volume, ord_type) + if total is None: + return + chance = self.get_chance(market=market) + if chance.market.bid is not None: + min_total = Decimal(chance.market.bid.min_total) + if total < min_total: + raise ValidationError( + f"Order total {total} is below minimum {min_total} for {market}", + market=market, + total=str(total), + min_total=chance.market.bid.min_total, + ) + def create( self, *, @@ -33,6 +84,7 @@ def create( time_in_force: str | None = None, smp_type: str | None = None, ) -> OrderCreated: + self._check_min_order(market, side, price, volume, ord_type) json_body = _filter_params( market=market, side=side, @@ -60,6 +112,7 @@ def create_test( time_in_force: str | None = None, smp_type: str | None = None, ) -> OrderCreated: + self._check_min_order(market, side, price, volume, ord_type) json_body = _filter_params( market=market, side=side, @@ -244,6 +297,42 @@ def get_chance(self, *, market: str) -> OrderChance: class AsyncOrdersAPI(_AsyncAPIResource): + _validate_min_order: bool + + def __init__( + self, + transport: AsyncTransport, + credentials: Credentials | None, + *, + validate_min_order: bool = False, + ) -> None: + super().__init__(transport, credentials) + self._validate_min_order = validate_min_order + + async def _check_min_order( + self, + market: str, + side: str, + price: str | None, + volume: str | None, + ord_type: str, + ) -> None: + if not self._validate_min_order or side != "bid": + return + total = _compute_bid_total(price, volume, ord_type) + if total is None: + return + chance = await self.get_chance(market=market) + if chance.market.bid is not None: + min_total = Decimal(chance.market.bid.min_total) + if total < min_total: + raise ValidationError( + f"Order total {total} is below minimum {min_total} for {market}", + market=market, + total=str(total), + min_total=chance.market.bid.min_total, + ) + async def create( self, *, @@ -256,6 +345,7 @@ async def create( time_in_force: str | None = None, smp_type: str | None = None, ) -> OrderCreated: + await self._check_min_order(market, side, price, volume, ord_type) json_body = _filter_params( market=market, side=side, @@ -283,6 +373,7 @@ async def create_test( time_in_force: str | None = None, smp_type: str | None = None, ) -> OrderCreated: + await self._check_min_order(market, side, price, volume, ord_type) json_body = _filter_params( market=market, side=side, diff --git a/tests/api/test_orders.py b/tests/api/test_orders.py index e0cc632..4961f2a 100644 --- a/tests/api/test_orders.py +++ b/tests/api/test_orders.py @@ -8,6 +8,7 @@ from upbeat._auth import Credentials from upbeat._constants import API_BASE_URL +from upbeat._errors import ValidationError from upbeat._http import AsyncTransport, SyncTransport from upbeat.api.orders import AsyncOrdersAPI, OrdersAPI from upbeat.types.order import ( @@ -583,3 +584,119 @@ async def handler(request: httpx.Request) -> httpx.Response: api = AsyncOrdersAPI(transport, CREDENTIALS) result = await api.get_chance(market="KRW-BTC") assert isinstance(result, OrderChance) + + +# ── TestMinOrderValidation ────────────────────────────────────────────── + + +def _multi_handler(request: httpx.Request) -> httpx.Response: + """Handle both /v1/orders/chance and /v1/orders endpoints.""" + if request.url.path == "/v1/orders/chance": + return _json_response(ORDER_CHANCE_DATA) + if request.url.path in ("/v1/orders", "/v1/orders/test"): + return _json_response(ORDER_CREATED_DATA, status_code=201) + return httpx.Response(404) + + +async def _async_multi_handler(request: httpx.Request) -> httpx.Response: + return _multi_handler(request) + + +class TestMinOrderValidation: + def test_validation_disabled_by_default(self) -> None: + called_chance = False + + def handler(request: httpx.Request) -> httpx.Response: + nonlocal called_chance + if request.url.path == "/v1/orders/chance": + called_chance = True + return _multi_handler(request) + + transport = _make_transport(handler) + api = OrdersAPI(transport, CREDENTIALS) + api.create(market="KRW-BTC", side="bid", ord_type="price", price="3000") + assert not called_chance + + def test_validation_raises_for_low_bid_market_order(self) -> None: + transport = _make_transport(_multi_handler) + api = OrdersAPI(transport, CREDENTIALS, validate_min_order=True) + with pytest.raises(ValidationError) as exc_info: + api.create(market="KRW-BTC", side="bid", ord_type="price", price="3000") + assert exc_info.value.market == "KRW-BTC" + assert exc_info.value.total == "3000" + assert exc_info.value.min_total == "5000" + + def test_validation_raises_for_low_bid_limit_order(self) -> None: + transport = _make_transport(_multi_handler) + api = OrdersAPI(transport, CREDENTIALS, validate_min_order=True) + with pytest.raises(ValidationError): + api.create( + market="KRW-BTC", side="bid", ord_type="limit", + price="1000", volume="3", + ) + + def test_validation_passes_for_sufficient_bid(self) -> None: + transport = _make_transport(_multi_handler) + api = OrdersAPI(transport, CREDENTIALS, validate_min_order=True) + result = api.create( + market="KRW-BTC", side="bid", ord_type="price", price="6000" + ) + assert isinstance(result, OrderCreated) + + def test_validation_skips_ask_orders(self) -> None: + called_chance = False + + def handler(request: httpx.Request) -> httpx.Response: + nonlocal called_chance + if request.url.path == "/v1/orders/chance": + called_chance = True + return _multi_handler(request) + + transport = _make_transport(handler) + api = OrdersAPI(transport, CREDENTIALS, validate_min_order=True) + api.create( + market="KRW-BTC", side="ask", ord_type="market", volume="0.001" + ) + assert not called_chance + + def test_validation_skips_when_price_none(self) -> None: + called_chance = False + + def handler(request: httpx.Request) -> httpx.Response: + nonlocal called_chance + if request.url.path == "/v1/orders/chance": + called_chance = True + return _multi_handler(request) + + transport = _make_transport(handler) + api = OrdersAPI(transport, CREDENTIALS, validate_min_order=True) + api.create(market="KRW-BTC", side="bid", ord_type="market", volume="0.001") + assert not called_chance + + def test_validation_on_create_test(self) -> None: + transport = _make_transport(_multi_handler) + api = OrdersAPI(transport, CREDENTIALS, validate_min_order=True) + with pytest.raises(ValidationError): + api.create_test( + market="KRW-BTC", side="bid", ord_type="price", price="3000" + ) + + @pytest.mark.asyncio + async def test_async_validation_raises(self) -> None: + transport = _make_async_transport(_async_multi_handler) + api = AsyncOrdersAPI(transport, CREDENTIALS, validate_min_order=True) + with pytest.raises(ValidationError) as exc_info: + await api.create( + market="KRW-BTC", side="bid", ord_type="price", price="3000" + ) + assert exc_info.value.market == "KRW-BTC" + assert exc_info.value.min_total == "5000" + + @pytest.mark.asyncio + async def test_async_validation_passes(self) -> None: + transport = _make_async_transport(_async_multi_handler) + api = AsyncOrdersAPI(transport, CREDENTIALS, validate_min_order=True) + result = await api.create( + market="KRW-BTC", side="bid", ord_type="price", price="6000" + ) + assert isinstance(result, OrderCreated)