Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion src/runpod_flash/runtime/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,9 @@
DEFAULT_CACHE_TTL = 300 # seconds

# Serialization limits
MAX_PAYLOAD_SIZE = 10 * 1024 * 1024 # 10MB
# max size of a single base64-encoded argument before decoding.
# base64 expands data by ~33%, so 10 MB encoded is ~7.5 MB decoded.
MAX_PAYLOAD_SIZE = 10 * 1024 * 1024 # 10 MB

# max wall-clock seconds for a single cloudpickle.loads() call
DESERIALIZE_TIMEOUT_SECONDS = 30
12 changes: 12 additions & 0 deletions src/runpod_flash/runtime/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,18 @@ class SerializationError(FlashRuntimeError):
pass


class PayloadTooLargeError(SerializationError):
"""Raised when a serialized argument exceeds MAX_PAYLOAD_SIZE."""

pass


class DeserializeTimeoutError(SerializationError):
"""Raised when cloudpickle.loads() exceeds DESERIALIZE_TIMEOUT_SECONDS."""

pass


class GraphQLError(FlashRuntimeError):
"""Base exception for GraphQL-related errors."""

Expand Down
70 changes: 63 additions & 7 deletions src/runpod_flash/runtime/serialization.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
"""Shared serialization utilities for cloudpickle + base64 encoding."""

import base64
import concurrent.futures
from typing import Any, Dict, List

import cloudpickle

from .exceptions import SerializationError
from .config import DESERIALIZE_TIMEOUT_SECONDS, MAX_PAYLOAD_SIZE
from .exceptions import (
DeserializeTimeoutError,
PayloadTooLargeError,
SerializationError,
)


def serialize_arg(arg: Any) -> str:
Expand Down Expand Up @@ -66,20 +72,66 @@ def serialize_kwargs(kwargs: dict) -> Dict[str, str]:
raise SerializationError(f"Failed to serialize kwargs: {e}") from e


def _check_payload_size(data: str) -> None:
"""Reject a base64-encoded payload that exceeds MAX_PAYLOAD_SIZE.

Raises:
PayloadTooLargeError: If len(data) > MAX_PAYLOAD_SIZE.
"""
size = len(data)
if size > MAX_PAYLOAD_SIZE:
limit_mb = MAX_PAYLOAD_SIZE / (1024 * 1024)
actual_mb = size / (1024 * 1024)
raise PayloadTooLargeError(
f"Payload size {actual_mb:.1f} MB exceeds limit of {limit_mb:.1f} MB"
)


def _unpickle_with_timeout(data: bytes, timeout: int) -> Any:
"""Run cloudpickle.loads in a worker thread with a wall-clock timeout.

Args:
data: Pickled bytes to deserialize.
timeout: Maximum seconds to allow.

Returns:
Deserialized Python object.

Raises:
DeserializeTimeoutError: If deserialization exceeds timeout.
"""
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool:
future = pool.submit(cloudpickle.loads, data)
try:
return future.result(timeout=timeout)
except concurrent.futures.TimeoutError:
future.cancel()
raise DeserializeTimeoutError(f"Deserialization timed out after {timeout}s")


def deserialize_arg(arg_b64: str) -> Any:
"""Deserialize single base64-encoded cloudpickle argument.

Validates payload size before decoding and applies a wall-clock
timeout to the unpickle step.

Args:
arg_b64: Base64-encoded serialized argument.

Returns:
Deserialized argument.

Raises:
SerializationError: If deserialization fails.
PayloadTooLargeError: If the encoded payload exceeds MAX_PAYLOAD_SIZE.
DeserializeTimeoutError: If cloudpickle.loads exceeds DESERIALIZE_TIMEOUT_SECONDS.
SerializationError: If deserialization fails for any other reason.
"""
try:
return cloudpickle.loads(base64.b64decode(arg_b64))
_check_payload_size(arg_b64)
raw = base64.b64decode(arg_b64)
return _unpickle_with_timeout(raw, DESERIALIZE_TIMEOUT_SECONDS)
except (PayloadTooLargeError, DeserializeTimeoutError):
raise
except Exception as e:
raise SerializationError(f"Failed to deserialize argument: {e}") from e

Expand All @@ -94,11 +146,13 @@ def deserialize_args(args_b64: List[str]) -> List[Any]:
List of deserialized arguments.

Raises:
SerializationError: If deserialization fails.
PayloadTooLargeError: If any encoded argument exceeds MAX_PAYLOAD_SIZE.
DeserializeTimeoutError: If any cloudpickle.loads exceeds the timeout.
SerializationError: If deserialization fails for any other reason.
"""
try:
return [deserialize_arg(arg) for arg in args_b64]
except SerializationError:
except (PayloadTooLargeError, DeserializeTimeoutError, SerializationError):
raise
except Exception as e:
raise SerializationError(f"Failed to deserialize args: {e}") from e
Expand All @@ -114,11 +168,13 @@ def deserialize_kwargs(kwargs_b64: Dict[str, str]) -> Dict[str, Any]:
Dictionary with deserialized values.

Raises:
SerializationError: If deserialization fails.
PayloadTooLargeError: If any encoded value exceeds MAX_PAYLOAD_SIZE.
DeserializeTimeoutError: If any cloudpickle.loads exceeds the timeout.
SerializationError: If deserialization fails for any other reason.
"""
try:
return {k: deserialize_arg(v) for k, v in kwargs_b64.items()}
except SerializationError:
except (PayloadTooLargeError, DeserializeTimeoutError, SerializationError):
raise
except Exception as e:
raise SerializationError(f"Failed to deserialize kwargs: {e}") from e
102 changes: 94 additions & 8 deletions tests/unit/runtime/test_serialization.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,20 @@
"""Tests for serialization utilities."""

import time
from unittest.mock import patch

import cloudpickle
import pytest

from runpod_flash.runtime.exceptions import SerializationError
from runpod_flash.runtime.config import MAX_PAYLOAD_SIZE
from runpod_flash.runtime.exceptions import (
DeserializeTimeoutError,
PayloadTooLargeError,
SerializationError,
)
from runpod_flash.runtime.serialization import (
_check_payload_size,
_unpickle_with_timeout,
deserialize_arg,
deserialize_args,
deserialize_kwargs,
Expand All @@ -22,7 +31,6 @@ def test_serialize_simple_arg(self):
"""Test serializing a simple argument."""
result = serialize_arg(42)
assert isinstance(result, str)
# Verify it's valid base64
import base64

decoded = base64.b64decode(result)
Expand Down Expand Up @@ -106,22 +114,88 @@ def test_serialize_kwargs_unexpected_error(self):
serialize_kwargs({"key": 42})


class TestCheckPayloadSize:
"""Test _check_payload_size function."""

def test_within_limit(self):
"""Payloads within MAX_PAYLOAD_SIZE pass silently."""
_check_payload_size("a" * 100)

def test_at_limit(self):
"""Payload exactly at MAX_PAYLOAD_SIZE passes."""
_check_payload_size("a" * MAX_PAYLOAD_SIZE)

def test_over_limit(self):
"""Payload exceeding MAX_PAYLOAD_SIZE raises PayloadTooLargeError."""
with pytest.raises(PayloadTooLargeError, match="exceeds limit"):
_check_payload_size("a" * (MAX_PAYLOAD_SIZE + 1))

def test_error_message_includes_sizes(self):
"""Error message reports actual and limit sizes in MB."""
oversized = "a" * (MAX_PAYLOAD_SIZE + 1)
with pytest.raises(PayloadTooLargeError) as exc_info:
_check_payload_size(oversized)
msg = str(exc_info.value)
assert "MB" in msg
assert "10.0 MB" in msg


class TestUnpickleWithTimeout:
"""Test _unpickle_with_timeout function."""

def test_normal_deserialization(self):
"""Small payloads deserialize within the timeout."""
data = cloudpickle.dumps(42)
assert _unpickle_with_timeout(data, 5) == 42

def test_timeout_raises(self):
"""A slow unpickle triggers DeserializeTimeoutError."""

def slow_loads(data):
time.sleep(5)
return None

with patch("runpod_flash.runtime.serialization.cloudpickle") as mock_cp:
mock_cp.loads = slow_loads
with pytest.raises(DeserializeTimeoutError, match="timed out"):
_unpickle_with_timeout(b"fake", 1)


class TestDeserializeArg:
"""Test deserialize_arg function."""

def test_deserialize_simple_arg(self):
"""Test deserializing a simple argument."""
# First serialize something
def test_roundtrip(self):
"""Serialize then deserialize returns the original value."""
serialized = serialize_arg(42)
# Then deserialize it
result = deserialize_arg(serialized)
assert result == 42

def test_deserialize_raises_on_invalid_base64(self):
"""Test deserialize_arg raises on invalid base64."""
def test_raises_on_invalid_base64(self):
"""Invalid base64 raises SerializationError."""
with pytest.raises(SerializationError, match="Failed to deserialize argument"):
deserialize_arg("not-valid-base64!!!")

def test_rejects_oversized_payload(self):
"""Payload larger than MAX_PAYLOAD_SIZE raises PayloadTooLargeError."""
oversized = "A" * (MAX_PAYLOAD_SIZE + 1)
with pytest.raises(PayloadTooLargeError):
deserialize_arg(oversized)

def test_timeout_on_slow_unpickle(self):
"""Slow cloudpickle.loads raises DeserializeTimeoutError."""
valid_b64 = serialize_arg("hello")

def slow_loads(data):
time.sleep(5)

with patch("runpod_flash.runtime.serialization.cloudpickle") as mock_cp:
mock_cp.loads = slow_loads
with patch(
"runpod_flash.runtime.serialization.DESERIALIZE_TIMEOUT_SECONDS", 1
):
with pytest.raises(DeserializeTimeoutError):
deserialize_arg(valid_b64)


class TestDeserializeArgs:
"""Test deserialize_args function."""
Expand All @@ -137,6 +211,12 @@ def test_deserialize_empty_args(self):
result = deserialize_args([])
assert result == []

def test_propagates_payload_too_large(self):
"""PayloadTooLargeError from a single arg propagates."""
oversized = "A" * (MAX_PAYLOAD_SIZE + 1)
with pytest.raises(PayloadTooLargeError):
deserialize_args([oversized])

def test_deserialize_args_propagates_serialization_error(self):
"""Test deserialize_args propagates SerializationError."""
with patch(
Expand Down Expand Up @@ -170,6 +250,12 @@ def test_deserialize_empty_kwargs(self):
result = deserialize_kwargs({})
assert result == {}

def test_propagates_payload_too_large(self):
"""PayloadTooLargeError from a single kwarg value propagates."""
oversized = "A" * (MAX_PAYLOAD_SIZE + 1)
with pytest.raises(PayloadTooLargeError):
deserialize_kwargs({"big": oversized})

def test_deserialize_kwargs_propagates_serialization_error(self):
"""Test deserialize_kwargs propagates SerializationError."""
with patch(
Expand Down
21 changes: 17 additions & 4 deletions tests/unit/test_regressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,20 +207,33 @@ async def test_quota_error_propagates_cleanly(self):
# REG-007: Large base64 payload (>10MB) handling
# ---------------------------------------------------------------------------
class TestREG007LargePayload:
"""Large serialized payloads should serialize without silent corruption."""
"""Large serialized payloads should round-trip without silent corruption,
and payloads exceeding MAX_PAYLOAD_SIZE are rejected before decoding."""

def test_large_payload_roundtrip(self):
"""10MB+ payload survives serialize deserialize without corruption."""
"""Payload near the size limit survives serialize -> deserialize without corruption."""
from runpod_flash.runtime.serialization import deserialize_arg, serialize_arg

# Create a ~10MB payload
large_data = b"x" * (10 * 1024 * 1024)
# ~7 MB raw -> ~9.3 MB base64, stays under the 10 MB limit
large_data = b"x" * (7 * 1024 * 1024)
serialized = serialize_arg(large_data)
restored = deserialize_arg(serialized)

assert restored == large_data
assert len(restored) == len(large_data)

def test_oversized_payload_rejected(self):
"""Payloads exceeding MAX_PAYLOAD_SIZE raise PayloadTooLargeError."""
from runpod_flash.runtime.exceptions import PayloadTooLargeError
from runpod_flash.runtime.serialization import deserialize_arg, serialize_arg

# 10 MB raw -> ~13.3 MB base64, exceeds the 10 MB limit
large_data = b"x" * (10 * 1024 * 1024)
serialized = serialize_arg(large_data)

with pytest.raises(PayloadTooLargeError):
deserialize_arg(serialized)


# ---------------------------------------------------------------------------
# REG-008: flash env delete works in non-interactive mode (no TTY)
Expand Down
Loading