From b316035908946beb8b98785dd3e3d47dd5d8b456 Mon Sep 17 00:00:00 2001 From: zeke <40004347+KAJdev@users.noreply.github.com> Date: Wed, 25 Mar 2026 14:40:50 -0700 Subject: [PATCH 1/3] fix: enforce payload size limit and timeout on deserialization --- src/runpod_flash/runtime/config.py | 7 +- src/runpod_flash/runtime/exceptions.py | 12 +++ src/runpod_flash/runtime/serialization.py | 68 +++++++++++++-- tests/unit/runtime/test_serialization.py | 102 ++++++++++++++++++++-- 4 files changed, 173 insertions(+), 16 deletions(-) diff --git a/src/runpod_flash/runtime/config.py b/src/runpod_flash/runtime/config.py index 974bb5d5..725ffd01 100644 --- a/src/runpod_flash/runtime/config.py +++ b/src/runpod_flash/runtime/config.py @@ -9,4 +9,9 @@ DEFAULT_CACHE_TTL = 300 # seconds # Serialization limits -MAX_PAYLOAD_SIZE = 10 * 1024 * 1024 # 10MB +# max size of a single base64-encoded argument before decoding. +# base64 expands data by ~33%, so 10 MB encoded is ~7.5 MB decoded. +MAX_PAYLOAD_SIZE = 10 * 1024 * 1024 # 10 MB + +# max wall-clock seconds for a single cloudpickle.loads() call +DESERIALIZE_TIMEOUT_SECONDS = 30 diff --git a/src/runpod_flash/runtime/exceptions.py b/src/runpod_flash/runtime/exceptions.py index 90821d4d..520f129f 100644 --- a/src/runpod_flash/runtime/exceptions.py +++ b/src/runpod_flash/runtime/exceptions.py @@ -19,6 +19,18 @@ class SerializationError(FlashRuntimeError): pass +class PayloadTooLargeError(SerializationError): + """Raised when a serialized argument exceeds MAX_PAYLOAD_SIZE.""" + + pass + + +class DeserializeTimeoutError(SerializationError): + """Raised when cloudpickle.loads() exceeds DESERIALIZE_TIMEOUT_SECONDS.""" + + pass + + class GraphQLError(FlashRuntimeError): """Base exception for GraphQL-related errors.""" diff --git a/src/runpod_flash/runtime/serialization.py b/src/runpod_flash/runtime/serialization.py index c063feb3..e687acdb 100644 --- a/src/runpod_flash/runtime/serialization.py +++ b/src/runpod_flash/runtime/serialization.py @@ -1,11 +1,13 @@ """Shared serialization utilities for cloudpickle + base64 encoding.""" import base64 +import concurrent.futures from typing import Any, Dict, List import cloudpickle -from .exceptions import SerializationError +from .config import DESERIALIZE_TIMEOUT_SECONDS, MAX_PAYLOAD_SIZE +from .exceptions import DeserializeTimeoutError, PayloadTooLargeError, SerializationError def serialize_arg(arg: Any) -> str: @@ -66,9 +68,51 @@ def serialize_kwargs(kwargs: dict) -> Dict[str, str]: raise SerializationError(f"Failed to serialize kwargs: {e}") from e +def _check_payload_size(data: str) -> None: + """Reject a base64-encoded payload that exceeds MAX_PAYLOAD_SIZE. + + Raises: + PayloadTooLargeError: If len(data) > MAX_PAYLOAD_SIZE. + """ + size = len(data) + if size > MAX_PAYLOAD_SIZE: + limit_mb = MAX_PAYLOAD_SIZE / (1024 * 1024) + actual_mb = size / (1024 * 1024) + raise PayloadTooLargeError( + f"Payload size {actual_mb:.1f} MB exceeds limit of {limit_mb:.1f} MB" + ) + + +def _unpickle_with_timeout(data: bytes, timeout: int) -> Any: + """Run cloudpickle.loads in a worker thread with a wall-clock timeout. + + Args: + data: Pickled bytes to deserialize. + timeout: Maximum seconds to allow. + + Returns: + Deserialized Python object. + + Raises: + DeserializeTimeoutError: If deserialization exceeds timeout. + """ + with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool: + future = pool.submit(cloudpickle.loads, data) + try: + return future.result(timeout=timeout) + except concurrent.futures.TimeoutError: + future.cancel() + raise DeserializeTimeoutError( + f"Deserialization timed out after {timeout}s" + ) + + def deserialize_arg(arg_b64: str) -> Any: """Deserialize single base64-encoded cloudpickle argument. + Validates payload size before decoding and applies a wall-clock + timeout to the unpickle step. + Args: arg_b64: Base64-encoded serialized argument. @@ -76,10 +120,16 @@ def deserialize_arg(arg_b64: str) -> Any: Deserialized argument. Raises: - SerializationError: If deserialization fails. + PayloadTooLargeError: If the encoded payload exceeds MAX_PAYLOAD_SIZE. + DeserializeTimeoutError: If cloudpickle.loads exceeds DESERIALIZE_TIMEOUT_SECONDS. + SerializationError: If deserialization fails for any other reason. """ try: - return cloudpickle.loads(base64.b64decode(arg_b64)) + _check_payload_size(arg_b64) + raw = base64.b64decode(arg_b64) + return _unpickle_with_timeout(raw, DESERIALIZE_TIMEOUT_SECONDS) + except (PayloadTooLargeError, DeserializeTimeoutError): + raise except Exception as e: raise SerializationError(f"Failed to deserialize argument: {e}") from e @@ -94,11 +144,13 @@ def deserialize_args(args_b64: List[str]) -> List[Any]: List of deserialized arguments. Raises: - SerializationError: If deserialization fails. + PayloadTooLargeError: If any encoded argument exceeds MAX_PAYLOAD_SIZE. + DeserializeTimeoutError: If any cloudpickle.loads exceeds the timeout. + SerializationError: If deserialization fails for any other reason. """ try: return [deserialize_arg(arg) for arg in args_b64] - except SerializationError: + except (PayloadTooLargeError, DeserializeTimeoutError, SerializationError): raise except Exception as e: raise SerializationError(f"Failed to deserialize args: {e}") from e @@ -114,11 +166,13 @@ def deserialize_kwargs(kwargs_b64: Dict[str, str]) -> Dict[str, Any]: Dictionary with deserialized values. Raises: - SerializationError: If deserialization fails. + PayloadTooLargeError: If any encoded value exceeds MAX_PAYLOAD_SIZE. + DeserializeTimeoutError: If any cloudpickle.loads exceeds the timeout. + SerializationError: If deserialization fails for any other reason. """ try: return {k: deserialize_arg(v) for k, v in kwargs_b64.items()} - except SerializationError: + except (PayloadTooLargeError, DeserializeTimeoutError, SerializationError): raise except Exception as e: raise SerializationError(f"Failed to deserialize kwargs: {e}") from e diff --git a/tests/unit/runtime/test_serialization.py b/tests/unit/runtime/test_serialization.py index 43246c79..d33b427a 100644 --- a/tests/unit/runtime/test_serialization.py +++ b/tests/unit/runtime/test_serialization.py @@ -1,11 +1,20 @@ """Tests for serialization utilities.""" +import time from unittest.mock import patch +import cloudpickle import pytest -from runpod_flash.runtime.exceptions import SerializationError +from runpod_flash.runtime.config import MAX_PAYLOAD_SIZE +from runpod_flash.runtime.exceptions import ( + DeserializeTimeoutError, + PayloadTooLargeError, + SerializationError, +) from runpod_flash.runtime.serialization import ( + _check_payload_size, + _unpickle_with_timeout, deserialize_arg, deserialize_args, deserialize_kwargs, @@ -22,7 +31,6 @@ def test_serialize_simple_arg(self): """Test serializing a simple argument.""" result = serialize_arg(42) assert isinstance(result, str) - # Verify it's valid base64 import base64 decoded = base64.b64decode(result) @@ -106,22 +114,88 @@ def test_serialize_kwargs_unexpected_error(self): serialize_kwargs({"key": 42}) +class TestCheckPayloadSize: + """Test _check_payload_size function.""" + + def test_within_limit(self): + """Payloads within MAX_PAYLOAD_SIZE pass silently.""" + _check_payload_size("a" * 100) + + def test_at_limit(self): + """Payload exactly at MAX_PAYLOAD_SIZE passes.""" + _check_payload_size("a" * MAX_PAYLOAD_SIZE) + + def test_over_limit(self): + """Payload exceeding MAX_PAYLOAD_SIZE raises PayloadTooLargeError.""" + with pytest.raises(PayloadTooLargeError, match="exceeds limit"): + _check_payload_size("a" * (MAX_PAYLOAD_SIZE + 1)) + + def test_error_message_includes_sizes(self): + """Error message reports actual and limit sizes in MB.""" + oversized = "a" * (MAX_PAYLOAD_SIZE + 1) + with pytest.raises(PayloadTooLargeError) as exc_info: + _check_payload_size(oversized) + msg = str(exc_info.value) + assert "MB" in msg + assert "10.0 MB" in msg + + +class TestUnpickleWithTimeout: + """Test _unpickle_with_timeout function.""" + + def test_normal_deserialization(self): + """Small payloads deserialize within the timeout.""" + data = cloudpickle.dumps(42) + assert _unpickle_with_timeout(data, 5) == 42 + + def test_timeout_raises(self): + """A slow unpickle triggers DeserializeTimeoutError.""" + + def slow_loads(data): + time.sleep(5) + return None + + with patch("runpod_flash.runtime.serialization.cloudpickle") as mock_cp: + mock_cp.loads = slow_loads + with pytest.raises(DeserializeTimeoutError, match="timed out"): + _unpickle_with_timeout(b"fake", 1) + + class TestDeserializeArg: """Test deserialize_arg function.""" - def test_deserialize_simple_arg(self): - """Test deserializing a simple argument.""" - # First serialize something + def test_roundtrip(self): + """Serialize then deserialize returns the original value.""" serialized = serialize_arg(42) - # Then deserialize it result = deserialize_arg(serialized) assert result == 42 - def test_deserialize_raises_on_invalid_base64(self): - """Test deserialize_arg raises on invalid base64.""" + def test_raises_on_invalid_base64(self): + """Invalid base64 raises SerializationError.""" with pytest.raises(SerializationError, match="Failed to deserialize argument"): deserialize_arg("not-valid-base64!!!") + def test_rejects_oversized_payload(self): + """Payload larger than MAX_PAYLOAD_SIZE raises PayloadTooLargeError.""" + oversized = "A" * (MAX_PAYLOAD_SIZE + 1) + with pytest.raises(PayloadTooLargeError): + deserialize_arg(oversized) + + def test_timeout_on_slow_unpickle(self): + """Slow cloudpickle.loads raises DeserializeTimeoutError.""" + valid_b64 = serialize_arg("hello") + + def slow_loads(data): + time.sleep(5) + + with patch("runpod_flash.runtime.serialization.cloudpickle") as mock_cp: + mock_cp.loads = slow_loads + with patch( + "runpod_flash.runtime.serialization.DESERIALIZE_TIMEOUT_SECONDS", 1 + ): + with pytest.raises(DeserializeTimeoutError): + deserialize_arg(valid_b64) + class TestDeserializeArgs: """Test deserialize_args function.""" @@ -137,6 +211,12 @@ def test_deserialize_empty_args(self): result = deserialize_args([]) assert result == [] + def test_propagates_payload_too_large(self): + """PayloadTooLargeError from a single arg propagates.""" + oversized = "A" * (MAX_PAYLOAD_SIZE + 1) + with pytest.raises(PayloadTooLargeError): + deserialize_args([oversized]) + def test_deserialize_args_propagates_serialization_error(self): """Test deserialize_args propagates SerializationError.""" with patch( @@ -170,6 +250,12 @@ def test_deserialize_empty_kwargs(self): result = deserialize_kwargs({}) assert result == {} + def test_propagates_payload_too_large(self): + """PayloadTooLargeError from a single kwarg value propagates.""" + oversized = "A" * (MAX_PAYLOAD_SIZE + 1) + with pytest.raises(PayloadTooLargeError): + deserialize_kwargs({"big": oversized}) + def test_deserialize_kwargs_propagates_serialization_error(self): """Test deserialize_kwargs propagates SerializationError.""" with patch( From 5aef72f0bd48741a61d68bf48a04c6e75812edea Mon Sep 17 00:00:00 2001 From: zeke <40004347+KAJdev@users.noreply.github.com> Date: Wed, 25 Mar 2026 14:42:31 -0700 Subject: [PATCH 2/3] fix: ruff formatting --- src/runpod_flash/runtime/serialization.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/runpod_flash/runtime/serialization.py b/src/runpod_flash/runtime/serialization.py index e687acdb..1cbae431 100644 --- a/src/runpod_flash/runtime/serialization.py +++ b/src/runpod_flash/runtime/serialization.py @@ -7,7 +7,11 @@ import cloudpickle from .config import DESERIALIZE_TIMEOUT_SECONDS, MAX_PAYLOAD_SIZE -from .exceptions import DeserializeTimeoutError, PayloadTooLargeError, SerializationError +from .exceptions import ( + DeserializeTimeoutError, + PayloadTooLargeError, + SerializationError, +) def serialize_arg(arg: Any) -> str: @@ -102,9 +106,7 @@ def _unpickle_with_timeout(data: bytes, timeout: int) -> Any: return future.result(timeout=timeout) except concurrent.futures.TimeoutError: future.cancel() - raise DeserializeTimeoutError( - f"Deserialization timed out after {timeout}s" - ) + raise DeserializeTimeoutError(f"Deserialization timed out after {timeout}s") def deserialize_arg(arg_b64: str) -> Any: From c679c46898723558bc3fe92ccc5cbe3f29bb089d Mon Sep 17 00:00:00 2001 From: zeke <40004347+KAJdev@users.noreply.github.com> Date: Wed, 25 Mar 2026 14:47:22 -0700 Subject: [PATCH 3/3] fix: update large payload regression test for size limit --- tests/unit/test_regressions.py | 21 +++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/tests/unit/test_regressions.py b/tests/unit/test_regressions.py index ea272f79..282ad420 100644 --- a/tests/unit/test_regressions.py +++ b/tests/unit/test_regressions.py @@ -207,20 +207,33 @@ async def test_quota_error_propagates_cleanly(self): # REG-007: Large base64 payload (>10MB) handling # --------------------------------------------------------------------------- class TestREG007LargePayload: - """Large serialized payloads should serialize without silent corruption.""" + """Large serialized payloads should round-trip without silent corruption, + and payloads exceeding MAX_PAYLOAD_SIZE are rejected before decoding.""" def test_large_payload_roundtrip(self): - """10MB+ payload survives serialize → deserialize without corruption.""" + """Payload near the size limit survives serialize -> deserialize without corruption.""" from runpod_flash.runtime.serialization import deserialize_arg, serialize_arg - # Create a ~10MB payload - large_data = b"x" * (10 * 1024 * 1024) + # ~7 MB raw -> ~9.3 MB base64, stays under the 10 MB limit + large_data = b"x" * (7 * 1024 * 1024) serialized = serialize_arg(large_data) restored = deserialize_arg(serialized) assert restored == large_data assert len(restored) == len(large_data) + def test_oversized_payload_rejected(self): + """Payloads exceeding MAX_PAYLOAD_SIZE raise PayloadTooLargeError.""" + from runpod_flash.runtime.exceptions import PayloadTooLargeError + from runpod_flash.runtime.serialization import deserialize_arg, serialize_arg + + # 10 MB raw -> ~13.3 MB base64, exceeds the 10 MB limit + large_data = b"x" * (10 * 1024 * 1024) + serialized = serialize_arg(large_data) + + with pytest.raises(PayloadTooLargeError): + deserialize_arg(serialized) + # --------------------------------------------------------------------------- # REG-008: flash env delete works in non-interactive mode (no TTY)