Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions src/factorium/backtest/vectorized.py

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

這兩件事情並不等價,需要查看 backtest 中處理 mask 的方式有沒有對應的改動

Original file line number Diff line number Diff line change
Expand Up @@ -195,9 +195,7 @@ def _calculate_weights(self, df: pl.DataFrame) -> pl.DataFrame:
)

if self._mask is not None:
df = df.with_columns(
pl.when(pl.col(self._mask).fill_null(False)).then(pl.col("weight")).otherwise(0.0).alias("weight")
).drop("_masked_signal")
df = df.drop("_masked_signal")

# Apply constraints
for constraint in self.constraints:
Expand Down
4 changes: 4 additions & 0 deletions src/factorium/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,13 @@
MIN_PERIODS_PER_YEAR = 1.0
MAX_PERIODS_PER_YEAR = 365.25 * 24 * 60 # Minutes in a year

# External API URLs
COINGECKO_BASE_URL = "https://api.coingecko.com/api/v3"

__all__ = [
"EPSILON",
"SECONDS_PER_YEAR",
"MIN_PERIODS_PER_YEAR",
"MAX_PERIODS_PER_YEAR",
"COINGECKO_BASE_URL",
]
11 changes: 6 additions & 5 deletions src/factorium/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"""

import logging
from datetime import datetime, timedelta
from datetime import datetime, timedelta, timezone

logger = logging.getLogger(__name__)

Expand All @@ -23,7 +23,7 @@ def calculate_date_range(
producing duplicate bars with partial OHLCV data.

Priority:
1. If both start_date and end_date are provided: [start, end]
1. If both start_date and end_date are provided: [start, end + 1 day)
2. If start_date and days are provided: [start, start + days]
3. If neither: [today_midnight - default_days, today_midnight + 1]
4. If only days: [today_midnight - days, today_midnight + 1]
Expand All @@ -40,9 +40,10 @@ def calculate_date_range(
try:
if start_date and end_date:
start = datetime.strptime(start_date, "%Y-%m-%d")
end = datetime.strptime(end_date, "%Y-%m-%d")
if start > end:
end_inclusive = datetime.strptime(end_date, "%Y-%m-%d")
if start > end_inclusive:
raise ValueError("Start date must be earlier than or equal to end date")
end = end_inclusive + timedelta(days=1)
return start, end

if start_date and days:
Expand All @@ -52,7 +53,7 @@ def calculate_date_range(
return start, start + timedelta(days=days)

# Snap to UTC midnight for consistent daily boundaries
today_midnight = datetime.now().replace(hour=0, minute=0, second=0, microsecond=0)
today_midnight = datetime.now(timezone.utc).replace(hour=0, minute=0, second=0, microsecond=0)
# end = start of tomorrow (exclusive) to include today's full data
end = today_midnight + timedelta(days=1)

Expand Down
9 changes: 7 additions & 2 deletions src/factorium/factors/analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,8 +243,13 @@ def __init__(self, factor: Factor, prices: AggBar | Factor, quantiles: int = 5,

def _ensure_data_prepared(self, periods: list[int] | None = None, price_col: str | None = None) -> None:
"""Ensure data is prepared. Auto-calls prepare_data() if needed."""
if not hasattr(self, "_clean_data"):
logger.info("Data not prepared. Auto-calling prepare_data()...")
has_missing_period = bool(
periods
and hasattr(self, "_clean_data")
and any(f"period_{p}" not in self._clean_data.columns for p in periods)
)
if not hasattr(self, "_clean_data") or has_missing_period:
logger.info("Data not prepared or missing requested periods. Auto-calling prepare_data()...")
self.prepare_data(periods=periods, price_col=price_col)

def analyze(self, price_col: str = "close", periods: int | list[int] = 1) -> FactorAnalysisResult:
Expand Down
4 changes: 2 additions & 2 deletions src/factorium/universe/metadata.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from __future__ import annotations

import asyncio
import json
import time
from pathlib import Path

import aiohttp

from ..data.loader import _run_async
from .rules import KNOWN_STABLECOINS, LEVERAGED_PATTERNS, SymbolMetadata


Expand Down Expand Up @@ -47,7 +47,7 @@ async def fetch_async(self) -> dict[str, SymbolMetadata]:
return parsed

def fetch(self) -> dict[str, SymbolMetadata]:
return asyncio.run(self.fetch_async())
return _run_async(self.fetch_async())

def _parse_exchange_info(self, data: dict) -> dict[str, SymbolMetadata]:
output: dict[str, SymbolMetadata] = {}
Expand Down
4 changes: 2 additions & 2 deletions src/factorium/universe/rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def apply(
listing_map[sym] = int(listing_date)

if not listing_map:
return pl.lit(True)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

需要檢查這塊的邏輯,為什麼是把 True 直接改成 False?

return pl.lit(False)

listing_expr = pl.col("symbol").replace_strict(listing_map, default=None).cast(pl.Int64, strict=False)
return ((pl.col("start_time") - listing_expr) >= self._min_ms) | listing_expr.is_null()
return ((pl.col("start_time") - listing_expr) >= self._min_ms).fill_null(False)
34 changes: 23 additions & 11 deletions src/factorium/universe/tags.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,26 @@

import asyncio
import json
import logging
import time
from pathlib import Path

import aiohttp

from ..constants import COINGECKO_BASE_URL
from ..data.loader import _run_async

COINGECKO_BASE_URL = "https://api.coingecko.com/api/v3"

logger = logging.getLogger(__name__)


class TagProvider:
"""Fetch and cache token categories from CoinGecko."""
"""Fetch and cache token categories from CoinGecko.

Note:
``symbols`` must be explicitly provided to avoid full-market
category fetching from CoinGecko, which can be very slow.
"""

def __init__(
self,
Expand All @@ -37,12 +46,13 @@ async def _request_json(
return await response.json()

async def fetch_async(self, symbols: list[str] | None = None) -> dict[str, list[str]]:
requested = [s.upper() for s in symbols] if symbols is not None else None
if symbols is None:
raise ValueError("symbols must be provided to avoid fetching the entire CoinGecko database")

requested = [s.upper() for s in symbols]
cached = self._load_cache()

if cached is not None:
if requested is None:
return cached
if all(sym in cached for sym in requested):
return {sym: cached[sym] for sym in requested}

Expand All @@ -56,10 +66,14 @@ async def fetch_async(self, symbols: list[str] | None = None) -> dict[str, list[
for item in raw_list if isinstance(raw_list, list) else []:
symbol = str(item.get("symbol", "")).upper()
coin_id = item.get("id")
if symbol and coin_id and symbol not in symbol_to_id:
symbol_to_id[symbol] = str(coin_id)
if not symbol or not coin_id:
continue

coin_id_str = str(coin_id)
if symbol not in symbol_to_id or coin_id_str == symbol.lower():
symbol_to_id[symbol] = coin_id_str

targets = requested or sorted(symbol_to_id.keys())
targets = requested
result: dict[str, list[str]] = {} if cached is None else dict(cached)

for symbol in targets:
Expand All @@ -76,12 +90,10 @@ async def fetch_async(self, symbols: list[str] | None = None) -> dict[str, list[
await asyncio.sleep(0.12)

self._save_cache(result)
if requested is None:
return result
return {sym: result.get(sym, []) for sym in requested if sym in result}

def fetch(self, symbols: list[str] | None = None) -> dict[str, list[str]]:
return asyncio.run(self.fetch_async(symbols=symbols))
return _run_async(self.fetch_async(symbols=symbols))

def _load_cache(self) -> dict[str, list[str]] | None:
if not self._cache_path.exists():
Expand Down
35 changes: 35 additions & 0 deletions tests/backtest/test_vectorized.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,41 @@ def test_long_only_weights_sum_to_one(self):
if ws > 0:
assert abs(ws - 1.0) < 1e-10

def test_calculate_weights_masked_assets_remain_zero_after_neutralize(self):
timestamps = [1704067200000, 1704070800000, 1704074400000]
rows = []
for i, ts in enumerate(timestamps):
for symbol, base_price, in_universe in [
("A", 100.0, True),
("B", 80.0, True),
("C", 60.0, False),
]:
price = base_price * (1 + 0.01 * i)
rows.append(
{
"start_time": ts,
"end_time": ts + 3600000,
"symbol": symbol,
"open": price,
"high": price,
"low": price,
"close": price,
"volume": 1000.0,
"in_universe": in_universe,
}
)

prices = AggBar(pl.DataFrame(rows))
signal = prices["close"].cs_rank()
bt = VectorizedBacktester(prices=prices, signal=signal, neutralization="market", mask="in_universe")

combined = bt._prepare_data()
weighted = bt._calculate_weights(combined)

masked = weighted.filter(~pl.col("in_universe").fill_null(False))
assert masked["weight"].abs().max() == 0.0
assert "_masked_signal" not in weighted.columns


class TestMetricsCalculation:
"""Tests for metrics calculation."""
Expand Down
21 changes: 21 additions & 0 deletions tests/data/test_timestamp_utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
# tests/data/test_timestamp_utils.py
from datetime import datetime, timedelta, timezone

import polars as pl
import pytest

from factorium.data import utils as data_utils
from factorium.data.loader import (
_convert_to_target_unit,
_detect_timestamp_unit,
_normalize_timestamps_to_ms,
)
from factorium.data.utils import calculate_date_range


def test_detect_timestamp_unit_seconds():
Expand Down Expand Up @@ -39,6 +43,23 @@ def test_convert_to_target_unit_invalid_unit():
_convert_to_target_unit(1704067200000, "invalid")


def test_calculate_date_range_uses_utc_midnight(monkeypatch):
class FakeDateTime(datetime):
@classmethod
def now(cls, tz=None):
if tz is None:
return cls(2026, 2, 14, 23, 30, tzinfo=timezone(timedelta(hours=8)))
return cls(2026, 2, 14, 15, 30, tzinfo=timezone.utc).astimezone(tz)

monkeypatch.setattr(data_utils, "datetime", FakeDateTime)

start, end = calculate_date_range(days=1)

expected_today_midnight = datetime(2026, 2, 14, 0, 0, tzinfo=timezone.utc)
assert start == expected_today_midnight
assert end == expected_today_midnight + timedelta(days=1)


class TestNormalizeTimestampsToMs:
"""Tests for _normalize_timestamps_to_ms function."""

Expand Down
14 changes: 14 additions & 0 deletions tests/factors/test_analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,3 +475,17 @@ def test_analyze_empty_periods_list_raises_error(sample_data):

with pytest.raises(ValueError, match="Periods list cannot be empty"):
analyzer.analyze(periods=[])


def test_ensure_data_prepared_reprepare_when_period_missing(sample_data):
agg = AggBar(sample_data)
factor = agg["my_factor"]
prices = agg["close"]
analyzer = FactorAnalyzer(factor, prices)

analyzer.prepare_data(periods=[1])
assert "period_1" in analyzer._clean_data.columns
assert "period_5" not in analyzer._clean_data.columns

analyzer._ensure_data_prepared(periods=[1, 5])
assert "period_5" in analyzer._clean_data.columns
4 changes: 2 additions & 2 deletions tests/factors/test_safe_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ def _make_factor(
for s_idx, sym in enumerate(symbols):
rows.append(
{
"start_time": t * 60000,
"end_time": (t + 1) * 60000,
"start_time": t * 60_000,
"end_time": (t + 1) * 60_000,
"symbol": sym,
"factor": values[t * n_symbols + s_idx],
}
Expand Down
21 changes: 10 additions & 11 deletions tests/test_data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import pyarrow as pa
import pyarrow.parquet as pq
from pathlib import Path
from datetime import datetime, timedelta
from datetime import datetime, timedelta, timezone
from unittest.mock import patch, MagicMock, AsyncMock
from freezegun import freeze_time

Expand Down Expand Up @@ -136,7 +136,7 @@ def test_with_start_and_end_date(self):
start_dt, end_dt = calculate_date_range(start_date="2024-01-01", end_date="2024-01-07", days=None)

assert start_dt == datetime(2024, 1, 1)
assert end_dt == datetime(2024, 1, 7)
assert end_dt == datetime(2024, 1, 8)

def test_with_start_date_and_days(self):
"""Test with start_date and days specified."""
Expand All @@ -151,8 +151,8 @@ def test_with_only_days(self):
start_dt, end_dt = calculate_date_range(start_date=None, end_date=None, days=7)

# end = start of tomorrow (exclusive), start = end - days
assert end_dt == datetime(2024, 6, 16, 0, 0, 0)
assert start_dt == datetime(2024, 6, 9, 0, 0, 0)
assert end_dt == datetime(2024, 6, 16, 0, 0, 0, tzinfo=timezone.utc)
assert start_dt == datetime(2024, 6, 9, 0, 0, 0, tzinfo=timezone.utc)
# Both must be midnight-aligned
assert start_dt.hour == 0 and start_dt.minute == 0 and start_dt.second == 0
assert end_dt.hour == 0 and end_dt.minute == 0 and end_dt.second == 0
Expand All @@ -162,8 +162,8 @@ def test_default_7_days(self):
"""Test default behavior (no params = 7 days ending tomorrow midnight)."""
start_dt, end_dt = calculate_date_range(start_date=None, end_date=None, days=None)

assert end_dt == datetime(2024, 6, 16, 0, 0, 0)
assert start_dt == datetime(2024, 6, 9, 0, 0, 0)
assert end_dt == datetime(2024, 6, 16, 0, 0, 0, tzinfo=timezone.utc)
assert start_dt == datetime(2024, 6, 9, 0, 0, 0, tzinfo=timezone.utc)
# Both must be midnight-aligned
assert start_dt.hour == 0 and start_dt.minute == 0 and start_dt.second == 0
assert end_dt.hour == 0 and end_dt.minute == 0 and end_dt.second == 0
Expand All @@ -174,8 +174,8 @@ def test_midnight_alignment_regardless_of_time(self):
start_dt, end_dt = calculate_date_range(start_date=None, end_date=None, days=3)

# Should snap to midnight boundaries
assert start_dt == datetime(2024, 6, 13, 0, 0, 0)
assert end_dt == datetime(2024, 6, 16, 0, 0, 0)
assert start_dt == datetime(2024, 6, 13, 0, 0, 0, tzinfo=timezone.utc)
assert end_dt == datetime(2024, 6, 16, 0, 0, 0, tzinfo=timezone.utc)
assert start_dt.microsecond == 0
assert end_dt.microsecond == 0

Expand All @@ -191,14 +191,14 @@ def test_cross_year_range(self):
start_dt, end_dt = calculate_date_range(start_date="2023-12-28", end_date="2024-01-05", days=None)

assert start_dt == datetime(2023, 12, 28)
assert end_dt == datetime(2024, 1, 5)
assert end_dt == datetime(2024, 1, 6)

def test_single_day_range(self):
"""Test single day range (start == end)."""
start_dt, end_dt = calculate_date_range(start_date="2024-01-01", end_date="2024-01-01", days=None)

assert start_dt == datetime(2024, 1, 1)
assert end_dt == datetime(2024, 1, 1)
assert end_dt == datetime(2024, 1, 2)

def test_start_date_with_one_day(self):
"""Test start_date with days=1."""
Expand All @@ -208,7 +208,6 @@ def test_start_date_with_one_day(self):
assert end_dt == datetime(2024, 1, 2)



# =============================================================================
# TestBuildDateFilter - 日期過濾條件測試
# =============================================================================
Expand Down
Loading