diff --git a/.github/workflows/e2e-tests.yml b/.github/workflows/e2e-tests.yml index 85b7921..fcc1b92 100644 --- a/.github/workflows/e2e-tests.yml +++ b/.github/workflows/e2e-tests.yml @@ -20,6 +20,7 @@ on: env: CARGO_TERM_COLOR: always VERBOSE: ${{ github.event.inputs.verbose }} + CACHE_LOCAL: "1" # job to run tests in parallel jobs: diff --git a/CHANGELOG.md b/CHANGELOG.md index 019fccf..e957143 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,17 @@ # Changelog +## 1.6.4 /2025-03-25 + +## What's Changed +* Better typing for ScaleObj by @thewhaleking in https://github.com/opentensor/async-substrate-interface/pull/278 +* Faster startup by @thewhaleking in https://github.com/opentensor/async-substrate-interface/pull/277 +* DNS/SSL Caching by @thewhaleking in https://github.com/opentensor/async-substrate-interface/pull/279 +* Added info about signed commits by @thewhaleking in https://github.com/opentensor/async-substrate-interface/pull/280 +* [fix] change legacy (old) runtimeApi params encoding by @camfairchild in https://github.com/opentensor/async-substrate-interface/pull/194 +* Prefer v15 metadata when available by @thewhaleking in https://github.com/opentensor/async-substrate-interface/pull/282 + +**Full Changelog**: https://github.com/opentensor/async-substrate-interface/compare/v1.6.3...v1.6.4 + ## 1.6.3 /2025-02-24 ## What's Changed diff --git a/README.md b/README.md index 6d5ad4d..e9a0aae 100644 --- a/README.md +++ b/README.md @@ -101,6 +101,26 @@ The following environment variables are used within async-substrate-interface Contributions are welcome! Please open an issue or submit a pull request to the `staging` branch. +### Signed Commits +All commits in pull requests must be signed. We require signed commits to verify the authenticity of contributions and ensure code integrity. + +To sign your commits, you must have GPG signing configured in Git: + +```bash +git commit -S -m "your commit message" +``` + +Or configure Git to sign all commits automatically: + +```bash +git config --global commit.gpgsign true +``` + +For instructions on setting up GPG key signing, see [GitHub's documentation on signing commits](https://docs.github.com/en/authentication/managing-commit-signature-verification/signing-commits). + +> **Note:** Pull requests containing unsigned commits will not be merged. + + ## License This project is licensed under the MIT License. See the [LICENSE](LICENSE) file for details. diff --git a/async_substrate_interface/async_substrate.py b/async_substrate_interface/async_substrate.py index 90ac49d..bdaafe7 100644 --- a/async_substrate_interface/async_substrate.py +++ b/async_substrate_interface/async_substrate.py @@ -10,7 +10,9 @@ import os import socket import ssl +import time import warnings +from contextlib import suppress from unittest.mock import AsyncMock from hashlib import blake2b from typing import ( @@ -38,6 +40,7 @@ from websockets.asyncio.client import connect, ClientConnection from websockets.exceptions import ( ConnectionClosed, + InvalidURI, ) from websockets.protocol import State @@ -86,6 +89,7 @@ # env vars dictating the cache size of the cached methods SUBSTRATE_CACHE_METHOD_SIZE = int(os.getenv("SUBSTRATE_CACHE_METHOD_SIZE", "512")) SUBSTRATE_RUNTIME_CACHE_SIZE = int(os.getenv("SUBSTRATE_RUNTIME_CACHE_SIZE", "16")) +SSL_SESSION_TTL = int(os.getenv("SUBSTRATE_SSL_SESSION_TTL", "300")) class AsyncExtrinsicReceipt: @@ -576,6 +580,55 @@ def __getitem__(self, item): return self.records[item] +class _SessionResumingSSLContext(ssl.SSLContext): + """ + An SSL context that saves the last TLS session and attempts to resume it on + reconnection, as long as it is still within its TTL. + + Session resumption avoids a full TLS handshake on reconnect, reducing + latency. The effective TTL is the minimum of ``session_ttl`` and the + server-advertised session timeout. + """ + + def __new__(cls, protocol: int = ssl.PROTOCOL_TLS_CLIENT, **_kwargs): + return ssl.SSLContext.__new__(cls, protocol) + + def __init__( + self, + protocol: int = ssl.PROTOCOL_TLS_CLIENT, + *, + session_ttl: int = SSL_SESSION_TTL, + ): + self._saved_session: Optional[ssl.SSLSession] = None + self._session_established_at: Optional[float] = None + self._session_ttl = session_ttl + + def save_session(self, session: ssl.SSLSession) -> None: + self._saved_session = session + self._session_established_at = time.monotonic() + + def _session_is_valid(self) -> bool: + if self._saved_session is None or self._session_established_at is None: + return False + elapsed = time.monotonic() - self._session_established_at + effective_ttl = min(self._session_ttl, self._saved_session.timeout) + return elapsed < effective_ttl + + def wrap_bio( + self, incoming, outgoing, server_side=False, server_hostname=None, session=None + ): + if not server_side and session is None and self._session_is_valid(): + session = self._saved_session + logger.debug("Attempting TLS session resumption") + return super().wrap_bio( + incoming, + outgoing, + server_side=server_side, + server_hostname=server_hostname, + session=session, + ) + + class Websocket: def __init__( self, @@ -587,6 +640,8 @@ def __init__( _log_raw_websockets: bool = False, retry_timeout: float = 60.0, max_retries: int = 5, + ssl_context: Optional[_SessionResumingSSLContext] = None, + dns_ttl: int = 300, ): """ Websocket manager object. Allows for the use of a single websocket connection by multiple @@ -603,6 +658,10 @@ def __init__( _log_raw_websockets: Whether to log raw websockets in the "raw_websocket" logger retry_timeout: Timeout in seconds to retry websocket connection max_retries: Maximum number of retries following a timeout + ssl_context: Optional session-resuming SSL context for wss:// connections. + When provided, the context's saved TLS session is reused on reconnection + to avoid a full handshake. + dns_ttl: Seconds to cache DNS results. Set to 0 to disable caching. """ # TODO allow setting max concurrent connections and rpc subscriptions per connection self.ws_url = ws_url @@ -626,6 +685,11 @@ def __init__( self._last_activity = asyncio.Event() self._last_activity.set() self._waiting_for_response = 0 + self._ssl_context = ssl_context + if ssl_context is not None and ws_url.startswith("wss://"): + self._options["ssl"] = ssl_context + self._dns_ttl = dns_ttl + self._dns_cache: Optional[tuple[list, float]] = None @property def state(self): @@ -735,6 +799,37 @@ async def _cancel(self): f"{e} encountered while trying to close websocket connection." ) + async def _resolve_host(self) -> tuple: + """ + Resolve the websocket hostname to a (family, type, proto, canonname, sockaddr) tuple, + using a cached result if it is still within ``dns_ttl`` seconds. + + Invalidate the cache by setting ``_dns_cache = None`` before calling. + """ + from urllib.parse import urlparse + + parsed = urlparse(self.ws_url) + if parsed.scheme not in ("ws", "wss"): + raise InvalidURI(self.ws_url, f"Invalid URI scheme: {parsed.scheme!r}") + host = parsed.hostname + port = parsed.port or (443 if parsed.scheme == "wss" else 80) + + now = time.monotonic() + if self._dns_cache is not None and self._dns_ttl > 0: + infos, resolved_at = self._dns_cache + if now - resolved_at < self._dns_ttl: + logger.debug(f"DNS cache hit for {host} (age={now - resolved_at:.0f}s)") + return infos[0] + + logger.debug(f"Resolving DNS for {host}:{port}") + loop = asyncio.get_running_loop() + infos = await loop.getaddrinfo( + host, port, family=socket.AF_UNSPEC, type=socket.SOCK_STREAM + ) + self._dns_cache = (infos, now) + logger.debug(f"DNS resolved {host} -> {infos[0][4][0]}") + return infos[0] + async def connect(self, force=False): if not force: async with self._lock: @@ -770,8 +865,20 @@ async def _connect_internal(self, force): pass logger.debug("Attempting connection") try: + family, type_, proto, _, sockaddr = await self._resolve_host() + tcp_sock = socket.socket(family, type_, proto) + tcp_sock.setblocking(False) + loop = asyncio.get_running_loop() + try: + await asyncio.wait_for( + loop.sock_connect(tcp_sock, sockaddr), timeout=10.0 + ) + except Exception: + tcp_sock.close() + self._dns_cache = None # invalidate on TCP failure + raise connection = await asyncio.wait_for( - connect(self.ws_url, **self._options), timeout=10.0 + connect(self.ws_url, sock=tcp_sock, **self._options), timeout=10.0 ) except socket.gaierror: logger.debug(f"Hostname not known (this is just for testing") @@ -779,6 +886,18 @@ async def _connect_internal(self, force): return await self.connect(force=force) logger.debug("Connection established") self.ws = connection + if self._ssl_context is not None: + try: + ssl_obj = connection.transport.get_extra_info("ssl_object") + if ssl_obj is not None and ssl_obj.session is not None: + self._ssl_context.save_session(ssl_obj.session) + logger.debug( + f"Saved TLS session " + f"(reused={ssl_obj.session_reused}, " + f"timeout={ssl_obj.session.timeout}s)" + ) + except Exception as e: + logger.debug(f"Could not save TLS session: {e}") if self._send_recv_task is None or self._send_recv_task.done(): self._send_recv_task = asyncio.get_running_loop().create_task( self._handler(self.ws) @@ -1145,6 +1264,8 @@ def __init__( _log_raw_websockets: bool = False, ws_shutdown_timer: Optional[float] = 5.0, decode_ss58: bool = False, + _ssl_context: Optional[_SessionResumingSSLContext] = None, + dns_ttl: int = 300, ): """ The asyncio-compatible version of the subtensor interface commands we use in bittensor. It is important to @@ -1165,6 +1286,10 @@ def __init__( _log_raw_websockets: whether to log raw websocket requests during RPC requests ws_shutdown_timer: how long after the last connection your websocket should close decode_ss58: Whether to decode AccountIds to SS58 or leave them in raw bytes tuples. + _ssl_context: optional session-resuming SSL context; used internally by + DiskCachedAsyncSubstrateInterface to enable TLS session reuse. + dns_ttl: seconds to cache DNS results for the websocket URL (default 300). Set to 0 + to disable caching. """ super().__init__( @@ -1191,6 +1316,8 @@ def __init__( shutdown_timer=ws_shutdown_timer, retry_timeout=self.retry_timeout, max_retries=max_retries, + ssl_context=_ssl_context, + dns_ttl=dns_ttl, ) else: self.ws = AsyncMock(spec=Websocket) @@ -1211,6 +1338,8 @@ def __init__( self.metadata_version_hex = "0x0f000000" # v15 self._initializing = False self._mock = _mock + self.startup_runtime_task: Optional[asyncio.Task] = None + self.startup_block_hash: Optional[str] = None async def __aenter__(self): if not self._mock: @@ -1230,8 +1359,12 @@ async def _initialize(self) -> None: if not self._chain: chain = await self.rpc_request("system_chain", []) self._chain = chain.get("result") - runtime = await self.init_runtime() + self.startup_block_hash = block_hash = await self.get_chain_head() + self.startup_runtime_task = asyncio.create_task( + self.init_runtime(block_hash=block_hash, init=True) + ) if self.ss58_format is None: + runtime = await self.init_runtime(block_hash) # Check and apply runtime constants ss58_prefix_constant = await self.get_constant( "System", "SS58Prefix", runtime=runtime @@ -1438,7 +1571,10 @@ async def decode_scale( return obj async def init_runtime( - self, block_hash: Optional[str] = None, block_id: Optional[int] = None + self, + block_hash: Optional[str] = None, + block_id: Optional[int] = None, + init: bool = False, ) -> Runtime: """ This method is used by all other methods that deals with metadata and types defined in the type registry. @@ -1455,6 +1591,13 @@ async def init_runtime( Returns: Runtime object """ + if ( + not init + and self.startup_runtime_task is not None + and block_hash == self.startup_block_hash + ): + await self.startup_runtime_task + self.startup_runtime_task = None if block_id and block_hash: raise ValueError("Cannot provide block_hash and block_id at the same time") @@ -2503,6 +2646,10 @@ async def _preprocess( # SCALE type string of value param_types = storage_item.get_params_type_string() value_scale_type = storage_item.get_value_type_string() + # V14 and V15 metadata may have different portable type registry numbering. + # Use V15 type ID when available to ensure correct decoding with the V15 registry. + if v15_type_id := runtime.get_v15_storage_type_id(module, storage_function): + value_scale_type = f"scale_info::{v15_type_id}" if len(params) != len(param_types): raise ValueError( @@ -3315,7 +3462,9 @@ async def _do_runtime_call_old( param_data = b"" if "encoder" in runtime_call_def: - param_data = runtime_call_def["encoder"](params) + if runtime is None: + runtime = await self.init_runtime(block_hash=block_hash) + param_data = runtime_call_def["encoder"](params, runtime.registry) else: for idx, param in enumerate(runtime_call_def["params"]): param_type_string = f"{param['type']}" @@ -4322,6 +4471,10 @@ async def close(self): Closes the substrate connection, and the websocket connection. """ try: + if self.startup_runtime_task is not None: + self.startup_runtime_task.cancel() + with suppress(asyncio.CancelledError): + await self.startup_runtime_task await self.ws.shutdown() except AttributeError: pass @@ -4379,22 +4532,52 @@ class DiskCachedAsyncSubstrateInterface(AsyncSubstrateInterface): Loads the cache from the disk at startup, where it is kept in-memory, and dumps to the disk when the connection is closed. + + For `wss://` endpoints, a persistent `_SessionResumingSSLContext` is created so + that TLS sessions are reused across reconnections. The effective session TTL is the minimum + of `ssl_session_ttl` (default `SSL_SESSION_TTL`) and the server-advertised timeout. """ + def __init__( + self, + url: str, + *args, + ssl_session_ttl: int = SSL_SESSION_TTL, + **kwargs, + ): + ssl_context: Optional[_SessionResumingSSLContext] = None + if url.startswith("wss://") and not kwargs.get("_mock", False): + ssl_context = _SessionResumingSSLContext(session_ttl=ssl_session_ttl) + ssl_context.set_default_verify_paths() + super().__init__(url, *args, _ssl_context=ssl_context, **kwargs) + async def initialize(self) -> None: + db = AsyncSqliteDB(self.url) + cached = await db.load_dns_cache(self.url) + if cached is not None: + addrinfos, saved_at_unix = cached + age = time.time() - saved_at_unix + # Reconstruct a monotonic timestamp so _resolve_host's TTL check works correctly + self.ws._dns_cache = (addrinfos, time.monotonic() - age) + logger.debug(f"Loaded DNS cache from disk (age={age:.0f}s)") await self.runtime_cache.load_from_disk(self.url) await self._initialize() async def close(self): """ - Closes the substrate connection and the websocket connection, dumps the runtime cache to disk + Closes the substrate connection and the websocket connection, dumps the runtime and DNS + caches to disk. """ + db = AsyncSqliteDB(self.url) + dns_cache = getattr(self.ws, "_dns_cache", None) + if dns_cache is not None: + addrinfos, _ = dns_cache + await db.save_dns_cache(self.url, addrinfos) try: await self.runtime_cache.dump_to_disk(self.url) await self.ws.shutdown() except AttributeError: pass - db = AsyncSqliteDB(self.url) await db.close() @async_sql_lru_cache(maxsize=SUBSTRATE_CACHE_METHOD_SIZE) diff --git a/async_substrate_interface/sync_substrate.py b/async_substrate_interface/sync_substrate.py index d393c25..7ee2507 100644 --- a/async_substrate_interface/sync_substrate.py +++ b/async_substrate_interface/sync_substrate.py @@ -1793,6 +1793,12 @@ def _preprocess( # SCALE type string of value param_types = storage_item.get_params_type_string() value_scale_type = storage_item.get_value_type_string() + # V14 and V15 metadata may have different portable type registry numbering. + # Use V15 type ID when available to ensure correct decoding with the V15 registry. + if v15_type_id := self.runtime.get_v15_storage_type_id( + module, storage_function + ): + value_scale_type = f"scale_info::{v15_type_id}" if len(params) != len(param_types): raise ValueError( @@ -2556,28 +2562,21 @@ def _do_runtime_call_old( runtime_call_def = _TYPE_REGISTRY["runtime_api"][api]["methods"][method] # Encode params - param_data = b"" + param_data: Union[ScaleBytes, bytes] = b"" - if "encoder" in runtime_call_def: - param_data = runtime_call_def["encoder"](params) - else: - for idx, param in enumerate(runtime_call_def["params"]): - param_type_string = f"{param['type']}" - if isinstance(params, list): - param_data += self.encode_scale(param_type_string, params[idx]) - else: - if param["name"] not in params: - raise ValueError( - f"Runtime Call param '{param['name']}' is missing" - ) + runtime = self.init_runtime(block_hash=block_hash) - param_data += self.encode_scale( - param_type_string, params[param["name"]] - ) + if "encoder" in runtime_call_def and runtime.registry is not None: + # only works if we have metadata v15 + param_data = runtime_call_def["encoder"](params, runtime.registry) + param_hex = param_data.hex() + else: + param_data = self._encode_scale_legacy(runtime_call_def, params, runtime) + param_hex = param_data.to_hex() # RPC request result_data = self.rpc_request( - "state_call", [f"{api}_{method}", param_data.hex(), block_hash] + "state_call", [f"{api}_{method}", param_hex, block_hash] ) result_vec_u8_bytes = hex_to_bytes(result_data["result"]) result_bytes = self.decode_scale("Vec", result_vec_u8_bytes) diff --git a/async_substrate_interface/type_registry.py b/async_substrate_interface/type_registry.py index 0f224e8..7e2e246 100644 --- a/async_substrate_interface/type_registry.py +++ b/async_substrate_interface/type_registry.py @@ -1,3 +1,5 @@ +from typing import Union +from collections import namedtuple from bt_decode import ( NeuronInfo, NeuronInfoLite, @@ -8,7 +10,54 @@ SubnetInfoV2, encode, ) -from scalecodec import ss58_encode +from scalecodec import ss58_decode + + +def stake_info_decode_vec_legacy_compatibility( + item, +) -> list[dict[str, Union[str, int, bytes, bool]]]: + stake_infos: list[StakeInfo] = StakeInfo.decode_vec(item) + NewStakeInfo = namedtuple( + "NewStakeInfo", + [ + "netuid", + "hotkey", + "coldkey", + "stake", + "locked", + "emission", + "drain", + "is_registered", + ], + ) + output = [] + for stake_info in stake_infos: + output.append( + NewStakeInfo( + 0, + stake_info.hotkey, + stake_info.coldkey, + stake_info.stake, + 0, + 0, + 0, + False, + ) + ) + return output + + +def preprocess_get_stake_info_for_coldkeys(addrs): + output = [] + if isinstance(addrs[0], list): # I think + for addr in addrs[0]: + output.append(list(bytes.fromhex(ss58_decode(addr)))) + else: + if isinstance(addrs[0], dict): + for addr in addrs[0]["coldkey_accounts"]: + output.append(list(bytes.fromhex(ss58_decode(addr)))) + return output + _TYPE_REGISTRY: dict[str, dict] = { "types": { @@ -24,7 +73,9 @@ "type": "Vec", }, ], - "encoder": lambda addr: encode(ss58_encode(addr), "Vec"), + "encoder": lambda addr, reg: encode( + "Vec", reg, list(bytes.fromhex(ss58_decode(addr))) + ), "type": "Vec", "decoder": DelegateInfo.decode_delegated, }, @@ -97,8 +148,20 @@ }, ], "type": "Vec", - "encoder": lambda addr: encode(ss58_encode(addr), "Vec"), - "decoder": StakeInfo.decode_vec, + "encoder": lambda addr, reg: encode( + "Vec", + reg, + list( + bytes.fromhex( + ss58_decode( + addr[0] + if isinstance(addr, list) + else addr["coldkey_account"] + ) + ) + ), + ), + "decoder": stake_info_decode_vec_legacy_compatibility, }, "get_stake_info_for_coldkeys": { "params": [ @@ -108,8 +171,10 @@ }, ], "type": "Vec", - "encoder": lambda addrs: encode( - [ss58_encode(addr) for addr in addrs], "Vec>" + "encoder": lambda addrs, reg: encode( + "Vec>", + reg, + preprocess_get_stake_info_for_coldkeys(addrs), ), "decoder": StakeInfo.decode_vec_tuple_vec, }, diff --git a/async_substrate_interface/types.py b/async_substrate_interface/types.py index 7af5e83..f05a44a 100644 --- a/async_substrate_interface/types.py +++ b/async_substrate_interface/types.py @@ -7,7 +7,7 @@ from contextlib import suppress from dataclasses import dataclass from datetime import datetime -from typing import Optional, Union, Any, Sequence +from typing import Optional, Union, Any, Sequence, Generic, TypeVar import scalecodec.types from bt_decode import PortableRegistry, encode as encode_by_type_string @@ -25,6 +25,8 @@ SUBSTRATE_RUNTIME_CACHE_SIZE = int(os.getenv("SUBSTRATE_RUNTIME_CACHE_SIZE", "16")) SUBSTRATE_CACHE_METHOD_SIZE = int(os.getenv("SUBSTRATE_CACHE_METHOD_SIZE", "512")) +T = TypeVar("T") + class RuntimeCache: """ @@ -225,6 +227,7 @@ def __init__( self.type_registry = type_registry self.metadata = metadata self.metadata_v15 = metadata_v15 + self._v15_storage_type_map: Optional[dict[tuple[str, str], int]] = None self.runtime_info = runtime_info self.registry = registry runtime_info = runtime_info or {} @@ -497,6 +500,38 @@ def resolve_type_definition(type_id_): self.registry_type_map = registry_type_map self.type_id_to_name = type_id_to_name + def get_v15_storage_type_id( + self, pallet: str, storage_function: str + ) -> Optional[int]: + """ + Returns the V15 type ID for a given pallet storage function. + V14 and V15 metadata may have different portable type registry numbering, + so using V15 type IDs ensures correct decoding with the V15 PortableRegistry. + """ + if self.metadata_v15 is None: + return None + if self._v15_storage_type_map is None: + self._v15_storage_type_map = {} + try: + v15_json = json.loads(self.metadata_v15.to_json()) + for p in v15_json.get("pallets", []): + storage = p.get("storage") + if not storage: + continue + for entry in storage.get("entries", []): + ty = entry.get("ty", {}) + if "Plain" in ty: + self._v15_storage_type_map[(p["name"], entry["name"])] = ty[ + "Plain" + ] + elif "Map" in ty: + self._v15_storage_type_map[(p["name"], entry["name"])] = ty[ + "Map" + ]["value"] + except Exception: + pass + return self._v15_storage_type_map.get((pallet, storage_function)) + RequestResults = dict[Union[str, int], list[Union[ScaleType, dict]]] @@ -569,7 +604,7 @@ class Preprocessed: storage_item: ScaleType -class ScaleObj: +class ScaleObj(Generic[T]): """Bittensor representation of Scale Object.""" def __init__(self, value): @@ -1092,6 +1127,27 @@ def _encode_scale( result = bytes(encode_by_type_string(type_string, runtime.registry, value)) return result + @staticmethod + def _encode_scale_legacy( + call_definition: list[dict], + params: Union[list[Any], dict[str, Any]], + runtime: Runtime, + ) -> bytes: + """Returns a hex encoded string of the params using their types.""" + param_data = scalecodec.ScaleBytes(b"") + + for i, param in enumerate(call_definition["params"]): # type: ignore + scale_obj = runtime.runtime_config.create_scale_object(param["type"]) + if type(params) is list: + param_data += scale_obj.encode(params[i]) + else: + if param["name"] not in params: + raise ValueError(f"Missing param {param['name']} in params dict.") + + param_data += scale_obj.encode(params[param["name"]]) + + return param_data + @staticmethod def _encode_account_id(account) -> bytes: """Encode an account ID into bytes. diff --git a/async_substrate_interface/utils/cache.py b/async_substrate_interface/utils/cache.py index 8de077b..bfac941 100644 --- a/async_substrate_interface/utils/cache.py +++ b/async_substrate_interface/utils/cache.py @@ -1,5 +1,6 @@ import asyncio import inspect +import time import weakref from collections import OrderedDict import functools @@ -14,6 +15,7 @@ USE_CACHE = True if os.getenv("NO_CACHE") != "1" else False +CACHE_LOCAL = os.getenv("CACHE_LOCAL") == "1" CACHE_LOCATION = ( os.path.expanduser( os.getenv("CACHE_LOCATION", "~/.cache/async-substrate-interface") @@ -30,6 +32,7 @@ class AsyncSqliteDB: _instances: dict[str, "AsyncSqliteDB"] = {} _db: Optional[aiosqlite.Connection] = None _lock: Optional[asyncio.Lock] = None + _created_tables: set def __new__(cls, chain_endpoint: str): try: @@ -37,6 +40,7 @@ def __new__(cls, chain_endpoint: str): except KeyError: instance = super().__new__(cls) instance._lock = asyncio.Lock() + instance._created_tables = set() cls._instances[chain_endpoint] = instance return instance @@ -45,8 +49,11 @@ async def close(self): if self._db: await self._db.close() self._db = None + self._created_tables.clear() async def _create_if_not_exists(self, chain: str, table_name: str): + if table_name in self._created_tables: + return _check_if_local(chain) if not (local_chain := _check_if_local(chain)) or not USE_CACHE: await self._db.execute( f""" @@ -76,6 +83,7 @@ async def _create_if_not_exists(self, chain: str, table_name: str): """ ) await self._db.commit() + self._created_tables.add(table_name) return local_chain async def __call__(self, chain, other_self, func, args, kwargs) -> Optional[Any]: @@ -86,18 +94,18 @@ async def __call__(self, chain, other_self, func, args, kwargs) -> Optional[Any] table_name = _get_table_name(func) local_chain = await self._create_if_not_exists(chain, table_name) key = pickle.dumps((args, kwargs or None)) - try: - cursor: aiosqlite.Cursor = await self._db.execute( - f"SELECT value FROM {table_name} WHERE key=? AND chain=?", - (key, chain), - ) - result = await cursor.fetchone() - await cursor.close() - if result is not None: - return pickle.loads(result[0]) - except (pickle.PickleError, sqlite3.Error) as e: - logger.exception("Cache error", exc_info=e) - pass + if not local_chain or not USE_CACHE: + try: + cursor: aiosqlite.Cursor = await self._db.execute( + f"SELECT value FROM {table_name} WHERE key=? AND chain=?", + (key, chain), + ) + result = await cursor.fetchone() + await cursor.close() + if result is not None: + return pickle.loads(result[0]) + except (pickle.PickleError, sqlite3.Error) as e: + logger.exception("Cache error", exc_info=e) result = await func(other_self, *args, **kwargs) if not local_chain or not USE_CACHE: # TODO use a task here @@ -108,6 +116,61 @@ async def __call__(self, chain, other_self, func, args, kwargs) -> Optional[Any] await self._db.commit() return result + async def _ensure_dns_table(self): + await self._db.execute( + """CREATE TABLE IF NOT EXISTS dns_cache ( + url TEXT PRIMARY KEY, + addrinfos BLOB, + saved_at REAL + )""" + ) + await self._db.commit() + + async def load_dns_cache(self, url: str) -> Optional[tuple[list, float]]: + """ + Load a previously saved DNS result for ``url``. + + Returns ``(addrinfos, saved_at_unix)`` where ``saved_at_unix`` is the Unix + timestamp at which the result was saved, or ``None`` if nothing is cached. + Skips localhost URLs. + """ + if _check_if_local(url): + return None + async with self._lock: + if not self._db: + _ensure_dir() + self._db = await aiosqlite.connect(CACHE_LOCATION) + await self._ensure_dns_table() + try: + cursor = await self._db.execute( + "SELECT addrinfos, saved_at FROM dns_cache WHERE url=?", (url,) + ) + row = await cursor.fetchone() + await cursor.close() + if row is not None: + return pickle.loads(row[0]), row[1] + except (pickle.PickleError, sqlite3.Error) as e: + logger.debug(f"DNS cache load error: {e}") + return None + + async def save_dns_cache(self, url: str, addrinfos: list) -> None: + """Persist DNS results for ``url`` to disk. Skips localhost URLs.""" + if _check_if_local(url): + return + async with self._lock: + if not self._db: + _ensure_dir() + self._db = await aiosqlite.connect(CACHE_LOCATION) + await self._ensure_dns_table() + try: + await self._db.execute( + "INSERT OR REPLACE INTO dns_cache (url, addrinfos, saved_at) VALUES (?,?,?)", + (url, pickle.dumps(addrinfos), time.time()), + ) + await self._db.commit() + except (pickle.PickleError, sqlite3.Error) as e: + logger.debug(f"DNS cache save error: {e}") + async def load_runtime_cache( self, chain: str ) -> tuple[OrderedDict[int, str], OrderedDict[str, int], OrderedDict[int, dict]]: @@ -202,6 +265,8 @@ def _get_table_name(func): def _check_if_local(chain: str) -> bool: + if CACHE_LOCAL: + return False return any([x in chain for x in ["127.0.0.1", "localhost", "0.0.0.0"]]) diff --git a/pyproject.toml b/pyproject.toml index a00d36b..cf30f07 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "async-substrate-interface" -version = "1.6.3" +version = "1.6.4" description = "Asyncio library for interacting with substrate. Mostly API-compatible with py-substrate-interface" readme = "README.md" license = { file = "LICENSE" } @@ -48,7 +48,7 @@ Repository = "https://github.com/opentensor/async-substrate-interface/" async_substrate_interface = ["py.typed"] [build-system] -requires = ["setuptools~=70.0.0", "wheel"] +requires = ["setuptools>=70.0", "wheel"] build-backend = "setuptools.build_meta" [project.optional-dependencies] diff --git a/tests/integration_tests/test_async_substrate_interface.py b/tests/integration_tests/test_async_substrate_interface.py index b3f2bb7..d3f3a52 100644 --- a/tests/integration_tests/test_async_substrate_interface.py +++ b/tests/integration_tests/test_async_substrate_interface.py @@ -345,3 +345,22 @@ async def handler(_): current_block + 3, result_handler=handler, task_return=False ) assert result is True + + +@pytest.mark.asyncio +async def test_old_runtime_calls(): + from bittensor import SubtensorApi + + sub = SubtensorApi( + network=ARCHIVE_ENTRYPOINT, legacy_methods=True, async_subtensor=True + ) + await sub.initialize() + # will pass + assert sub.get_stake_info_for_coldkey( + "5CQ6dMW8JZhKCZX9kWsZRqa3kZRKmNHxbPPVFEt6FgyvGv2G", 4943592 + ) + # needs to use legacy + assert sub.get_stake_info_for_coldkey( + "5CQ6dMW8JZhKCZX9kWsZRqa3kZRKmNHxbPPVFEt6FgyvGv2G", 4670227 + ) + await sub.close() diff --git a/tests/integration_tests/test_disk_cache.py b/tests/integration_tests/test_disk_cache.py index 063eca1..5d6d838 100644 --- a/tests/integration_tests/test_disk_cache.py +++ b/tests/integration_tests/test_disk_cache.py @@ -1,3 +1,11 @@ +""" +Thresholds: + DISK_CACHE_TIMEOUT – first access per method hits SQLite (aiosqlite thread-pool + overhead); must be << any real network call (~200 ms). + MEMORY_CACHE_TIMEOUT – repeat access with the same args hits the in-process LRU; + should be effectively instant. +""" + import pytest import time from async_substrate_interface.async_substrate import ( @@ -8,6 +16,10 @@ from tests.helpers.settings import LATENT_LITE_ENTRYPOINT +DISK_CACHE_TIMEOUT = 0.5 +MEMORY_CACHE_TIMEOUT = 0.002 + + @pytest.mark.asyncio async def test_disk_cache(): print("Testing test_disk_cache") @@ -81,57 +93,44 @@ async def test_disk_cache(): assert parent_block_hash == parent_block_hash_sync assert block_runtime_info == block_runtime_info_sync assert block_runtime_version_for == block_runtime_version_for_sync - # Verify data is pulling from disk cache + # Verify data is pulling from disk cache. async with DiskCachedAsyncSubstrateInterface( LATENT_LITE_ENTRYPOINT, ss58_format=42, chain_name="Bittensor" ) as disk_cached_substrate: start = time.monotonic() new_block_hash = await disk_cached_substrate.get_block_hash(current_block) new_time = time.monotonic() - assert new_time - start < 0.001 + assert new_time - start < DISK_CACHE_TIMEOUT start = time.monotonic() - new_parent_block_hash = await disk_cached_substrate.get_parent_block_hash( - block_hash - ) + _ = await disk_cached_substrate.get_parent_block_hash(block_hash) new_time = time.monotonic() - assert new_time - start < 0.001 + assert new_time - start < DISK_CACHE_TIMEOUT start = time.monotonic() - new_block_runtime_info = await disk_cached_substrate.get_block_runtime_info( - block_hash - ) + _ = await disk_cached_substrate.get_block_runtime_info(block_hash) new_time = time.monotonic() - assert new_time - start < 0.001 + assert new_time - start < DISK_CACHE_TIMEOUT start = time.monotonic() - new_block_runtime_version_for = ( - await disk_cached_substrate.get_block_runtime_version_for(block_hash) - ) + _ = await disk_cached_substrate.get_block_runtime_version_for(block_hash) new_time = time.monotonic() - assert new_time - start < 0.001 + assert new_time - start < DISK_CACHE_TIMEOUT + # Repeat calls with the same args must come from the in-process LRU cache. start = time.monotonic() - new_block_hash_from_cache = await disk_cached_substrate.get_block_hash( - current_block - ) + _ = await disk_cached_substrate.get_block_hash(current_block) new_time = time.monotonic() - assert new_time - start < 0.001 + assert new_time - start < MEMORY_CACHE_TIMEOUT start = time.monotonic() - new_parent_block_hash_from_cache = ( - await disk_cached_substrate.get_parent_block_hash(block_hash_from_cache) - ) + _ = await disk_cached_substrate.get_parent_block_hash(block_hash_from_cache) new_time = time.monotonic() - assert new_time - start < 0.001 + assert new_time - start < MEMORY_CACHE_TIMEOUT start = time.monotonic() - new_block_runtime_info_from_cache = ( - await disk_cached_substrate.get_block_runtime_info(block_hash_from_cache) - ) + _ = await disk_cached_substrate.get_block_runtime_info(block_hash_from_cache) new_time = time.monotonic() - assert new_time - start < 0.001 + assert new_time - start < MEMORY_CACHE_TIMEOUT start = time.monotonic() - new_block_runtime_version_from_cache = ( - await disk_cached_substrate.get_block_runtime_version_for( - block_hash_from_cache - ) + _ = await disk_cached_substrate.get_block_runtime_version_for( + block_hash_from_cache ) new_time = time.monotonic() - assert new_time - start < 0.001 + assert new_time - start < MEMORY_CACHE_TIMEOUT print("Disk Cache tests passed") diff --git a/tests/integration_tests/test_substrate_interface.py b/tests/integration_tests/test_substrate_interface.py index f6cf0eb..35e8804 100644 --- a/tests/integration_tests/test_substrate_interface.py +++ b/tests/integration_tests/test_substrate_interface.py @@ -163,3 +163,19 @@ def test_get_payment_info(): assert partial_fee_all_options > partial_fee_no_era assert partial_fee_all_options > partial_fee_era print("test_get_payment_info succeeded") + + +def test_old_runtime_calls(): + from bittensor import SubtensorApi + + sub = SubtensorApi( + network=ARCHIVE_ENTRYPOINT, legacy_methods=True, async_subtensor=False + ) + # will pass + assert sub.get_stake_info_for_coldkey( + "5CQ6dMW8JZhKCZX9kWsZRqa3kZRKmNHxbPPVFEt6FgyvGv2G", 4943592 + ) + # needs to use legacy + assert sub.get_stake_info_for_coldkey( + "5CQ6dMW8JZhKCZX9kWsZRqa3kZRKmNHxbPPVFEt6FgyvGv2G", 4670227 + ) diff --git a/tests/unit_tests/asyncio_/test_substrate_interface.py b/tests/unit_tests/asyncio_/test_substrate_interface.py index afefe7a..a0ac123 100644 --- a/tests/unit_tests/asyncio_/test_substrate_interface.py +++ b/tests/unit_tests/asyncio_/test_substrate_interface.py @@ -296,7 +296,9 @@ async def test_get_account_next_index_cached_mode_uses_internal_cache(): substrate.supports_rpc_method = AsyncMock(return_value=True) substrate.rpc_request = AsyncMock(return_value={"result": 5}) - first = await substrate.get_account_next_index("5F3sa2TJAWMqDhXG6jhV4N8ko9NoFz5Y2s8vS8uM9f7v7mA") + first = await substrate.get_account_next_index( + "5F3sa2TJAWMqDhXG6jhV4N8ko9NoFz5Y2s8vS8uM9f7v7mA" + ) second = await substrate.get_account_next_index( "5F3sa2TJAWMqDhXG6jhV4N8ko9NoFz5Y2s8vS8uM9f7v7mA" ) @@ -331,7 +333,9 @@ async def test_get_account_next_index_bypass_mode_does_not_create_or_mutate_cach async def test_get_account_next_index_bypass_mode_raises_on_rpc_error(): substrate = AsyncSubstrateInterface("ws://localhost", _mock=True) substrate.supports_rpc_method = AsyncMock(return_value=True) - substrate.rpc_request = AsyncMock(return_value={"error": {"message": "rpc failure"}}) + substrate.rpc_request = AsyncMock( + return_value={"error": {"message": "rpc failure"}} + ) with pytest.raises(SubstrateRequestException, match="rpc failure"): await substrate.get_account_next_index(