Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 17 additions & 28 deletions tests/data/test_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import pytest

from factorium.data.cache import BarCache

from factorium.storage import LocalStorageBackend


Expand All @@ -20,6 +19,13 @@ def temp_cache_dir():
yield Path(tmpdir)


@pytest.fixture
def cache(temp_cache_dir):
"""Create cache using non-deprecated storage API."""
storage = LocalStorageBackend(str(temp_cache_dir))
return BarCache(storage=storage, cache_prefix="")


@pytest.fixture
def sample_bar_df():
"""Create sample bar DataFrame in Polars format."""
Expand All @@ -41,16 +47,13 @@ def sample_bar_df():
class TestBarCache:
"""Tests for BarCache."""

def test_cache_initialization(self, temp_cache_dir):
def test_cache_initialization(self, temp_cache_dir, cache):
"""Test cache initializes correctly."""
cache = BarCache(storage=LocalStorageBackend(str(temp_cache_dir)))
assert isinstance(cache.storage, LocalStorageBackend)
assert cache.cache_dir is None
assert temp_cache_dir.exists()

def test_cache_miss_returns_none(self, temp_cache_dir):
def test_cache_miss_returns_none(self, cache):
"""Test that cache miss returns None."""
cache = BarCache(storage=LocalStorageBackend(str(temp_cache_dir)))

result = cache.get(
exchange="binance",
symbols=["BTCUSDT"],
Expand All @@ -62,10 +65,8 @@ def test_cache_miss_returns_none(self, temp_cache_dir):

assert result is None

def test_cache_put_and_get(self, temp_cache_dir, sample_bar_df):
def test_cache_put_and_get(self, cache, sample_bar_df):
"""Test putting and getting from cache."""
cache = BarCache(storage=LocalStorageBackend(str(temp_cache_dir)))

cache.put(
df=sample_bar_df,
exchange="binance",
Expand All @@ -88,10 +89,8 @@ def test_cache_put_and_get(self, temp_cache_dir, sample_bar_df):
assert result is not None
assert len(result) == len(sample_bar_df)

def test_cache_key_different_symbols(self, temp_cache_dir, sample_bar_df):
def test_cache_key_different_symbols(self, cache, sample_bar_df):
"""Test that different symbols produce different cache keys."""
cache = BarCache(storage=LocalStorageBackend(str(temp_cache_dir)))

cache.put(
df=sample_bar_df,
exchange="binance",
Expand All @@ -113,10 +112,8 @@ def test_cache_key_different_symbols(self, temp_cache_dir, sample_bar_df):

assert result is None

def test_cache_key_different_interval(self, temp_cache_dir, sample_bar_df):
def test_cache_key_different_interval(self, cache, sample_bar_df):
"""Test that different intervals produce different cache keys."""
cache = BarCache(storage=LocalStorageBackend(str(temp_cache_dir)))

cache.put(
df=sample_bar_df,
exchange="binance",
Expand All @@ -138,10 +135,8 @@ def test_cache_key_different_interval(self, temp_cache_dir, sample_bar_df):

assert result is None

def test_cache_daily_files(self, temp_cache_dir, sample_bar_df):
def test_cache_daily_files(self, temp_cache_dir, cache, sample_bar_df):
"""Test that cache creates daily files."""
cache = BarCache(storage=LocalStorageBackend(str(temp_cache_dir)))

cache.put(
df=sample_bar_df,
exchange="binance",
Expand All @@ -156,10 +151,8 @@ def test_cache_daily_files(self, temp_cache_dir, sample_bar_df):
assert len(cache_files) == 1
assert "2024-01-15" in cache_files[0].name

def test_get_date_range(self, temp_cache_dir, sample_bar_df):
def test_get_date_range(self, cache, sample_bar_df):
"""Test getting data for a date range."""
cache = BarCache(storage=LocalStorageBackend(str(temp_cache_dir)))

for day in range(1, 4):
cache.put(
df=sample_bar_df,
Expand All @@ -184,10 +177,8 @@ def test_get_date_range(self, temp_cache_dir, sample_bar_df):
assert result is not None
assert len(result) == len(sample_bar_df) * 3

def test_get_date_range_partial_miss(self, temp_cache_dir, sample_bar_df):
def test_get_date_range_partial_miss(self, cache, sample_bar_df):
"""Test that partial cache miss returns None for range."""
cache = BarCache(storage=LocalStorageBackend(str(temp_cache_dir)))

for day in [1, 3]:
cache.put(
df=sample_bar_df,
Expand All @@ -211,10 +202,8 @@ def test_get_date_range_partial_miss(self, temp_cache_dir, sample_bar_df):

assert result is None

def test_clear_cache(self, temp_cache_dir, sample_bar_df):
def test_clear_cache(self, cache, sample_bar_df):
"""Test clearing the cache."""
cache = BarCache(storage=LocalStorageBackend(str(temp_cache_dir)))

cache.put(
df=sample_bar_df,
exchange="binance",
Expand Down
27 changes: 12 additions & 15 deletions tests/data/test_cache_polars.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,13 @@ def temp_cache_dir():
yield Path(tmpdir)


@pytest.fixture
def cache(temp_cache_dir):
"""Create cache using non-deprecated storage API."""
storage = LocalStorageBackend(str(temp_cache_dir))
return BarCache(storage=storage, cache_prefix="")


@pytest.fixture
def sample_bar_df_polars():
"""Create sample bar DataFrame in Polars format."""
Expand All @@ -40,10 +47,8 @@ def sample_bar_df_polars():
class TestBarCachePolars:
"""Tests for BarCache with Polars DataFrames."""

def test_put_and_get_polars(self, temp_cache_dir, sample_bar_df_polars):
def test_put_and_get_polars(self, cache, sample_bar_df_polars):
"""Test storing and retrieving Polars DataFrame from cache."""
cache = BarCache(storage=LocalStorageBackend(str(temp_cache_dir)))

cache.put(
df=sample_bar_df_polars,
exchange="binance",
Expand All @@ -68,10 +73,8 @@ def test_put_and_get_polars(self, temp_cache_dir, sample_bar_df_polars):
assert len(result) == len(sample_bar_df_polars)
assert result.shape == sample_bar_df_polars.shape

def test_get_returns_none_when_not_cached(self, temp_cache_dir):
def test_get_returns_none_when_not_cached(self, cache):
"""Test that get returns None if data not in cache."""
cache = BarCache(storage=LocalStorageBackend(str(temp_cache_dir)))

result = cache.get(
exchange="binance",
symbols=["BTCUSDT"],
Expand All @@ -83,10 +86,8 @@ def test_get_returns_none_when_not_cached(self, temp_cache_dir):

assert result is None

def test_get_range_returns_polars(self, temp_cache_dir, sample_bar_df_polars):
def test_get_range_returns_polars(self, cache, sample_bar_df_polars):
"""Test get_range returns concatenated Polars DataFrame."""
cache = BarCache(storage=LocalStorageBackend(str(temp_cache_dir)))

# Store data for 3 consecutive days
for day in range(1, 4):
cache.put(
Expand All @@ -113,10 +114,8 @@ def test_get_range_returns_polars(self, temp_cache_dir, sample_bar_df_polars):
assert isinstance(result, pl.DataFrame)
assert len(result) == len(sample_bar_df_polars) * 3

def test_get_range_returns_none_if_any_missing(self, temp_cache_dir, sample_bar_df_polars):
def test_get_range_returns_none_if_any_missing(self, cache, sample_bar_df_polars):
"""Test get_range returns None if any day missing from range."""
cache = BarCache(storage=LocalStorageBackend(str(temp_cache_dir)))

# Store data for days 1 and 3, but skip day 2
for day in [1, 3]:
cache.put(
Expand All @@ -142,10 +141,8 @@ def test_get_range_returns_none_if_any_missing(self, temp_cache_dir, sample_bar_
# Should return None because day 2 is missing
assert result is None

def test_put_and_get_preserves_data_types(self, temp_cache_dir):
def test_put_and_get_preserves_data_types(self, cache):
"""Test that data types are preserved through cache round-trip."""
cache = BarCache(storage=LocalStorageBackend(str(temp_cache_dir)))

df = pl.DataFrame(
{
"symbol": ["BTCUSDT", "ETHUSDT"],
Expand Down