From 5e4eabbdac51df3ee0ad7d56fbbcb20b60f4a91f Mon Sep 17 00:00:00 2001 From: hozan23 Date: Wed, 28 Jun 2023 21:57:07 +0300 Subject: [PATCH] WIP: major clean ups for the codebase --- deta/__init__.py | 19 ++++++----- deta/_async/client.py | 44 +++++++++++++------------ deta/base.py | 75 ++++++++++++++++++++++++------------------- deta/drive.py | 57 ++++++++++++++++++++------------ deta/service.py | 60 +++++++++++++++++++++------------- deta/utils.py | 6 ++-- 6 files changed, 154 insertions(+), 107 deletions(-) diff --git a/deta/__init__.py b/deta/__init__.py index 1eec891..84323a9 100644 --- a/deta/__init__.py +++ b/deta/__init__.py @@ -2,6 +2,7 @@ import urllib.error import urllib.request import json +from typing import Union from .base import _Base from .drive import _Drive @@ -9,21 +10,20 @@ try: - from detalib.app import App + from detalib.app import App # pyright: ignore app = App() except Exception: pass try: - from ._async.client import AsyncBase + from ._async.client import AsyncBase # pyright: ignore except ImportError: pass __version__ = "1.1.0" - def Base(name: str): project_key, project_id = _get_project_key_id() return _Base(name, project_key, project_id) @@ -35,19 +35,19 @@ def Drive(name: str): class Deta: - def __init__(self, project_key: str = None, *, project_id: str = None): + def __init__(self, project_key: Union[str, None] = None, *, project_id: Union[str, None] = None): project_key, project_id = _get_project_key_id(project_key, project_id) self.project_key = project_key self.project_id = project_id - def Base(self, name: str, host: str = None): + def Base(self, name: str, host: Union[str, None] = None): return _Base(name, self.project_key, self.project_id, host) - def AsyncBase(self, name: str, host: str = None): + def AsyncBase(self, name: str, host: Union[str, None] = None): from ._async.client import _AsyncBase return _AsyncBase(name, self.project_key, self.project_id, host) - def Drive(self, name: str, host: str = None): + def Drive(self, name: str, host: Union[str, None] = None): return _Drive( name=name, project_key=self.project_key, @@ -73,9 +73,12 @@ def send_email(to, subject, message, charset="UTF-8"): "charset": charset, } + assert api_key + headers = {"X-API-Key": api_key} - req = urllib.request.Request(endpoint, json.dumps(data).encode("utf-8"), headers) + req = urllib.request.Request( + endpoint, json.dumps(data).encode("utf-8"), headers) try: resp = urllib.request.urlopen(req) diff --git a/deta/_async/client.py b/deta/_async/client.py index f183253..8d4fab4 100644 --- a/deta/_async/client.py +++ b/deta/_async/client.py @@ -1,10 +1,10 @@ -import typing - +from typing import Union, List import datetime import os -import aiohttp from urllib.parse import quote +import aiohttp + from deta.utils import _get_project_key_id from deta.base import FetchResponse, Util, insert_ttl, BASE_TTL_ATTTRIBUTE @@ -15,7 +15,7 @@ def AsyncBase(name: str): class _AsyncBase: - def __init__(self, name: str, project_key: str, project_id: str, host: str = None): + def __init__(self, name: str, project_key: str, project_id: str, host: Union[str, None] = None): if not project_key: raise AssertionError("No Base name provided") @@ -56,11 +56,11 @@ async def delete(self, key: str): async def insert( self, - data: typing.Union[dict, list, str, int, bool], - key: str = None, + data: Union[dict, list, str, int, bool], + key: Union[str, None] = None, *, - expire_in: int = None, - expire_at: typing.Union[int, float, datetime.datetime] = None, + expire_in: Union[int, None] = None, + expire_at: Union[int, float, datetime.datetime, None] = None, ): if not isinstance(data, dict): data = {"value": data} @@ -70,7 +70,8 @@ async def insert( if key: data["key"] = key - insert_ttl(data, self.__ttl_attribute, expire_in=expire_in, expire_at=expire_at) + insert_ttl(data, self.__ttl_attribute, + expire_in=expire_in, expire_at=expire_at) async with self._session.post( f"{self._base_url}/items", json={"item": data} ) as resp: @@ -78,11 +79,11 @@ async def insert( async def put( self, - data: typing.Union[dict, list, str, int, bool], - key: str = None, + data: Union[dict, list, str, int, bool], + key: Union[str, None] = None, *, - expire_in: int = None, - expire_at: typing.Union[int, float, datetime.datetime] = None, + expire_in: Union[int, None] = None, + expire_at: Union[int, float, datetime.datetime, None] = None, ): if not isinstance(data, dict): data = {"value": data} @@ -92,7 +93,8 @@ async def put( if key: data["key"] = key - insert_ttl(data, self.__ttl_attribute, expire_in=expire_in, expire_at=expire_at) + insert_ttl(data, self.__ttl_attribute, + expire_in=expire_in, expire_at=expire_at) async with self._session.put( f"{self._base_url}/items", json={"items": [data]} ) as resp: @@ -104,10 +106,10 @@ async def put( async def put_many( self, - items: typing.List[typing.Union[dict, list, str, int, bool]], + items: List[Union[dict, list, str, int, bool]], *, - expire_in: int = None, - expire_at: typing.Union[int, float, datetime.datetime] = None, + expire_in: Union[int, None] = None, + expire_at: Union[int, float, datetime.datetime, None] = None, ): if len(items) > 25: raise AssertionError("We can't put more than 25 items at a time.") @@ -128,10 +130,10 @@ async def put_many( async def fetch( self, - query: typing.Union[dict, list] = None, + query: Union[dict, list, None] = None, *, limit: int = 1000, - last: str = None, + last: Union[str, None] = None, ): payload = {} if query: @@ -152,8 +154,8 @@ async def update( updates: dict, key: str, *, - expire_in: int = None, - expire_at: typing.Union[int, float, datetime.datetime] = None, + expire_in: Union[int, None] = None, + expire_at: Union[int, float, datetime.datetime, None] = None, ): if key == "": raise ValueError("Key is empty") diff --git a/deta/base.py b/deta/base.py index c15c127..9fdd56c 100644 --- a/deta/base.py +++ b/deta/base.py @@ -1,7 +1,6 @@ import os import datetime -from re import I -import typing +from typing import Union, List from urllib.parse import quote from .service import _Service, JSON_MIME @@ -62,18 +61,18 @@ def __init__(self, value): def trim(self): return self.Trim() - def increment(self, value: typing.Union[int, float] = None): + def increment(self, value: Union[int, float, None] = None): return self.Increment(value) - def append(self, value: typing.Union[dict, list, str, int, float, bool]): + def append(self, value: Union[dict, list, str, int, float, bool]): return self.Append(value) - def prepend(self, value: typing.Union[dict, list, str, int, float, bool]): + def prepend(self, value: Union[dict, list, str, int, float, bool]): return self.Prepend(value) class _Base(_Service): - def __init__(self, name: str, project_key: str, project_id: str, host: str = None): + def __init__(self, name: str, project_key: str, project_id: str, host: Union[str, None] = None): assert name, "No Base name provided" host = host or os.getenv("DETA_BASE_HOST") or "database.deta.sh" @@ -110,11 +109,11 @@ def delete(self, key: str): def insert( self, - data: typing.Union[dict, list, str, int, bool], - key: str = None, + data: Union[dict, list, str, int, bool], + key: Union[str, None] = None, *, - expire_in: int = None, - expire_at: typing.Union[int, float, datetime.datetime] = None, + expire_in: Union[int, None] = None, + expire_at: Union[int, float, datetime.datetime, None] = None, ): if not isinstance(data, dict): data = {"value": data} @@ -124,7 +123,8 @@ def insert( if key: data["key"] = key - insert_ttl(data, self.__ttl_attribute, expire_in=expire_in, expire_at=expire_at) + insert_ttl(data, self.__ttl_attribute, + expire_in=expire_in, expire_at=expire_at) code, res = self._request( "/items", "POST", {"item": data}, content_type=JSON_MIME ) @@ -135,11 +135,11 @@ def insert( def put( self, - data: typing.Union[dict, list, str, int, bool], - key: str = None, + data: Union[dict, list, str, int, bool], + key: Union[str, None] = None, *, - expire_in: int = None, - expire_at: typing.Union[int, float, datetime.datetime] = None, + expire_in: Union[int, None] = None, + expire_at: Union[int, float, datetime.datetime, None] = None, ): """store (put) an item in the database. Overrides an item if key already exists. `key` could be provided as function argument or a field in the data dict. @@ -154,18 +154,21 @@ def put( if key: data["key"] = key - insert_ttl(data, self.__ttl_attribute, expire_in=expire_in, expire_at=expire_at) + insert_ttl(data, self.__ttl_attribute, + expire_in=expire_in, expire_at=expire_at) code, res = self._request( "/items", "PUT", {"items": [data]}, content_type=JSON_MIME ) - return res["processed"]["items"][0] if res and code == 207 else None + + if res and code == 207: + return res["processed"]["items"][0] # pyright: ignore def put_many( self, - items: typing.List[typing.Union[dict, list, str, int, bool]], + items: List[Union[dict, list, str, int, bool]], *, - expire_in: int = None, - expire_at: typing.Union[int, float, datetime.datetime] = None, + expire_in: Union[int, None] = None, + expire_at: Union[int, float, datetime.datetime, None] = None, ): assert len(items) <= 25, "We can't put more than 25 items at a time." _items = [] @@ -185,10 +188,10 @@ def put_many( def _fetch( self, - query: typing.Union[dict, list] = None, - buffer: int = None, - last: str = None, - ) -> typing.Optional[typing.Tuple[int, list]]: + query: Union[dict, list, None] = None, + buffer: Union[int, None] = None, + last: Union[str, None] = None, + ): """This is where actual fetch happens.""" payload = { "limit": buffer, @@ -198,33 +201,39 @@ def _fetch( if query: payload["query"] = query if isinstance(query, list) else [query] - code, res = self._request("/query", "POST", payload, content_type=JSON_MIME) - return code, res + _, res = self._request( + "/query", "POST", payload, content_type=JSON_MIME) + + assert res + + return res def fetch( self, - query: typing.Union[dict, list] = None, + query: Union[dict, list, None] = None, *, limit: int = 1000, - last: str = None, + last: Union[str, None] = None, ): """ fetch items from the database. `query` is an optional filter or list of filters. Without filter, it will return the whole db. """ - _, res = self._fetch(query, limit, last) + res = self._fetch(query, limit, last) - paging = res.get("paging") + paging = res.get("paging") # pyright: ignore - return FetchResponse(paging.get("size"), paging.get("last"), res.get("items")) + return FetchResponse(paging.get("size"), + paging.get("last"), + res.get("items")) # pyright: ignore def update( self, updates: dict, key: str, *, - expire_in: int = None, - expire_at: typing.Union[int, float, datetime.datetime] = None, + expire_in: Union[int, None] = None, + expire_at: Union[int, float, datetime.datetime, None] = None, ): """ update an item in the database diff --git a/deta/drive.py b/deta/drive.py index 2d4db38..f50b56b 100644 --- a/deta/drive.py +++ b/deta/drive.py @@ -1,5 +1,5 @@ import os -import typing +from typing import Union, List from io import BufferedIOBase, TextIOBase, RawIOBase, StringIO, BytesIO from urllib.parse import quote_plus @@ -20,7 +20,7 @@ def __init__(self, res: BufferedIOBase): def closed(self): return self.__stream.closed - def read(self, size: int = None): + def read(self, size: Union[int, None] = None): return self.__stream.read(size) def iter_chunks(self, chunk_size: int = 1024): @@ -29,7 +29,7 @@ def iter_chunks(self, chunk_size: int = 1024): if not chunk: break yield chunk - + def iter_lines(self, chunk_size: int = 1024): while True: chunk = self.__stream.readline(chunk_size) @@ -48,14 +48,17 @@ def close(self): class _Drive(_Service): def __init__( self, - name: str = None, - project_key: str = None, - project_id: str = None, - host: str = None, + name: Union[str, None] = None, + project_key: Union[str, None] = None, + project_id: Union[str, None] = None, + host: Union[str, None] = None, ): assert name, "No Drive name provided" host = host or os.getenv("DETA_DRIVE_HOST") or "drive.deta.sh" + assert project_key, "Project key must be provided" + assert project_id, "Project id must be provided" + super().__init__( project_key=project_key, project_id=project_id, @@ -78,10 +81,10 @@ def get(self, name: str): f"/files/download?name={self._quote(name)}", "GET", stream=True ) if res: - return DriveStreamingBody(res) + return DriveStreamingBody(res) # pyright: ignore return None - def delete_many(self, names: typing.List[str]): + def delete_many(self, names: List[str]): """Delete many files from drive in single request. `names` are the names of the files to be deleted. Returns a dict with 'deleted' and 'failed' files. @@ -99,13 +102,18 @@ def delete(self, name: str): Returns the name of the file deleted. """ assert name, "Name not provided or empty" + payload = self.delete_many([name]) - failed = payload.get("failed") + + failed = payload.get("failed") # pyright: ignore + if failed: raise Exception(f"Failed to delete '{name}':{failed[name]}") + return name - def list(self, limit: int = 1000, prefix: str = None, last: str = None): + def list(self, limit: int = 1000, prefix: Union[str, None] = None, + last: Union[str, None] = None): """List file names from drive. `limit` is the limit of number of file names to get, defaults to 1000. `prefix` is the prefix of file names. @@ -122,21 +130,22 @@ def list(self, limit: int = 1000, prefix: str = None, last: str = None): def _start_upload(self, name: str): _, res = self._request(f"/uploads?name={self._quote(name)}", "POST") - return res["upload_id"] + return res["upload_id"] # pyright: ignore def _finish_upload(self, name: str, upload_id: str): self._request(f"/uploads/{upload_id}?name={self._quote(name)}", "PATCH") def _abort_upload(self, name: str, upload_id: str): - self._request(f"/uploads/{upload_id}?name={self._quote(name)}", "DELETE") + self._request( + f"/uploads/{upload_id}?name={self._quote(name)}", "DELETE") def _upload_part( self, name: str, - chunk: bytes, + chunk: Union[bytes, str], upload_id: str, part: int, - content_type: str = None, + content_type: Union[str, None] = None, ): self._request( f"/uploads/{upload_id}/parts?name={self._quote(name)}&part={part}", @@ -146,7 +155,7 @@ def _upload_part( ) def _get_content_stream( - self, data: typing.Union[str, bytes, TextIOBase, BufferedIOBase, RawIOBase] + self, data: Union[str, bytes, TextIOBase, BufferedIOBase, RawIOBase] ): if isinstance(data, str): return StringIO(data) @@ -157,10 +166,11 @@ def _get_content_stream( def put( self, name: str, - data: typing.Union[str, bytes, TextIOBase, BufferedIOBase, RawIOBase] = None, + data: Union[str, bytes, TextIOBase, + BufferedIOBase, RawIOBase, None] = None, *, - path: str = None, - content_type: str = None, + path: Union[str, None] = None, + content_type: Union[str, None] = None, ) -> str: """Put a file in drive. `name` is the name of the file. @@ -175,13 +185,18 @@ def put( # start upload upload_id = self._start_upload(name) - content_stream = open(path, "rb") if path else self._get_content_stream(data) + if path: + content_stream = open(path, "rb") + else: + assert data + content_stream = self._get_content_stream(data) + part = 1 # upload chunks while True: chunk = content_stream.read(UPLOAD_CHUNK_SIZE) - ## eof stop the loop + # eof stop the loop if not chunk: self._finish_upload(name, upload_id) content_stream.close() diff --git a/deta/service.py b/deta/service.py index e1f1cdf..e0a22c6 100644 --- a/deta/service.py +++ b/deta/service.py @@ -3,7 +3,7 @@ import json import socket import struct -import typing +from typing import Union import urllib.error JSON_MIME = "application/json" @@ -24,16 +24,17 @@ def __init__( self.host = host self.timeout = timeout self.keep_alive = keep_alive - self.client = ( - http.client.HTTPSConnection(host, timeout=timeout) if keep_alive else None - ) + self.client = (http.client.HTTPSConnection( + host, timeout=timeout) if keep_alive else None) def _is_socket_closed(self): - if not self.client.sock: + if not self.client or not self.client.sock: return True + fmt = "B" * 7 + "I" * 21 tcp_info = struct.unpack( - fmt, self.client.sock.getsockopt(socket.IPPROTO_TCP, socket.TCP_INFO, 92) + fmt, self.client.sock.getsockopt( + socket.IPPROTO_TCP, socket.TCP_INFO, 92) ) # 8 = CLOSE_WAIT if len(tcp_info) > 0 and tcp_info[0] == 8: @@ -44,16 +45,20 @@ def _request( self, path: str, method: str, - data: typing.Union[str, bytes, dict] = None, - headers: dict = None, - content_type: str = None, + data: Union[str, bytes, dict, None] = None, + headers: Union[dict, None] = None, + content_type: Union[str, None] = None, stream: bool = False, ): + url = self.base_path + path + headers = headers or {} headers["X-Api-Key"] = self.project_key + if content_type: headers["Content-Type"] = content_type + if not self.keep_alive: headers["Connection"] = "close" @@ -74,39 +79,44 @@ def _request( # response res = self._send_request_with_retry(method, url, headers, body) + + assert res + status = res.status if status not in [200, 201, 202, 207]: # need to read the response so subsequent requests can be sent on the client res.read() - if not self.keep_alive: + if not self.keep_alive and self.client: self.client.close() - ## return None if not found + # return None if not found if status == 404: return status, None - raise urllib.error.HTTPError(url, status, res.reason, res.headers, res.fp) + raise urllib.error.HTTPError( + url, status, res.reason, res.headers, res.fp) - ## if stream return the response and client without reading and closing the client + # if stream return the response and client without reading and closing the client if stream: return status, res - ## return json if application/json - payload = ( - json.loads(res.read()) - if JSON_MIME in res.getheader("content-type") - else res.read() - ) + # return json if application/json + res_content_type = res.getheader("content-type") + if res_content_type and JSON_MIME in res_content_type: + payload = json.loads(res.read()) + else: + payload = res.read() - if not self.keep_alive: + if not self.keep_alive and self.client: self.client.close() + return status, payload def _send_request_with_retry( self, method: str, url: str, - headers: dict = None, - body: typing.Union[str, bytes, dict] = None, + headers: Union[dict, None] = None, + body: Union[str, bytes, dict, None] = None, retry=2, # try at least twice to regain a new connection ): reinitializeConnection = False @@ -117,6 +127,11 @@ def _send_request_with_retry( host=self.host, timeout=self.timeout ) + if headers is None: + headers = {} + + assert self.client + self.client.request( method, url, @@ -125,6 +140,7 @@ def _send_request_with_retry( ) res = self.client.getresponse() return res + except http.client.RemoteDisconnected: reinitializeConnection = True retry -= 1 diff --git a/deta/utils.py b/deta/utils.py index f94c598..97a17cb 100644 --- a/deta/utils.py +++ b/deta/utils.py @@ -1,7 +1,9 @@ import os +from typing import Union -def _get_project_key_id(project_key: str = None, project_id: str = None): +def _get_project_key_id(project_key: Union[str, None] = None, + project_id: Union[str, None] = None): project_key = project_key or os.getenv("DETA_PROJECT_KEY", "") if not project_key: @@ -13,4 +15,4 @@ def _get_project_key_id(project_key: str = None, project_id: str = None): if project_id == project_key: raise AssertionError("Bad project key provided") - return project_key, project_id \ No newline at end of file + return project_key, project_id