diff --git a/.env.example b/.env.example index 8504fbf..372e856 100644 --- a/.env.example +++ b/.env.example @@ -7,18 +7,29 @@ # ============================================================================= # MODEL CONFIGURATION (choose one provider) # ============================================================================= - -# Default model to use (overridden by --model CLI flag) -LLM_MODEL=gpt-4o +# +# Model resolution order (highest to lowest priority): +# 1. Explicit --model CLI flag or API request parameter +# 2. Session-stored model (from database) +# 3. DSAGENT_DEFAULT_MODEL environment variable (recommended) +# 4. LLM_MODEL environment variable (legacy, for backward compatibility) +# 5. Fallback: gpt-4o +# +# For server/API deployments, use DSAGENT_DEFAULT_MODEL: +DSAGENT_DEFAULT_MODEL=gpt-4o + +# For CLI usage, LLM_MODEL also works (backward compatible): +# LLM_MODEL=gpt-4o # Examples: -# LLM_MODEL=gpt-4o # OpenAI GPT-4o -# LLM_MODEL=gpt-4o-mini # OpenAI GPT-4o Mini (cheaper) -# LLM_MODEL=claude-3-5-sonnet-20241022 # Anthropic Claude 3.5 Sonnet -# LLM_MODEL=claude-3-opus-20240229 # Anthropic Claude 3 Opus -# LLM_MODEL=gemini/gemini-1.5-pro # Google Gemini 1.5 Pro -# LLM_MODEL=ollama/llama3 # Ollama local model -# LLM_MODEL=ollama/codellama # Ollama CodeLlama +# DSAGENT_DEFAULT_MODEL=gpt-4o # OpenAI GPT-4o +# DSAGENT_DEFAULT_MODEL=gpt-4o-mini # OpenAI GPT-4o Mini (cheaper) +# DSAGENT_DEFAULT_MODEL=claude-3-5-sonnet-20241022 # Anthropic Claude 3.5 Sonnet +# DSAGENT_DEFAULT_MODEL=claude-3-opus-20240229 # Anthropic Claude 3 Opus +# DSAGENT_DEFAULT_MODEL=gemini/gemini-1.5-pro # Google Gemini 1.5 Pro +# DSAGENT_DEFAULT_MODEL=groq/llama-3.3-70b-versatile # Groq (fast inference) +# DSAGENT_DEFAULT_MODEL=ollama/llama3 # Ollama local model +# DSAGENT_DEFAULT_MODEL=ollama/codellama # Ollama CodeLlama # ============================================================================= # API KEYS (set the one for your chosen provider) diff --git a/pyproject.toml b/pyproject.toml index 2f4365b..36fd893 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "datascience-agent" -version = "0.8.3" +version = "0.8.4" description = "AI Agent with dynamic planning and persistent Jupyter kernel execution for data analysis" readme = "README.md" license = "MIT" diff --git a/src/dsagent/__init__.py b/src/dsagent/__init__.py index f3b1f61..7407aef 100644 --- a/src/dsagent/__init__.py +++ b/src/dsagent/__init__.py @@ -62,7 +62,7 @@ MCPServerConfig = None # type: ignore _MCP_AVAILABLE = False -__version__ = "0.8.1" +__version__ = "0.8.4" __all__ = [ # Main classes diff --git a/src/dsagent/agents/base.py b/src/dsagent/agents/base.py index d7f2025..d697ece 100644 --- a/src/dsagent/agents/base.py +++ b/src/dsagent/agents/base.py @@ -9,6 +9,7 @@ from typing import Optional, Callable, Any, Generator, Dict, Union, TYPE_CHECKING from datetime import datetime +from dsagent.config import get_default_model from dsagent.schema.models import ( AgentConfig, AgentEvent, @@ -73,7 +74,7 @@ async def analyze(task: str): def __init__( self, - model: str = "gpt-4o", + model: Optional[str] = None, # Resolved via get_default_model() if None workspace: str | Path = "./workspace", data: Optional[Union[str, Path]] = None, session_id: Optional[str] = None, @@ -110,8 +111,11 @@ def __init__( ConfigurationError: If the model or API key configuration is invalid FileNotFoundError: If data path does not exist """ + # Resolve model if not explicitly set + effective_model = get_default_model(explicit=model) + # Validate model and API key before anything else - validate_configuration(model) + validate_configuration(effective_model) # Store or create context self.context = context @@ -145,7 +149,7 @@ def __init__( # Create configuration self.config = AgentConfig( - model=model, + model=effective_model, session_id=session_id, max_rounds=max_rounds, max_tokens=max_tokens, diff --git a/src/dsagent/agents/conversational.py b/src/dsagent/agents/conversational.py index ab01a30..85600e4 100644 --- a/src/dsagent/agents/conversational.py +++ b/src/dsagent/agents/conversational.py @@ -16,6 +16,7 @@ from litellm import completion +from dsagent.config import get_default_model from dsagent.kernel import LocalExecutor, ExecutorConfig, KernelIntrospector from dsagent.session import Session, SessionManager, ConversationMessage, SessionLogger from dsagent.utils.validation import validate_configuration, get_proxy_model_name @@ -63,7 +64,7 @@ class ChatResponse: class ConversationalAgentConfig: """Configuration for the conversational agent.""" - model: str = "gpt-4o" + model: Optional[str] = None # Resolved via get_default_model() if None temperature: float = 0.3 max_tokens: int = 4096 code_timeout: int = 300 @@ -89,6 +90,10 @@ class ConversationalAgentConfig: # Observability settings observability_config: Optional[Any] = None # ObservabilityConfig object or None + def get_effective_model(self) -> str: + """Get the effective model, using resolution cascade if not set.""" + return get_default_model(explicit=self.model) + @classmethod def from_agent_config(cls, config: AgentConfig) -> "ConversationalAgentConfig": """Create from AgentConfig.""" @@ -311,6 +316,10 @@ def start(self, session: Optional[Session] = None) -> None: if self._started: return + # Resolve model if not explicitly set + if self.config.model is None: + self.config.model = self.config.get_effective_model() + # Validate model configuration and apply API base mapping validate_configuration(self.config.model) diff --git a/src/dsagent/cli/main.py b/src/dsagent/cli/main.py index 7caa908..1862090 100644 --- a/src/dsagent/cli/main.py +++ b/src/dsagent/cli/main.py @@ -130,8 +130,8 @@ def create_parser() -> argparse.ArgumentParser: chat_parser.add_argument( "--model", "-m", type=str, - default=os.getenv("LLM_MODEL", "gpt-4o"), - help="LLM model to use (default: gpt-4o)", + default=None, # Resolved via get_default_model() + help="LLM model to use (default: from DSAGENT_DEFAULT_MODEL or LLM_MODEL env var)", ) chat_parser.add_argument( "--workspace", "-w", @@ -196,8 +196,8 @@ def create_parser() -> argparse.ArgumentParser: run_parser.add_argument( "--model", "-m", type=str, - default=os.getenv("LLM_MODEL", "gpt-4o"), - help="LLM model to use (default: gpt-4o)", + default=None, # Resolved via get_default_model() + help="LLM model to use (default: from DSAGENT_DEFAULT_MODEL or LLM_MODEL env var)", ) run_parser.add_argument( "--workspace", "-w", diff --git a/src/dsagent/cli/repl.py b/src/dsagent/cli/repl.py index b4987d8..4f1e1fa 100644 --- a/src/dsagent/cli/repl.py +++ b/src/dsagent/cli/repl.py @@ -26,6 +26,7 @@ from dsagent.cli.banner import print_welcome from dsagent.cli.commands import CommandRegistry, CommandResult, create_default_registry from dsagent.cli.renderer import CLIRenderer +from dsagent.config import get_default_model from dsagent.session import Session, SessionManager from dsagent.agents import ConversationalAgent, ConversationalAgentConfig from dsagent.schema.models import HITLMode @@ -51,7 +52,7 @@ class CLIContext: registry: CommandRegistry console: Console session: Optional[Session] = None - model: str = "gpt-4o" + model: Optional[str] = None # Resolved via get_default_model() if None data_path: Optional[str] = None workspace: Path = field(default_factory=lambda: Path("./workspace")) @@ -59,6 +60,10 @@ class CLIContext: _agent: Optional[object] = field(default=None, repr=False) _kernel_running: bool = False + def get_effective_model(self) -> str: + """Get the effective model, using resolution cascade if not set.""" + return get_default_model(explicit=self.model) + def set_session(self, session: Session) -> None: """Set the active session.""" self.session = session @@ -619,8 +624,8 @@ def main(): parser.add_argument( "--model", "-m", type=str, - default=os.getenv("LLM_MODEL", "gpt-4o"), - help="LLM model to use (default: gpt-4o)", + default=None, # Resolved via get_default_model() + help="LLM model to use (default: from DSAGENT_DEFAULT_MODEL or LLM_MODEL env var)", ) parser.add_argument( diff --git a/src/dsagent/config.py b/src/dsagent/config.py new file mode 100644 index 0000000..341065b --- /dev/null +++ b/src/dsagent/config.py @@ -0,0 +1,192 @@ +"""Centralized configuration for DSAgent. + +This module provides a single source of truth for all configuration, +supporting environment variables, .env files, and programmatic overrides. + +Configuration Priority (highest to lowest): + 1. Explicit parameters (API request, CLI --model flag) + 2. Session-stored model (from database) + 3. DSAGENT_DEFAULT_MODEL environment variable + 4. LLM_MODEL environment variable (legacy/CLI compatibility) + 5. Hardcoded fallback "gpt-4o" + +Usage: + from dsagent.config import get_settings, get_default_model + + settings = get_settings() + model = get_default_model() + + # With explicit override + model = get_default_model(explicit="claude-3-5-sonnet") + + # With session model + model = get_default_model(session_model=session.model) +""" + +from __future__ import annotations + +import logging +import os +from functools import lru_cache +from typing import Optional + +from pydantic import Field, field_validator +from pydantic_settings import BaseSettings, SettingsConfigDict + +logger = logging.getLogger(__name__) + +# Hardcoded fallback - only used when no env vars are set +FALLBACK_MODEL = "gpt-4o" + + +class DSAgentSettings(BaseSettings): + """Unified settings for DSAgent CLI and Server. + + All settings can be configured via environment variables with the + DSAGENT_ prefix (e.g., DSAGENT_DEFAULT_MODEL). + + For backward compatibility, LLM_MODEL is also supported as a fallback + for the default_model setting. + """ + + model_config = SettingsConfigDict( + env_file=".env", + env_file_encoding="utf-8", + extra="ignore", + env_prefix="DSAGENT_", + ) + + # ─── Model Configuration ────────────────────────────────────────────────── + default_model: Optional[str] = Field( + default=None, + description="Default LLM model when not specified in request", + ) + temperature: float = Field(default=0.3, ge=0.0, le=2.0) + max_tokens: int = Field(default=4096, ge=1) + max_rounds: int = Field(default=30, ge=1) + code_timeout: int = Field(default=300, ge=1) + + # ─── Workspace ──────────────────────────────────────────────────────────── + workspace: str = Field(default="./workspace") + sessions_dir: str = Field(default="workspace") + session_backend: str = Field(default="sqlite") + + # ─── Server Settings ────────────────────────────────────────────────────── + host: str = Field(default="0.0.0.0") + port: int = Field(default=8000, ge=1, le=65535) + api_key: Optional[str] = Field(default=None) + cors_origins: str = Field(default="*") + default_hitl_mode: str = Field(default="none") + + # ─── Observability ──────────────────────────────────────────────────────── + observability_enabled: bool = Field(default=False) + observability_providers: Optional[str] = Field(default=None) + + @field_validator("default_model", mode="before") + @classmethod + def resolve_model_from_legacy(cls, v: Optional[str]) -> Optional[str]: + """Check legacy LLM_MODEL if DSAGENT_DEFAULT_MODEL not set.""" + if v is None: + legacy = os.getenv("LLM_MODEL") + if legacy: + logger.debug(f"Using legacy LLM_MODEL={legacy}") + return legacy + return v + + +@lru_cache +def get_settings() -> DSAgentSettings: + """Get cached DSAgent settings. + + Returns: + DSAgentSettings instance with resolved configuration. + + Note: + Settings are cached for performance. Use clear_settings_cache() + to reload settings (useful for testing). + """ + settings = DSAgentSettings() + + # Log configuration source for debugging + if settings.default_model: + source = "DSAGENT_DEFAULT_MODEL" + if os.getenv("LLM_MODEL") and not os.getenv("DSAGENT_DEFAULT_MODEL"): + source = "LLM_MODEL (legacy)" + logger.info(f"Configuration: default_model={settings.default_model} (from {source})") + else: + logger.info(f"Configuration: no default_model set, will use fallback={FALLBACK_MODEL}") + + return settings + + +def clear_settings_cache() -> None: + """Clear cached settings. + + Useful for testing or when environment variables change at runtime. + """ + get_settings.cache_clear() + + +def get_default_model( + explicit: Optional[str] = None, + session_model: Optional[str] = None, +) -> str: + """Get the effective default model using the resolution cascade. + + Resolution order (first non-None wins): + 1. explicit - Model passed as parameter (API request, CLI flag) + 2. session_model - Model stored in session (for resumed sessions) + 3. DSAGENT_DEFAULT_MODEL - Primary environment variable + 4. LLM_MODEL - Legacy environment variable (CLI compatibility) + 5. FALLBACK_MODEL - Hardcoded fallback ("gpt-4o") + + Args: + explicit: Explicitly specified model (highest priority) + session_model: Model from session storage + + Returns: + Resolved model name + + Example: + # Use default from environment + model = get_default_model() + + # Override with explicit model + model = get_default_model(explicit="claude-3-5-sonnet") + + # Use session model if available + model = get_default_model(session_model=session.model) + """ + if explicit: + logger.info(f"Model resolution: using explicit={explicit}") + return explicit + + if session_model: + logger.info(f"Model resolution: using session_model={session_model}") + return session_model + + settings = get_settings() + if settings.default_model: + logger.info(f"Model resolution: using settings.default_model={settings.default_model}") + return settings.default_model + + logger.info(f"Model resolution: using fallback={FALLBACK_MODEL}") + return FALLBACK_MODEL + + +def log_configuration() -> None: + """Log current configuration for debugging. + + Useful for troubleshooting deployment issues. + """ + settings = get_settings() + logger.info("=== DSAgent Configuration ===") + logger.info(f" default_model: {settings.default_model or f'(none, fallback={FALLBACK_MODEL})'}") + logger.info(f" workspace: {settings.workspace}") + logger.info(f" sessions_dir: {settings.sessions_dir}") + logger.info(f" session_backend: {settings.session_backend}") + logger.info(f" temperature: {settings.temperature}") + logger.info(f" max_tokens: {settings.max_tokens}") + logger.info(f" host: {settings.host}") + logger.info(f" port: {settings.port}") + logger.info("=============================") diff --git a/src/dsagent/schema/models.py b/src/dsagent/schema/models.py index 00f22ee..23946d2 100644 --- a/src/dsagent/schema/models.py +++ b/src/dsagent/schema/models.py @@ -84,8 +84,9 @@ class AgentConfig(BaseSettings): ) # Model configuration - checks multiple env var names for compatibility - model: str = Field( - default="gpt-4o", + # Default is None; resolved via get_default_model() at runtime + model: Optional[str] = Field( + default=None, validation_alias="LLM_MODEL", ) api_key: str = Field( diff --git a/src/dsagent/server/manager.py b/src/dsagent/server/manager.py index ce27c30..db2a512 100644 --- a/src/dsagent/server/manager.py +++ b/src/dsagent/server/manager.py @@ -9,6 +9,7 @@ from fastapi import WebSocket from dsagent.agents import ConversationalAgent, ConversationalAgentConfig +from dsagent.config import get_default_model from dsagent.schema.models import HITLMode from dsagent.server.models import ( ExecutionResultResponse, @@ -50,6 +51,7 @@ def __init__( self._session_manager = session_manager self._default_model = default_model self._default_hitl_mode = default_hitl_mode + logger.info(f"AgentConnectionManager initialized with default_model={default_model!r}") # session_id -> set of WebSocket connections self._connections: Dict[str, Set[WebSocket]] = {} @@ -260,14 +262,12 @@ async def _get_or_create_agent( # Get or create session session = self._session_manager.get_or_create(session_id) - # Create agent config - use session config, then params, then defaults - import os - effective_model = ( - model - or getattr(session, "model", None) - or self._default_model - or os.getenv("LLM_MODEL", "gpt-4o") + # Use centralized config for model resolution + effective_model = get_default_model( + explicit=model, + session_model=getattr(session, "model", None), ) + logger.info(f"Agent model resolved to: {effective_model}") # Convert hitl_mode string to HITLMode enum # Priority: parameter > session > default diff --git a/src/dsagent/server/routes/sessions.py b/src/dsagent/server/routes/sessions.py index d4bf299..61f8053 100644 --- a/src/dsagent/server/routes/sessions.py +++ b/src/dsagent/server/routes/sessions.py @@ -1,5 +1,6 @@ """Session management endpoints.""" +import logging from datetime import datetime from pathlib import Path from typing import Optional @@ -7,6 +8,7 @@ from fastapi import APIRouter, Depends, HTTPException, Query, status from fastapi.responses import FileResponse, JSONResponse +from dsagent.config import get_default_model from dsagent.server.deps import ( get_connection_manager, get_session_manager, @@ -22,6 +24,8 @@ ) from dsagent.session import Session, SessionManager, SessionStatus +logger = logging.getLogger(__name__) + router = APIRouter(dependencies=[Depends(verify_api_key)]) @@ -80,15 +84,19 @@ async def create_session( # Create session session = session_manager.create_session(name=request.name) + # Resolve model using centralized config (not None) + effective_model = get_default_model(explicit=request.model) + logger.info(f"Creating session {session.id} with model={effective_model}") + # Store agent configuration in session for persistence - session.model = request.model + session.model = effective_model # Store resolved model, not None session.hitl_mode = request.hitl_mode or "none" session_manager.save_session(session) # Create and start agent for this session await connection_manager.get_or_create_agent( session.id, - model=request.model, + model=effective_model, hitl_mode=request.hitl_mode, )