diff --git a/paimon-python/pypaimon/api/rest_util.py b/paimon-python/pypaimon/api/rest_util.py index 97a709ecc34c..8eb4092c26a4 100644 --- a/paimon-python/pypaimon/api/rest_util.py +++ b/paimon-python/pypaimon/api/rest_util.py @@ -15,8 +15,7 @@ # specific language governing permissions and limitations # under the License. -from typing import Dict -from urllib.parse import unquote +from typing import Dict, Optional from pypaimon.common.options import Options @@ -31,6 +30,7 @@ def encode_string(value: str) -> str: @staticmethod def decode_string(encoded: str) -> str: """Decode URL-encoded string""" + from urllib.parse import unquote return unquote(encoded) @staticmethod @@ -46,21 +46,21 @@ def extract_prefix_map( @staticmethod def merge( - base_properties: Dict[str, str], - override_properties: Dict[str, str]) -> Dict[str, str]: + base_properties: Optional[Dict[str, str]], + override_properties: Optional[Dict[str, str]]) -> Dict[str, str]: if override_properties is None: override_properties = {} if base_properties is None: base_properties = {} - + result = {} - + for key, value in base_properties.items(): if value is not None and key not in override_properties: result[key] = value - + for key, value in override_properties.items(): if value is not None: result[key] = value - + return result diff --git a/paimon-python/pypaimon/catalog/rest/rest_token_file_io.py b/paimon-python/pypaimon/catalog/rest/rest_token_file_io.py index 2cec5df7216c..c3bffc38c789 100644 --- a/paimon-python/pypaimon/catalog/rest/rest_token_file_io.py +++ b/paimon-python/pypaimon/catalog/rest/rest_token_file_io.py @@ -18,20 +18,30 @@ import logging import threading import time +from pathlib import Path from typing import Optional +import pyarrow +import pyarrow.fs +from pypaimon.common.options.config import CatalogOptions, OssOptions from pyarrow._fs import FileSystem from pypaimon.api.rest_api import RESTApi +from pypaimon.common.options import Options from pypaimon.api.rest_util import RESTUtil from pypaimon.catalog.rest.rest_token import RESTToken from pypaimon.common.file_io import FileIO from pypaimon.common.identifier import Identifier -from pypaimon.common.options import Options -from pypaimon.common.options.config import CatalogOptions, OssOptions + +from cachetools import TTLCache class RESTTokenFileIO(FileIO): + _FILESYSTEM_CACHE: TTLCache = TTLCache(maxsize=1000, ttl=36000) # 10 hours TTL + _CACHE_LOCK = threading.Lock() + _TOKEN_CACHE: dict = {} + _TOKEN_LOCKS: dict = {} + _TOKEN_LOCKS_LOCK = threading.Lock() def __init__(self, identifier: Identifier, path: str, catalog_options: Optional[Options] = None): @@ -39,7 +49,7 @@ def __init__(self, identifier: Identifier, path: str, self.path = path self.token: Optional[RESTToken] = None self.api_instance: Optional[RESTApi] = None - self.lock = threading.Lock() + self.lock = threading.RLock() self.log = logging.getLogger(__name__) super().__init__(path, catalog_options) @@ -54,42 +64,221 @@ def __getstate__(self): def __setstate__(self, state): self.__dict__.update(state) # Recreate lock after deserialization - self.lock = threading.Lock() + self.lock = threading.RLock() # api_instance will be recreated when needed self.api_instance = None - def _initialize_oss_fs(self, path) -> FileSystem: + def _initialize_oss_fs(self, path, properties: Optional[Options] = None) -> FileSystem: + if properties is not None: + file_io = FileIO(path, properties) + return file_io.filesystem + self.try_to_refresh_token() - merged_token = self._merge_token_with_catalog_options(self.token.token) + if self.token is None: + return super()._initialize_oss_fs(path) + merged_properties = RESTUtil.merge( self.properties.to_map() if self.properties else {}, - merged_token + self._merge_token_with_catalog_options(self.token.token) ) merged_options = Options(merged_properties) - original_properties = self.properties - self.properties = merged_options - try: - return super()._initialize_oss_fs(path) - finally: - self.properties = original_properties + + file_io = FileIO(path, merged_options) + return file_io.filesystem + + def _initialize_s3_fs(self, properties: Optional[Options] = None) -> FileSystem: + if properties is not None: + file_io = FileIO(self.path, properties) + return file_io.filesystem + + self.try_to_refresh_token() + if self.token is None: + return super()._initialize_s3_fs() + + merged_properties = RESTUtil.merge( + self.properties.to_map() if self.properties else {}, + self._merge_token_with_catalog_options(self.token.token) + ) + merged_options = Options(merged_properties) + + file_io = FileIO(self.path, merged_options) + return file_io.filesystem def _merge_token_with_catalog_options(self, token: dict) -> dict: """Merge token with catalog options, DLF OSS endpoint should override the standard OSS endpoint.""" merged_token = dict(token) - dlf_oss_endpoint = self.properties.get(CatalogOptions.DLF_OSS_ENDPOINT) + with self.lock: + dlf_oss_endpoint = self.properties.get(CatalogOptions.DLF_OSS_ENDPOINT) if self.properties else None if dlf_oss_endpoint and dlf_oss_endpoint.strip(): merged_token[OssOptions.OSS_ENDPOINT.key()] = dlf_oss_endpoint return merged_token + def get_merged_properties(self) -> Options: + self.try_to_refresh_token() + if self.token is None: + with self.lock: + return self.properties + + with self.lock: + properties_map = self.properties.to_map() if self.properties else {} + + merged_properties = RESTUtil.merge( + properties_map, + self.token.token + ) + return Options(merged_properties) + + def _get_filesystem(self) -> FileSystem: + self.try_to_refresh_token() + + if self.token is None: + return self.filesystem + + with self._CACHE_LOCK: + filesystem = self._FILESYSTEM_CACHE.get(self.token) + if filesystem is not None: + return filesystem + + with self.lock: + properties_map = self.properties.to_map() if self.properties else {} + + merged_properties = RESTUtil.merge( + properties_map, + self.token.token + ) + merged_options = Options(merged_properties) + + scheme, netloc, _ = self.parse_location(self.path) + if scheme in {"oss"}: + filesystem = self._initialize_oss_fs(self.path, merged_options) + elif scheme in {"s3", "s3a", "s3n"}: + filesystem = self._initialize_s3_fs(merged_options) + else: + filesystem = self.filesystem + + self._FILESYSTEM_CACHE[self.token] = filesystem + return filesystem + + def new_input_stream(self, path: str): + filesystem = self._get_filesystem() + path_str = self.to_filesystem_path(path) + return filesystem.open_input_file(path_str) + def new_output_stream(self, path: str): - # Call parent class method to ensure path conversion and parent directory creation - return super().new_output_stream(path) + filesystem = self._get_filesystem() + path_str = self.to_filesystem_path(path) + parent_dir = Path(path_str).parent + if str(parent_dir) and not self.exists(str(parent_dir)): + self.mkdirs(str(parent_dir)) + return filesystem.open_output_stream(path_str) + + def get_file_status(self, path: str): + filesystem = self._get_filesystem() + path_str = self.to_filesystem_path(path) + file_infos = filesystem.get_file_info([path_str]) + return file_infos[0] + + def list_status(self, path: str): + filesystem = self._get_filesystem() + path_str = self.to_filesystem_path(path) + selector = pyarrow.fs.FileSelector(path_str, recursive=False, allow_not_found=True) + return filesystem.get_file_info(selector) + + def exists(self, path: str) -> bool: + filesystem = self._get_filesystem() + path_str = self.to_filesystem_path(path) + file_info = filesystem.get_file_info([path_str])[0] + return file_info.type != pyarrow.fs.FileType.NotFound + + def delete(self, path: str, recursive: bool = False) -> bool: + filesystem = self._get_filesystem() + path_str = self.to_filesystem_path(path) + file_info = filesystem.get_file_info([path_str])[0] + if file_info.type == pyarrow.fs.FileType.NotFound: + return False + if file_info.type == pyarrow.fs.FileType.Directory: + if recursive: + filesystem.delete_dir_contents(path_str) + else: + filesystem.delete_dir(path_str) + else: + filesystem.delete_file(path_str) + return True + + def mkdirs(self, path: str) -> bool: + filesystem = self._get_filesystem() + path_str = self.to_filesystem_path(path) + filesystem.create_dir(path_str, recursive=True) + return True + + def rename(self, src: str, dst: str) -> bool: + filesystem = self._get_filesystem() + dst_str = self.to_filesystem_path(dst) + dst_parent = Path(dst_str).parent + if str(dst_parent) and not self.exists(str(dst_parent)): + self.mkdirs(str(dst_parent)) + + src_str = self.to_filesystem_path(src) + filesystem.move(src_str, dst_str) + return True + + def copy_file(self, source_path: str, target_path: str, overwrite: bool = False): + if not overwrite and self.exists(target_path): + raise FileExistsError(f"Target file {target_path} already exists and overwrite=False") + + filesystem = self._get_filesystem() + source_str = self.to_filesystem_path(source_path) + target_str = self.to_filesystem_path(target_path) + filesystem.copy_file(source_str, target_str) def try_to_refresh_token(self): - if self.should_refresh(): - with self.lock: - if self.should_refresh(): - self.refresh_token() + identifier_str = str(self.identifier) + + # Fast path 1: Check instance token + if self.token is not None and not self._is_token_expired(self.token): + return + + # Fast path 2: Check global cache + cached_token = self._get_cached_token(identifier_str) + if cached_token and not self._is_token_expired(cached_token): + self.token = cached_token + return + + # Slow path: Acquire global lock for this identifier + global_lock = self._get_global_token_lock() + with global_lock: + cached_token = self._get_cached_token(identifier_str) + if cached_token and not self._is_token_expired(cached_token): + self.token = cached_token + return + + token_to_check = cached_token if cached_token else self.token + if token_to_check is None or self._is_token_expired(token_to_check): + self.refresh_token() + self._set_cached_token(identifier_str, self.token) + else: + self.token = cached_token if cached_token else self.token + + def _get_cached_token(self, identifier_str: str) -> Optional[RESTToken]: + with self._TOKEN_LOCKS_LOCK: + return self._TOKEN_CACHE.get(identifier_str) + + def _set_cached_token(self, identifier_str: str, token: RESTToken): + with self._TOKEN_LOCKS_LOCK: + self._TOKEN_CACHE[identifier_str] = token + + def _get_global_token_lock(self) -> threading.Lock: + identifier_str = str(self.identifier) + with self._TOKEN_LOCKS_LOCK: + if identifier_str not in self._TOKEN_LOCKS: + self._TOKEN_LOCKS[identifier_str] = threading.Lock() + return self._TOKEN_LOCKS[identifier_str] + + def _is_token_expired(self, token: RESTToken) -> bool: + if token is None: + return True + current_time = int(time.time() * 1000) + return (token.expire_at_millis - current_time) < RESTApi.TOKEN_EXPIRATION_SAFE_TIME_MILLIS def should_refresh(self): if self.token is None: @@ -100,13 +289,50 @@ def should_refresh(self): def refresh_token(self): self.log.info(f"begin refresh data token for identifier [{self.identifier}]") if self.api_instance is None: - self.api_instance = RESTApi(self.properties, False) + if not self.properties: + raise RuntimeError( + "Cannot refresh token: properties is None or empty. " + "This may indicate a serialization issue when passing RESTTokenFileIO to Ray workers." + ) + + if not self.properties.contains(CatalogOptions.URI): + available_keys = list(self.properties.data.keys()) if self.properties else [] + raise RuntimeError( + f"Cannot refresh token: missing required configuration '{CatalogOptions.URI}' in properties. " + "This is required to create RESTApi for token refresh. " + f"Available configuration keys: {available_keys}. " + "This may indicate that catalog options were not properly serialized when passing to Ray workers." + ) + + uri = self.properties.get(CatalogOptions.URI) + if not uri or not uri.strip(): + raise RuntimeError( + f"Cannot refresh token: '{CatalogOptions.URI}' is empty or whitespace. " + f"Value: '{uri}'. Please ensure the REST catalog URI is properly configured." + ) + + try: + self.api_instance = RESTApi(self.properties, False) + except Exception as e: + raise RuntimeError( + f"Failed to create RESTApi for token refresh: {e}. " + "Please check that all required catalog options are properly configured. " + f"Identifier: {self.identifier}" + ) from e + + try: + response = self.api_instance.load_table_token(self.identifier) + except Exception as e: + raise RuntimeError( + f"Failed to load table token from REST API: {e}. " + f"Identifier: {self.identifier}, URI: {self.properties.get(CatalogOptions.URI)}" + ) from e - response = self.api_instance.load_table_token(self.identifier) self.log.info( f"end refresh data token for identifier [{self.identifier}] expiresAtMillis [{response.expires_at_millis}]" ) - self.token = RESTToken(response.token, response.expires_at_millis) + merged_token = self._merge_token_with_catalog_options(response.token) + self.token = RESTToken(merged_token, response.expires_at_millis) def valid_token(self): self.try_to_refresh_token() diff --git a/paimon-python/pypaimon/read/reader/lance_utils.py b/paimon-python/pypaimon/read/reader/lance_utils.py index 60c7763aa3c2..845a75e7de79 100644 --- a/paimon-python/pypaimon/read/reader/lance_utils.py +++ b/paimon-python/pypaimon/read/reader/lance_utils.py @@ -25,7 +25,13 @@ def to_lance_specified(file_io: FileIO, file_path: str) -> Tuple[str, Optional[Dict[str, str]]]: - """Convert path and extract storage options for Lance format.""" + """Convert path and extract storage options for Lance format. + """ + if hasattr(file_io, 'get_merged_properties'): + properties = file_io.get_merged_properties() + else: + properties = file_io.properties if hasattr(file_io, 'properties') and file_io.properties else None + scheme, _, _ = file_io.parse_location(file_path) storage_options = None file_path_for_lance = file_io.to_filesystem_path(file_path) @@ -37,37 +43,40 @@ def to_lance_specified(file_io: FileIO, file_path: str) -> Tuple[str, Optional[D file_path_for_lance = file_path if scheme == 'oss': - storage_options = {} - if hasattr(file_io, 'properties'): - for key, value in file_io.properties.data.items(): + parsed = urlparse(file_path) + bucket = parsed.netloc + path = parsed.path.lstrip('/') + + if properties: + storage_options = {} + for key, value in properties.to_map().items(): if str(key).startswith('fs.'): storage_options[key] = value - parsed = urlparse(file_path) - bucket = parsed.netloc - path = parsed.path.lstrip('/') - - endpoint = file_io.properties.get(OssOptions.OSS_ENDPOINT) + endpoint = properties.get(OssOptions.OSS_ENDPOINT) if endpoint: endpoint_clean = endpoint.replace('http://', '').replace('https://', '') storage_options['endpoint'] = f"https://{bucket}.{endpoint_clean}" - if file_io.properties.contains(OssOptions.OSS_ACCESS_KEY_ID): - storage_options['access_key_id'] = file_io.properties.get(OssOptions.OSS_ACCESS_KEY_ID) - storage_options['oss_access_key_id'] = file_io.properties.get(OssOptions.OSS_ACCESS_KEY_ID) - if file_io.properties.contains(OssOptions.OSS_ACCESS_KEY_SECRET): - storage_options['secret_access_key'] = file_io.properties.get(OssOptions.OSS_ACCESS_KEY_SECRET) - storage_options['oss_secret_access_key'] = file_io.properties.get(OssOptions.OSS_ACCESS_KEY_SECRET) - if file_io.properties.contains(OssOptions.OSS_SECURITY_TOKEN): - storage_options['session_token'] = file_io.properties.get(OssOptions.OSS_SECURITY_TOKEN) - storage_options['oss_session_token'] = file_io.properties.get(OssOptions.OSS_SECURITY_TOKEN) - if file_io.properties.contains(OssOptions.OSS_ENDPOINT): - storage_options['oss_endpoint'] = file_io.properties.get(OssOptions.OSS_ENDPOINT) + if properties.contains(OssOptions.OSS_ACCESS_KEY_ID): + storage_options['access_key_id'] = properties.get(OssOptions.OSS_ACCESS_KEY_ID) + storage_options['oss_access_key_id'] = properties.get(OssOptions.OSS_ACCESS_KEY_ID) + if properties.contains(OssOptions.OSS_ACCESS_KEY_SECRET): + storage_options['secret_access_key'] = properties.get(OssOptions.OSS_ACCESS_KEY_SECRET) + storage_options['oss_secret_access_key'] = properties.get(OssOptions.OSS_ACCESS_KEY_SECRET) + if properties.contains(OssOptions.OSS_SECURITY_TOKEN): + storage_options['session_token'] = properties.get(OssOptions.OSS_SECURITY_TOKEN) + storage_options['oss_session_token'] = properties.get(OssOptions.OSS_SECURITY_TOKEN) + if properties.contains(OssOptions.OSS_ENDPOINT): + storage_options['oss_endpoint'] = properties.get(OssOptions.OSS_ENDPOINT) + storage_options['virtual_hosted_style_request'] = 'true' - + if bucket and path: file_path_for_lance = f"oss://{bucket}/{path}" elif bucket: file_path_for_lance = f"oss://{bucket}" + else: + storage_options = None return file_path_for_lance, storage_options diff --git a/paimon-python/pypaimon/tests/rest/rest_server.py b/paimon-python/pypaimon/tests/rest/rest_server.py index e0bedeac1bb9..2e2a85bce99f 100755 --- a/paimon-python/pypaimon/tests/rest/rest_server.py +++ b/paimon-python/pypaimon/tests/rest/rest_server.py @@ -24,9 +24,12 @@ from dataclasses import dataclass from http.server import BaseHTTPRequestHandler, HTTPServer from pathlib import Path -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union, TYPE_CHECKING from urllib.parse import urlparse +if TYPE_CHECKING: + from pypaimon.catalog.rest.rest_token import RESTToken + from pypaimon.api.api_request import (AlterTableRequest, CreateDatabaseRequest, CreateTableRequest, RenameTableRequest) from pypaimon.api.api_response import (ConfigResponse, GetDatabaseResponse, @@ -213,6 +216,7 @@ def __init__(self, data_path: str, auth_provider, config: ConfigResponse, wareho self.table_partitions_store: Dict[str, List] = {} self.no_permission_databases: List[str] = [] self.no_permission_tables: List[str] = [] + self.table_token_store: Dict[str, "RESTToken"] = {} # Initialize mock catalog (simplified) self.data_path = data_path @@ -469,10 +473,12 @@ def _handle_table_resource(self, method: str, path_parts: List[str], # Basic table operations (GET, DELETE, etc.) return self._table_handle(method, data, lookup_identifier) elif len(path_parts) == 4: - # Extended operations (e.g., commit) + # Extended operations (e.g., commit, token) operation = path_parts[3] if operation == "commit": return self._table_commit_handle(method, data, lookup_identifier, branch_part) + elif operation == "token": + return self._table_token_handle(method, lookup_identifier) else: return self._mock_response(ErrorResponse(None, None, "Not Found", 404), 404) return self._mock_response(ErrorResponse(None, None, "Not Found", 404), 404) @@ -574,6 +580,44 @@ def _table_handle(self, method: str, data: str, identifier: Identifier) -> Tuple return self._mock_response(ErrorResponse(None, None, "Method Not Allowed", 405), 405) + def _table_token_handle(self, method: str, identifier: Identifier) -> Tuple[str, int]: + if method != "GET": + return self._mock_response(ErrorResponse(None, None, "Method Not Allowed", 405), 405) + + if identifier.get_full_name() not in self.table_metadata_store: + raise TableNotExistException(identifier) + + from pypaimon.api.api_response import GetTableTokenResponse + + token_key = identifier.get_full_name() + if token_key in self.table_token_store: + rest_token = self.table_token_store[token_key] + response = GetTableTokenResponse( + token=rest_token.token, + expires_at_millis=rest_token.expire_at_millis + ) + else: + default_token = { + "akId": "akId" + str(int(time.time() * 1000)), + "akSecret": "akSecret" + str(int(time.time() * 1000)) + } + response = GetTableTokenResponse( + token=default_token, + expires_at_millis=int(time.time() * 1000) + 3600_000 # 1 hour from now + ) + + return self._mock_response(response, 200) + + def set_table_token(self, identifier: Identifier, token: "RESTToken") -> None: + self.table_token_store[identifier.get_full_name()] = token + + def get_table_token(self, identifier: Identifier) -> Optional["RESTToken"]: + return self.table_token_store.get(identifier.get_full_name()) + + def reset_table_token(self, identifier: Identifier) -> None: + if identifier.get_full_name() in self.table_token_store: + del self.table_token_store[identifier.get_full_name()] + def _table_commit_handle(self, method: str, data: str, identifier: Identifier, branch: str = None) -> Tuple[str, int]: """Handle table commit operations""" diff --git a/paimon-python/pypaimon/tests/rest/rest_token_file_io_test.py b/paimon-python/pypaimon/tests/rest/rest_token_file_io_test.py index 07e445b12a4f..3c521f2dfbe9 100644 --- a/paimon-python/pypaimon/tests/rest/rest_token_file_io_test.py +++ b/paimon-python/pypaimon/tests/rest/rest_token_file_io_test.py @@ -197,6 +197,7 @@ def test_catalog_options_not_modified(self): from pypaimon.api.rest_util import RESTUtil from pypaimon.catalog.rest.rest_token import RESTToken from pyarrow.fs import LocalFileSystem + import time original_catalog_options = Options({ CatalogOptions.URI.key(): "http://test-uri", @@ -238,6 +239,237 @@ def test_catalog_options_not_modified(self): self.assertIn(OssOptions.OSS_ACCESS_KEY_ID.key(), merged_properties) self.assertEqual(merged_properties[OssOptions.OSS_ACCESS_KEY_ID.key()], "token-access-key") + def test_filesystem_cache_token_consistency(self): + import time + from unittest.mock import patch, MagicMock + from pypaimon.catalog.rest.rest_token import RESTToken + + current_time = int(time.time() * 1000) + original_token = RESTToken( + token={"fs.oss.accessKeyId": "old-key", "fs.oss.accessKeySecret": "old-secret"}, + expire_at_millis=current_time + 2 * 3600 * 1000 + ) + refreshed_token = RESTToken( + token={"fs.oss.accessKeyId": "new-key", "fs.oss.accessKeySecret": "new-secret"}, + expire_at_millis=current_time + 3 * 3600 * 1000 + ) + + def mock_try_to_refresh_token(self): + if self.token == original_token: + self.token = refreshed_token + + original_parse_location = FileIO.parse_location + + def mock_parse_location(self_or_location, path=None): + if path is None: + location = self_or_location + scheme, netloc, path_part = original_parse_location(location) + else: + scheme, netloc, path_part = original_parse_location(path) + return ("oss", netloc, path_part) + + def mock_initialize_oss_fs(self, path, properties=None): + return MagicMock() + + with patch.object(RESTTokenFileIO, 'parse_location', mock_parse_location), \ + patch.object(RESTTokenFileIO, 'try_to_refresh_token', mock_try_to_refresh_token), \ + patch.object(RESTTokenFileIO, '_initialize_oss_fs', mock_initialize_oss_fs): + file_io = RESTTokenFileIO( + self.identifier, + self.warehouse_path, + self.catalog_options + ) + + file_io.token = original_token + RESTTokenFileIO._FILESYSTEM_CACHE.clear() + + fs1 = file_io._get_filesystem() + + self.assertEqual( + file_io.token, refreshed_token, + "Token should be refreshed to refreshed_token" + ) + + cached_with_refreshed = RESTTokenFileIO._FILESYSTEM_CACHE.get(refreshed_token) + self.assertIsNotNone( + cached_with_refreshed, + "Filesystem should be cached with refreshed token as key" + ) + self.assertIs( + cached_with_refreshed, fs1, + "Cached filesystem should be the same instance" + ) + + cached_with_original = RESTTokenFileIO._FILESYSTEM_CACHE.get(original_token) + self.assertIsNone( + cached_with_original, + "Filesystem should NOT be cached with original token as key" + ) + + fs2 = file_io._get_filesystem() + self.assertIs(fs1, fs2, "Should return same filesystem from cache") + + cached_after_second_call = RESTTokenFileIO._FILESYSTEM_CACHE.get(refreshed_token) + self.assertIsNotNone( + cached_after_second_call, + "Cache should still contain filesystem with refreshed token key" + ) + self.assertIs( + cached_after_second_call, fs1, + "Cached filesystem should be the same instance" + ) + + def test_filesystem_cache_reuse(self): + import time + from pypaimon.catalog.rest.rest_token import RESTToken + + current_time = int(time.time() * 1000) + token = RESTToken( + token={"fs.oss.accessKeyId": "test-key", "fs.oss.accessKeySecret": "test-secret"}, + expire_at_millis=current_time + 2 * 3600 * 1000 + ) + + with patch.object(RESTTokenFileIO, 'try_to_refresh_token'): + file_io = RESTTokenFileIO( + self.identifier, + self.warehouse_path, + self.catalog_options + ) + + file_io.token = token + + RESTTokenFileIO._FILESYSTEM_CACHE.clear() + fs1 = file_io._get_filesystem() + fs2 = file_io._get_filesystem() + + self.assertIs(fs1, fs2, "Same filesystem instance should be returned from cache") + cached = RESTTokenFileIO._FILESYSTEM_CACHE.get(token) + self.assertIsNotNone(cached, "Filesystem should be in cache") + self.assertIs(cached, fs1, "Cached filesystem should be the same instance") + + def test_error_handling_raises_exception(self): + with patch.object(RESTTokenFileIO, 'try_to_refresh_token'): + file_io = RESTTokenFileIO( + self.identifier, + self.warehouse_path, + self.catalog_options + ) + + with patch.object(file_io, '_get_filesystem', side_effect=Exception("Test error")): + with self.assertRaises(Exception) as context: + file_io.exists("test_path") + self.assertIn("Test error", str(context.exception)) + + with patch.object(file_io, '_get_filesystem', side_effect=Exception("Test error")): + with self.assertRaises(Exception) as context: + file_io.delete("test_path") + self.assertIn("Test error", str(context.exception)) + + with patch.object(file_io, '_get_filesystem', side_effect=Exception("Test error")): + with self.assertRaises(Exception) as context: + file_io.mkdirs("test_path") + self.assertIn("Test error", str(context.exception)) + + with patch.object(file_io, '_get_filesystem', side_effect=Exception("Test error")): + with self.assertRaises(Exception) as context: + file_io.rename("src", "dst") + self.assertIn("Test error", str(context.exception)) + + def test_error_handling_with_filesystem_errors(self): + import pyarrow.fs as pafs + from unittest.mock import MagicMock + + with patch.object(RESTTokenFileIO, 'try_to_refresh_token'): + file_io = RESTTokenFileIO( + self.identifier, + self.warehouse_path, + self.catalog_options + ) + + mock_fs_exists = MagicMock() + mock_fs_exists.get_file_info.side_effect = Exception("Filesystem error") + + mock_fs_delete = MagicMock() + mock_fs_delete.get_file_info.return_value = [pafs.FileInfo("test_path", pafs.FileType.File)] + mock_fs_delete.delete_file.side_effect = Exception("Delete error") + + mock_fs_mkdirs = MagicMock() + mock_fs_mkdirs.create_dir.side_effect = Exception("Create dir error") + + mock_fs_rename = MagicMock() + mock_fs_rename.get_file_info.return_value = [pafs.FileInfo("dst_parent", pafs.FileType.NotFound)] + mock_fs_rename.move.side_effect = Exception("Move error") + + with patch.object(file_io, '_get_filesystem', return_value=mock_fs_exists): + with self.assertRaises(Exception) as context: + file_io.exists("test_path") + self.assertIn("Filesystem error", str(context.exception)) + + with patch.object(file_io, '_get_filesystem', return_value=mock_fs_delete): + with self.assertRaises(Exception) as context: + file_io.delete("test_path") + self.assertIn("Delete error", str(context.exception)) + + with patch.object(file_io, '_get_filesystem', return_value=mock_fs_mkdirs): + with self.assertRaises(Exception) as context: + file_io.mkdirs("test_path") + self.assertIn("Create dir error", str(context.exception)) + + with patch.object(file_io, '_get_filesystem', return_value=mock_fs_rename): + with self.assertRaises(Exception) as context: + file_io.rename("src", "dst") + self.assertIn("Move error", str(context.exception)) + + def test_cache_lazy_expiration(self): + import time + from pypaimon.catalog.rest.rest_token import RESTToken + from unittest.mock import patch + + current_time = int(time.time() * 1000) + token = RESTToken( + token={"fs.oss.accessKeyId": "test-key", "fs.oss.accessKeySecret": "test-secret"}, + expire_at_millis=current_time + 2 * 3600 * 1000 + ) + + with patch.object(RESTTokenFileIO, 'try_to_refresh_token'): + file_io = RESTTokenFileIO( + self.identifier, + self.warehouse_path, + self.catalog_options + ) + + file_io.token = token + RESTTokenFileIO._FILESYSTEM_CACHE.clear() + + fs1 = file_io._get_filesystem() + self.assertIsNotNone(fs1) + self.assertEqual(len(RESTTokenFileIO._FILESYSTEM_CACHE), 1) + + fs2 = file_io._get_filesystem() + self.assertIs(fs1, fs2) + self.assertEqual(len(RESTTokenFileIO._FILESYSTEM_CACHE), 1) + + cached_fs = RESTTokenFileIO._FILESYSTEM_CACHE.get(token) + self.assertIsNotNone(cached_fs) + self.assertIs(cached_fs, fs1) + + def test_token_none_handling(self): + with patch.object(RESTTokenFileIO, 'try_to_refresh_token'): + file_io = RESTTokenFileIO( + self.identifier, + self.warehouse_path, + self.catalog_options + ) + + file_io.token = None + + filesystem = file_io._get_filesystem() + self.assertIsNotNone(filesystem, "Should return default filesystem when token is None") + self.assertIs( + filesystem, file_io.filesystem, + "Should return default filesystem instance" + ) + if __name__ == '__main__': unittest.main()