Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 97 additions & 21 deletions src/sqrl/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,18 @@
Uses LiteLLM for provider-agnostic embeddings.
"""

import asyncio
import logging
import os
import struct
from dataclasses import dataclass
import time
from dataclasses import dataclass, field
from typing import Optional

import litellm

logger = logging.getLogger(__name__)


class EmbeddingError(Exception):
"""Error during embedding generation."""
Expand All @@ -34,26 +39,51 @@ class EmbeddingConfig:
Attributes:
model: Embedding model to use (default: text-embedding-3-small)
dimensions: Expected embedding dimensions (default: 1536)
max_retries: Maximum retry attempts on failure (default: 3)
retry_delay: Base delay between retries in seconds (default: 1.0)
retry_backoff: Backoff multiplier for retry delay (default: 2.0)
"""

model: str = "text-embedding-3-small"
dimensions: int = 1536
max_retries: int = 3
retry_delay: float = 1.0
retry_backoff: float = 2.0

@classmethod
def from_env(cls) -> "EmbeddingConfig":
"""Create config from environment variables."""
return cls(
model=os.getenv("SQRL_EMBEDDING_MODEL", "text-embedding-3-small"),
dimensions=int(os.getenv("SQRL_EMBEDDING_DIMS", "1536")),
max_retries=int(os.getenv("SQRL_EMBEDDING_MAX_RETRIES", "3")),
retry_delay=float(os.getenv("SQRL_EMBEDDING_RETRY_DELAY", "1.0")),
retry_backoff=float(os.getenv("SQRL_EMBEDDING_RETRY_BACKOFF", "2.0")),
)


def _is_retryable_error(e: Exception) -> bool:
"""Check if an error is retryable (transient network/API issues)."""
error_str = str(e).lower()
retryable_patterns = [
"rate limit",
"timeout",
"connection",
"temporarily unavailable",
"503",
"502",
"500",
"429",
]
return any(pattern in error_str for pattern in retryable_patterns)


async def embed_text(
text: str,
config: Optional[EmbeddingConfig] = None,
) -> list[float]:
"""
Generate embedding for text (IPC-002).
Generate embedding for text (IPC-002) with retry logic.

Args:
text: Text to embed (must not be empty)
Expand All @@ -63,29 +93,52 @@ async def embed_text(
1536-dim float32 vector

Raises:
EmbeddingError: With code -32040 if text empty, -32041 if API fails
EmbeddingError: With code -32040 if text empty, -32041 if API fails after retries
"""
if not text or not text.strip():
raise EmbeddingError(ERROR_EMPTY_TEXT, "Empty text")

cfg = config or EmbeddingConfig.from_env()

try:
response = await litellm.aembedding(
model=cfg.model,
input=text,
)
return response.data[0]["embedding"]
except Exception as e:
raise EmbeddingError(ERROR_EMBEDDING_FAILED, f"Embedding error: {e}") from e
last_error: Optional[Exception] = None
delay = cfg.retry_delay

for attempt in range(cfg.max_retries + 1):
try:
response = await litellm.aembedding(
model=cfg.model,
input=text,
)
return response.data[0]["embedding"]
except Exception as e:
last_error = e

# Don't retry on non-retryable errors
if not _is_retryable_error(e):
logger.warning(f"Embedding failed (non-retryable): {e}")
break

# Don't wait after the last attempt
if attempt < cfg.max_retries:
logger.warning(
f"Embedding attempt {attempt + 1}/{cfg.max_retries + 1} failed: {e}. "
f"Retrying in {delay:.1f}s..."
)
await asyncio.sleep(delay)
delay *= cfg.retry_backoff

raise EmbeddingError(
ERROR_EMBEDDING_FAILED,
f"Embedding failed after {cfg.max_retries + 1} attempts: {last_error}",
) from last_error


def embed_text_sync(
text: str,
config: Optional[EmbeddingConfig] = None,
) -> list[float]:
"""
Synchronous version of embed_text (IPC-002).
Synchronous version of embed_text (IPC-002) with retry logic.

Args:
text: Text to embed (must not be empty)
Expand All @@ -95,21 +148,44 @@ def embed_text_sync(
1536-dim float32 vector

Raises:
EmbeddingError: With code -32040 if text empty, -32041 if API fails
EmbeddingError: With code -32040 if text empty, -32041 if API fails after retries
"""
if not text or not text.strip():
raise EmbeddingError(ERROR_EMPTY_TEXT, "Empty text")

cfg = config or EmbeddingConfig.from_env()

try:
response = litellm.embedding(
model=cfg.model,
input=text,
)
return response.data[0]["embedding"]
except Exception as e:
raise EmbeddingError(ERROR_EMBEDDING_FAILED, f"Embedding error: {e}") from e
last_error: Optional[Exception] = None
delay = cfg.retry_delay

for attempt in range(cfg.max_retries + 1):
try:
response = litellm.embedding(
model=cfg.model,
input=text,
)
return response.data[0]["embedding"]
except Exception as e:
last_error = e

# Don't retry on non-retryable errors
if not _is_retryable_error(e):
logger.warning(f"Embedding failed (non-retryable): {e}")
break

# Don't wait after the last attempt
if attempt < cfg.max_retries:
logger.warning(
f"Embedding attempt {attempt + 1}/{cfg.max_retries + 1} failed: {e}. "
f"Retrying in {delay:.1f}s..."
)
time.sleep(delay)
delay *= cfg.retry_backoff

raise EmbeddingError(
ERROR_EMBEDDING_FAILED,
f"Embedding failed after {cfg.max_retries + 1} attempts: {last_error}",
) from last_error


def embedding_to_bytes(embedding: list[float]) -> bytes:
Expand Down
36 changes: 31 additions & 5 deletions src/sqrl/ipc/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,11 +258,16 @@ class SearchMemoriesHandler:
Handler for IPC-004: search_memories.

Note: In production, search is handled by Rust daemon with sqlite-vec.
This handler is for testing/development with Python-side search.
This handler provides a Python-side implementation for testing/development
or as a fallback when Rust daemon is unavailable.
"""

search_fn: Optional[callable] = None
_warned_no_search_fn: bool = False

def __init__(self, search_fn: Optional[callable] = None):
self.search_fn = search_fn
self._warned_no_search_fn = False

async def __call__(self, params: dict) -> dict:
"""
Expand All @@ -272,8 +277,12 @@ async def __call__(self, params: dict) -> dict:
params: Request params with project_id, query, top_k, filters

Returns:
Dict with results array
Dict with results array and metadata
"""
import logging

logger = logging.getLogger(__name__)

query = params.get("query", "")
if not query:
raise IPCError(ERROR_SEARCH_EMPTY_QUERY, "Empty query")
Expand All @@ -293,10 +302,27 @@ async def __call__(self, params: dict) -> dict:
top_k=top_k,
filters=filters,
)
return {"results": results}
return {
"results": results,
"search_backend": "python",
"warning": None,
}

# Otherwise return empty (search handled by Rust daemon)
return {"results": []}
# Log warning once about missing search function
if not self._warned_no_search_fn:
logger.warning(
"SearchMemoriesHandler: No search_fn configured. "
"In production, search should be handled by Rust daemon with sqlite-vec. "
"Returning empty results."
)
self._warned_no_search_fn = True

# Return empty with warning
return {
"results": [],
"search_backend": "none",
"warning": "Search not configured. Use Rust daemon for production search.",
}


@dataclass
Expand Down
7 changes: 6 additions & 1 deletion src/sqrl/memory_writer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,14 @@
MemoryWriterOutput,
OpType,
)
from sqrl.memory_writer.writer import MemoryWriter, MemoryWriterConfig
from sqrl.memory_writer.writer import (
ConfigurationError,
MemoryWriter,
MemoryWriterConfig,
)

__all__ = [
"ConfigurationError",
"MemoryWriter",
"MemoryWriterConfig",
"MemoryWriterOutput",
Expand Down
67 changes: 55 additions & 12 deletions src/sqrl/memory_writer/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,30 @@
# Supported PydanticAI providers
SUPPORTED_PROVIDERS = {"openrouter", "openai", "anthropic", "ollama", "together", "fireworks"}

# Default model when SQRL_STRONG_MODEL is not set
# Using DeepSeek R1 distilled model - free on OpenRouter with strong reasoning
# Alternatives:
# - "openrouter/deepseek/deepseek-chat-v2.5" (free, 128K context)
# - "openrouter/anthropic/claude-3.5-haiku" ($0.80/$4 per M tokens, best for JSON)
# - "openrouter/deepseek/deepseek-r1" ($0.30/$1.20 per M tokens, advanced reasoning)
# - "anthropic/claude-3-5-haiku-latest" (direct Anthropic API)
DEFAULT_STRONG_MODEL = "openrouter/deepseek/deepseek-r1-distill-qwen-32b"

# Recommended models by use case (for documentation)
RECOMMENDED_MODELS = {
"free_best": "openrouter/deepseek/deepseek-r1-distill-qwen-32b", # Free, 64K, strong reasoning
"free_long_context": "openrouter/deepseek/deepseek-chat-v2.5", # Free, 128K context
"budget": "openrouter/deepseek/deepseek-r1-0528-qwen3-8b", # $0.02/$0.10, 33K
"json_best": "openrouter/anthropic/claude-3.5-haiku", # $0.80/$4, 200K, best structured output
"reasoning": "openrouter/deepseek/deepseek-r1", # $0.30/$1.20, 164K, o1-level reasoning
}


class ConfigurationError(Exception):
"""Raised when configuration is invalid or missing."""

pass


@dataclass
class MemoryWriterConfig:
Expand All @@ -24,7 +48,7 @@ class MemoryWriterConfig:
Model format: 'provider/model-name'
- Examples: 'openrouter/anthropic/claude-3-haiku', 'openai/gpt-4o'
- For OpenRouter: 'openrouter/<provider>/<model>'
- Set SQRL_STRONG_MODEL env var
- Set SQRL_STRONG_MODEL env var to override default

max_memories_per_episode: Default 5. Limits noise while capturing key learnings.
- Set SQRL_MAX_MEMORIES_PER_EPISODE env var to override
Expand All @@ -33,37 +57,56 @@ class MemoryWriterConfig:
provider: str = field(default="")
model_name: str = field(default="")
max_memories_per_episode: int = 5
_initialized: bool = field(default=False, repr=False)

def __post_init__(self):
# Model from env var (required)
env_model = os.getenv("SQRL_STRONG_MODEL")
if not env_model:
raise ValueError(
"SQRL_STRONG_MODEL env var required. "
"Format: 'provider/model' (e.g., 'openrouter/anthropic/claude-3-haiku')"
)
# Skip if already initialized (allows explicit provider/model_name)
if self._initialized:
return

# Model from env var or default
env_model = os.getenv("SQRL_STRONG_MODEL", DEFAULT_STRONG_MODEL)

# Parse provider/model format
parts = env_model.split("/", 1)
if len(parts) < 2:
raise ValueError(
raise ConfigurationError(
f"Invalid model format: {env_model}. "
"Expected 'provider/model' (e.g., 'openrouter/anthropic/claude-3-haiku')"
"Expected 'provider/model' (e.g., 'openai/gpt-4o-mini')"
)

self.provider = parts[0]
self.model_name = parts[1]

if self.provider not in SUPPORTED_PROVIDERS:
raise ValueError(
raise ConfigurationError(
f"Unsupported provider: {self.provider}. "
f"Supported: {', '.join(sorted(SUPPORTED_PROVIDERS))}"
)

# max_memories_per_episode from env or default (5)
env_max = os.getenv("SQRL_MAX_MEMORIES_PER_EPISODE")
if env_max:
self.max_memories_per_episode = int(env_max)
try:
self.max_memories_per_episode = int(env_max)
except ValueError:
raise ConfigurationError(
f"Invalid SQRL_MAX_MEMORIES_PER_EPISODE: {env_max}. Must be an integer."
)

self._initialized = True

@classmethod
def from_explicit(
cls, provider: str, model_name: str, max_memories_per_episode: int = 5
) -> "MemoryWriterConfig":
"""Create config with explicit values (ignores env vars)."""
config = cls.__new__(cls)
config.provider = provider
config.model_name = model_name
config.max_memories_per_episode = max_memories_per_episode
config._initialized = True
return config

def create_model(self) -> Any:
"""Create PydanticAI model instance."""
Expand Down