Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 28 additions & 4 deletions src/runpod_flash/client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import difflib
import inspect
import logging
import os
from functools import wraps
from typing import Any, List, Optional

Expand Down Expand Up @@ -95,6 +96,19 @@ async def _resolve_deployed_endpoint_id(func_name: str) -> Optional[str]:
return None


def _reject_unknown_kwargs(extra: dict[str, Any], known: set[str]) -> None:
"""Raise TypeError for unknown kwargs with 'did you mean?' suggestions."""
names = sorted(extra)
parts: list[str] = []
for name in names:
close = difflib.get_close_matches(name, sorted(known), n=1, cutoff=0.6)
hint = f" (Did you mean '{close[0]}'?)" if close else ""
parts.append(f"'{name}'{hint}")

noun = "argument" if len(names) == 1 else "arguments"
raise TypeError(f"remote() got unknown keyword {noun}: {', '.join(parts)}")


def remote(
resource_config: ServerlessResource,
dependencies: Optional[List[str]] = None,
Expand All @@ -104,6 +118,8 @@ def remote(
method: Optional[str] = None,
path: Optional[str] = None,
_internal: bool = False,
# **extra is retained (rather than removing it and relying on Python's own
# TypeError) so we can provide "did you mean?" suggestions for typos.
**extra,
):
"""
Expand Down Expand Up @@ -142,7 +158,6 @@ def remote(
Ignored for queue-based endpoints. Defaults to None.
_internal (bool, optional): suppress deprecation warning when called from
Endpoint internals. not part of the public API. Defaults to False.
extra (dict, optional): Additional parameters for the execution of the resource. Defaults to an empty dict.

Returns:
Callable: A decorator that wraps the target function, enabling remote execution with the specified
Expand Down Expand Up @@ -180,6 +195,8 @@ async def my_test_function(data):
pass
```
"""
if extra:
_reject_unknown_kwargs(extra, _REMOTE_KNOWN_KWARGS)

if not _internal:
import warnings
Expand Down Expand Up @@ -268,7 +285,6 @@ def decorator(func_or_class):
dependencies,
system_dependencies,
accelerate_downloads,
extra,
)
wrapped_class.__remote_config__ = routing_config
return wrapped_class
Expand All @@ -287,7 +303,7 @@ async def wrapper(*args, **kwargs):
resource_config
)

stub = stub_resource(remote_resource, **extra)
stub = stub_resource(remote_resource)
return await stub(
func_or_class,
dependencies,
Expand All @@ -302,3 +318,11 @@ async def wrapper(*args, **kwargs):
return wrapper

return decorator


# Derived from remote()'s signature so it stays in sync automatically.
_REMOTE_KNOWN_KWARGS = {
p.name
for p in inspect.signature(remote).parameters.values()
if p.kind != inspect.Parameter.VAR_KEYWORD
}
4 changes: 1 addition & 3 deletions src/runpod_flash/execute_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,6 @@ def create_remote_class(
dependencies: Optional[List[str]],
system_dependencies: Optional[List[str]],
accelerate_downloads: bool,
extra: dict,
):
"""
Create a remote class wrapper.
Expand All @@ -211,7 +210,6 @@ def __init__(self, *args, **kwargs):
self._dependencies = dependencies or []
self._system_dependencies = system_dependencies or []
self._accelerate_downloads = accelerate_downloads
self._extra = extra
self._constructor_args = args
self._constructor_kwargs = kwargs
self._instance_id = (
Expand All @@ -235,7 +233,7 @@ async def _ensure_initialized(self):
remote_resource = await resource_manager.get_or_deploy_resource(
self._resource_config
)
self._stub = stub_resource(remote_resource, **self._extra)
self._stub = stub_resource(remote_resource)

# Create the remote instance by calling a method (which will trigger instance creation)
# We'll do this on first method call
Expand Down
26 changes: 13 additions & 13 deletions src/runpod_flash/stubs/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,14 @@


@singledispatch
def stub_resource(resource, **extra):
def stub_resource(resource):
async def fallback(*args, **kwargs):
return {"error": f"Cannot stub {resource.__class__.__name__}."}

return fallback


def _create_live_serverless_stub(resource, **extra):
def _create_live_serverless_stub(resource):
"""Create a live serverless stub for both LiveServerless and CpuLiveServerless."""
stub = LiveServerlessStub(resource)

Expand Down Expand Up @@ -105,17 +105,17 @@ async def wrapped_class_method(request):


@stub_resource.register(LiveServerless)
def _(resource, **extra):
return _create_live_serverless_stub(resource, **extra)
def _(resource):
return _create_live_serverless_stub(resource)


@stub_resource.register(CpuLiveServerless)
def _(resource, **extra):
return _create_live_serverless_stub(resource, **extra)
def _(resource):
return _create_live_serverless_stub(resource)


@stub_resource.register(ServerlessEndpoint)
def _(resource, **extra):
def _(resource):
async def stubbed_resource(
func,
dependencies,
Expand All @@ -132,14 +132,14 @@ async def stubbed_resource(

stub = ServerlessEndpointStub(resource)
payload = stub.prepare_payload(func, *args, **kwargs)
response = await stub.execute(payload, sync=extra.get("sync", False))
response = await stub.execute(payload, sync=False)
return stub.handle_response(response)

return stubbed_resource


@stub_resource.register(CpuServerlessEndpoint)
def _(resource, **extra):
def _(resource):
async def stubbed_resource(
func,
dependencies,
Expand All @@ -156,14 +156,14 @@ async def stubbed_resource(

stub = ServerlessEndpointStub(resource)
payload = stub.prepare_payload(func, *args, **kwargs)
response = await stub.execute(payload, sync=extra.get("sync", False))
response = await stub.execute(payload, sync=False)
return stub.handle_response(response)

return stubbed_resource


@stub_resource.register(LoadBalancerSlsResource)
def _(resource, **extra):
def _(resource):
"""Create stub for LoadBalancerSlsResource (HTTP-based execution)."""
stub = LoadBalancerSlsStub(resource)

Expand All @@ -188,7 +188,7 @@ async def stubbed_resource(


@stub_resource.register(LiveLoadBalancer)
def _(resource, **extra):
def _(resource):
"""Create stub for LiveLoadBalancer (HTTP-based execution, local testing)."""
stub = LoadBalancerSlsStub(resource)

Expand All @@ -213,7 +213,7 @@ async def stubbed_resource(


@stub_resource.register(CpuLiveLoadBalancer)
def _(resource, **extra):
def _(resource):
"""Create stub for CpuLiveLoadBalancer (HTTP-based execution, local testing)."""
stub = LoadBalancerSlsStub(resource)

Expand Down
21 changes: 7 additions & 14 deletions tests/integration/test_class_execution_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ async def test_remote_decorator_on_class(self):
resource_config=self.mock_resource_config,
dependencies=self.dependencies,
system_dependencies=self.system_dependencies,
timeout=60,
)
class RemoteCalculator:
def __init__(self, initial_value=0):
Expand Down Expand Up @@ -206,7 +205,7 @@ def get_state(self):
}

RemoteCounter = create_remote_class(
StatefulCounter, self.mock_resource_config, [], [], True, {}
StatefulCounter, self.mock_resource_config, [], [], True
)

counter = RemoteCounter(5)
Expand Down Expand Up @@ -278,7 +277,7 @@ def get_completed_count(self):
return self.tasks_completed

RemoteWorker = create_remote_class(
AsyncWorker, self.mock_resource_config, [], [], True, {}
AsyncWorker, self.mock_resource_config, [], [], True
)

worker = RemoteWorker()
Expand Down Expand Up @@ -378,7 +377,6 @@ def process_with_config(self, input_data):
["scikit-learn", "pandas"],
[], # system_dependencies
True, # accelerate_downloads
{}, # extra
)

model = RemoteModel(
Expand Down Expand Up @@ -479,7 +477,7 @@ def get_service_info(self):
api_keys = ["key1", "key2", "key3"]

RemoteDataService = create_remote_class(
DataService, self.mock_resource_config, ["psycopg2"], [], True, {}
DataService, self.mock_resource_config, ["psycopg2"], [], True
)

service = RemoteDataService(db_conn, cache_conf, api_keys=api_keys)
Expand Down Expand Up @@ -550,7 +548,7 @@ def safe_method(self):
return "This always works"

RemoteErrorProneClass = create_remote_class(
ErrorProneClass, self.mock_resource_config, [], [], True, {}
ErrorProneClass, self.mock_resource_config, [], [], True
)

error_instance = RemoteErrorProneClass(should_fail=True)
Expand Down Expand Up @@ -586,7 +584,7 @@ def simple_method(self):
return "hello"

RemoteSimpleClass = create_remote_class(
SimpleClass, self.mock_resource_config, [], [], True, {}
SimpleClass, self.mock_resource_config, [], [], True
)

instance = RemoteSimpleClass()
Expand Down Expand Up @@ -622,7 +620,7 @@ def process_file(self):

with tempfile.NamedTemporaryFile() as temp_file:
RemoteUnserializableClass = create_remote_class(
UnserializableClass, self.mock_resource_config, [], [], True, {}
UnserializableClass, self.mock_resource_config, [], [], True
)

# This should not fail during initialization (lazy serialization)
Expand Down Expand Up @@ -674,7 +672,6 @@ def slow_method(self, duration):
[],
[],
True,
{"timeout": 5}, # 5 second timeout
)

instance = RemoteSlowClass()
Expand Down Expand Up @@ -709,17 +706,14 @@ def test_invalid_class_type_error(self):
[],
[],
True,
{},
)

# Test with function instead of class
def not_a_class():
pass

with pytest.raises(TypeError, match="Expected a class"):
create_remote_class(
not_a_class, self.mock_resource_config, [], [], True, {}
)
create_remote_class(not_a_class, self.mock_resource_config, [], [], True)

# Note: Testing class without __name__ is not practically possible
# since Python classes always have __name__ attribute
Expand All @@ -741,7 +735,6 @@ def use_dependency(self):
["nonexistent-package==999.999.999"], # Invalid package
[],
True,
{},
)

instance = RemoteDependentClass()
Expand Down
22 changes: 11 additions & 11 deletions tests/unit/test_class_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def __init__(self, value):
self.value = value

RemoteCacheTestClass = create_remote_class(
CacheTestClass, self.mock_resource_config, [], [], True, {}
CacheTestClass, self.mock_resource_config, [], [], True
)

# First instance - should be cache miss
Expand Down Expand Up @@ -177,7 +177,7 @@ def __init__(self, x, y=None):
self.y = y

RemoteMultiArgClass = create_remote_class(
MultiArgClass, self.mock_resource_config, [], [], True, {}
MultiArgClass, self.mock_resource_config, [], [], True
)

# Different args should create different cache entries
Expand All @@ -198,7 +198,7 @@ def __init__(self, file_handle, name="default"):
self.name = name

RemoteFileHandlerClass = create_remote_class(
FileHandlerClass, self.mock_resource_config, [], [], True, {}
FileHandlerClass, self.mock_resource_config, [], [], True
)

with tempfile.NamedTemporaryFile() as temp_file:
Expand All @@ -224,7 +224,7 @@ def __init__(self, value):
self.value = value

RemoteOptimizationTestClass = create_remote_class(
OptimizationTestClass, self.mock_resource_config, [], [], True, {}
OptimizationTestClass, self.mock_resource_config, [], [], True
)

with patch(
Expand Down Expand Up @@ -252,7 +252,7 @@ def get_value(self):
return self.value

RemoteConsistencyTestClass = create_remote_class(
ConsistencyTestClass, self.mock_resource_config, [], [], True, {}
ConsistencyTestClass, self.mock_resource_config, [], [], True
)

instance1 = RemoteConsistencyTestClass(1)
Expand All @@ -275,7 +275,7 @@ def __init__(self, file_handle):
self.file_handle = file_handle

RemoteUUIDFallbackClass = create_remote_class(
UUIDFallbackClass, self.mock_resource_config, [], [], True, {}
UUIDFallbackClass, self.mock_resource_config, [], [], True
)

with (
Expand All @@ -301,7 +301,7 @@ def __init__(self, value):
self.value = value

RemoteMemoryTestClass = create_remote_class(
MemoryTestClass, self.mock_resource_config, [], [], True, {}
MemoryTestClass, self.mock_resource_config, [], [], True
)

# Create many instances with same args - should only create one cache entry
Expand All @@ -325,10 +325,10 @@ def __init__(self, value):
self.value = value

RemoteClassTypeA = create_remote_class(
ClassTypeA, self.mock_resource_config, [], [], True, {}
ClassTypeA, self.mock_resource_config, [], [], True
)
RemoteClassTypeB = create_remote_class(
ClassTypeB, self.mock_resource_config, [], [], True, {}
ClassTypeB, self.mock_resource_config, [], [], True
)

instanceA = RemoteClassTypeA(42)
Expand Down Expand Up @@ -360,7 +360,7 @@ def __init__(self, value, config=None):
)

RemoteStructureTestClass = create_remote_class(
StructureTestClass, resource_config, [], [], True, {}
StructureTestClass, resource_config, [], [], True
)

instance = RemoteStructureTestClass(42, config={"key": "value"})
Expand Down Expand Up @@ -403,7 +403,7 @@ def __init__(self, data):
)

RemoteSerializationTestClass = create_remote_class(
SerializationTestClass, resource_config, [], [], True, {}
SerializationTestClass, resource_config, [], [], True
)

test_data = {"test": [1, 2, 3]}
Expand Down
Loading
Loading