Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 21 additions & 10 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,29 @@
# =============================================================================
# MODEL CONFIGURATION (choose one provider)
# =============================================================================

# Default model to use (overridden by --model CLI flag)
LLM_MODEL=gpt-4o
#
# Model resolution order (highest to lowest priority):
# 1. Explicit --model CLI flag or API request parameter
# 2. Session-stored model (from database)
# 3. DSAGENT_DEFAULT_MODEL environment variable (recommended)
# 4. LLM_MODEL environment variable (legacy, for backward compatibility)
# 5. Fallback: gpt-4o
#
# For server/API deployments, use DSAGENT_DEFAULT_MODEL:
DSAGENT_DEFAULT_MODEL=gpt-4o

# For CLI usage, LLM_MODEL also works (backward compatible):
# LLM_MODEL=gpt-4o

# Examples:
# LLM_MODEL=gpt-4o # OpenAI GPT-4o
# LLM_MODEL=gpt-4o-mini # OpenAI GPT-4o Mini (cheaper)
# LLM_MODEL=claude-3-5-sonnet-20241022 # Anthropic Claude 3.5 Sonnet
# LLM_MODEL=claude-3-opus-20240229 # Anthropic Claude 3 Opus
# LLM_MODEL=gemini/gemini-1.5-pro # Google Gemini 1.5 Pro
# LLM_MODEL=ollama/llama3 # Ollama local model
# LLM_MODEL=ollama/codellama # Ollama CodeLlama
# DSAGENT_DEFAULT_MODEL=gpt-4o # OpenAI GPT-4o
# DSAGENT_DEFAULT_MODEL=gpt-4o-mini # OpenAI GPT-4o Mini (cheaper)
# DSAGENT_DEFAULT_MODEL=claude-3-5-sonnet-20241022 # Anthropic Claude 3.5 Sonnet
# DSAGENT_DEFAULT_MODEL=claude-3-opus-20240229 # Anthropic Claude 3 Opus
# DSAGENT_DEFAULT_MODEL=gemini/gemini-1.5-pro # Google Gemini 1.5 Pro
# DSAGENT_DEFAULT_MODEL=groq/llama-3.3-70b-versatile # Groq (fast inference)
# DSAGENT_DEFAULT_MODEL=ollama/llama3 # Ollama local model
# DSAGENT_DEFAULT_MODEL=ollama/codellama # Ollama CodeLlama

# =============================================================================
# API KEYS (set the one for your chosen provider)
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "hatchling.build"

[project]
name = "datascience-agent"
version = "0.8.3"
version = "0.8.4"
description = "AI Agent with dynamic planning and persistent Jupyter kernel execution for data analysis"
readme = "README.md"
license = "MIT"
Expand Down
2 changes: 1 addition & 1 deletion src/dsagent/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
MCPServerConfig = None # type: ignore
_MCP_AVAILABLE = False

__version__ = "0.8.1"
__version__ = "0.8.4"

__all__ = [
# Main classes
Expand Down
10 changes: 7 additions & 3 deletions src/dsagent/agents/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from typing import Optional, Callable, Any, Generator, Dict, Union, TYPE_CHECKING
from datetime import datetime

from dsagent.config import get_default_model
from dsagent.schema.models import (
AgentConfig,
AgentEvent,
Expand Down Expand Up @@ -73,7 +74,7 @@ async def analyze(task: str):

def __init__(
self,
model: str = "gpt-4o",
model: Optional[str] = None, # Resolved via get_default_model() if None
workspace: str | Path = "./workspace",
data: Optional[Union[str, Path]] = None,
session_id: Optional[str] = None,
Expand Down Expand Up @@ -110,8 +111,11 @@ def __init__(
ConfigurationError: If the model or API key configuration is invalid
FileNotFoundError: If data path does not exist
"""
# Resolve model if not explicitly set
effective_model = get_default_model(explicit=model)

# Validate model and API key before anything else
validate_configuration(model)
validate_configuration(effective_model)

# Store or create context
self.context = context
Expand Down Expand Up @@ -145,7 +149,7 @@ def __init__(

# Create configuration
self.config = AgentConfig(
model=model,
model=effective_model,
session_id=session_id,
max_rounds=max_rounds,
max_tokens=max_tokens,
Expand Down
11 changes: 10 additions & 1 deletion src/dsagent/agents/conversational.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from litellm import completion

from dsagent.config import get_default_model
from dsagent.kernel import LocalExecutor, ExecutorConfig, KernelIntrospector
from dsagent.session import Session, SessionManager, ConversationMessage, SessionLogger
from dsagent.utils.validation import validate_configuration, get_proxy_model_name
Expand Down Expand Up @@ -63,7 +64,7 @@ class ChatResponse:
class ConversationalAgentConfig:
"""Configuration for the conversational agent."""

model: str = "gpt-4o"
model: Optional[str] = None # Resolved via get_default_model() if None
temperature: float = 0.3
max_tokens: int = 4096
code_timeout: int = 300
Expand All @@ -89,6 +90,10 @@ class ConversationalAgentConfig:
# Observability settings
observability_config: Optional[Any] = None # ObservabilityConfig object or None

def get_effective_model(self) -> str:
"""Get the effective model, using resolution cascade if not set."""
return get_default_model(explicit=self.model)

@classmethod
def from_agent_config(cls, config: AgentConfig) -> "ConversationalAgentConfig":
"""Create from AgentConfig."""
Expand Down Expand Up @@ -311,6 +316,10 @@ def start(self, session: Optional[Session] = None) -> None:
if self._started:
return

# Resolve model if not explicitly set
if self.config.model is None:
self.config.model = self.config.get_effective_model()

# Validate model configuration and apply API base mapping
validate_configuration(self.config.model)

Expand Down
8 changes: 4 additions & 4 deletions src/dsagent/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,8 @@ def create_parser() -> argparse.ArgumentParser:
chat_parser.add_argument(
"--model", "-m",
type=str,
default=os.getenv("LLM_MODEL", "gpt-4o"),
help="LLM model to use (default: gpt-4o)",
default=None, # Resolved via get_default_model()
help="LLM model to use (default: from DSAGENT_DEFAULT_MODEL or LLM_MODEL env var)",
)
chat_parser.add_argument(
"--workspace", "-w",
Expand Down Expand Up @@ -196,8 +196,8 @@ def create_parser() -> argparse.ArgumentParser:
run_parser.add_argument(
"--model", "-m",
type=str,
default=os.getenv("LLM_MODEL", "gpt-4o"),
help="LLM model to use (default: gpt-4o)",
default=None, # Resolved via get_default_model()
help="LLM model to use (default: from DSAGENT_DEFAULT_MODEL or LLM_MODEL env var)",
)
run_parser.add_argument(
"--workspace", "-w",
Expand Down
11 changes: 8 additions & 3 deletions src/dsagent/cli/repl.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from dsagent.cli.banner import print_welcome
from dsagent.cli.commands import CommandRegistry, CommandResult, create_default_registry
from dsagent.cli.renderer import CLIRenderer
from dsagent.config import get_default_model
from dsagent.session import Session, SessionManager
from dsagent.agents import ConversationalAgent, ConversationalAgentConfig
from dsagent.schema.models import HITLMode
Expand All @@ -51,14 +52,18 @@ class CLIContext:
registry: CommandRegistry
console: Console
session: Optional[Session] = None
model: str = "gpt-4o"
model: Optional[str] = None # Resolved via get_default_model() if None
data_path: Optional[str] = None
workspace: Path = field(default_factory=lambda: Path("./workspace"))

# Agent components (initialized lazily)
_agent: Optional[object] = field(default=None, repr=False)
_kernel_running: bool = False

def get_effective_model(self) -> str:
"""Get the effective model, using resolution cascade if not set."""
return get_default_model(explicit=self.model)

def set_session(self, session: Session) -> None:
"""Set the active session."""
self.session = session
Expand Down Expand Up @@ -619,8 +624,8 @@ def main():
parser.add_argument(
"--model", "-m",
type=str,
default=os.getenv("LLM_MODEL", "gpt-4o"),
help="LLM model to use (default: gpt-4o)",
default=None, # Resolved via get_default_model()
help="LLM model to use (default: from DSAGENT_DEFAULT_MODEL or LLM_MODEL env var)",
)

parser.add_argument(
Expand Down
192 changes: 192 additions & 0 deletions src/dsagent/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
"""Centralized configuration for DSAgent.

This module provides a single source of truth for all configuration,
supporting environment variables, .env files, and programmatic overrides.

Configuration Priority (highest to lowest):
1. Explicit parameters (API request, CLI --model flag)
2. Session-stored model (from database)
3. DSAGENT_DEFAULT_MODEL environment variable
4. LLM_MODEL environment variable (legacy/CLI compatibility)
5. Hardcoded fallback "gpt-4o"

Usage:
from dsagent.config import get_settings, get_default_model

settings = get_settings()
model = get_default_model()

# With explicit override
model = get_default_model(explicit="claude-3-5-sonnet")

# With session model
model = get_default_model(session_model=session.model)
"""

from __future__ import annotations

import logging
import os
from functools import lru_cache
from typing import Optional

from pydantic import Field, field_validator
from pydantic_settings import BaseSettings, SettingsConfigDict

logger = logging.getLogger(__name__)

# Hardcoded fallback - only used when no env vars are set
FALLBACK_MODEL = "gpt-4o"


class DSAgentSettings(BaseSettings):
"""Unified settings for DSAgent CLI and Server.

All settings can be configured via environment variables with the
DSAGENT_ prefix (e.g., DSAGENT_DEFAULT_MODEL).

For backward compatibility, LLM_MODEL is also supported as a fallback
for the default_model setting.
"""

model_config = SettingsConfigDict(
env_file=".env",
env_file_encoding="utf-8",
extra="ignore",
env_prefix="DSAGENT_",
)

# ─── Model Configuration ──────────────────────────────────────────────────
default_model: Optional[str] = Field(
default=None,
description="Default LLM model when not specified in request",
)
temperature: float = Field(default=0.3, ge=0.0, le=2.0)
max_tokens: int = Field(default=4096, ge=1)
max_rounds: int = Field(default=30, ge=1)
code_timeout: int = Field(default=300, ge=1)

# ─── Workspace ────────────────────────────────────────────────────────────
workspace: str = Field(default="./workspace")
sessions_dir: str = Field(default="workspace")
session_backend: str = Field(default="sqlite")

# ─── Server Settings ──────────────────────────────────────────────────────
host: str = Field(default="0.0.0.0")

Check warning

Code scanning / Bandit

Possible binding to all interfaces. Warning

Possible binding to all interfaces.
port: int = Field(default=8000, ge=1, le=65535)
api_key: Optional[str] = Field(default=None)
cors_origins: str = Field(default="*")
default_hitl_mode: str = Field(default="none")

# ─── Observability ────────────────────────────────────────────────────────
observability_enabled: bool = Field(default=False)
observability_providers: Optional[str] = Field(default=None)

@field_validator("default_model", mode="before")
@classmethod
def resolve_model_from_legacy(cls, v: Optional[str]) -> Optional[str]:
"""Check legacy LLM_MODEL if DSAGENT_DEFAULT_MODEL not set."""
if v is None:
legacy = os.getenv("LLM_MODEL")
if legacy:
logger.debug(f"Using legacy LLM_MODEL={legacy}")
return legacy
return v


@lru_cache
def get_settings() -> DSAgentSettings:
"""Get cached DSAgent settings.

Returns:
DSAgentSettings instance with resolved configuration.

Note:
Settings are cached for performance. Use clear_settings_cache()
to reload settings (useful for testing).
"""
settings = DSAgentSettings()

# Log configuration source for debugging
if settings.default_model:
source = "DSAGENT_DEFAULT_MODEL"
if os.getenv("LLM_MODEL") and not os.getenv("DSAGENT_DEFAULT_MODEL"):
source = "LLM_MODEL (legacy)"
logger.info(f"Configuration: default_model={settings.default_model} (from {source})")
else:
logger.info(f"Configuration: no default_model set, will use fallback={FALLBACK_MODEL}")

return settings


def clear_settings_cache() -> None:
"""Clear cached settings.

Useful for testing or when environment variables change at runtime.
"""
get_settings.cache_clear()


def get_default_model(
explicit: Optional[str] = None,
session_model: Optional[str] = None,
) -> str:
"""Get the effective default model using the resolution cascade.

Resolution order (first non-None wins):
1. explicit - Model passed as parameter (API request, CLI flag)
2. session_model - Model stored in session (for resumed sessions)
3. DSAGENT_DEFAULT_MODEL - Primary environment variable
4. LLM_MODEL - Legacy environment variable (CLI compatibility)
5. FALLBACK_MODEL - Hardcoded fallback ("gpt-4o")

Args:
explicit: Explicitly specified model (highest priority)
session_model: Model from session storage

Returns:
Resolved model name

Example:
# Use default from environment
model = get_default_model()

# Override with explicit model
model = get_default_model(explicit="claude-3-5-sonnet")

# Use session model if available
model = get_default_model(session_model=session.model)
"""
if explicit:
logger.info(f"Model resolution: using explicit={explicit}")
return explicit

if session_model:
logger.info(f"Model resolution: using session_model={session_model}")
return session_model

settings = get_settings()
if settings.default_model:
logger.info(f"Model resolution: using settings.default_model={settings.default_model}")
return settings.default_model

logger.info(f"Model resolution: using fallback={FALLBACK_MODEL}")
return FALLBACK_MODEL


def log_configuration() -> None:
"""Log current configuration for debugging.

Useful for troubleshooting deployment issues.
"""
settings = get_settings()
logger.info("=== DSAgent Configuration ===")
logger.info(f" default_model: {settings.default_model or f'(none, fallback={FALLBACK_MODEL})'}")
logger.info(f" workspace: {settings.workspace}")
logger.info(f" sessions_dir: {settings.sessions_dir}")
logger.info(f" session_backend: {settings.session_backend}")
logger.info(f" temperature: {settings.temperature}")
logger.info(f" max_tokens: {settings.max_tokens}")
logger.info(f" host: {settings.host}")
logger.info(f" port: {settings.port}")
logger.info("=============================")
Loading