diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 83e1588..97f16a1 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -96,6 +96,19 @@ jobs: steps: - uses: actions/checkout@v4 + # The ML stack (torch, transformers, bitsandbytes, …) is large. + # Reclaim ~25 GB by removing tools the Docker build doesn't need. + - name: Free disk space + uses: jlumbroso/free-disk-space@main + with: + tool-cache: false + android: true + dotnet: true + haskell: true + large-packages: true + docker-images: true + swap-storage: true + - name: Set up Buildx uses: docker/setup-buildx-action@v3 @@ -104,7 +117,7 @@ jobs: with: context: . push: false - load: true + load: false tags: tuneos:ci cache-from: type=gha cache-to: type=gha,mode=max diff --git a/README.md b/README.md index ed4a4f8..c8dd8bb 100644 --- a/README.md +++ b/README.md @@ -26,6 +26,14 @@ The application ships in two forms from one codebase: --- +## What's New + +- **7-step fine-tuning wizard** — a guided end-to-end flow from model selection through dataset, technique (LoRA/QLoRA), hyperparameters, live training, and deployment. Opens as a first-class workspace tab. +- **Experiment tracking** — every training run is recorded in a local SQLite database (`storage/experiments.db`). Run history, hyperparameters, loss curves, and final metrics persist across restarts and are browsable in the Experiments view. +- **Deploy tab** — after training completes, step 7 provides one-click actions: download the adapter weights, push to Hugging Face Hub, export to GGUF for local inference engines, push to a GitHub repository, and test the model in a built-in chat interface. + +--- + ## Capabilities | Domain | Description | @@ -34,6 +42,8 @@ The application ships in two forms from one codebase: | Dataset preparation | Generate, format, and validate instruction and chat datasets prior to training. | | Model conversion | Convert weights between Hugging Face, SafeTensors, and GGUF formats for downstream inference engines. | | Training analysis | Track loss curves, evaluation metrics, and run history in real time. | +| Experiment tracking | Persist every fine-tuning run (hyperparameters, loss history, metrics) in a local SQLite database, with comparison and filtering across runs. | +| Model deployment | Download adapter weights, push to Hugging Face Hub or GitHub, export to GGUF, and test the fine-tuned model via a built-in inference chat. | | Model inspection | Explore architecture, tokenization behavior, and configuration of any supported checkpoint. | --- diff --git a/app/api.py b/app/api.py deleted file mode 100644 index b7dd41e..0000000 --- a/app/api.py +++ /dev/null @@ -1,399 +0,0 @@ -""" -TuneOS — REST API endpoints (FastAPI Router). - -Mounted under ``/api`` by the Reflex application. Provides health -checks, GPU detection, model listing, and CRUD placeholders for -fine-tuning jobs. -""" - -from __future__ import annotations - -import io -import json -import os -import platform -import subprocess -import uuid -import zipfile - -from fastapi import FastAPI, HTTPException -from fastapi.responses import StreamingResponse -from pydantic import BaseModel, Field - -# ── Router ─────────────────────────────────────────────────────── -app_api = FastAPI(title="TuneOS API") - -# ── Constants ──────────────────────────────────────────────────── -_VERSION = "0.1.0" -REDIS_URL = os.getenv("REDIS_URL", "redis://localhost:6379/0") -OUTPUT_DIR = os.getenv("OUTPUT_DIR", "./outputs") - -# Process-level inference cache: job_id -> (model, tokenizer) -# Max 1 entry — evicted when a different job_id is requested. -_INFER_CACHE: dict = {} - -_SUPPORTED_MODELS: list[dict[str, str]] = [ - { - "name": "Mistral 7B", - "hf_id": "mistralai/Mistral-7B-v0.1", - "notes": "Primary target, well-tested with QLoRA", - }, - { - "name": "Llama 3 8B", - "hf_id": "meta-llama/Meta-Llama-3-8B", - "notes": "Requires HF token", - }, - { - "name": "Phi-3 Mini", - "hf_id": "microsoft/Phi-3-mini-4k-instruct", - "notes": "Fast, runs on smaller GPUs", - }, - { - "name": "Gemma 2B", - "hf_id": "google/gemma-2b", - "notes": "Good for low-VRAM environments", - }, -] - - -# ── Pydantic Schemas ───────────────────────────────────────────── -class HealthResponse(BaseModel): - """Health-check response.""" - - status: str = "ok" - version: str = _VERSION - - -class GpuInfo(BaseModel): - """GPU detection result.""" - - available: bool - backend: str - name: str - detail: str = "" - - -class ModelInfo(BaseModel): - """A supported base model.""" - - name: str - hf_id: str - notes: str = "" - - -class JobConfig(BaseModel): - """Request body for creating a new fine-tuning job.""" - - model_id: str = Field(..., description="Hugging Face model ID") - dataset_path: str = Field(..., description="Path to the uploaded dataset") - lora_rank: int = Field(default=8, ge=1, le=256) - lora_alpha: int = Field(default=16, ge=1) - learning_rate: float = Field(default=2e-4, gt=0) - epochs: int = Field(default=3, ge=1, le=100) - batch_size: int = Field(default=4, ge=1) - - -class JobStatus(BaseModel): - """Response schema for job status.""" - - job_id: str - status: str - progress: float = 0.0 - message: str = "" - - -class JobCreated(BaseModel): - """Response returned when a job is successfully queued.""" - - job_id: str - status: str = "queued" - - -# ── GPU Detection ──────────────────────────────────────────────── -def _detect_gpu() -> GpuInfo: - """Detect the available GPU (NVIDIA via nvidia-smi, Apple MPS via sysctl).""" - # NVIDIA check - try: - result = subprocess.run( - ["nvidia-smi", "--query-gpu=name", "--format=csv,noheader,nounits"], - capture_output=True, - text=True, - timeout=5, - ) - if result.returncode == 0 and result.stdout.strip(): - gpu_name = result.stdout.strip().split("\n")[0] - return GpuInfo( - available=True, - backend="cuda", - name=gpu_name, - detail=result.stdout.strip(), - ) - except (FileNotFoundError, subprocess.TimeoutExpired): - pass - - # Apple Silicon MPS check - if platform.system() == "Darwin": - try: - result = subprocess.run( - ["sysctl", "-n", "machdep.cpu.brand_string"], - capture_output=True, - text=True, - timeout=5, - ) - if result.returncode == 0: - cpu = result.stdout.strip() - if "Apple" in cpu: - return GpuInfo( - available=True, - backend="mps", - name=cpu, - detail="Apple Metal Performance Shaders", - ) - except (FileNotFoundError, subprocess.TimeoutExpired): - pass - - return GpuInfo(available=False, backend="cpu", name="CPU only") - - -# ── Endpoints ──────────────────────────────────────────────────── -@app_api.get("/health", response_model=HealthResponse) -async def health() -> HealthResponse: - """Basic liveness / readiness check.""" - return HealthResponse() - - -@app_api.get("/gpu", response_model=GpuInfo) -async def gpu_info() -> GpuInfo: - """Detect and return GPU information.""" - return _detect_gpu() - - -@app_api.get("/models", response_model=list[ModelInfo]) -async def list_models() -> list[ModelInfo]: - """Return the list of supported base models.""" - return [ModelInfo(**m) for m in _SUPPORTED_MODELS] - - -# ── Celery + Redis wiring ──────────────────────────────────────── -def _get_celery(): - """Import Celery app lazily so the API starts even if Redis is down.""" - from workers.celery_app import celery_app - - return celery_app - - -def _get_job_status_from_redis(job_id: str) -> dict: - """Read job status from Redis. Returns dict with at least 'status' key.""" - try: - from workers.status import get_job_status - - return get_job_status(job_id) - except Exception: - return {"status": "unknown", "job_id": job_id} - - -# ── Job CRUD ───────────────────────────────────────────────────── -@app_api.get("/jobs", response_model=list[JobStatus]) -async def list_jobs() -> list[JobStatus]: - """List all fine-tuning jobs (placeholder — returns empty list).""" - return [] - - -@app_api.post("/jobs", response_model=JobCreated, status_code=201) -async def create_job(config: JobConfig) -> JobCreated: - """Create and enqueue a new fine-tuning job.""" - job_id = str(uuid.uuid4()) - - model_cfg = { - "model_name": config.model_id, - "use_4bit": True, - "use_8bit": False, - "trust_remote_code": False, - } - lora_cfg = { - "r": config.lora_rank, - "lora_alpha": config.lora_alpha, - "lora_dropout": 0.05, - "target_modules": ["q_proj", "v_proj"], - } - train_cfg = { - "output_dir": f"./outputs/{job_id}", - "num_train_epochs": config.epochs, - "per_device_train_batch_size": config.batch_size, - "learning_rate": config.learning_rate, - } - - try: - from workers.train_task import run_finetune - - run_finetune.apply_async( - args=[job_id, model_cfg, lora_cfg, train_cfg, config.dataset_path], - task_id=job_id, - ) - except Exception as exc: - raise HTTPException(status_code=503, detail=f"Could not enqueue job: {exc}") from exc - - return JobCreated(job_id=job_id) - - -@app_api.get("/jobs/{job_id}", response_model=JobStatus) -async def get_job(job_id: str) -> JobStatus: - """Get real-time status of a job from Redis.""" - state = _get_job_status_from_redis(job_id) - return JobStatus( - job_id=job_id, - status=state.get("status", "unknown"), - progress=state.get("progress", 0.0), - message=state.get("error", ""), - ) - - -@app_api.delete("/jobs/{job_id}", response_model=JobStatus) -async def cancel_job(job_id: str) -> JobStatus: - """Revoke a queued or running Celery task.""" - try: - celery_app = _get_celery() - celery_app.control.revoke(job_id, terminate=True, signal="SIGTERM") - except Exception: - pass # Best-effort cancellation - return JobStatus(job_id=job_id, status="cancelled") - - -# ── Post-training endpoints ─────────────────────────────────────── - -@app_api.get("/jobs/{job_id}/download") -async def download_adapter(job_id: str): - """Zip and stream the adapter weights directory.""" - adapter_dir = os.path.join(OUTPUT_DIR, job_id) - if not os.path.isdir(adapter_dir): - raise HTTPException(status_code=404, detail="Adapter not found") - - buf = io.BytesIO() - with zipfile.ZipFile(buf, "w", zipfile.ZIP_DEFLATED) as zf: - for fname in os.listdir(adapter_dir): - full = os.path.join(adapter_dir, fname) - if os.path.isfile(full): - zf.write(full, fname) - buf.seek(0) - - return StreamingResponse( - buf, - media_type="application/zip", - headers={"Content-Disposition": f"attachment; filename=adapter_{job_id[:8]}.zip"}, - ) - - -class PushHubRequest(BaseModel): - repo_name: str - hf_token: str = "" - - -@app_api.post("/jobs/{job_id}/push_hub") -async def push_to_hub(job_id: str, req: PushHubRequest): - """Push the adapter to Hugging Face Hub.""" - token = req.hf_token or os.getenv("HF_TOKEN", "") - if not token: - raise HTTPException(status_code=400, detail="HF token required") - - adapter_dir = os.path.join(OUTPUT_DIR, job_id) - if not os.path.isdir(adapter_dir): - raise HTTPException(status_code=404, detail="Adapter not found") - - try: - from huggingface_hub import HfApi - - api = HfApi(token=token) - api.create_repo(repo_id=req.repo_name, repo_type="model", exist_ok=True, private=True) - api.upload_folder(folder_path=adapter_dir, repo_id=req.repo_name, repo_type="model") - except Exception as exc: - raise HTTPException(status_code=500, detail=str(exc)) from exc - - return {"status": "pushed", "repo_url": f"https://huggingface.co/{req.repo_name}"} - - -@app_api.get("/jobs/{job_id}/eval") -async def get_eval(job_id: str): - """Read evaluation metrics written by the training worker.""" - try: - import redis as _redis - - r = _redis.from_url(REDIS_URL) - raw = r.get(f"job:{job_id}:eval") - if not raw: - return {"status": "not_ready", "perplexity": None, "bleu": None} - data = json.loads(raw) - return {"status": "done", **data} - except Exception as exc: - raise HTTPException(status_code=500, detail=str(exc)) from exc - - -class InferRequest(BaseModel): - prompt: str - max_new_tokens: int = 200 - temperature: float = 0.7 - - -@app_api.post("/jobs/{job_id}/infer") -async def infer(job_id: str, req: InferRequest): - """Run inference using the fine-tuned adapter (loaded lazily, cached in-process).""" - import torch - - state = _get_job_status_from_redis(job_id) - if state.get("status") != "done": - raise HTTPException(status_code=400, detail="Job not complete") - - adapter_dir = state.get("output_path") or os.path.join(OUTPUT_DIR, job_id) - if not os.path.isdir(adapter_dir): - raise HTTPException(status_code=404, detail="Adapter directory not found") - - global _INFER_CACHE - - if job_id not in _INFER_CACHE: - # Evict any previously cached model to free VRAM - _INFER_CACHE.clear() - - config_path = os.path.join(adapter_dir, "adapter_config.json") - if not os.path.exists(config_path): - raise HTTPException(status_code=404, detail="adapter_config.json not found") - - with open(config_path) as f: - adapter_cfg = json.load(f) - base_model_name = adapter_cfg.get("base_model_name_or_path", "") - if not base_model_name: - raise HTTPException(status_code=500, detail="base_model_name_or_path missing in adapter_config.json") - - try: - from peft import PeftModel - from transformers import AutoModelForCausalLM, AutoTokenizer - - base_model = AutoModelForCausalLM.from_pretrained( - base_model_name, - device_map="auto", - torch_dtype=torch.float16, - ) - tokenizer = AutoTokenizer.from_pretrained(base_model_name) - model = PeftModel.from_pretrained(base_model, adapter_dir) - model.eval() - _INFER_CACHE[job_id] = (model, tokenizer) - except Exception as exc: - raise HTTPException(status_code=500, detail=f"Failed to load model: {exc}") from exc - - model, tokenizer = _INFER_CACHE[job_id] - - try: - inputs = tokenizer(req.prompt, return_tensors="pt").to(model.device) - with torch.no_grad(): - out = model.generate( - **inputs, - max_new_tokens=req.max_new_tokens, - temperature=req.temperature, - do_sample=True, - pad_token_id=tokenizer.eos_token_id, - ) - response = tokenizer.decode( - out[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True - ) - except Exception as exc: - raise HTTPException(status_code=500, detail=f"Inference failed: {exc}") from exc - - return {"response": response} diff --git a/app/api/__init__.py b/app/api/__init__.py new file mode 100644 index 0000000..5cb56ed --- /dev/null +++ b/app/api/__init__.py @@ -0,0 +1,26 @@ +"""TuneOS REST API package. + +Exports ``app_api`` — a FastAPI instance with all routers mounted. +""" + +from __future__ import annotations + +from fastapi import FastAPI + +from app.api import ( + datasets_routes, + experiments_routes, + jobs_routes, + models_routes, + system, +) + +app_api = FastAPI(title="TuneOS API") + +app_api.include_router(system.router) +app_api.include_router(models_routes.router) +app_api.include_router(datasets_routes.router) +app_api.include_router(jobs_routes.router) +app_api.include_router(experiments_routes.router) + +__all__ = ["app_api"] diff --git a/app/api/datasets_routes.py b/app/api/datasets_routes.py new file mode 100644 index 0000000..81cf057 --- /dev/null +++ b/app/api/datasets_routes.py @@ -0,0 +1,294 @@ +"""Dataset search, preview, and generation API routes.""" + +from __future__ import annotations + +import json +import os +import re +import uuid + +from fastapi import APIRouter, HTTPException, Query + +from app.api.deps import DATASET_DIR +from app.api.schemas import DatasetGenRequest + +router = APIRouter() + + +@router.get("/datasets/search") +async def search_datasets(q: str = Query(default="", description="Search query")): + """Search HF Hub datasets.""" + import asyncio + + def _search(): + from huggingface_hub import list_datasets + + results = list(list_datasets(search=q or None, limit=20, sort="downloads")) + return [ + { + "id": d.id, + "downloads": getattr(d, "downloads", 0), + "tags": getattr(d, "tags", []), + "description": "", + } + for d in results + ] + + try: + results = await asyncio.get_event_loop().run_in_executor(None, _search) + return {"results": results} + except Exception as exc: + raise HTTPException(status_code=500, detail=str(exc)) from exc + + +@router.get("/datasets/{dataset_id:path}/preview") +async def preview_dataset(dataset_id: str): + """Fetch first 5 rows and column names from an HF Hub dataset.""" + import asyncio + + def _load(): + from datasets import load_dataset + + ds = load_dataset(dataset_id, split="train[:5]", trust_remote_code=False) + rows = [ + dict(zip(ds.column_names, [ds[col][i] for col in ds.column_names], strict=False)) + for i in range(len(ds)) + ] + return ds.column_names, rows + + try: + columns, rows = await asyncio.get_event_loop().run_in_executor(None, _load) + return {"columns": columns, "rows": rows} + except Exception as exc: + raise HTTPException(status_code=500, detail=str(exc)) from exc + + +@router.post("/datasets/generate") +async def generate_dataset(req: DatasetGenRequest): + """Generate synthetic training data from a plain-English use-case description.""" + import asyncio + + def _generate(): + hf_token = req.hf_token or os.getenv("HF_TOKEN", "") + samples = [] + + if hf_token and req.method in ("self_instruct", "few_shot"): + try: + samples = _self_instruct_generate( + req.user_intent, req.n_samples, req.seed_examples, hf_token + ) + except Exception: + samples = [] + + if not samples: + samples = _template_generate(req.user_intent, req.n_samples) + + # Dedup by instruction (approximate) + seen = set() + unique = [] + for s in samples: + key = s.get("instruction", "")[:60].lower() + if key not in seen: + seen.add(key) + unique.append(s) + + stats = { + "total_generated": len(samples), + "final_count": len(unique), + "diversity_score": _diversity_score(unique), + } + + # Save to disk + os.makedirs(DATASET_DIR, exist_ok=True) + fname = f"generated_{uuid.uuid4().hex[:8]}.jsonl" + fpath = os.path.join(DATASET_DIR, fname) + with open(fpath, "w") as f: + for row in unique: + f.write(json.dumps(row) + "\n") + + return {"samples": unique, "dataset_path": fpath, "stats": stats} + + result = await asyncio.get_event_loop().run_in_executor(None, _generate) + return result + + +# ── Dataset generation helpers ─────────────────────────────────── + + +def _self_instruct_generate(intent: str, n: int, seeds: list[dict], hf_token: str) -> list[dict]: + from huggingface_hub import InferenceClient + + client = InferenceClient(token=hf_token) + seed_str = "\n".join( + f"- Instruction: {s['instruction']}\n Output: {s['output']}" + for s in (seeds or _default_seeds(intent))[:5] + ) + prompt = ( + f"You are a dataset creator. The user wants to fine-tune a language model for: {intent}\n\n" + f"Here are some example instruction/output pairs:\n{seed_str}\n\n" + f"Generate {n} more diverse and high-quality examples in this JSON format:\n" + f'[{{"instruction": "...", "output": "..."}}, ...]\n\n' + f"Return ONLY the JSON array, no other text." + ) + text = client.text_generation( + prompt, model="mistralai/Mistral-7B-Instruct-v0.2", max_new_tokens=min(4096, n * 60) + ) + match = re.search(r"\[.*?\]", text, re.DOTALL) + if match: + return json.loads(match.group()) + return [] + + +def _default_seeds(intent: str) -> list[dict]: + intent_lower = intent.lower() + if any(k in intent_lower for k in ["health", "medical", "doctor", "diabetes"]): + return [ + { + "instruction": "What are the symptoms of Type 2 diabetes?", + "output": "Common symptoms include frequent urination, increased thirst, fatigue, blurred vision, and slow-healing wounds.", + }, + { + "instruction": "How often should a diabetic check their blood sugar?", + "output": "Most people with Type 2 diabetes should check 1–4 times daily, but your doctor will give specific guidance based on your treatment plan.", + }, + ] + if any(k in intent_lower for k in ["code", "programming", "python", "developer"]): + return [ + { + "instruction": "Write a Python function to reverse a string.", + "output": "def reverse_string(s: str) -> str:\n return s[::-1]", + }, + { + "instruction": "What is the difference between a list and a tuple in Python?", + "output": "Lists are mutable (can be changed after creation) while tuples are immutable. Lists use [], tuples use ().", + }, + ] + return [ + { + "instruction": f"Tell me about {intent}.", + "output": f"Here is a helpful response about {intent}.", + }, + { + "instruction": f"What is the best way to approach {intent}?", + "output": f"The best approach for {intent} involves careful planning, clear goals, and iterative improvement.", + }, + ] + + +def _template_generate(intent: str, n: int) -> list[dict]: + import random + + intent_lower = intent.lower() + if any(k in intent_lower for k in ["health", "medical", "doctor", "diabetes", "nutrition"]): + templates = [ + ("What is {topic}?", "It is a medical condition/concept related to {intent_short}."), + ( + "How can I manage {topic}?", + "Managing {topic} involves lifestyle changes, medication, and regular monitoring.", + ), + ( + "What foods should I avoid with {topic}?", + "With {topic}, it is best to limit processed foods, refined sugars, and high-sodium items.", + ), + ( + "When should I see a doctor about {topic}?", + "Seek medical advice if you experience persistent or worsening symptoms related to {topic}.", + ), + ( + "What are the early signs of {topic}?", + "Early signs may include fatigue, discomfort, and changes in normal bodily functions.", + ), + ] + topics = [ + "diabetes", + "hypertension", + "heart disease", + "obesity", + "cholesterol", + "inflammation", + "nutrition", + "exercise recovery", + ] + elif any(k in intent_lower for k in ["code", "programming", "python", "developer", "software"]): + templates = [ + ( + "How do I {topic} in Python?", + "Here is a simple example:\n```python\n# {topic} example\nresult = None # implement here\n```", + ), + ( + "What is the difference between {topic} and its alternative?", + "{topic} is commonly used for one scenario while its alternative suits another use case.", + ), + ( + "Debug this Python error: {topic}", + "This error typically occurs when the variable is undefined or out of scope. Check your variable declarations.", + ), + ( + "Explain {topic} with an example.", + "{topic} is a programming concept. Here is a simple example to illustrate it.", + ), + ( + "Write a function that {topic}.", + "```python\ndef solution():\n # {topic}\n pass\n```", + ), + ] + topics = [ + "sorts a list", + "reads a file", + "handles exceptions", + "makes an API call", + "parses JSON", + "validates input", + "formats strings", + "uses decorators", + ] + else: + templates = [ + ("What is {topic}?", "{topic} is an important aspect of {intent_short}."), + ( + "How does {topic} work?", + "{topic} works by following a structured process aligned with best practices.", + ), + ( + "What are the benefits of {topic}?", + "The main benefits include efficiency, clarity, and improved outcomes.", + ), + ( + "Can you explain {topic} in simple terms?", + "Simply put, {topic} is about achieving a specific goal in a structured way.", + ), + ( + "What should I know about {topic}?", + "Key things to know: it requires preparation, practice, and continuous learning.", + ), + ] + words = [w for w in intent.split() if len(w) > 3] + topics = words * max(1, n // max(len(words), 1) + 1) + + intent_short = intent[:30] if len(intent) > 30 else intent + samples = [] + for _i in range(n): + topic = random.choice(topics) + tmpl_inst, tmpl_out = random.choice(templates) + instruction = tmpl_inst.format(topic=topic, intent_short=intent_short) + output = tmpl_out.format(topic=topic, intent_short=intent_short) + samples.append({"instruction": instruction, "output": output}) + + return samples + + +def _diversity_score(samples: list[dict]) -> float: + if len(samples) < 2: + return 0.0 + instructions = [s.get("instruction", "") for s in samples] + # Approximate diversity: avg fraction of unique words across instructions + all_words = set() + per_sample_words = [] + for inst in instructions: + words = set(inst.lower().split()) + per_sample_words.append(words) + all_words |= words + if not all_words: + return 0.0 + avg_unique = sum(len(w) for w in per_sample_words) / len(per_sample_words) + return round(min(1.0, avg_unique / max(len(all_words), 1)), 3) diff --git a/app/api/deps.py b/app/api/deps.py new file mode 100644 index 0000000..a0dc509 --- /dev/null +++ b/app/api/deps.py @@ -0,0 +1,82 @@ +"""Shared constants and helper functions used across all API routers.""" + +from __future__ import annotations + +import os +import platform +import subprocess + +from app.api.schemas import GpuInfo + +REDIS_URL = os.getenv("REDIS_URL", "redis://localhost:6379/0") +OUTPUT_DIR = os.getenv("OUTPUT_DIR", "./outputs") +DATASET_DIR = os.getenv("DATASET_DIR", "./storage/datasets") + +_SUPPORTED_MODELS: list[dict] = [ + { + "name": "Mistral 7B", + "hf_id": "mistralai/Mistral-7B-v0.1", + "notes": "Primary target, well-tested with QLoRA", + }, + {"name": "Llama 3 8B", "hf_id": "meta-llama/Meta-Llama-3-8B", "notes": "Requires HF token"}, + { + "name": "Phi-3 Mini", + "hf_id": "microsoft/Phi-3-mini-4k-instruct", + "notes": "Fast, runs on smaller GPUs", + }, + {"name": "Gemma 2B", "hf_id": "google/gemma-2b", "notes": "Good for low-VRAM environments"}, +] + + +def _redis_sync(): + import redis + + return redis.from_url(REDIS_URL) + + +def _get_job_status_from_redis(job_id: str) -> dict: + try: + from workers.status import get_job_status + + return get_job_status(job_id) + except Exception: + return {"status": "unknown", "job_id": job_id} + + +def _detect_gpu() -> GpuInfo: + try: + result = subprocess.run( + ["nvidia-smi", "--query-gpu=name", "--format=csv,noheader,nounits"], + capture_output=True, + text=True, + timeout=5, + ) + if result.returncode == 0 and result.stdout.strip(): + return GpuInfo( + available=True, + backend="cuda", + name=result.stdout.strip().split("\n")[0], + detail=result.stdout.strip(), + ) + except (FileNotFoundError, subprocess.TimeoutExpired): + pass + + if platform.system() == "Darwin": + try: + result = subprocess.run( + ["sysctl", "-n", "machdep.cpu.brand_string"], + capture_output=True, + text=True, + timeout=5, + ) + if result.returncode == 0 and "Apple" in result.stdout: + return GpuInfo( + available=True, + backend="mps", + name=result.stdout.strip(), + detail="Apple Metal Performance Shaders", + ) + except (FileNotFoundError, subprocess.TimeoutExpired): + pass + + return GpuInfo(available=False, backend="cpu", name="CPU only") diff --git a/app/api/experiments_routes.py b/app/api/experiments_routes.py new file mode 100644 index 0000000..e1f62e8 --- /dev/null +++ b/app/api/experiments_routes.py @@ -0,0 +1,42 @@ +"""Experiment tracking API routes.""" + +from __future__ import annotations + +import os +import sqlite3 + +from fastapi import APIRouter, HTTPException + +router = APIRouter() + + +def _get_db_path() -> str: + return os.getenv("EXPERIMENT_DB", "./storage/experiments.db") + + +@router.get("/experiments") +async def list_experiments(): + try: + db_path = _get_db_path() + if not os.path.exists(db_path): + return {"runs": []} + conn = sqlite3.connect(db_path) + conn.row_factory = sqlite3.Row + rows = conn.execute("SELECT * FROM runs ORDER BY started_at DESC").fetchall() + conn.close() + return {"runs": [dict(r) for r in rows]} + except Exception as exc: + raise HTTPException(status_code=500, detail=str(exc)) from exc + + +@router.delete("/experiments/{experiment_id}") +async def delete_experiment(experiment_id: str): + try: + db_path = _get_db_path() + conn = sqlite3.connect(db_path) + conn.execute("DELETE FROM runs WHERE id = ?", (experiment_id,)) + conn.commit() + conn.close() + return {"status": "deleted"} + except Exception as exc: + raise HTTPException(status_code=500, detail=str(exc)) from exc diff --git a/app/api/jobs_routes.py b/app/api/jobs_routes.py new file mode 100644 index 0000000..0cafb10 --- /dev/null +++ b/app/api/jobs_routes.py @@ -0,0 +1,420 @@ +"""Job CRUD and per-job action routes.""" + +from __future__ import annotations + +import io +import json +import os +import uuid +import zipfile + +from fastapi import APIRouter, HTTPException +from fastapi.responses import StreamingResponse + +from app.api.deps import OUTPUT_DIR, REDIS_URL, _get_job_status_from_redis +from app.api.schemas import ( + CommentaryRequest, + GgufRequest, + GitHubPushRequest, + InferRequest, + JobConfig, + JobCreated, + JobStatus, + MergeRequest, + PushHubRequest, +) + +router = APIRouter() + +_INFER_CACHE: dict = {} +_INFER_CACHE_LOCK: dict = {} # per-job threading.Lock acquired during model load + + +def _resolve_job_dir(job_id: str) -> str: + """Resolve job_id to an absolute path inside OUTPUT_DIR, rejecting traversal.""" + base = os.path.abspath(OUTPUT_DIR) + candidate = os.path.abspath(os.path.join(base, job_id)) + if os.path.commonpath([base, candidate]) != base: + raise HTTPException(status_code=400, detail="Invalid job id") + return candidate + + +@router.get("/jobs", response_model=list[JobStatus]) +async def list_jobs(): + return [] + + +@router.post("/jobs", response_model=JobCreated, status_code=201) +async def create_job(config: JobConfig): + """Create and enqueue a new fine-tuning job.""" + job_id = config.experiment_id or str(uuid.uuid4()) + + model_cfg = { + "model_name": config.model_id, + "use_4bit": config.use_4bit, + "use_8bit": False, + "trust_remote_code": False, + "max_seq_length": config.max_seq_length, + "hf_token": config.hf_token, + "local_model_path": config.local_model_path, + "model_source": config.model_source, + } + lora_cfg = { + "r": config.lora_rank, + "lora_alpha": config.lora_alpha, + "lora_dropout": config.lora_dropout, + "bias": "none", + "task_type": "CAUSAL_LM", + "target_modules": ["q_proj", "v_proj"], + } + train_cfg = { + "output_dir": OUTPUT_DIR, + "num_train_epochs": config.epochs, + "per_device_train_batch_size": config.batch_size, + "gradient_accumulation_steps": config.gradient_accumulation_steps, + "learning_rate": config.learning_rate, + "fp16": not config.bf16, + "bf16": config.bf16, + "logging_steps": 1, + "save_steps": 100, + "warmup_ratio": config.warmup_ratio, + "lr_scheduler_type": config.lr_scheduler_type, + "optim": "paged_adamw_32bit", + "max_grad_norm": 0.3, + } + + try: + from workers.train_task import run_finetune + + run_finetune.apply_async( + kwargs={ + "job_id": job_id, + "model_cfg": model_cfg, + "lora_cfg": lora_cfg, + "train_cfg": train_cfg, + "dataset_path": config.dataset_path, + "hub_dataset_id": config.hub_dataset_id, + "hub_split": config.hub_dataset_split, + "instruction_col": config.instruction_col, + "output_col": config.output_col, + }, + task_id=job_id, + ) + except Exception as exc: + raise HTTPException(status_code=503, detail=f"Could not enqueue job: {exc}") from exc + + return JobCreated(job_id=job_id) + + +@router.get("/jobs/{job_id}", response_model=JobStatus) +async def get_job(job_id: str): + state = _get_job_status_from_redis(job_id) + return JobStatus( + job_id=job_id, + status=state.get("status", "unknown"), + progress=state.get("progress", 0.0), + message=state.get("error", ""), + output_path=state.get("output_path", ""), + error=state.get("error", ""), + ) + + +@router.delete("/jobs/{job_id}", response_model=JobStatus) +async def cancel_job(job_id: str): + try: + from workers.celery_app import celery_app + + celery_app.control.revoke(job_id, terminate=True, signal="SIGTERM") + except Exception: + pass + return JobStatus(job_id=job_id, status="cancelled") + + +@router.get("/jobs/{job_id}/download") +async def download_adapter(job_id: str): + adapter_dir = _resolve_job_dir(job_id) + if not os.path.isdir(adapter_dir): + raise HTTPException(status_code=404, detail="Adapter not found") + + buf = io.BytesIO() + with zipfile.ZipFile(buf, "w", zipfile.ZIP_DEFLATED) as zf: + for root, _, files in os.walk(adapter_dir): + for fname in files: + full = os.path.join(root, fname) + arcname = os.path.relpath(full, adapter_dir) + zf.write(full, arcname) + buf.seek(0) + return StreamingResponse( + buf, + media_type="application/zip", + headers={"Content-Disposition": f"attachment; filename=adapter_{job_id[:8]}.zip"}, + ) + + +@router.post("/jobs/{job_id}/push_hub") +async def push_to_hub(job_id: str, req: PushHubRequest): + token = req.hf_token or os.getenv("HF_TOKEN", "") + if not token: + raise HTTPException(status_code=400, detail="HF token required") + + adapter_dir = _resolve_job_dir(job_id) + if not os.path.isdir(adapter_dir): + raise HTTPException(status_code=404, detail="Adapter not found") + + try: + from huggingface_hub import HfApi + + api = HfApi(token=token) + api.create_repo(repo_id=req.repo_name, repo_type="model", exist_ok=True, private=True) + api.upload_folder(folder_path=adapter_dir, repo_id=req.repo_name, repo_type="model") + except Exception as exc: + raise HTTPException(status_code=500, detail=str(exc)) from exc + + return {"status": "pushed", "repo_url": f"https://huggingface.co/{req.repo_name}"} + + +@router.post("/jobs/{job_id}/merge", status_code=202) +async def merge_adapter(job_id: str, req: MergeRequest): + adapter_dir = _resolve_job_dir(job_id) + if not os.path.isdir(adapter_dir): + raise HTTPException(status_code=404, detail="Adapter not found") + + config_path = os.path.join(adapter_dir, "adapter_config.json") + if not os.path.exists(config_path): + raise HTTPException(status_code=404, detail="adapter_config.json not found") + + with open(config_path) as f: + adapter_cfg = json.load(f) + base_model_id = adapter_cfg.get("base_model_name_or_path", "") + if not base_model_id: + raise HTTPException(status_code=500, detail="base_model_name_or_path missing") + + try: + from workers.merge_task import merge_adapter_task + + merge_adapter_task.apply_async( + kwargs={ + "job_id": job_id, + "base_model_id": base_model_id, + "adapter_path": adapter_dir, + "hf_token": req.hf_token or os.getenv("HF_TOKEN", ""), + }, + task_id=f"{job_id}-merge", + ) + except Exception as exc: + raise HTTPException(status_code=503, detail=f"Could not enqueue merge: {exc}") from exc + + return {"status": "merging", "job_id": job_id} + + +@router.get("/jobs/{job_id}/download-merged") +async def download_merged(job_id: str): + merged_dir = os.path.join(_resolve_job_dir(job_id), "merged") + if not os.path.isdir(merged_dir): + raise HTTPException(status_code=404, detail="Merged model not found — run merge first") + + buf = io.BytesIO() + with zipfile.ZipFile(buf, "w", zipfile.ZIP_DEFLATED) as zf: + for root, _, files in os.walk(merged_dir): + for fname in files: + full = os.path.join(root, fname) + arcname = os.path.relpath(full, merged_dir) + zf.write(full, arcname) + buf.seek(0) + return StreamingResponse( + buf, + media_type="application/zip", + headers={"Content-Disposition": f"attachment; filename=merged_{job_id[:8]}.zip"}, + ) + + +@router.post("/jobs/{job_id}/export-gguf", status_code=202) +async def export_gguf(job_id: str, req: GgufRequest): + merged_dir = os.path.join(_resolve_job_dir(job_id), "merged") + if not os.path.isdir(merged_dir): + raise HTTPException(status_code=400, detail="Merge the model first before exporting GGUF") + + try: + from workers.merge_task import export_gguf_task + + export_gguf_task.apply_async( + kwargs={ + "job_id": job_id, + "merged_model_path": merged_dir, + "quant_type": req.quant_type, + }, + task_id=f"{job_id}-gguf", + ) + except Exception as exc: + raise HTTPException(status_code=503, detail=str(exc)) from exc + + return {"status": "exporting", "quant_type": req.quant_type} + + +@router.post("/jobs/{job_id}/push-github") +async def push_github(job_id: str, req: GitHubPushRequest): + adapter_dir = _resolve_job_dir(job_id) + if not os.path.isdir(adapter_dir): + raise HTTPException(status_code=404, detail="Adapter not found") + + try: + from workers.merge_task import push_github_task + + push_github_task.apply_async( + kwargs={ + "job_id": job_id, + "adapter_path": adapter_dir, + "repo_url": req.repo_url, + "github_token": req.github_token, + }, + task_id=f"{job_id}-github", + ) + except Exception as exc: + raise HTTPException(status_code=503, detail=str(exc)) from exc + + return {"status": "pushing"} + + +@router.post("/jobs/{job_id}/commentary") +async def get_commentary(job_id: str, req: CommentaryRequest): + """Return a plain-English sentence describing training progress.""" + epoch_frac = req.epoch / max(req.total_epochs, 1) + drop = req.loss_drop_pct + + quality = ( + "great" if drop > 40 else ("healthy" if drop > 20 else ("slow" if drop > 5 else "stalled")) + ) + phase = "early" if epoch_frac < 0.33 else ("middle" if epoch_frac < 0.67 else "final") + intent_frag = f" for your {req.intent}" if req.intent else "" + loss_verb = { + "great": "dropped significantly", + "healthy": "is decreasing steadily", + "slow": "is decreasing slowly", + "stalled": "has barely moved", + }[quality] + + templates = { + ( + "great", + "early", + ): f"Strong start! Loss {loss_verb} — the model is picking up patterns{intent_frag} quickly.", + ( + "great", + "middle", + ): f"Training is going well. Loss {loss_verb} and the model is solidifying its skills{intent_frag}.", + ( + "great", + "final", + ): f"Excellent run! Loss {loss_verb}. Your model looks ready{intent_frag}.", + ( + "healthy", + "early", + ): f"Good progress. Loss {loss_verb} — on track for a solid result{intent_frag}.", + ("healthy", "middle"): f"Training looks healthy. Loss {loss_verb}. Keep it running.", + ("healthy", "final"): f"Looking good in the final stretch. Loss {loss_verb}.", + ("slow", "early"): f"Loss {loss_verb} — a slow start is normal. Give it a few more epochs.", + ("slow", "middle"): f"Loss {loss_verb}. Consider a higher learning rate if this continues.", + ( + "slow", + "final", + ): f"Loss {loss_verb}. The model may need more data or more epochs next time.", + ( + "stalled", + "early", + ): f"Loss {loss_verb} yet. Try a higher learning rate or check your dataset.", + ( + "stalled", + "middle", + ): f"Loss {loss_verb}. Training may be stuck — check the learning rate.", + ( + "stalled", + "final", + ): f"Loss {loss_verb} much. Try more epochs or a larger learning rate next run.", + } + + commentary = templates.get( + (quality, phase), f"Training in progress. Current loss: {req.current_loss:.4f}." + ) + return {"commentary": commentary} + + +@router.get("/jobs/{job_id}/eval") +async def get_eval(job_id: str): + import redis.asyncio as aioredis + + r = aioredis.from_url(REDIS_URL) + try: + raw = await r.get(f"job:{job_id}:eval") + if not raw: + return {"status": "not_ready", "perplexity": None, "bleu": None} + data = json.loads(raw) + return {"status": "done", **data} + except Exception as exc: + raise HTTPException(status_code=500, detail=str(exc)) from exc + finally: + await r.aclose() + + +@router.post("/jobs/{job_id}/infer") +async def infer(job_id: str, req: InferRequest): + import torch + + state = _get_job_status_from_redis(job_id) + if state.get("status") != "done": + raise HTTPException(status_code=400, detail="Job not complete") + + adapter_dir = state.get("output_path") or _resolve_job_dir(job_id) + if not os.path.isdir(adapter_dir): + raise HTTPException(status_code=404, detail="Adapter directory not found") + + import threading + + global _INFER_CACHE, _INFER_CACHE_LOCK + if job_id not in _INFER_CACHE: + lock = _INFER_CACHE_LOCK.setdefault(job_id, threading.Lock()) + with lock: + # Double-checked: another thread may have loaded it while we waited + if job_id not in _INFER_CACHE: + config_path = os.path.join(adapter_dir, "adapter_config.json") + if not os.path.exists(config_path): + raise HTTPException(status_code=404, detail="adapter_config.json not found") + with open(config_path) as f: + adapter_cfg = json.load(f) + base_model_name = adapter_cfg.get("base_model_name_or_path", "") + if not base_model_name: + raise HTTPException(status_code=500, detail="base_model_name_or_path missing") + + try: + from peft import PeftModel + from transformers import AutoModelForCausalLM, AutoTokenizer + + base_model = AutoModelForCausalLM.from_pretrained( + base_model_name, device_map="auto", torch_dtype=torch.float16 + ) + tokenizer = AutoTokenizer.from_pretrained(base_model_name) + model = PeftModel.from_pretrained(base_model, adapter_dir) + model.eval() + _INFER_CACHE[job_id] = (model, tokenizer) + except Exception as exc: + raise HTTPException( + status_code=500, detail=f"Failed to load model: {exc}" + ) from exc + + model, tokenizer = _INFER_CACHE[job_id] + try: + inputs = tokenizer(req.prompt, return_tensors="pt").to(model.device) + with torch.no_grad(): + out = model.generate( + **inputs, + max_new_tokens=req.max_new_tokens, + temperature=req.temperature, + do_sample=True, + pad_token_id=tokenizer.eos_token_id, + ) + response = tokenizer.decode( + out[0][inputs["input_ids"].shape[1] :], skip_special_tokens=True + ) + except Exception as exc: + raise HTTPException(status_code=500, detail=f"Inference failed: {exc}") from exc + + return {"response": response} diff --git a/app/api/models_routes.py b/app/api/models_routes.py new file mode 100644 index 0000000..9da933f --- /dev/null +++ b/app/api/models_routes.py @@ -0,0 +1,40 @@ +"""Model listing and validation API routes.""" + +from __future__ import annotations + +import os + +from fastapi import APIRouter + +from app.api.deps import _SUPPORTED_MODELS +from app.api.schemas import ModelInfo, ModelValidateRequest + +router = APIRouter() + + +@router.get("/models", response_model=list[ModelInfo]) +async def list_models(): + return [ModelInfo(**m) for m in _SUPPORTED_MODELS] + + +@router.post("/models/validate") +async def validate_model(req: ModelValidateRequest): + """Validate that a model ID is loadable (HF Hub or local path).""" + import asyncio + + token = req.hf_token or os.getenv("HF_TOKEN") or None + + def _check(): + from transformers import AutoConfig + + cfg = AutoConfig.from_pretrained(req.model_id, token=token, trust_remote_code=False) + return cfg.model_type, getattr(cfg, "num_parameters", lambda: None)() + + try: + model_type, num_params = await asyncio.wait_for( + asyncio.get_event_loop().run_in_executor(None, _check), timeout=20.0 + ) + param_str = f"{num_params / 1e9:.1f}B" if num_params else "unknown size" + return {"valid": True, "model_type": model_type, "num_params": param_str, "error": ""} + except Exception as exc: + return {"valid": False, "model_type": "", "num_params": "", "error": str(exc)} diff --git a/app/api/schemas.py b/app/api/schemas.py new file mode 100644 index 0000000..9b22b9b --- /dev/null +++ b/app/api/schemas.py @@ -0,0 +1,112 @@ +"""Pydantic schemas / response models for the TuneOS API.""" + +from __future__ import annotations + +from pydantic import BaseModel, Field + +_VERSION = "0.2.0" + + +class HealthResponse(BaseModel): + status: str = "ok" + version: str = _VERSION + + +class GpuInfo(BaseModel): + available: bool + backend: str + name: str + detail: str = "" + + +class ModelInfo(BaseModel): + name: str + hf_id: str + notes: str = "" + + +class JobConfig(BaseModel): + model_id: str + model_source: str = "hub" + local_model_path: str = "" + hf_token: str = "" + dataset_path: str = "" + hub_dataset_id: str = "" + hub_dataset_split: str = "train" + instruction_col: str = "instruction" + output_col: str = "output" + technique: str = "qlora" + use_4bit: bool = True + lora_rank: int = Field(default=16, ge=1, le=256) + lora_alpha: int = Field(default=32, ge=1) + lora_dropout: float = Field(default=0.05, ge=0.0, le=0.5) + learning_rate: float = Field(default=2e-4, gt=0) + epochs: int = Field(default=3, ge=1, le=100) + batch_size: int = Field(default=4, ge=1) + max_seq_length: int = Field(default=512, ge=64) + gradient_accumulation_steps: int = Field(default=4, ge=1) + warmup_ratio: float = Field(default=0.03, ge=0.0, le=0.5) + lr_scheduler_type: str = "cosine" + bf16: bool = False + user_intent: str = "" + experiment_name: str = "" + experiment_id: str = "" + + +class JobStatus(BaseModel): + job_id: str + status: str + progress: float = 0.0 + message: str = "" + output_path: str = "" + error: str = "" + + +class JobCreated(BaseModel): + job_id: str + status: str = "queued" + + +class ModelValidateRequest(BaseModel): + model_id: str + hf_token: str = "" + + +class DatasetGenRequest(BaseModel): + user_intent: str + method: str = "self_instruct" + n_samples: int = Field(default=50, ge=5, le=500) + seed_examples: list[dict] = [] + hf_token: str = "" + + +class CommentaryRequest(BaseModel): + epoch: float + total_epochs: int + loss_drop_pct: float + current_loss: float + intent: str = "" + + +class PushHubRequest(BaseModel): + repo_name: str + hf_token: str = "" + + +class InferRequest(BaseModel): + prompt: str + max_new_tokens: int = 300 + temperature: float = 0.7 + + +class MergeRequest(BaseModel): + hf_token: str = "" + + +class GgufRequest(BaseModel): + quant_type: str = "Q4_K_M" + + +class GitHubPushRequest(BaseModel): + repo_url: str + github_token: str diff --git a/app/api/system.py b/app/api/system.py new file mode 100644 index 0000000..3462eac --- /dev/null +++ b/app/api/system.py @@ -0,0 +1,20 @@ +"""System-level API routes: /health and /gpu.""" + +from __future__ import annotations + +from fastapi import APIRouter + +from app.api.deps import _detect_gpu +from app.api.schemas import GpuInfo, HealthResponse + +router = APIRouter() + + +@router.get("/health", response_model=HealthResponse) +async def health(): + return HealthResponse() + + +@router.get("/gpu", response_model=GpuInfo) +async def gpu_info(): + return _detect_gpu() diff --git a/app/components/finetune/progress_bar.py b/app/components/finetune/progress_bar.py new file mode 100644 index 0000000..d4adb0a --- /dev/null +++ b/app/components/finetune/progress_bar.py @@ -0,0 +1,78 @@ +"""Fine-tune wizard progress bar and step dots.""" + +from __future__ import annotations + +import reflex as rx + +from app.state.finetune_state import FinetuneState +from app.styles import c + +_STEP_LABELS = ["Model", "Intent", "Data", "Configure", "Train", "Results", "Deploy"] + + +def _step_dot(index: int) -> rx.Component: + step_num = index + 1 + is_done = FinetuneState.current_step > step_num + is_active = FinetuneState.current_step == step_num + return rx.vstack( + rx.box( + rx.cond( + is_done, + rx.icon("check", size=12, color="white"), + rx.text( + str(step_num), + font_size="0.72rem", + font_weight="600", + color=rx.cond(is_active, "white", c("text_muted")), + ), + ), + width="26px", + height="26px", + border_radius="50%", + background=rx.cond( + is_done, c("success"), rx.cond(is_active, c("accent"), c("bg_input")) + ), + border="2px solid", + border_color=rx.cond(is_active | is_done, c("accent"), c("border")), + display="flex", + align_items="center", + justify_content="center", + ), + rx.text( + _STEP_LABELS[index], + font_size="0.68rem", + color=rx.cond(is_active, c("text_primary"), c("text_muted")), + font_weight=rx.cond(is_active, "500", "400"), + ), + spacing="1", + align="center", + ) + + +def _progress_bar() -> rx.Component: + return rx.hstack( + *[ + rx.hstack( + _step_dot(i), + rx.box( + height="2px", + flex="1", + background=rx.cond( + FinetuneState.current_step > i + 1, c("accent"), c("border") + ), + min_width="20px", + ) + if i < len(_STEP_LABELS) - 1 + else rx.fragment(), + spacing="0", + align="center", + flex="1" if i < len(_STEP_LABELS) - 1 else "0", + ) + for i in range(len(_STEP_LABELS)) + ], + width="100%", + max_width="680px", + align="center", + justify="center", + margin_bottom="32px", + ) diff --git a/app/components/finetune/shared.py b/app/components/finetune/shared.py new file mode 100644 index 0000000..3b16261 --- /dev/null +++ b/app/components/finetune/shared.py @@ -0,0 +1,120 @@ +"""Shared UI helpers used across all fine-tune wizard steps.""" + +from __future__ import annotations + +import reflex as rx + +from app.state.finetune_state import FinetuneState +from app.styles import c + + +def _card(*children, padding: str = "20px", width: str = "100%", **props) -> rx.Component: + return rx.box( + *children, + background=c("bg_card"), + border="1px solid", + border_color=c("border"), + border_radius="12px", + padding=padding, + width=width, + **props, + ) + + +def _label(text: str) -> rx.Component: + return rx.text( + text, font_size="0.8rem", font_weight="500", color=c("text_secondary"), margin_bottom="6px" + ) + + +def _section_heading(text: str) -> rx.Component: + return rx.text( + text, font_size="1.05rem", font_weight="600", color=c("text_primary"), margin_bottom="16px" + ) + + +def _nav_buttons( + back_label: str = "← Back", + next_label: str = "Next →", + next_disabled: bool = False, + next_event=None, + show_back: bool = True, +) -> rx.Component: + return rx.hstack( + rx.button( + back_label, + on_click=FinetuneState.prev_step, + variant="soft", + color_scheme="gray", + size="2", + ) + if show_back + else rx.fragment(), + rx.spacer(), + rx.button( + next_label, + on_click=next_event or FinetuneState.next_step, + disabled=next_disabled, + size="3", + color_scheme="blue", + ), + width="100%", + padding_top="16px", + ) + + +def _badge_status(status: str) -> rx.Component: + color = rx.match( + status, + ("running", "blue"), + ("done", "green"), + ("failed", "red"), + "gray", + ) + return rx.badge(status.upper(), color_scheme=color, size="2") + + +def _preview_table(rows: list, label: str = "Preview") -> rx.Component: + return rx.vstack( + rx.text(label, font_size="0.78rem", font_weight="500", color=c("text_muted")), + rx.table.root( + rx.table.header( + rx.table.row( + rx.table.column_header_cell("Instruction"), + rx.table.column_header_cell("Output"), + ) + ), + rx.table.body( + rx.foreach( + rows, + lambda row: rx.table.row( + rx.table.cell( + rx.text( + row["instruction"], + font_size="0.78rem", + overflow="hidden", + text_overflow="ellipsis", + white_space="nowrap", + max_width="300px", + ) + ), + rx.table.cell( + rx.text( + row["output"], + font_size="0.78rem", + overflow="hidden", + text_overflow="ellipsis", + white_space="nowrap", + max_width="260px", + ) + ), + ), + ) + ), + width="100%", + variant="surface", + size="1", + ), + width="100%", + spacing="2", + ) diff --git a/app/components/finetune/step1_model.py b/app/components/finetune/step1_model.py new file mode 100644 index 0000000..b39e783 --- /dev/null +++ b/app/components/finetune/step1_model.py @@ -0,0 +1,330 @@ +"""Fine-tune wizard — Step 1: Model selection and training technique.""" + +from __future__ import annotations + +import reflex as rx + +from app.components.finetune.shared import _card, _label, _nav_buttons, _section_heading +from app.state.finetune_state import FinetuneState +from app.styles import c + +_MODELS = [ + { + "id": "mistralai/Mistral-7B-v0.1", + "name": "Mistral 7B", + "size": "7B params", + "notes": "Well-tested with QLoRA, great all-rounder", + "token_required": False, + }, + { + "id": "meta-llama/Meta-Llama-3-8B", + "name": "Llama 3 8B", + "size": "8B params", + "notes": "Strong general-purpose model", + "token_required": True, + }, + { + "id": "microsoft/Phi-3-mini-4k-instruct", + "name": "Phi-3 Mini", + "size": "3.8B params", + "notes": "Fast, runs on smaller GPUs", + "token_required": False, + }, + { + "id": "google/gemma-2b", + "name": "Gemma 2B", + "size": "2B params", + "notes": "Good for low-VRAM environments", + "token_required": False, + }, + { + "id": "EleutherAI/pythia-410m", + "name": "Pythia 410M", + "size": "410M params", + "notes": "Tiny model — great for testing pipelines fast", + "token_required": False, + }, + { + "id": "bigcode/starcoder2-3b", + "name": "StarCoder2 3B", + "size": "3B params", + "notes": "Excellent for code generation tasks", + "token_required": False, + }, +] + +_GGUF_QUANTS = ["Q4_K_M", "Q5_K_M", "Q8_0", "F16"] + + +def _model_card(m: dict) -> rx.Component: + is_selected = FinetuneState.selected_model_id == m["id"] + return rx.box( + rx.vstack( + rx.hstack( + rx.text(m["name"], font_size="0.92rem", font_weight="600", color=c("text_primary")), + rx.cond( + m["token_required"], + rx.badge("HF Token", color_scheme="orange", size="1"), + rx.fragment(), + ), + justify="between", + width="100%", + ), + rx.text(m["size"], font_size="0.78rem", color=c("text_secondary")), + rx.text(m["notes"], font_size="0.78rem", color=c("text_muted")), + spacing="1", + align_items="flex-start", + width="100%", + ), + background=rx.cond(is_selected, c("accent_soft"), c("bg_card")), + border="2px solid", + border_color=rx.cond(is_selected, c("accent"), c("border")), + border_radius="10px", + padding="14px", + cursor="pointer", + width="100%", + on_click=FinetuneState.select_model(m["id"], m["name"]), + _hover={"border_color": c("accent"), "background": c("accent_soft")}, + ) + + +def _source_tab(source: str, label: str, icon: str) -> rx.Component: + is_active = FinetuneState.model_source == source + return rx.button( + rx.hstack(rx.icon(icon, size=14), rx.text(label), spacing="2", align="center"), + on_click=FinetuneState.set_model_source(source), + variant=rx.cond(is_active, "solid", "soft"), + color_scheme="blue", + size="2", + ) + + +def _step1() -> rx.Component: + return rx.vstack( + _section_heading("Choose your model"), + rx.text( + "Pick from common models, paste any Hugging Face ID, load a local file, " + "or type any model string that Transformers accepts.", + font_size="0.86rem", + color=c("text_secondary"), + margin_bottom="16px", + ), + # Source switcher + rx.hstack( + _source_tab("hub", "HF Hub", "globe"), + _source_tab("custom_string", "Any Model ID", "terminal"), + _source_tab("local", "Local File", "folder-open"), + spacing="2", + margin_bottom="20px", + ), + # Hub tab + rx.cond( + FinetuneState.model_source == "hub", + rx.vstack( + rx.grid(*[_model_card(m) for m in _MODELS], columns="2", spacing="3", width="100%"), + # HF token field for gated models + rx.box(height="12px"), + _card( + rx.vstack( + _label("HF Token (required for gated models like Llama)"), + rx.input( + placeholder="hf_xxxxxxxxxxxxx", + type="password", + value=FinetuneState.hf_token, + on_change=FinetuneState.set_hf_token, + width="100%", + ), + spacing="1", + ) + ), + width="100%", + spacing="0", + ), + rx.fragment(), + ), + # Custom string tab + rx.cond( + FinetuneState.model_source == "custom_string", + _card( + rx.vstack( + _label( + "Model ID or path (any string AutoModelForCausalLM.from_pretrained() accepts)" + ), + rx.hstack( + rx.input( + placeholder='e.g. "EleutherAI/gpt-j-6b" or "/local/path/to/model"', + value=FinetuneState.custom_model_str, + on_change=FinetuneState.set_custom_model_str, + flex="1", + ), + rx.button( + rx.cond( + FinetuneState.is_validating_model, + rx.hstack( + rx.spinner(size="1"), rx.text("Checking..."), spacing="2" + ), + rx.text("Validate"), + ), + on_click=FinetuneState.validate_and_select_custom_model, + disabled=FinetuneState.is_validating_model, + color_scheme="blue", + size="2", + ), + spacing="2", + ), + rx.cond( + FinetuneState.model_url_error != "", + rx.callout(FinetuneState.model_url_error, color_scheme="red", size="1"), + rx.fragment(), + ), + rx.cond( + FinetuneState.selected_model_id != "", + rx.callout( + rx.hstack( + rx.icon("check-circle", size=14), + rx.text(f"Model ready: {FinetuneState.selected_model_id}"), + spacing="2", + ), + color_scheme="green", + size="1", + ), + rx.fragment(), + ), + _label("HF Token (for gated or private models)"), + rx.input( + placeholder="hf_xxxxxxxxxxxxx", + type="password", + value=FinetuneState.hf_token, + on_change=FinetuneState.set_hf_token, + width="100%", + ), + rx.text( + "Note: If you skip validation, any errors will appear when training starts.", + font_size="0.75rem", + color=c("text_muted"), + ), + spacing="2", + ) + ), + rx.fragment(), + ), + # Local file tab + rx.cond( + FinetuneState.model_source == "local", + _card( + rx.vstack( + _label( + "Upload your model (.safetensors, .bin, .gguf, or .zip of model directory)" + ), + rx.upload( + rx.vstack( + rx.icon("upload", size=28, color=c("text_muted")), + rx.text("Drag & drop or click to upload", color=c("text_secondary")), + rx.text( + "Supports: .safetensors, .bin, .gguf, .zip", + font_size="0.75rem", + color=c("text_muted"), + ), + spacing="2", + align="center", + ), + id="model_upload", + border=f"2px dashed {c('border')}", + border_radius="10px", + padding="32px", + width="100%", + cursor="pointer", + on_drop=FinetuneState.handle_local_model_upload( + rx.upload_files(upload_id="model_upload") + ), + ), + rx.cond( + FinetuneState.local_model_path != "", + rx.callout( + rx.text(f"Loaded: {FinetuneState.local_model_path}"), + color_scheme="green", + size="1", + ), + rx.fragment(), + ), + spacing="2", + ) + ), + rx.fragment(), + ), + # Technique selector (always visible) + rx.box(height="20px"), + _section_heading("Training technique"), + rx.flex( + *[ + rx.box( + rx.vstack( + rx.hstack( + rx.text( + label, + font_size="0.88rem", + font_weight="500", + color=rx.cond( + FinetuneState.selected_technique == tech, + c("accent"), + c("text_primary"), + ), + ), + rx.cond( + FinetuneState.selected_technique == tech, + rx.icon("check-circle", size=14, color=c("accent")), + rx.fragment(), + ), + rx.cond( + coming_soon, + rx.badge("Soon", color_scheme="gray", size="1"), + rx.fragment(), + ), + spacing="2", + align="center", + ), + rx.text(desc, font_size="0.76rem", color=c("text_muted")), + spacing="1", + align_items="flex-start", + ), + background=rx.cond( + FinetuneState.selected_technique == tech, + c("accent_soft"), + c("bg_input"), + ), + border="1px solid", + border_color=rx.cond( + FinetuneState.selected_technique == tech, + c("accent"), + c("border"), + ), + border_radius="8px", + padding="12px 14px", + cursor=rx.cond(coming_soon, "not-allowed", "pointer"), + opacity=rx.cond(coming_soon, "0.5", "1"), + on_click=rx.cond( + coming_soon, rx.prevent_default, FinetuneState.select_technique(tech) + ), + flex="1", + min_width="140px", + ) + for tech, label, desc, coming_soon in [ + ("qlora", "QLoRA", "4-bit compressed. Runs on 12 GB+ GPU. Recommended.", False), + ("lora", "LoRA", "Float16. Needs ~16 GB GPU for 7B models.", False), + ("full", "Full Fine-tune", "All weights updated. Needs 80 GB+ GPU.", True), + ("dpo", "DPO", "Preference tuning for alignment.", True), + ] + ], + wrap="wrap", + gap="10px", + width="100%", + ), + _nav_buttons( + next_label="Next: Intent →", + next_disabled=~FinetuneState.can_go_to_intent, + show_back=False, + ), + spacing="0", + width="100%", + align_items="flex-start", + ) diff --git a/app/components/finetune/step2_intent.py b/app/components/finetune/step2_intent.py new file mode 100644 index 0000000..1d40d60 --- /dev/null +++ b/app/components/finetune/step2_intent.py @@ -0,0 +1,70 @@ +"""Fine-tune wizard — Step 2: Intent / use-case description.""" + +from __future__ import annotations + +import reflex as rx + +from app.components.finetune.shared import _card, _label, _nav_buttons, _section_heading +from app.state.finetune_state import FinetuneState +from app.styles import c + +_INTENT_IDEAS = [ + "Health chatbot for diabetes patients", + "Python code review assistant", + "Customer support for SaaS products", + "Legal document summarizer", + "Recipe recommendation assistant", + "Scientific paper Q&A bot", + "SQL query generator", + "Children's education tutor", +] + + +def _step2() -> rx.Component: + return rx.vstack( + _section_heading("What are you building?"), + rx.text( + "Describe your use-case in plain English. TuneOS uses this to generate starter data, " + "guide the training dashboard, and pre-fill the system prompt for testing.", + font_size="0.86rem", + color=c("text_secondary"), + margin_bottom="16px", + ), + _card( + rx.vstack( + _label("Your goal (1–3 sentences)"), + rx.text_area( + placeholder="e.g. A health chatbot that answers questions for people with Type 2 diabetes in simple language.", + value=FinetuneState.user_intent, + on_change=FinetuneState.set_user_intent, + rows="4", + width="100%", + resize="vertical", + ), + rx.text("Quick ideas:", font_size="0.76rem", color=c("text_muted")), + rx.flex( + *[ + rx.badge( + idea, + cursor="pointer", + on_click=FinetuneState.set_user_intent(idea), + color_scheme="blue", + variant="soft", + size="1", + ) + for idea in _INTENT_IDEAS + ], + wrap="wrap", + gap="6px", + ), + spacing="3", + ) + ), + _nav_buttons( + next_label="Next: Add Data →", + next_disabled=FinetuneState.user_intent == "", + ), + spacing="0", + width="100%", + align_items="flex-start", + ) diff --git a/app/components/finetune/step3_data.py b/app/components/finetune/step3_data.py new file mode 100644 index 0000000..d00ba91 --- /dev/null +++ b/app/components/finetune/step3_data.py @@ -0,0 +1,282 @@ +"""Fine-tune wizard — Step 3: Training data (upload / Hub / generate).""" + +from __future__ import annotations + +import reflex as rx + +from app.components.finetune.shared import ( + _card, + _label, + _nav_buttons, + _preview_table, + _section_heading, +) +from app.state.finetune_state import FinetuneState +from app.styles import c + + +def _data_mode_btn(mode: str, label: str, icon: str) -> rx.Component: + is_active = FinetuneState.data_source == mode + return rx.button( + rx.hstack(rx.icon(icon, size=14), rx.text(label), spacing="2", align="center"), + on_click=FinetuneState.set_data_source(mode), + variant=rx.cond(is_active, "solid", "soft"), + color_scheme="blue", + size="2", + ) + + +def _upload_panel() -> rx.Component: + return _card( + rx.vstack( + _label("Upload CSV, JSONL, or JSON array — any two columns work, you can remap them"), + rx.upload( + rx.vstack( + rx.icon("upload", size=28, color=c("text_muted")), + rx.text("Drag & drop or click to select a file", color=c("text_secondary")), + rx.text(".csv · .jsonl · .json", font_size="0.75rem", color=c("text_muted")), + spacing="2", + align="center", + ), + id="dataset_upload", + border=f"2px dashed {c('border')}", + border_radius="10px", + padding="24px", + width="100%", + cursor="pointer", + on_drop=FinetuneState.handle_dataset_upload( + rx.upload_files(upload_id="dataset_upload") + ), + ), + rx.cond( + FinetuneState.is_uploading, + rx.hstack( + rx.spinner(size="2"), rx.text("Uploading...", font_size="0.84rem"), spacing="2" + ), + rx.fragment(), + ), + rx.cond( + FinetuneState.dataset_error != "", + rx.callout(FinetuneState.dataset_error, color_scheme="red", size="1"), + rx.fragment(), + ), + rx.cond( + FinetuneState.dataset_preview.length() > 0, + _preview_table(FinetuneState.dataset_preview, "File preview (first 5 rows)"), + rx.fragment(), + ), + spacing="3", + ) + ) + + +def _hub_dataset_panel() -> rx.Component: + return _card( + rx.vstack( + rx.cond( + FinetuneState.hub_dataset_id != "", + rx.vstack( + rx.hstack( + rx.icon("database", size=16, color=c("accent")), + rx.text( + FinetuneState.hub_dataset_id, font_weight="500", color=c("text_primary") + ), + spacing="2", + align="center", + ), + rx.hstack( + rx.vstack( + _label("Instruction column"), + rx.input( + value=FinetuneState.hub_dataset_instruction_col, + on_change=FinetuneState.set_hub_instruction_col, + size="2", + width="180px", + ), + spacing="1", + ), + rx.vstack( + _label("Output column"), + rx.input( + value=FinetuneState.hub_dataset_output_col, + on_change=FinetuneState.set_hub_output_col, + size="2", + width="180px", + ), + spacing="1", + ), + rx.button( + "Load preview", + size="2", + color_scheme="blue", + variant="soft", + on_click=FinetuneState.load_hub_dataset_preview, + align_self="flex-end", + ), + spacing="4", + wrap="wrap", + ), + rx.cond( + FinetuneState.is_loading_hub_preview, + rx.hstack( + rx.spinner(size="2"), + rx.text("Loading...", font_size="0.84rem"), + spacing="2", + ), + rx.fragment(), + ), + rx.cond( + FinetuneState.hub_preview_error != "", + rx.callout(FinetuneState.hub_preview_error, color_scheme="red", size="1"), + rx.fragment(), + ), + rx.cond( + FinetuneState.hub_dataset_preview.length() > 0, + _preview_table(FinetuneState.hub_dataset_preview), + rx.fragment(), + ), + spacing="3", + width="100%", + ), + rx.vstack( + rx.text("No dataset selected yet.", color=c("text_muted"), font_size="0.86rem"), + rx.text( + 'Go to the Datasets tab and click "Use in Fine-tune" on any dataset.', + color=c("text_muted"), + font_size="0.82rem", + ), + rx.button( + "Browse Datasets →", + on_click=rx.redirect("/datasets"), + color_scheme="blue", + variant="soft", + size="2", + ), + spacing="3", + ), + ), + spacing="2", + ) + ) + + +def _generate_panel() -> rx.Component: + return _card( + rx.vstack( + rx.hstack( + rx.icon("sparkles", size=16, color=c("accent")), + rx.text( + "Generate synthetic training data", font_weight="500", color=c("text_primary") + ), + spacing="2", + align="center", + ), + rx.text( + "TuneOS will create instruction/output pairs tailored to your stated goal using " + "the Self-Instruct method (the same approach used to create Stanford Alpaca).", + font_size="0.82rem", + color=c("text_secondary"), + ), + rx.cond( + FinetuneState.user_intent != "", + rx.box( + rx.text( + f'Goal: "{FinetuneState.user_intent}"', + font_size="0.82rem", + color=c("text_muted"), + font_style="italic", + ), + background=c("bg_input"), + border_radius="6px", + padding="8px 12px", + ), + rx.fragment(), + ), + rx.hstack( + rx.vstack( + _label("Method"), + rx.select.root( + rx.select.trigger(width="200px"), + rx.select.content( + rx.select.item("Self-Instruct (recommended)", value="self_instruct"), + rx.select.item("Few-Shot Expansion", value="few_shot"), + rx.select.item("Template-Based (offline)", value="template"), + ), + value=FinetuneState.generation_method, + on_change=FinetuneState.set_generation_method, + ), + spacing="1", + ), + rx.vstack( + _label("Number of examples"), + rx.select.root( + rx.select.trigger(width="120px"), + rx.select.content( + rx.select.item("50", value="50"), + rx.select.item("100", value="100"), + rx.select.item("250", value="250"), + rx.select.item("500", value="500"), + ), + value=FinetuneState.generation_n.to_string(), + on_change=FinetuneState.set_generation_n, + ), + spacing="1", + ), + spacing="4", + wrap="wrap", + ), + rx.button( + rx.cond( + FinetuneState.is_generating, + rx.hstack(rx.spinner(size="2"), rx.text("Generating..."), spacing="2"), + rx.hstack( + rx.icon("sparkles", size=14), rx.text("Generate examples"), spacing="2" + ), + ), + on_click=FinetuneState.generate_starter_dataset, + disabled=FinetuneState.is_generating, + color_scheme="blue", + size="3", + ), + rx.cond( + FinetuneState.generation_status != "", + rx.text( + FinetuneState.generation_status, font_size="0.82rem", color=c("text_secondary") + ), + rx.fragment(), + ), + rx.cond( + FinetuneState.generated_samples.length() > 0, + _preview_table(FinetuneState.generated_samples, "Generated examples preview"), + rx.fragment(), + ), + spacing="3", + ) + ) + + +def _step3() -> rx.Component: + return rx.vstack( + _section_heading("Add your training data"), + rx.hstack( + _data_mode_btn("upload", "Upload a file", "upload"), + _data_mode_btn("hub_dataset", "HF Hub dataset", "database"), + _data_mode_btn("generate", "Generate with AI", "sparkles"), + spacing="2", + margin_bottom="16px", + ), + rx.match( + FinetuneState.data_source, + ("upload", _upload_panel()), + ("hub_dataset", _hub_dataset_panel()), + ("generate", _generate_panel()), + _upload_panel(), + ), + _nav_buttons( + next_label="Next: Configure →", + next_disabled=~FinetuneState.can_go_to_configure, + ), + spacing="0", + width="100%", + align_items="flex-start", + ) diff --git a/app/components/finetune/step4_configure.py b/app/components/finetune/step4_configure.py new file mode 100644 index 0000000..a3165b7 --- /dev/null +++ b/app/components/finetune/step4_configure.py @@ -0,0 +1,309 @@ +"""Fine-tune wizard — Step 4: Hyperparameter configuration.""" + +from __future__ import annotations + +import reflex as rx + +from app.components.finetune.shared import _card, _label, _nav_buttons, _section_heading +from app.state.finetune_state import FinetuneState +from app.styles import c + +_LR_PRESETS = [ + ("1e-4", "Slow & careful"), + ("2e-4", "Balanced (recommended)"), + ("5e-4", "Fast learning"), +] + + +def _step4() -> rx.Component: + return rx.vstack( + rx.hstack( + _section_heading("Training configuration"), + rx.spacer(), + rx.hstack( + rx.text("Simple", font_size="0.82rem", color=c("text_secondary")), + rx.switch( + checked=FinetuneState.ui_mode == "advanced", + on_change=lambda v: FinetuneState.set_ui_mode(rx.cond(v, "advanced", "simple")), + size="2", + ), + rx.text("Advanced", font_size="0.82rem", color=c("text_secondary")), + spacing="2", + align="center", + ), + ), + # Simple mode + _card( + rx.vstack( + rx.grid( + rx.vstack( + _label("Epochs"), + rx.input( + value=FinetuneState.epochs.to_string(), + on_change=FinetuneState.set_epochs, + type="number", + width="100%", + ), + rx.text( + "One full pass through your dataset", + font_size="0.72rem", + color=c("text_muted"), + ), + spacing="1", + ), + rx.vstack( + _label("Learning rate"), + rx.select.root( + rx.select.trigger(width="100%"), + rx.select.content( + *[ + rx.select.item(f"{lr} — {desc}", value=lr) + for lr, desc in _LR_PRESETS + ], + ), + value=FinetuneState.learning_rate, + on_change=FinetuneState.set_learning_rate, + ), + spacing="1", + ), + rx.vstack( + _label("Technique"), + rx.text( + FinetuneState.technique_label, + font_size="0.88rem", + font_weight="500", + color=c("accent"), + ), + rx.text("Change in Step 1", font_size="0.72rem", color=c("text_muted")), + spacing="1", + ), + columns="3", + spacing="4", + width="100%", + ), + spacing="0", + ) + ), + # Advanced mode + rx.cond( + FinetuneState.ui_mode == "advanced", + _card( + rx.vstack( + rx.text( + "Advanced hyperparameters", + font_size="0.88rem", + font_weight="600", + color=c("text_primary"), + margin_bottom="12px", + ), + rx.grid( + rx.vstack( + _label("LoRA rank (r)"), + rx.slider( + min=4, + max=128, + step=4, + default_value=[FinetuneState.lora_r], + on_value_commit=FinetuneState.set_lora_r, + ), + rx.text( + FinetuneState.lora_r, font_size="0.82rem", color=c("text_secondary") + ), + spacing="1", + ), + rx.vstack( + _label("LoRA alpha"), + rx.input( + value=FinetuneState.lora_alpha.to_string(), + on_change=FinetuneState.set_lora_alpha, + type="number", + width="100%", + ), + spacing="1", + ), + rx.vstack( + _label("LoRA dropout"), + rx.slider( + min=0.0, + max=0.3, + step=0.01, + default_value=[FinetuneState.lora_dropout], + on_value_commit=FinetuneState.set_lora_dropout, + ), + rx.text( + FinetuneState.lora_dropout, + font_size="0.82rem", + color=c("text_secondary"), + ), + spacing="1", + ), + rx.vstack( + _label("Batch size"), + rx.select.root( + rx.select.trigger(width="100%"), + rx.select.content( + *[ + rx.select.item(str(v), value=str(v)) + for v in [1, 2, 4, 8, 16] + ], + ), + value=FinetuneState.batch_size.to_string(), + on_change=FinetuneState.set_batch_size, + ), + spacing="1", + ), + rx.vstack( + _label("Max sequence length"), + rx.select.root( + rx.select.trigger(width="100%"), + rx.select.content( + *[ + rx.select.item(str(v), value=str(v)) + for v in [128, 256, 512, 1024, 2048] + ], + ), + value=FinetuneState.max_seq_length.to_string(), + on_change=FinetuneState.set_max_seq_length, + ), + spacing="1", + ), + rx.vstack( + _label("Gradient accumulation steps"), + rx.select.root( + rx.select.trigger(width="100%"), + rx.select.content( + *[ + rx.select.item(str(v), value=str(v)) + for v in [1, 2, 4, 8, 16] + ], + ), + value=FinetuneState.gradient_accumulation_steps.to_string(), + on_change=FinetuneState.set_gradient_accumulation_steps, + ), + spacing="1", + ), + rx.vstack( + _label("LR scheduler"), + rx.select.root( + rx.select.trigger(width="100%"), + rx.select.content( + *[ + rx.select.item(v, value=v) + for v in [ + "cosine", + "linear", + "constant", + "cosine_with_restarts", + ] + ], + ), + value=FinetuneState.lr_scheduler, + on_change=FinetuneState.set_lr_scheduler, + ), + spacing="1", + ), + rx.vstack( + _label("BF16 mode (A100/H100 only)"), + rx.switch( + checked=FinetuneState.bf16, + on_change=FinetuneState.set_bf16, + size="2", + ), + rx.text( + "Better precision than FP16 on Ampere+ GPUs", + font_size="0.72rem", + color=c("text_muted"), + ), + spacing="1", + ), + rx.vstack( + _label("Experiment name"), + rx.input( + placeholder="my-run-1", + value=FinetuneState.experiment_name, + on_change=FinetuneState.set_experiment_name, + width="100%", + ), + spacing="1", + ), + columns="3", + spacing="4", + width="100%", + ), + spacing="0", + ) + ), + rx.fragment(), + ), + # Run summary + _card( + rx.vstack( + rx.text( + "Run summary", + font_size="0.82rem", + font_weight="600", + color=c("text_secondary"), + margin_bottom="8px", + ), + rx.grid( + rx.vstack( + rx.text("Model", font_size="0.72rem", color=c("text_muted")), + rx.text( + FinetuneState.effective_model_name, + font_size="0.84rem", + font_weight="500", + color=c("text_primary"), + ), + spacing="0", + ), + rx.vstack( + rx.text("Dataset", font_size="0.72rem", color=c("text_muted")), + rx.text( + FinetuneState.dataset_name, + font_size="0.84rem", + font_weight="500", + color=c("text_primary"), + ), + spacing="0", + ), + rx.vstack( + rx.text("Technique", font_size="0.72rem", color=c("text_muted")), + rx.text( + FinetuneState.technique_label, + font_size="0.84rem", + font_weight="500", + color=c("text_primary"), + ), + spacing="0", + ), + rx.vstack( + rx.text("Training", font_size="0.72rem", color=c("text_muted")), + rx.text( + FinetuneState.epochs.to_string() + + " epochs · lr=" + + FinetuneState.learning_rate + + " · batch=" + + FinetuneState.batch_size.to_string(), + font_size="0.82rem", + font_weight="500", + color=c("text_primary"), + ), + spacing="0", + ), + columns="2", + spacing="4", + width="100%", + ), + spacing="0", + ), + background=c("bg_input"), + ), + _nav_buttons( + next_label="Start Training →", + next_disabled=~FinetuneState.can_start_training, + next_event=FinetuneState.start_training, + ), + spacing="4", + width="100%", + align_items="flex-start", + ) diff --git a/app/components/loss_chart.py b/app/components/loss_chart.py index 4d127f0..7b48ec9 100644 --- a/app/components/loss_chart.py +++ b/app/components/loss_chart.py @@ -1,21 +1,43 @@ import reflex as rx -from app.state.job_state import JobState +from app.state.finetune_state import FinetuneState def loss_chart() -> rx.Component: - return rx.recharts.line_chart( + """Dual-series chart: training loss (blue) + learning rate (amber) + eval loss (green dashed).""" + return rx.recharts.composed_chart( rx.recharts.line( data_key="loss", - stroke="#8B5CF6", + stroke="#3b82f6", stroke_width=2, dot=False, + name="Train Loss", + y_axis_id="left", ), - rx.recharts.x_axis(data_key="step", label="Step"), - rx.recharts.y_axis(label="Loss"), - rx.recharts.cartesian_grid(stroke_dasharray="3 3"), + rx.recharts.line( + data_key="eval_loss", + stroke="#22c55e", + stroke_width=2, + stroke_dasharray="5 5", + dot=False, + name="Eval Loss", + y_axis_id="left", + ), + rx.recharts.line( + data_key="learning_rate", + stroke="#f59e0b", + stroke_width=1, + dot=False, + name="Learning Rate", + y_axis_id="right", + ), + rx.recharts.x_axis(data_key="step"), + rx.recharts.y_axis(y_axis_id="left", width=60), + rx.recharts.y_axis(y_axis_id="right", orientation="right", width=70), + rx.recharts.cartesian_grid(stroke_dasharray="3 3", opacity=0.3), + rx.recharts.legend(), rx.recharts.graphing_tooltip(), - data=JobState.loss_history, + data=FinetuneState.loss_history, width="100%", - height=300, + height=280, ) diff --git a/app/pages/datasets.py b/app/pages/datasets.py index 052892d..b2c0d46 100644 --- a/app/pages/datasets.py +++ b/app/pages/datasets.py @@ -1,181 +1,407 @@ -"""TuneOS Datasets discovery page.""" +"""TuneOS Datasets discovery page — search HF Hub, preview, use in fine-tune wizard.""" +from __future__ import annotations + +from typing import Any + +import httpx import reflex as rx -from app.state.app_state import AppState +from app.state.finetune_state import FinetuneState from app.styles import c +API_BASE = "http://localhost:8000" + CATEGORIES = ["All", "NLP", "Code", "Math", "Science", "Chat", "Instruction"] -SAMPLE_DATASETS = [ +# Curated starter cards (shown before search) +STARTER_DATASETS = [ { - "id": "alpaca", - "name": "tatsu-lab/alpaca", + "id": "tatsu-lab/alpaca", "short": "Alpaca", "category": "Instruction", - "rows": "52K rows", - "desc": "Stanford Alpaca instruction-following dataset generated with GPT-3.", + "rows": "52K", + "desc": "Stanford Alpaca instruction-following — the classic starting point.", "license": "CC BY NC 4.0", - "tags": ["instruction", "NLP"], }, { - "id": "dolly", - "name": "databricks/databricks-dolly-15k", + "id": "databricks/databricks-dolly-15k", "short": "Dolly 15K", "category": "Instruction", - "rows": "15K rows", + "rows": "15K", "desc": "High-quality human-generated instruction-response pairs.", "license": "CC BY SA 3.0", - "tags": ["instruction", "NLP"], }, { - "id": "openhermes", - "name": "teknium/OpenHermes-2.5", + "id": "teknium/OpenHermes-2.5", "short": "OpenHermes 2.5", "category": "Chat", - "rows": "1M rows", + "rows": "1M", "desc": "Large synthetic chat dataset for instruction tuning.", "license": "MIT", - "tags": ["chat", "NLP"], }, { - "id": "code_alpaca", - "name": "sahil2801/CodeAlpaca-20k", + "id": "sahil2801/CodeAlpaca-20k", "short": "CodeAlpaca", "category": "Code", - "rows": "20K rows", - "desc": "Code instruction-following dataset generated from GPT-3.", + "rows": "20K", + "desc": "Code generation instructions in Alpaca format.", "license": "Apache 2.0", - "tags": ["code"], }, { - "id": "math_instruct", - "name": "TIGER-Lab/MathInstruct", + "id": "TIGER-Lab/MathInstruct", "short": "MathInstruct", "category": "Math", - "rows": "262K rows", - "desc": "Math reasoning dataset with chain-of-thought solutions.", + "rows": "262K", + "desc": "Math reasoning with chain-of-thought solutions.", "license": "MIT", - "tags": ["math"], }, { - "id": "sciq", - "name": "allenai/sciq", + "id": "allenai/sciq", "short": "SciQ", "category": "Science", - "rows": "13.7K rows", + "rows": "13.7K", "desc": "Science exam questions with supporting evidence.", "license": "CC BY NC 3.0", - "tags": ["science", "QA"], }, { - "id": "sharegpt", - "name": "anon8231489123/ShareGPT_Vicuna_unfiltered", - "short": "ShareGPT", - "category": "Chat", - "rows": "90K rows", - "desc": "Multi-turn ChatGPT conversations from ShareGPT.", - "license": "Unknown", - "tags": ["chat", "multi-turn"], - }, - { - "id": "python_code", - "name": "iamtarun/python_code_instructions_18k_alpaca", + "id": "iamtarun/python_code_instructions_18k_alpaca", "short": "Python Code 18K", "category": "Code", - "rows": "18K rows", - "desc": "Python code generation instructions in Alpaca format.", + "rows": "18K", + "desc": "Python code generation instructions.", "license": "Apache 2.0", - "tags": ["code", "python"], }, { - "id": "ultrachat", - "name": "stingning/ultrachat", + "id": "stingning/ultrachat", "short": "UltraChat", "category": "Chat", - "rows": "1.5M rows", - "desc": "Large-scale multi-turn chat dataset for fine-tuning.", + "rows": "1.5M", + "desc": "Large-scale multi-turn chat dataset.", "license": "CC BY NC 4.0", - "tags": ["chat", "NLP"], + }, + { + "id": "WizardLM/WizardLM_evol_instruct_70k", + "short": "WizardLM 70K", + "category": "Instruction", + "rows": "70K", + "desc": "Evolved instruction dataset for complex tasks.", + "license": "Apache 2.0", }, ] +class DatasetState(rx.State): + search_query: str = "" + search_results: list[dict[str, Any]] = [] + is_searching: bool = False + selected_category: str = "All" + + # Preview panel + preview_dataset_id: str = "" + preview_columns: list[str] = [] + preview_rows: list[dict[str, Any]] = [] + is_loading_preview: bool = False + preview_error: str = "" + + @rx.event + def set_search_query(self, value: str): + self.search_query = value + + @rx.event + def set_category(self, cat: str): + self.selected_category = cat + + @rx.event(background=True) + async def search_datasets(self): + async with self: + self.is_searching = True + self.search_results = [] + + try: + async with httpx.AsyncClient(timeout=20.0) as client: + resp = await client.get( + f"{API_BASE}/api/datasets/search", + params={"q": self.search_query}, + ) + if resp.status_code == 200: + async with self: + self.search_results = resp.json().get("results", []) + self.is_searching = False + else: + async with self: + self.is_searching = False + except Exception: + async with self: + self.is_searching = False + + @rx.event(background=True) + async def load_preview(self, dataset_id: str): + async with self: + self.preview_dataset_id = dataset_id + self.is_loading_preview = True + self.preview_error = "" + self.preview_rows = [] + self.preview_columns = [] + + try: + async with httpx.AsyncClient(timeout=30.0) as client: + resp = await client.get(f"{API_BASE}/api/datasets/{dataset_id}/preview") + if resp.status_code == 200: + data = resp.json() + async with self: + self.preview_columns = data.get("columns", []) + self.preview_rows = data.get("rows", []) + self.is_loading_preview = False + else: + async with self: + self.preview_error = "Failed to load preview" + self.is_loading_preview = False + except Exception as exc: + async with self: + self.preview_error = str(exc) + self.is_loading_preview = False + + @rx.event + def use_in_finetune(self, dataset_id: str): + return [ + FinetuneState.set_hub_dataset_id(dataset_id), + rx.redirect("/finetune"), + ] + + +# ── Components ──────────────────────────────────────────────────── +def _card(*children, **props) -> rx.Component: + return rx.box( + *children, + background=c("bg_card"), + border="1px solid", + border_color=c("border"), + border_radius="12px", + **props, + ) + + def _dataset_card(ds: dict) -> rx.Component: + is_previewing = DatasetState.preview_dataset_id == ds["id"] return rx.box( rx.vstack( rx.hstack( - rx.badge(ds["category"], variant="soft", color_scheme="blue"), + rx.badge(ds["category"], variant="soft", color_scheme="blue", size="1"), rx.spacer(), - rx.text(ds["rows"], font_size="0.78rem", color=c("text_muted")), + rx.text(ds["rows"] + " rows", font_size="0.74rem", color=c("text_muted")), align="center", width="100%", ), - rx.text( - ds["short"], - font_size="0.95rem", - font_weight="600", - color=c("text_primary"), - ), + rx.text(ds["short"], font_size="0.92rem", font_weight="600", color=c("text_primary")), rx.text( ds["desc"], - font_size="0.83rem", + font_size="0.8rem", color=c("text_secondary"), line_height="1.4", overflow="hidden", - display="-webkit-box", - style={"-webkit-line-clamp": "2", "-webkit-box-orient": "vertical"}, + style={ + "-webkit-line-clamp": "2", + "-webkit-box-orient": "vertical", + "display": "-webkit-box", + }, ), rx.hstack( - rx.text(ds["license"], font_size="0.75rem", color=c("text_muted")), + rx.text(ds["license"], font_size="0.72rem", color=c("text_muted")), + rx.spacer(), + rx.hstack( + rx.button( + "Preview", + size="1", + variant="soft", + color_scheme="gray", + on_click=DatasetState.load_preview(ds["id"]), + ), + rx.button( + "Use in Fine-tune →", + size="1", + variant="solid", + color_scheme="blue", + on_click=DatasetState.use_in_finetune(ds["id"]), + ), + spacing="1", + ), + align="center", + width="100%", + ), + spacing="2", + align_items="flex-start", + width="100%", + ), + padding="16px", + background=rx.cond(is_previewing, c("accent_soft"), c("bg_card")), + border="1px solid", + border_color=rx.cond(is_previewing, c("accent"), c("border")), + border_radius="12px", + _hover={"border_color": c("border_strong")}, + transition="all 0.15s ease", + ) + + +def _search_result_card(ds: dict) -> rx.Component: + return rx.box( + rx.vstack( + rx.hstack( + rx.text(ds["id"], font_size="0.88rem", font_weight="600", color=c("text_primary")), rx.spacer(), rx.button( - "Use", - on_click=AppState.set_hf_model(ds["name"]), + "Use in Fine-tune →", size="1", - variant="outline", - border_radius="999px", - cursor="pointer", - font_size="0.78rem", + variant="solid", + color_scheme="blue", + on_click=DatasetState.use_in_finetune(ds["id"]), ), align="center", width="100%", ), + rx.hstack( + *[ + rx.badge(tag, size="1", variant="soft", color_scheme="gray") + for tag in (ds.get("tags") or [])[:4] + ], + spacing="1", + wrap="wrap", + ), spacing="2", align_items="flex-start", width="100%", ), - padding="16px", + padding="14px", background=c("bg_card"), border="1px solid", border_color=c("border"), - border_radius="12px", - cursor="pointer", - _hover={"border_color": c("border_strong"), "background": c("hover")}, - transition="all 0.15s ease", + border_radius="10px", + ) + + +def _preview_panel() -> rx.Component: + return rx.cond( + DatasetState.preview_dataset_id != "", + rx.box( + rx.vstack( + rx.hstack( + rx.icon("table", size=16, color=c("accent")), + rx.text( + DatasetState.preview_dataset_id, + font_weight="600", + font_size="0.86rem", + color=c("text_primary"), + ), + rx.spacer(), + rx.button( + "Use in Fine-tune →", + size="2", + color_scheme="blue", + on_click=DatasetState.use_in_finetune(DatasetState.preview_dataset_id), + ), + spacing="2", + align="center", + width="100%", + ), + rx.cond( + DatasetState.is_loading_preview, + rx.hstack( + rx.spinner(size="2"), + rx.text("Loading preview...", font_size="0.82rem"), + spacing="2", + ), + rx.fragment(), + ), + rx.cond( + DatasetState.preview_error != "", + rx.callout(DatasetState.preview_error, color_scheme="red", size="1"), + rx.fragment(), + ), + rx.cond( + DatasetState.preview_rows.length() > 0, + rx.vstack( + rx.text( + "Columns: " + DatasetState.preview_columns.join(", "), + font_size="0.74rem", + color=c("text_muted"), + ), + rx.table.root( + rx.table.header( + rx.table.row( + rx.foreach( + DatasetState.preview_columns, + lambda col: rx.table.column_header_cell( + rx.text(col, font_size="0.76rem") + ), + ) + ) + ), + rx.table.body( + rx.foreach( + DatasetState.preview_rows, + lambda row: rx.table.row( + rx.foreach( + DatasetState.preview_columns, + lambda col: rx.table.cell( + rx.text( + row[col].to_string(), + font_size="0.76rem", + overflow="hidden", + text_overflow="ellipsis", + white_space="nowrap", + max_width="220px", + ) + ), + ) + ), + ) + ), + variant="surface", + size="1", + width="100%", + ), + spacing="2", + ), + rx.fragment(), + ), + spacing="3", + ), + background=c("bg_card"), + border="1px solid", + border_color=c("border"), + border_radius="12px", + padding="16px", + margin_top="16px", + ), + rx.fragment(), ) def _category_item(label: str) -> rx.Component: + is_active = DatasetState.selected_category == label return rx.text( label, - font_size="0.88rem", - color=c("text_secondary"), + font_size="0.86rem", + color=rx.cond(is_active, c("accent"), c("text_secondary")), + font_weight=rx.cond(is_active, "600", "400"), padding_x="12px", padding_y="7px", border_radius="8px", cursor="pointer", + background=rx.cond(is_active, c("accent_soft"), "transparent"), + on_click=DatasetState.set_category(label), _hover={"background": c("hover"), "color": c("text_primary")}, ) def datasets_page() -> rx.Component: return rx.hstack( + # Category sidebar rx.box( rx.vstack( rx.text( "Categories", - font_size="0.78rem", + font_size="0.74rem", font_weight="500", color=c("text_muted"), padding_x="12px", @@ -195,46 +421,82 @@ def datasets_page() -> rx.Component: height="100%", padding_x="8px", ), + # Main area rx.box( rx.vstack( + # Header rx.hstack( rx.heading( - "Datasets", font_size="1.4rem", font_weight="600", color=c("text_primary") + "Datasets", font_size="1.3rem", font_weight="600", color=c("text_primary") ), rx.spacer(), - rx.button( - rx.icon("plus", size=15), - rx.text("Generate", font_size="0.88rem"), - variant="solid", - size="2", - background=rx.color_mode_cond(light="#171717", dark="#ededed"), - color=rx.color_mode_cond(light="#ffffff", dark="#171717"), - border_radius="8px", - cursor="pointer", - ), align="center", width="100%", ), - rx.input( - placeholder="Search datasets...", - size="2", + # Search bar + rx.hstack( + rx.input( + placeholder="Search HF Hub datasets — e.g. medical, code, chat...", + value=DatasetState.search_query, + on_change=DatasetState.set_search_query, + size="2", + flex="1", + ), + rx.button( + rx.cond( + DatasetState.is_searching, + rx.spinner(size="2"), + rx.hstack(rx.icon("search", size=14), rx.text("Search"), spacing="2"), + ), + on_click=DatasetState.search_datasets, + disabled=DatasetState.is_searching, + color_scheme="blue", + size="2", + ), + spacing="2", width="100%", - max_width="420px", - background=c("bg_input"), - border="1px solid", - border_color=c("border"), - border_radius="8px", + max_width="560px", ), - rx.grid( - *[_dataset_card(ds) for ds in SAMPLE_DATASETS], - columns="3", - spacing="4", - width="100%", + # Search results (when search has been triggered) + rx.cond( + DatasetState.search_results.length() > 0, + rx.vstack( + rx.text( + f"Search results ({DatasetState.search_results.length()})", + font_size="0.82rem", + color=c("text_muted"), + ), + rx.vstack( + rx.foreach(DatasetState.search_results, _search_result_card), + spacing="2", + width="100%", + ), + spacing="2", + width="100%", + ), + # Starter curated cards (default view) + rx.vstack( + rx.text( + "Curated datasets for fine-tuning", + font_size="0.82rem", + color=c("text_muted"), + ), + rx.grid( + *[_dataset_card(ds) for ds in STARTER_DATASETS], + columns="3", + spacing="3", + width="100%", + ), + spacing="2", + width="100%", + ), ), + # Preview panel (shown below list when a dataset is selected) + _preview_panel(), spacing="5", align_items="flex-start", width="100%", - padding="32px", + padding="28px", ), flex="1", height="100%", diff --git a/app/pages/finetune.py b/app/pages/finetune.py index b08eb98..ee8ebad 100644 --- a/app/pages/finetune.py +++ b/app/pages/finetune.py @@ -1,21 +1,21 @@ -"""TuneOS — Fine-tuning wizard page (/finetune).""" +"""TuneOS — Fine-tuning wizard (/finetune) — 7-step flow.""" from __future__ import annotations import reflex as rx from app.components.loss_chart import loss_chart +from app.state.experiment_state import ExperimentState from app.state.finetune_state import FinetuneState -from app.state.job_state import JobState from app.styles import c -# ── Supported models ───────────────────────────────────────────── +# ── Preset models ───────────────────────────────────────────────── _MODELS = [ { "id": "mistralai/Mistral-7B-v0.1", "name": "Mistral 7B", "size": "7B params", - "notes": "Primary target, well-tested with QLoRA", + "notes": "Well-tested with QLoRA, great all-rounder", "token_required": False, }, { @@ -39,16 +39,46 @@ "notes": "Good for low-VRAM environments", "token_required": False, }, + { + "id": "EleutherAI/pythia-410m", + "name": "Pythia 410M", + "size": "410M params", + "notes": "Tiny model — great for testing pipelines fast", + "token_required": False, + }, + { + "id": "bigcode/starcoder2-3b", + "name": "StarCoder2 3B", + "size": "3B params", + "notes": "Excellent for code generation tasks", + "token_required": False, + }, +] + +_STEP_LABELS = ["Model", "Intent", "Data", "Configure", "Train", "Results", "Deploy"] + +_INTENT_IDEAS = [ + "Health chatbot for diabetes patients", + "Python code review assistant", + "Customer support for SaaS products", + "Legal document summarizer", + "Recipe recommendation assistant", + "Scientific paper Q&A bot", + "SQL query generator", + "Children's education tutor", +] + +_LR_PRESETS = [ + ("1e-4", "Slow & careful"), + ("2e-4", "Balanced (recommended)"), + ("5e-4", "Fast learning"), ] +_GGUF_QUANTS = ["Q4_K_M", "Q5_K_M", "Q8_0", "F16"] + # ── Shared helpers ──────────────────────────────────────────────── -def _card( - *children, - padding: str = "20px", - width: str = "100%", - **props, -) -> rx.Component: +def _card(*children, padding: str = "20px", width: str = "100%", **props) -> rx.Component: return rx.box( *children, background=c("bg_card"), @@ -63,28 +93,58 @@ def _card( def _label(text: str) -> rx.Component: return rx.text( - text, - font_size="0.82rem", - font_weight="500", - color=c("text_secondary"), - margin_bottom="6px", + text, font_size="0.8rem", font_weight="500", color=c("text_secondary"), margin_bottom="6px" ) def _section_heading(text: str) -> rx.Component: return rx.text( - text, - font_size="1.05rem", - font_weight="600", - color=c("text_primary"), - margin_bottom="16px", + text, font_size="1.05rem", font_weight="600", color=c("text_primary"), margin_bottom="16px" + ) + + +def _nav_buttons( + back_label: str = "← Back", + next_label: str = "Next →", + next_disabled: bool = False, + next_event=None, + show_back: bool = True, +) -> rx.Component: + return rx.hstack( + rx.button( + back_label, + on_click=FinetuneState.prev_step, + variant="soft", + color_scheme="gray", + size="2", + ) + if show_back + else rx.fragment(), + rx.spacer(), + rx.button( + next_label, + on_click=next_event or FinetuneState.next_step, + disabled=next_disabled, + size="3", + color_scheme="blue", + ), + width="100%", + padding_top="16px", ) -# ── Progress bar ───────────────────────────────────────────────── -_STEP_LABELS = ["Model", "Dataset", "Configure", "Train", "Results"] +def _badge_status(status: str) -> rx.Component: + color = rx.match( + status, + ("running", "blue"), + ("done", "green"), + ("failed", "red"), + "gray", + ) + return rx.badge(status.upper(), color_scheme=color, size="2") +# ── Progress bar ────────────────────────────────────────────────── def _step_dot(index: int) -> rx.Component: step_num = index + 1 is_done = FinetuneState.current_step > step_num @@ -96,30 +156,26 @@ def _step_dot(index: int) -> rx.Component: rx.icon("check", size=12, color="white"), rx.text( str(step_num), - font_size="0.75rem", + font_size="0.72rem", font_weight="600", color=rx.cond(is_active, "white", c("text_muted")), ), ), - width="28px", - height="28px", + width="26px", + height="26px", border_radius="50%", background=rx.cond( - is_done, - c("success"), - rx.cond(is_active, c("accent"), c("bg_input")), + is_done, c("success"), rx.cond(is_active, c("accent"), c("bg_input")) ), border="2px solid", - border_color=rx.cond( - is_active | is_done, c("accent"), c("border") - ), + border_color=rx.cond(is_active | is_done, c("accent"), c("border")), display="flex", align_items="center", justify_content="center", ), rx.text( _STEP_LABELS[index], - font_size="0.75rem", + font_size="0.68rem", color=rx.cond(is_active, c("text_primary"), c("text_muted")), font_weight=rx.cond(is_active, "500", "400"), ), @@ -139,7 +195,7 @@ def _progress_bar() -> rx.Component: background=rx.cond( FinetuneState.current_step > i + 1, c("accent"), c("border") ), - min_width="40px", + min_width="20px", ) if i < len(_STEP_LABELS) - 1 else rx.fragment(), @@ -150,25 +206,20 @@ def _progress_bar() -> rx.Component: for i in range(len(_STEP_LABELS)) ], width="100%", - max_width="600px", + max_width="680px", align="center", justify="center", margin_bottom="32px", ) -# ── Step 1: Model + Technique ───────────────────────────────────── +# ── Step 1: Model Source ────────────────────────────────────────── def _model_card(m: dict) -> rx.Component: is_selected = FinetuneState.selected_model_id == m["id"] return rx.box( rx.vstack( rx.hstack( - rx.text( - m["name"], - font_size="0.95rem", - font_weight="600", - color=c("text_primary"), - ), + rx.text(m["name"], font_size="0.92rem", font_weight="600", color=c("text_primary")), rx.cond( m["token_required"], rx.badge("HF Token", color_scheme="orange", size="1"), @@ -177,8 +228,8 @@ def _model_card(m: dict) -> rx.Component: justify="between", width="100%", ), - rx.text(m["size"], font_size="0.8rem", color=c("text_secondary")), - rx.text(m["notes"], font_size="0.82rem", color=c("text_muted")), + rx.text(m["size"], font_size="0.78rem", color=c("text_secondary")), + rx.text(m["notes"], font_size="0.78rem", color=c("text_muted")), spacing="1", align_items="flex-start", width="100%", @@ -187,7 +238,7 @@ def _model_card(m: dict) -> rx.Component: border="2px solid", border_color=rx.cond(is_selected, c("accent"), c("border")), border_radius="10px", - padding="16px", + padding="14px", cursor="pointer", width="100%", on_click=FinetuneState.select_model(m["id"], m["name"]), @@ -195,479 +246,901 @@ def _model_card(m: dict) -> rx.Component: ) -def _technique_btn(technique: str, label: str, description: str, coming_soon: bool = False) -> rx.Component: - is_active = FinetuneState.selected_technique == technique - return rx.box( - rx.vstack( - rx.hstack( - rx.text( - label, - font_size="0.88rem", - font_weight="500", - color=rx.cond( - coming_soon, - c("text_muted"), - rx.cond(is_active, c("accent"), c("text_primary")), - ), - ), - rx.cond( - coming_soon, - rx.badge("Soon", color_scheme="gray", size="1"), - rx.cond( - is_active, - rx.icon("check-circle", size=14, color=c("accent")), - rx.fragment(), - ), - ), - spacing="2", - align="center", - ), - rx.text( - description, - font_size="0.78rem", - color=c("text_muted"), - ), - spacing="1", - align_items="flex-start", - ), - background=rx.cond( - is_active & ~coming_soon, c("accent_soft"), c("bg_input") - ), - border="1px solid", - border_color=rx.cond( - is_active & ~coming_soon, c("accent"), c("border") - ), - border_radius="8px", - padding="12px 14px", - cursor=rx.cond(coming_soon, "not-allowed", "pointer"), - opacity=rx.cond(coming_soon, "0.5", "1"), - on_click=rx.cond( - coming_soon, - rx.prevent_default, - FinetuneState.select_technique(technique), - ), - flex="1", - min_width="160px", +def _source_tab(source: str, label: str, icon: str) -> rx.Component: + is_active = FinetuneState.model_source == source + return rx.button( + rx.hstack(rx.icon(icon, size=14), rx.text(label), spacing="2", align="center"), + on_click=FinetuneState.set_model_source(source), + variant=rx.cond(is_active, "solid", "soft"), + color_scheme="blue", + size="2", ) def _step1() -> rx.Component: return rx.vstack( - _section_heading("Pick a base model"), + _section_heading("Choose your model"), rx.text( - "Choose the model you want to fine-tune. This is the starting point — your dataset will teach it a new skill.", - font_size="0.88rem", + "Pick from common models, paste any Hugging Face ID, load a local file, " + "or type any model string that Transformers accepts.", + font_size="0.86rem", color=c("text_secondary"), margin_bottom="16px", ), - rx.grid( - *[_model_card(m) for m in _MODELS], - columns="2", - spacing="3", - width="100%", - ), - rx.box(height="24px"), - _section_heading("Choose technique"), - rx.text( - "The technique determines how the model weights are updated during training.", - font_size="0.88rem", - color=c("text_secondary"), - margin_bottom="12px", - ), - rx.flex( - _technique_btn( - "qlora", - "QLoRA", - "Compressed mode. Works on 12 GB+ GPU.", - ), - _technique_btn( - "lora", - "LoRA", - "Float16 mode. Needs ~16 GB GPU.", - ), - _technique_btn( - "full", - "Full Fine-tune", - "All weights updated. Needs 80 GB+ GPU.", - coming_soon=True, - ), - _technique_btn( - "dpo", - "DPO", - "Preference tuning for alignment.", - coming_soon=True, - ), - wrap="wrap", - gap="10px", - width="100%", - ), - rx.box(height="24px"), + # Source switcher rx.hstack( - rx.button( - "Next: Upload Dataset →", - on_click=FinetuneState.next_step, - disabled=~FinetuneState.can_go_to_dataset, - size="3", - color_scheme="blue", - ), - justify="end", - width="100%", - ), - spacing="0", - width="100%", - align_items="flex-start", - ) - - -# ── Step 2: Dataset ─────────────────────────────────────────────── -def _preview_row(row: dict) -> rx.Component: - return rx.table.row( - rx.table.cell( - rx.text( - row["instruction"], - font_size="0.8rem", - color=c("text_primary"), - white_space="nowrap", - overflow="hidden", - text_overflow="ellipsis", - max_width="320px", - ) - ), - rx.table.cell( - rx.text( - row["output"], - font_size="0.8rem", - color=c("text_secondary"), - white_space="nowrap", - overflow="hidden", - text_overflow="ellipsis", - max_width="260px", - ) - ), - ) - - -def _step2() -> rx.Component: - return rx.vstack( - _section_heading("Upload your dataset"), - rx.text( - "Your dataset teaches the model the new skill. Each row needs an 'instruction' and an 'output'.", - font_size="0.88rem", - color=c("text_secondary"), - margin_bottom="16px", + _source_tab("hub", "HF Hub", "globe"), + _source_tab("custom_string", "Any Model ID", "terminal"), + _source_tab("local", "Local File", "folder-open"), + spacing="2", + margin_bottom="20px", ), - # Upload dropzone - _card( + # Hub tab + rx.cond( + FinetuneState.model_source == "hub", rx.vstack( - _label("Upload new file (.jsonl, .json, .csv)"), - rx.upload( + rx.grid(*[_model_card(m) for m in _MODELS], columns="2", spacing="3", width="100%"), + # HF token field for gated models + rx.box(height="12px"), + _card( rx.vstack( - rx.icon("upload", size=24, color=c("text_muted")), - rx.text( - "Drag & drop or click to select", - font_size="0.88rem", - color=c("text_secondary"), - ), - rx.text( - "Required columns: instruction, output", - font_size="0.78rem", - color=c("text_muted"), + _label("HF Token (required for gated models like Llama)"), + rx.input( + placeholder="hf_xxxxxxxxxxxxx", + type="password", + value=FinetuneState.hf_token, + on_change=FinetuneState.set_hf_token, + width="100%", ), - spacing="2", - align="center", - ), - id="finetune_dataset_upload", - multiple=False, - accept={ - "application/json": [".jsonl", ".json"], - "text/csv": [".csv"], - }, - max_files=1, - padding="24px", - border="2px dashed", - border_color=c("border"), - border_radius="8px", - width="100%", - cursor="pointer", - _hover={"border_color": c("accent")}, - ), - rx.button( - rx.cond( - FinetuneState.is_uploading, - rx.hstack(rx.spinner(size="2"), rx.text("Uploading..."), spacing="2"), - rx.text("Upload"), - ), - on_click=FinetuneState.handle_dataset_upload( - rx.upload_files(upload_id="finetune_dataset_upload") - ), - color_scheme="blue", - variant="soft", - size="2", - disabled=FinetuneState.is_uploading, + spacing="1", + ) ), - spacing="3", - align_items="flex-start", width="100%", - ) + spacing="0", + ), + rx.fragment(), ), - # Existing datasets + # Custom string tab rx.cond( - FinetuneState.existing_datasets.length() > 0, + FinetuneState.model_source == "custom_string", _card( rx.vstack( - _label("Or reuse an existing dataset"), - rx.flex( - rx.foreach( - FinetuneState.existing_datasets, - lambda f: rx.box( - rx.text(f, font_size="0.82rem", color=c("text_primary")), - background=rx.cond( - FinetuneState.dataset_filename == f, - c("accent_soft"), - c("bg_input"), - ), - border="1px solid", - border_color=rx.cond( - FinetuneState.dataset_filename == f, - c("accent"), - c("border"), + _label( + "Model ID or path (any string AutoModelForCausalLM.from_pretrained() accepts)" + ), + rx.hstack( + rx.input( + placeholder='e.g. "EleutherAI/gpt-j-6b" or "/local/path/to/model"', + value=FinetuneState.custom_model_str, + on_change=FinetuneState.set_custom_model_str, + flex="1", + ), + rx.button( + rx.cond( + FinetuneState.is_validating_model, + rx.hstack( + rx.spinner(size="1"), rx.text("Checking..."), spacing="2" ), - border_radius="6px", - padding="6px 12px", - cursor="pointer", - on_click=FinetuneState.select_existing_dataset(f), - _hover={"border_color": c("accent")}, + rx.text("Validate"), ), + on_click=FinetuneState.validate_and_select_custom_model, + disabled=FinetuneState.is_validating_model, + color_scheme="blue", + size="2", ), - wrap="wrap", - gap="8px", + spacing="2", + ), + rx.cond( + FinetuneState.model_url_error != "", + rx.callout(FinetuneState.model_url_error, color_scheme="red", size="1"), + rx.fragment(), + ), + rx.cond( + FinetuneState.selected_model_id != "", + rx.callout( + rx.hstack( + rx.icon("check-circle", size=14), + rx.text(f"Model ready: {FinetuneState.selected_model_id}"), + spacing="2", + ), + color_scheme="green", + size="1", + ), + rx.fragment(), + ), + _label("HF Token (for gated or private models)"), + rx.input( + placeholder="hf_xxxxxxxxxxxxx", + type="password", + value=FinetuneState.hf_token, + on_change=FinetuneState.set_hf_token, + width="100%", + ), + rx.text( + "Note: If you skip validation, any errors will appear when training starts.", + font_size="0.75rem", + color=c("text_muted"), ), spacing="2", - width="100%", - align_items="flex-start", ) ), rx.fragment(), ), - # Validation error - rx.cond( - FinetuneState.dataset_error != "", - rx.callout( - FinetuneState.dataset_error, - icon="triangle-alert", - color_scheme="red", - width="100%", - ), - rx.fragment(), - ), - # Preview table + # Local file tab rx.cond( - FinetuneState.dataset_preview.length() > 0, + FinetuneState.model_source == "local", _card( rx.vstack( - _label("Preview (first 5 rows)"), - rx.table.root( - rx.table.header( - rx.table.row( - rx.table.column_header_cell("instruction"), - rx.table.column_header_cell("output"), - ) + _label( + "Upload your model (.safetensors, .bin, .gguf, or .zip of model directory)" + ), + rx.upload( + rx.vstack( + rx.icon("upload", size=28, color=c("text_muted")), + rx.text("Drag & drop or click to upload", color=c("text_secondary")), + rx.text( + "Supports: .safetensors, .bin, .gguf, .zip", + font_size="0.75rem", + color=c("text_muted"), + ), + spacing="2", + align="center", ), - rx.table.body( - rx.foreach(FinetuneState.dataset_preview, _preview_row) + id="model_upload", + border=f"2px dashed {c('border')}", + border_radius="10px", + padding="32px", + width="100%", + cursor="pointer", + on_drop=FinetuneState.handle_local_model_upload( + rx.upload_files(upload_id="model_upload") ), ), + rx.cond( + FinetuneState.local_model_path != "", + rx.callout( + rx.text(f"Loaded: {FinetuneState.local_model_path}"), + color_scheme="green", + size="1", + ), + rx.fragment(), + ), spacing="2", - width="100%", - overflow_x="auto", ) ), rx.fragment(), ), - # Navigation - rx.box(height="8px"), - rx.hstack( - rx.button( - "← Back", - on_click=FinetuneState.prev_step, - variant="soft", - color_scheme="gray", - size="3", - ), - rx.button( - "Next: Configure →", - on_click=FinetuneState.next_step, - disabled=~FinetuneState.can_go_to_configure, - size="3", - color_scheme="blue", - ), - justify="between", - width="100%", - ), - spacing="4", - width="100%", - align_items="flex-start", - ) - - -# ── Step 3: Configure ───────────────────────────────────────────── -def _slider_field(label: str, hint: str, value: rx.Var, min_val: int, max_val: int, on_change) -> rx.Component: - return rx.vstack( - rx.hstack( - _label(label), - rx.text(value, font_size="0.82rem", font_weight="600", color=c("accent")), - justify="between", + # Technique selector (always visible) + rx.box(height="20px"), + _section_heading("Training technique"), + rx.flex( + *[ + rx.box( + rx.vstack( + rx.hstack( + rx.text( + label, + font_size="0.88rem", + font_weight="500", + color=rx.cond( + FinetuneState.selected_technique == tech, + c("accent"), + c("text_primary"), + ), + ), + rx.cond( + FinetuneState.selected_technique == tech, + rx.icon("check-circle", size=14, color=c("accent")), + rx.fragment(), + ), + rx.cond( + coming_soon, + rx.badge("Soon", color_scheme="gray", size="1"), + rx.fragment(), + ), + spacing="2", + align="center", + ), + rx.text(desc, font_size="0.76rem", color=c("text_muted")), + spacing="1", + align_items="flex-start", + ), + background=rx.cond( + FinetuneState.selected_technique == tech, + c("accent_soft"), + c("bg_input"), + ), + border="1px solid", + border_color=rx.cond( + FinetuneState.selected_technique == tech, + c("accent"), + c("border"), + ), + border_radius="8px", + padding="12px 14px", + cursor=rx.cond(coming_soon, "not-allowed", "pointer"), + opacity=rx.cond(coming_soon, "0.5", "1"), + on_click=rx.cond( + coming_soon, rx.prevent_default, FinetuneState.select_technique(tech) + ), + flex="1", + min_width="140px", + ) + for tech, label, desc, coming_soon in [ + ("qlora", "QLoRA", "4-bit compressed. Runs on 12 GB+ GPU. Recommended.", False), + ("lora", "LoRA", "Float16. Needs ~16 GB GPU for 7B models.", False), + ("full", "Full Fine-tune", "All weights updated. Needs 80 GB+ GPU.", True), + ("dpo", "DPO", "Preference tuning for alignment.", True), + ] + ], + wrap="wrap", + gap="10px", width="100%", ), - rx.slider( - default_value=value, - min=min_val, - max=max_val, - step=1, - on_change=on_change, - color_scheme="blue", - width="100%", + _nav_buttons( + next_label="Next: Intent →", + next_disabled=~FinetuneState.can_go_to_intent, + show_back=False, ), - rx.text(hint, font_size="0.75rem", color=c("text_muted")), - spacing="1", + spacing="0", width="100%", align_items="flex-start", ) -def _step3() -> rx.Component: +# ── Step 2: Intent ──────────────────────────────────────────────── +def _step2() -> rx.Component: return rx.vstack( - # Config pill - rx.hstack( - rx.badge( - FinetuneState.selected_model_name, - color_scheme="blue", - size="2", - ), - rx.badge( - FinetuneState.technique_label, - color_scheme="green", - size="2", - ), - spacing="2", - margin_bottom="8px", - ), - _section_heading("Configure training"), + _section_heading("What are you building?"), rx.text( - "The defaults work well for most cases. Adjust if you have specific needs.", - font_size="0.88rem", + "Describe your use-case in plain English. TuneOS uses this to generate starter data, " + "guide the training dashboard, and pre-fill the system prompt for testing.", + font_size="0.86rem", color=c("text_secondary"), margin_bottom="16px", ), - # LoRA params _card( rx.vstack( - _label("LoRA Parameters"), - _slider_field( - "Rank (r)", - "Controls adapter size. 16 is a good default.", - FinetuneState.lora_r, - 4, 64, - FinetuneState.set_lora_r, + _label("Your goal (1–3 sentences)"), + rx.text_area( + placeholder="e.g. A health chatbot that answers questions for people with Type 2 diabetes in simple language.", + value=FinetuneState.user_intent, + on_change=FinetuneState.set_user_intent, + rows="4", + width="100%", + resize="vertical", ), - _slider_field( - "Alpha", - "Scaling factor. Usually set to 2× rank.", - FinetuneState.lora_alpha, - 8, 128, - FinetuneState.set_lora_alpha, + rx.text("Quick ideas:", font_size="0.76rem", color=c("text_muted")), + rx.flex( + *[ + rx.badge( + idea, + cursor="pointer", + on_click=FinetuneState.set_user_intent(idea), + color_scheme="blue", + variant="soft", + size="1", + ) + for idea in _INTENT_IDEAS + ], + wrap="wrap", + gap="6px", ), - spacing="4", - width="100%", + spacing="3", ) ), - # Training params - _card( - rx.vstack( - _label("Training Parameters"), - rx.grid( - rx.vstack( - _label("Epochs"), - rx.input( - value=FinetuneState.epochs.to_string(), - on_change=FinetuneState.set_epochs, - type="number", - min="1", - max="20", - width="100%", - background=c("bg_input"), - border_color=c("border"), - color=c("text_primary"), - ), - rx.text("How many times to train on your full dataset.", font_size="0.75rem", color=c("text_muted")), - spacing="1", - align_items="flex-start", - ), - rx.vstack( - _label("Learning Rate"), - rx.select( - ["1e-4", "2e-4", "5e-4"], - value=FinetuneState.learning_rate, - on_change=FinetuneState.set_learning_rate, - width="100%", + _nav_buttons( + next_label="Next: Add Data →", + next_disabled=FinetuneState.user_intent == "", + ), + spacing="0", + width="100%", + align_items="flex-start", + ) + + +# ── Step 3: Data ────────────────────────────────────────────────── +def _preview_table(rows: list, label: str = "Preview") -> rx.Component: + return rx.vstack( + rx.text(label, font_size="0.78rem", font_weight="500", color=c("text_muted")), + rx.table.root( + rx.table.header( + rx.table.row( + rx.table.column_header_cell("Instruction"), + rx.table.column_header_cell("Output"), + ) + ), + rx.table.body( + rx.foreach( + rows, + lambda row: rx.table.row( + rx.table.cell( + rx.text( + row["instruction"], + font_size="0.78rem", + overflow="hidden", + text_overflow="ellipsis", + white_space="nowrap", + max_width="300px", + ) + ), + rx.table.cell( + rx.text( + row["output"], + font_size="0.78rem", + overflow="hidden", + text_overflow="ellipsis", + white_space="nowrap", + max_width="260px", + ) ), - rx.text("How fast the model adapts. 2e-4 is standard.", font_size="0.75rem", color=c("text_muted")), - spacing="1", - align_items="flex-start", ), + ) + ), + width="100%", + variant="surface", + size="1", + ), + width="100%", + spacing="2", + ) + + +def _data_mode_btn(mode: str, label: str, icon: str) -> rx.Component: + is_active = FinetuneState.data_source == mode + return rx.button( + rx.hstack(rx.icon(icon, size=14), rx.text(label), spacing="2", align="center"), + on_click=FinetuneState.set_data_source(mode), + variant=rx.cond(is_active, "solid", "soft"), + color_scheme="blue", + size="2", + ) + + +def _upload_panel() -> rx.Component: + return _card( + rx.vstack( + _label("Upload CSV, JSONL, or JSON array — any two columns work, you can remap them"), + rx.upload( + rx.vstack( + rx.icon("upload", size=28, color=c("text_muted")), + rx.text("Drag & drop or click to select a file", color=c("text_secondary")), + rx.text(".csv · .jsonl · .json", font_size="0.75rem", color=c("text_muted")), + spacing="2", + align="center", + ), + id="dataset_upload", + border=f"2px dashed {c('border')}", + border_radius="10px", + padding="24px", + width="100%", + cursor="pointer", + on_drop=FinetuneState.handle_dataset_upload( + rx.upload_files(upload_id="dataset_upload") + ), + ), + rx.cond( + FinetuneState.is_uploading, + rx.hstack( + rx.spinner(size="2"), rx.text("Uploading...", font_size="0.84rem"), spacing="2" + ), + rx.fragment(), + ), + rx.cond( + FinetuneState.dataset_error != "", + rx.callout(FinetuneState.dataset_error, color_scheme="red", size="1"), + rx.fragment(), + ), + rx.cond( + FinetuneState.dataset_preview.length() > 0, + _preview_table(FinetuneState.dataset_preview, "File preview (first 5 rows)"), + rx.fragment(), + ), + spacing="3", + ) + ) + + +def _hub_dataset_panel() -> rx.Component: + return _card( + rx.vstack( + rx.cond( + FinetuneState.hub_dataset_id != "", + rx.vstack( + rx.hstack( + rx.icon("database", size=16, color=c("accent")), + rx.text( + FinetuneState.hub_dataset_id, font_weight="500", color=c("text_primary") + ), + spacing="2", + align="center", + ), + rx.hstack( + rx.vstack( + _label("Instruction column"), + rx.input( + value=FinetuneState.hub_dataset_instruction_col, + on_change=FinetuneState.set_hub_instruction_col, + size="2", + width="180px", + ), + spacing="1", + ), + rx.vstack( + _label("Output column"), + rx.input( + value=FinetuneState.hub_dataset_output_col, + on_change=FinetuneState.set_hub_output_col, + size="2", + width="180px", + ), + spacing="1", + ), + rx.button( + "Load preview", + size="2", + color_scheme="blue", + variant="soft", + on_click=FinetuneState.load_hub_dataset_preview, + align_self="flex-end", + ), + spacing="4", + wrap="wrap", + ), + rx.cond( + FinetuneState.is_loading_hub_preview, + rx.hstack( + rx.spinner(size="2"), + rx.text("Loading...", font_size="0.84rem"), + spacing="2", + ), + rx.fragment(), + ), + rx.cond( + FinetuneState.hub_preview_error != "", + rx.callout(FinetuneState.hub_preview_error, color_scheme="red", size="1"), + rx.fragment(), + ), + rx.cond( + FinetuneState.hub_dataset_preview.length() > 0, + _preview_table(FinetuneState.hub_dataset_preview), + rx.fragment(), + ), + spacing="3", + width="100%", + ), + rx.vstack( + rx.text("No dataset selected yet.", color=c("text_muted"), font_size="0.86rem"), + rx.text( + 'Go to the Datasets tab and click "Use in Fine-tune" on any dataset.', + color=c("text_muted"), + font_size="0.82rem", + ), + rx.button( + "Browse Datasets →", + on_click=rx.redirect("/datasets"), + color_scheme="blue", + variant="soft", + size="2", + ), + spacing="3", + ), + ), + spacing="2", + ) + ) + + +def _generate_panel() -> rx.Component: + return _card( + rx.vstack( + rx.hstack( + rx.icon("sparkles", size=16, color=c("accent")), + rx.text( + "Generate synthetic training data", font_weight="500", color=c("text_primary") + ), + spacing="2", + align="center", + ), + rx.text( + "TuneOS will create instruction/output pairs tailored to your stated goal using " + "the Self-Instruct method (the same approach used to create Stanford Alpaca).", + font_size="0.82rem", + color=c("text_secondary"), + ), + rx.cond( + FinetuneState.user_intent != "", + rx.box( + rx.text( + f'Goal: "{FinetuneState.user_intent}"', + font_size="0.82rem", + color=c("text_muted"), + font_style="italic", + ), + background=c("bg_input"), + border_radius="6px", + padding="8px 12px", + ), + rx.fragment(), + ), + rx.hstack( + rx.vstack( + _label("Method"), + rx.select.root( + rx.select.trigger(width="200px"), + rx.select.content( + rx.select.item("Self-Instruct (recommended)", value="self_instruct"), + rx.select.item("Few-Shot Expansion", value="few_shot"), + rx.select.item("Template-Based (offline)", value="template"), + ), + value=FinetuneState.generation_method, + on_change=FinetuneState.set_generation_method, + ), + spacing="1", + ), + rx.vstack( + _label("Number of examples"), + rx.select.root( + rx.select.trigger(width="120px"), + rx.select.content( + rx.select.item("50", value="50"), + rx.select.item("100", value="100"), + rx.select.item("250", value="250"), + rx.select.item("500", value="500"), + ), + value=FinetuneState.generation_n.to_string(), + on_change=FinetuneState.set_generation_n, + ), + spacing="1", + ), + spacing="4", + wrap="wrap", + ), + rx.button( + rx.cond( + FinetuneState.is_generating, + rx.hstack(rx.spinner(size="2"), rx.text("Generating..."), spacing="2"), + rx.hstack( + rx.icon("sparkles", size=14), rx.text("Generate examples"), spacing="2" + ), + ), + on_click=FinetuneState.generate_starter_dataset, + disabled=FinetuneState.is_generating, + color_scheme="blue", + size="3", + ), + rx.cond( + FinetuneState.generation_status != "", + rx.text( + FinetuneState.generation_status, font_size="0.82rem", color=c("text_secondary") + ), + rx.fragment(), + ), + rx.cond( + FinetuneState.generated_samples.length() > 0, + _preview_table(FinetuneState.generated_samples, "Generated examples preview"), + rx.fragment(), + ), + spacing="3", + ) + ) + + +def _step3() -> rx.Component: + return rx.vstack( + _section_heading("Add your training data"), + rx.hstack( + _data_mode_btn("upload", "Upload a file", "upload"), + _data_mode_btn("hub_dataset", "HF Hub dataset", "database"), + _data_mode_btn("generate", "Generate with AI", "sparkles"), + spacing="2", + margin_bottom="16px", + ), + rx.match( + FinetuneState.data_source, + ("upload", _upload_panel()), + ("hub_dataset", _hub_dataset_panel()), + ("generate", _generate_panel()), + _upload_panel(), + ), + _nav_buttons( + next_label="Next: Configure →", + next_disabled=~FinetuneState.can_go_to_configure, + ), + spacing="0", + width="100%", + align_items="flex-start", + ) + + +# ── Step 4: Configure ───────────────────────────────────────────── +def _step4() -> rx.Component: + return rx.vstack( + rx.hstack( + _section_heading("Training configuration"), + rx.spacer(), + rx.hstack( + rx.text("Simple", font_size="0.82rem", color=c("text_secondary")), + rx.switch( + checked=FinetuneState.ui_mode == "advanced", + on_change=lambda v: FinetuneState.set_ui_mode(rx.cond(v, "advanced", "simple")), + size="2", + ), + rx.text("Advanced", font_size="0.82rem", color=c("text_secondary")), + spacing="2", + align="center", + ), + ), + # Simple mode + _card( + rx.vstack( + rx.grid( rx.vstack( - _label("Batch Size"), - rx.select( - ["1", "2", "4", "8"], - value=FinetuneState.batch_size.to_string(), - on_change=FinetuneState.set_batch_size, + _label("Epochs"), + rx.input( + value=FinetuneState.epochs.to_string(), + on_change=FinetuneState.set_epochs, + type="number", width="100%", ), - rx.text("Samples processed per step. Lower = less VRAM.", font_size="0.75rem", color=c("text_muted")), + rx.text( + "One full pass through your dataset", + font_size="0.72rem", + color=c("text_muted"), + ), spacing="1", - align_items="flex-start", ), rx.vstack( - _label("Max Sequence Length"), - rx.select( - ["256", "512", "1024", "2048"], - value=FinetuneState.max_seq_length.to_string(), - on_change=FinetuneState.set_max_seq_length, - width="100%", + _label("Learning rate"), + rx.select.root( + rx.select.trigger(width="100%"), + rx.select.content( + *[ + rx.select.item(f"{lr} — {desc}", value=lr) + for lr, desc in _LR_PRESETS + ], + ), + value=FinetuneState.learning_rate, + on_change=FinetuneState.set_learning_rate, ), - rx.text("Max tokens per sample. 512 fits most use cases.", font_size="0.75rem", color=c("text_muted")), spacing="1", - align_items="flex-start", ), - columns="2", + rx.vstack( + _label("Technique"), + rx.text( + FinetuneState.technique_label, + font_size="0.88rem", + font_weight="500", + color=c("accent"), + ), + rx.text("Change in Step 1", font_size="0.72rem", color=c("text_muted")), + spacing="1", + ), + columns="3", spacing="4", width="100%", ), - spacing="3", - width="100%", + spacing="0", ) ), - # Navigation - rx.hstack( - rx.button("← Back", on_click=FinetuneState.prev_step, variant="soft", color_scheme="gray", size="3"), - rx.button( - rx.cond( - FinetuneState.is_starting, - rx.hstack(rx.spinner(size="2"), rx.text("Starting..."), spacing="2"), - rx.text("Start Training"), + # Advanced mode + rx.cond( + FinetuneState.ui_mode == "advanced", + _card( + rx.vstack( + rx.text( + "Advanced hyperparameters", + font_size="0.88rem", + font_weight="600", + color=c("text_primary"), + margin_bottom="12px", + ), + rx.grid( + rx.vstack( + _label("LoRA rank (r)"), + rx.slider( + min=4, + max=128, + step=4, + default_value=[FinetuneState.lora_r], + on_value_commit=FinetuneState.set_lora_r, + ), + rx.text( + FinetuneState.lora_r, font_size="0.82rem", color=c("text_secondary") + ), + spacing="1", + ), + rx.vstack( + _label("LoRA alpha"), + rx.input( + value=FinetuneState.lora_alpha.to_string(), + on_change=FinetuneState.set_lora_alpha, + type="number", + width="100%", + ), + spacing="1", + ), + rx.vstack( + _label("LoRA dropout"), + rx.slider( + min=0.0, + max=0.3, + step=0.01, + default_value=[FinetuneState.lora_dropout], + on_value_commit=FinetuneState.set_lora_dropout, + ), + rx.text( + FinetuneState.lora_dropout, + font_size="0.82rem", + color=c("text_secondary"), + ), + spacing="1", + ), + rx.vstack( + _label("Batch size"), + rx.select.root( + rx.select.trigger(width="100%"), + rx.select.content( + *[ + rx.select.item(str(v), value=str(v)) + for v in [1, 2, 4, 8, 16] + ], + ), + value=FinetuneState.batch_size.to_string(), + on_change=FinetuneState.set_batch_size, + ), + spacing="1", + ), + rx.vstack( + _label("Max sequence length"), + rx.select.root( + rx.select.trigger(width="100%"), + rx.select.content( + *[ + rx.select.item(str(v), value=str(v)) + for v in [128, 256, 512, 1024, 2048] + ], + ), + value=FinetuneState.max_seq_length.to_string(), + on_change=FinetuneState.set_max_seq_length, + ), + spacing="1", + ), + rx.vstack( + _label("Gradient accumulation steps"), + rx.select.root( + rx.select.trigger(width="100%"), + rx.select.content( + *[ + rx.select.item(str(v), value=str(v)) + for v in [1, 2, 4, 8, 16] + ], + ), + value=FinetuneState.gradient_accumulation_steps.to_string(), + on_change=FinetuneState.set_gradient_accumulation_steps, + ), + spacing="1", + ), + rx.vstack( + _label("LR scheduler"), + rx.select.root( + rx.select.trigger(width="100%"), + rx.select.content( + *[ + rx.select.item(v, value=v) + for v in [ + "cosine", + "linear", + "constant", + "cosine_with_restarts", + ] + ], + ), + value=FinetuneState.lr_scheduler, + on_change=FinetuneState.set_lr_scheduler, + ), + spacing="1", + ), + rx.vstack( + _label("BF16 mode (A100/H100 only)"), + rx.switch( + checked=FinetuneState.bf16, + on_change=FinetuneState.set_bf16, + size="2", + ), + rx.text( + "Better precision than FP16 on Ampere+ GPUs", + font_size="0.72rem", + color=c("text_muted"), + ), + spacing="1", + ), + rx.vstack( + _label("Experiment name"), + rx.input( + placeholder="my-run-1", + value=FinetuneState.experiment_name, + on_change=FinetuneState.set_experiment_name, + width="100%", + ), + spacing="1", + ), + columns="3", + spacing="4", + width="100%", + ), + spacing="0", + ) + ), + rx.fragment(), + ), + # Run summary + _card( + rx.vstack( + rx.text( + "Run summary", + font_size="0.82rem", + font_weight="600", + color=c("text_secondary"), + margin_bottom="8px", ), - on_click=FinetuneState.start_training, - disabled=FinetuneState.is_starting | ~FinetuneState.can_start_training, - size="3", - color_scheme="blue", + rx.grid( + rx.vstack( + rx.text("Model", font_size="0.72rem", color=c("text_muted")), + rx.text( + FinetuneState.effective_model_name, + font_size="0.84rem", + font_weight="500", + color=c("text_primary"), + ), + spacing="0", + ), + rx.vstack( + rx.text("Dataset", font_size="0.72rem", color=c("text_muted")), + rx.text( + FinetuneState.dataset_name, + font_size="0.84rem", + font_weight="500", + color=c("text_primary"), + ), + spacing="0", + ), + rx.vstack( + rx.text("Technique", font_size="0.72rem", color=c("text_muted")), + rx.text( + FinetuneState.technique_label, + font_size="0.84rem", + font_weight="500", + color=c("text_primary"), + ), + spacing="0", + ), + rx.vstack( + rx.text("Training", font_size="0.72rem", color=c("text_muted")), + rx.text( + FinetuneState.epochs.to_string() + + " epochs · lr=" + + FinetuneState.learning_rate + + " · batch=" + + FinetuneState.batch_size.to_string(), + font_size="0.82rem", + font_weight="500", + color=c("text_primary"), + ), + spacing="0", + ), + columns="2", + spacing="4", + width="100%", + ), + spacing="0", ), - justify="between", - width="100%", + background=c("bg_input"), ), - rx.cond( - FinetuneState.start_error != "", - rx.callout(FinetuneState.start_error, icon="triangle-alert", color_scheme="red"), - rx.fragment(), + _nav_buttons( + next_label="Start Training →", + next_disabled=~FinetuneState.can_start_training, + next_event=FinetuneState.start_training, ), spacing="4", width="100%", @@ -675,211 +1148,507 @@ def _step3() -> rx.Component: ) -# ── Step 4: Training progress ───────────────────────────────────── -def _status_badge() -> rx.Component: - return rx.match( - JobState.status, - ("running", rx.badge("Running", color_scheme="yellow", size="2")), - ("done", rx.badge("Complete", color_scheme="green", size="2")), - ("failed", rx.badge("Failed", color_scheme="red", size="2")), - ("cancelled", rx.badge("Cancelled", color_scheme="gray", size="2")), - rx.badge("Queued", color_scheme="blue", size="2"), +# ── Step 5: Training Dashboard ──────────────────────────────────── +def _metric_tile(label: str, value) -> rx.Component: + return _card( + rx.vstack( + rx.text( + label, + font_size="0.68rem", + font_weight="500", + color=c("text_muted"), + letter_spacing="0.05em", + ), + rx.text(value, font_size="1.5rem", font_weight="700", color=c("text_primary")), + spacing="1", + ), + padding="14px", ) -def _log_entry(entry: dict) -> rx.Component: - return rx.text( - rx.el.span(f"[step {entry['step']}]", color=c("text_muted")), - rx.el.span(f" loss={entry['loss']} epoch={entry['epoch']}"), - font_size="0.78rem", - font_family="monospace", - color=c("text_secondary"), - white_space="nowrap", +def _epoch_log_row(entry: dict) -> rx.Component: + return rx.hstack( + rx.text( + f"Epoch {entry['epoch']}", + font_size="0.78rem", + font_weight="500", + color=c("text_primary"), + width="60px", + ), + rx.text( + f"loss {entry['loss_start']} → {entry['loss_end']}", + font_size="0.78rem", + color=c("text_secondary"), + flex="1", + ), + rx.text( + rx.cond(entry["drop_pct"] > 0, f"↓{entry['drop_pct']}%", f"Δ{entry['drop_pct']}%"), + font_size="0.78rem", + color=rx.cond(entry["drop_pct"] > 10, c("success"), c("warning")), + width="60px", + text_align="right", + ), + width="100%", ) -def _step4() -> rx.Component: +def _step5() -> rx.Component: return rx.vstack( + # Header rx.hstack( - _section_heading("Training in progress"), - _status_badge(), - justify="between", - width="100%", + rx.vstack( + rx.text("Training", font_size="1.1rem", font_weight="700", color=c("text_primary")), + rx.text( + FinetuneState.effective_model_name, font_size="0.82rem", color=c("text_muted") + ), + spacing="0", + ), + rx.spacer(), + _badge_status(FinetuneState.training_status), align="center", ), - rx.text( - "Your model is learning. This typically takes 10–60 minutes depending on dataset size and GPU.", - font_size="0.88rem", - color=c("text_secondary"), - margin_bottom="8px", + # Start error + rx.cond( + FinetuneState.start_error != "", + rx.callout(FinetuneState.start_error, color_scheme="red"), + rx.fragment(), + ), + # Metric tiles + rx.grid( + _metric_tile("Epoch", FinetuneState.current_epoch_display), + _metric_tile("Steps", FinetuneState.current_total_steps_display), + _metric_tile("Elapsed", FinetuneState.elapsed_time_display), + _metric_tile("GPU Memory", FinetuneState.gpu_memory_display), + columns="4", + spacing="3", + width="100%", + ), + # Epoch progress bar + rx.vstack( + rx.hstack( + rx.text("Epoch progress", font_size="0.76rem", color=c("text_muted")), + rx.spacer(), + rx.text( + FinetuneState.epoch_progress_pct.to_string() + "%", + font_size="0.76rem", + color=c("text_secondary"), + ), + ), + rx.progress( + value=FinetuneState.epoch_progress_pct, max=100, width="100%", color_scheme="blue" + ), + width="100%", + spacing="1", + ), + # Loss + LR chart + _card(loss_chart()), + # AI Commentary + rx.cond( + FinetuneState.ai_commentary != "", + _card( + rx.hstack( + rx.icon("sparkles", size=16, color=c("accent")), + rx.text( + FinetuneState.ai_commentary, font_size="0.86rem", color=c("text_primary") + ), + spacing="2", + align="flex-start", + ) + ), + rx.fragment(), + ), + # Epoch log + rx.cond( + FinetuneState.epoch_log.length() > 0, + _card( + rx.vstack( + rx.text( + "Epoch log", + font_size="0.78rem", + font_weight="600", + color=c("text_secondary"), + margin_bottom="8px", + ), + rx.foreach(FinetuneState.epoch_log, _epoch_log_row), + spacing="2", + ) + ), + rx.fragment(), + ), + # Status message after done + rx.cond( + FinetuneState.training_status == "done", + rx.vstack( + rx.callout( + rx.hstack( + rx.icon("check-circle", size=16), + rx.text("Training complete! Advancing to results..."), + spacing="2", + ), + color_scheme="green", + ), + rx.button( + "View Results →", + on_click=FinetuneState.go_to_step(6), + color_scheme="green", + size="3", + ), + spacing="3", + ), + rx.fragment(), + ), + rx.cond( + FinetuneState.training_status == "failed", + rx.callout( + rx.vstack( + rx.text("Training failed", font_weight="600"), + rx.text(FinetuneState.error_msg, font_size="0.82rem"), + ), + color_scheme="red", + ), + rx.fragment(), ), - # Loss chart + spacing="4", + width="100%", + align_items="flex-start", + ) + + +# ── Step 6: Results ─────────────────────────────────────────────── +def _step6() -> rx.Component: + return rx.vstack( + _section_heading("Results & Evaluation"), + # Eval metrics _card( rx.vstack( - _label("Loss curve"), + rx.hstack( + rx.text( + "Evaluation metrics", + font_size="0.9rem", + font_weight="600", + color=c("text_primary"), + ), + rx.spacer(), + rx.cond( + FinetuneState.eval_status == "idle", + rx.button( + "Run evaluation", + on_click=FinetuneState.run_eval, + size="2", + color_scheme="blue", + variant="soft", + ), + rx.badge(FinetuneState.eval_status, color_scheme="blue", size="1"), + ), + align="center", + ), rx.cond( - JobState.loss_history.length() > 0, - loss_chart(), + FinetuneState.eval_status == "done", + rx.grid( + rx.vstack( + rx.text("Perplexity", font_size="0.72rem", color=c("text_muted")), + rx.text( + FinetuneState.eval_perplexity.to_string(), + font_size="1.8rem", + font_weight="700", + color=c("accent"), + ), + rx.text("Lower is better", font_size="0.7rem", color=c("text_muted")), + spacing="0", + ), + rx.vstack( + rx.text("What it means", font_size="0.72rem", color=c("text_muted")), + rx.text( + rx.cond( + FinetuneState.eval_perplexity < 10, + "Excellent — model learned the domain well", + rx.cond( + FinetuneState.eval_perplexity < 30, + "Good — decent task alignment", + "Try more epochs or a larger dataset", + ), + ), + font_size="0.84rem", + color=c("text_secondary"), + ), + spacing="1", + ), + columns="2", + spacing="4", + ), + rx.fragment(), + ), + spacing="3", + ) + ), + # Inference tester + _card( + rx.vstack( + rx.text( + "Test your model", + font_size="0.9rem", + font_weight="600", + color=c("text_primary"), + margin_bottom="8px", + ), + rx.cond( + FinetuneState.user_intent != "", + rx.text( + f"System context: {FinetuneState.user_intent}", + font_size="0.76rem", + color=c("text_muted"), + font_style="italic", + ), + rx.fragment(), + ), + # Chat history + rx.cond( + FinetuneState.test_chat_history.length() > 0, rx.box( - rx.text("Waiting for first step...", font_size="0.85rem", color=c("text_muted")), - height="120px", - display="flex", - align_items="center", - justify_content="center", + rx.foreach( + FinetuneState.test_chat_history, + lambda msg: rx.box( + rx.text( + msg["content"], + font_size="0.84rem", + color=rx.cond( + msg["role"] == "user", + c("text_primary"), + c("text_secondary"), + ), + padding="8px 12px", + background=rx.cond( + msg["role"] == "user", + c("accent_soft"), + c("bg_input"), + ), + border_radius="8px", + align_self=rx.cond( + msg["role"] == "user", "flex-end", "flex-start" + ), + max_width="80%", + ), + display="flex", + flex_direction=rx.cond(msg["role"] == "user", "row-reverse", "row"), + width="100%", + margin_bottom="6px", + ), + ), + width="100%", + max_height="300px", + overflow_y="auto", + padding="8px", + border="1px solid", + border_color=c("border"), + border_radius="8px", ), + rx.fragment(), ), - spacing="2", - width="100%", - ) - ), - # Live log stream - _card( - rx.vstack( - _label("Live log"), - rx.box( - rx.cond( - JobState.loss_history.length() > 0, - rx.vstack( - rx.foreach(JobState.loss_history, _log_entry), - spacing="0", - width="100%", - align_items="flex-start", + rx.hstack( + rx.input( + placeholder="Type a test message...", + value=FinetuneState.chat_input, + on_change=FinetuneState.set_chat_input, + on_key_down=lambda e: rx.cond( + e == "Enter", + FinetuneState.send_test_chat, + rx.prevent_default, ), - rx.text("No logs yet...", font_size="0.78rem", color=c("text_muted"), font_family="monospace"), + flex="1", ), - height="160px", - overflow_y="auto", - width="100%", - padding="12px", - background=c("bg_input"), - border_radius="6px", - border="1px solid", - border_color=c("border"), - ), - spacing="2", - width="100%", - ) - ), - # Action row - rx.hstack( - rx.button( - "Stop Training", - on_click=FinetuneState.prev_step, - variant="soft", - color_scheme="red", - size="3", - ), - rx.cond( - JobState.status == "done", - rx.button( - "View Results →", - on_click=[ - FinetuneState.go_to_step(5), - FinetuneState.run_eval, - ], - size="3", - color_scheme="green", + rx.button( + rx.cond( + FinetuneState.chat_loading, + rx.spinner(size="2"), + rx.icon("send", size=16), + ), + on_click=FinetuneState.send_test_chat, + disabled=FinetuneState.chat_loading, + color_scheme="blue", + size="2", + ), + spacing="2", ), rx.cond( - JobState.status == "failed", - rx.callout( - JobState.error_msg, - icon="triangle-alert", - color_scheme="red", - ), + FinetuneState.chat_error != "", + rx.callout(FinetuneState.chat_error, color_scheme="red", size="1"), rx.fragment(), ), + spacing="3", + ) + ), + # Experiment comparison + rx.cond( + ExperimentState.completed_runs.length() > 1, + _card( + rx.vstack( + rx.text( + "Past runs with this setup", + font_size="0.9rem", + font_weight="600", + color=c("text_primary"), + margin_bottom="8px", + ), + rx.table.root( + rx.table.header( + rx.table.row( + rx.table.column_header_cell("Name"), + rx.table.column_header_cell("Model"), + rx.table.column_header_cell("Epochs"), + rx.table.column_header_cell("Final Loss"), + rx.table.column_header_cell("Perplexity"), + ) + ), + rx.table.body( + rx.foreach( + ExperimentState.completed_runs, + lambda r: rx.table.row( + rx.table.cell(rx.text(r.name, font_size="0.8rem")), + rx.table.cell( + rx.text(r.model_id.split("/")[-1], font_size="0.8rem") + ), + rx.table.cell( + rx.text(r.epochs.to_string(), font_size="0.8rem") + ), + rx.table.cell( + rx.text(r.final_loss.to_string(), font_size="0.8rem") + ), + rx.table.cell( + rx.text(r.perplexity.to_string(), font_size="0.8rem") + ), + ), + ) + ), + variant="surface", + size="1", + width="100%", + ), + spacing="2", + ) ), - justify="between", - width="100%", - align="center", + rx.fragment(), ), + _nav_buttons(next_label="Next: Deploy →"), spacing="4", width="100%", align_items="flex-start", ) -# ── Step 5: Results ─────────────────────────────────────────────── -def _result_card(title: str, icon: str, *children) -> rx.Component: - return _card( +# ── Step 7: Deploy ──────────────────────────────────────────────── +def _deploy_target_row( + target_key: str, + label: str, + description: str, + icon: str, + is_checked, +) -> rx.Component: + return rx.hstack( + rx.checkbox( + checked=is_checked, + on_change=lambda _: FinetuneState.toggle_deploy_target(target_key), + size="2", + ), rx.vstack( - rx.hstack( - rx.icon(icon, size=18, color=c("accent")), - rx.text(title, font_size="0.95rem", font_weight="600", color=c("text_primary")), - spacing="2", - align="center", - ), - rx.divider(margin_y="10px"), - *children, - spacing="3", - width="100%", - align_items="flex-start", - ) + rx.text(label, font_weight="500", font_size="0.88rem", color=c("text_primary")), + rx.text(description, font_size="0.76rem", color=c("text_muted")), + spacing="0", + ), + spacing="3", + align="flex-start", + padding="10px 0", + border_bottom="1px solid", + border_color=c("border"), + width="100%", ) -def _step5() -> rx.Component: +def _step7() -> rx.Component: return rx.vstack( - rx.hstack( - rx.icon("party-popper", size=22, color=c("success")), - _section_heading("Training complete!"), - spacing="2", - align="center", - margin_bottom="4px", - ), + _section_heading("Deploy your model"), rx.text( - "Your fine-tuned adapter is ready. Choose what to do with it below.", - font_size="0.88rem", + "Choose how you want to export or share your fine-tuned adapter.", + font_size="0.86rem", color=c("text_secondary"), margin_bottom="16px", ), - rx.grid( - # Card 1: Download - _result_card( - "Download Adapter", - "download", - rx.text( - "Download the LoRA adapter weights as a zip file (~50 MB). Use them locally with the base model.", - font_size="0.82rem", - color=c("text_secondary"), + # Target selector + _card( + rx.vstack( + _deploy_target_row( + "adapter", + "Download adapter", + "Zip the LoRA adapter files (~100 MB) — works with PEFT/Transformers", + "download", + FinetuneState.deploy_adapter, ), - rx.button( - rx.hstack(rx.icon("download", size=16), rx.text("Download adapter.zip"), spacing="2"), - on_click=FinetuneState.download_adapter, - color_scheme="blue", - variant="soft", - size="2", + _deploy_target_row( + "merged", + "Download merged model", + "Merge adapter into base model → full standalone safetensors (~14 GB for 7B)", + "layers", + FinetuneState.deploy_merged, ), - ), - # Card 2: Push to HF Hub - _result_card( - "Push to Hugging Face Hub", - "upload-cloud", - rx.text( - "Publish your adapter to your HF account as a private repo.", - font_size="0.82rem", - color=c("text_secondary"), + _deploy_target_row( + "hub", + "Push to Hugging Face Hub", + "Upload adapter to a public or private HF repository", + "globe", + FinetuneState.deploy_hub, + ), + _deploy_target_row( + "gguf", + "Export as GGUF", + "Convert to GGUF format for use with Ollama or llama.cpp (CPU inference)", + "cpu", + FinetuneState.deploy_gguf, ), + _deploy_target_row( + "github", + "Push to GitHub", + "Commit adapter files to a GitHub repository using Git LFS", + "github", + FinetuneState.deploy_github, + ), + spacing="0", + ) + ), + # HF Hub fields + rx.cond( + FinetuneState.deploy_hub, + _card( rx.vstack( - rx.input( - placeholder="username/my-adapter", - value=FinetuneState.hf_repo_name, - on_change=FinetuneState.set_hf_repo_name, - width="100%", - background=c("bg_input"), - border_color=c("border"), + rx.text( + "Hugging Face Hub", + font_weight="600", + font_size="0.88rem", color=c("text_primary"), ), - rx.input( - placeholder="HF token (hf_...)", - value=FinetuneState.hf_token_input, - on_change=FinetuneState.set_hf_token_input, - type="password", + rx.grid( + rx.vstack( + _label("HF Token"), + rx.input( + type="password", + placeholder="hf_xxxxxxxxxxxxx", + value=FinetuneState.hf_token_input, + on_change=FinetuneState.set_hf_token_input, + width="100%", + ), + spacing="1", + ), + rx.vstack( + _label("Repository name (e.g. myuser/my-chatbot-lora)"), + rx.input( + placeholder="username/repo-name", + value=FinetuneState.hf_repo_name, + on_change=FinetuneState.set_hf_repo_name, + width="100%", + ), + spacing="1", + ), + columns="2", + spacing="3", width="100%", - background=c("bg_input"), - border_color=c("border"), - color=c("text_primary"), ), rx.button( rx.cond( @@ -888,140 +1657,216 @@ def _step5() -> rx.Component: rx.text("Push to Hub"), ), on_click=FinetuneState.push_to_hub, - disabled=(FinetuneState.push_status == "pushing") | (FinetuneState.hf_repo_name == ""), + disabled=FinetuneState.push_status == "pushing", color_scheme="blue", - variant="soft", size="2", ), rx.cond( FinetuneState.push_status == "done", - rx.hstack( - rx.icon("check-circle", size=14, color=c("success")), - rx.text( - FinetuneState.push_repo_url, - font_size="0.78rem", - color=c("success"), + rx.callout( + rx.hstack( + rx.icon("check-circle", size=14), + rx.text(f"Pushed to {FinetuneState.push_repo_url}"), + spacing="2", ), - spacing="1", - ), - rx.cond( - FinetuneState.push_error != "", - rx.text(FinetuneState.push_error, font_size="0.78rem", color=c("error")), - rx.fragment(), + color_scheme="green", + size="1", ), + rx.fragment(), ), - spacing="2", - width="100%", - ), + rx.cond( + FinetuneState.push_error != "", + rx.callout(FinetuneState.push_error, color_scheme="red", size="1"), + rx.fragment(), + ), + spacing="3", + ) ), - # Card 3: Evaluation - _result_card( - "Evaluation Metrics", - "chart-bar", - rx.match( - FinetuneState.eval_status, - ("running", rx.hstack(rx.spinner(size="2"), rx.text("Computing perplexity...", font_size="0.82rem", color=c("text_secondary")), spacing="2")), - ("done", + rx.fragment(), + ), + # GGUF fields + rx.cond( + FinetuneState.deploy_gguf, + _card( + rx.vstack( + rx.text( + "GGUF Export", + font_weight="600", + font_size="0.88rem", + color=c("text_primary"), + ), + rx.callout( + "GGUF export requires the model to be merged first. " + "Enable 'Download merged model' above to trigger the merge step.", + color_scheme="amber", + size="1", + ), + rx.hstack( rx.vstack( - rx.hstack( - rx.text("Perplexity", font_size="0.82rem", color=c("text_secondary")), - rx.text( - FinetuneState.eval_perplexity.to_string(), - font_size="1.4rem", - font_weight="700", - color=c("accent"), + _label("Quantization"), + rx.select.root( + rx.select.trigger(width="160px"), + rx.select.content( + *[rx.select.item(q, value=q) for q in _GGUF_QUANTS], ), - justify="between", - width="100%", - ), - rx.text( - "Lower is better. Under 10 = good domain fit.", - font_size="0.75rem", - color=c("text_muted"), + value=FinetuneState.gguf_quantization, + on_change=FinetuneState.set_gguf_quantization, ), spacing="1", - width="100%", - ) + ), + rx.button( + rx.cond( + FinetuneState.gguf_status == "exporting", + rx.hstack( + rx.spinner(size="2"), rx.text("Exporting..."), spacing="2" + ), + rx.text("Export GGUF"), + ), + on_click=FinetuneState.start_gguf_export, + disabled=FinetuneState.gguf_status == "exporting", + color_scheme="blue", + size="2", + align_self="flex-end", + ), + spacing="3", + align="flex-end", ), - ("not_ready", rx.text("Metrics not available for this job.", font_size="0.82rem", color=c("text_muted"))), - rx.text("Waiting for eval results...", font_size="0.82rem", color=c("text_muted")), - ), + spacing="3", + ) ), - # Card 4: Test in chat - _result_card( - "Test in Chat", - "message-circle", - rx.text( - "Try a prompt to see how your fine-tuned model responds. First call loads the model (~60s).", - font_size="0.82rem", - color=c("text_secondary"), - ), + rx.fragment(), + ), + # GitHub fields + rx.cond( + FinetuneState.deploy_github, + _card( rx.vstack( - rx.text_area( - placeholder="Enter a prompt to test your model...", - value=FinetuneState.chat_input, - on_change=FinetuneState.set_chat_input, - width="100%", - rows="3", - background=c("bg_input"), - border_color=c("border"), + rx.text( + "GitHub Push", + font_weight="600", + font_size="0.88rem", color=c("text_primary"), - resize="vertical", + ), + rx.grid( + rx.vstack( + _label("Repository URL (HTTPS)"), + rx.input( + placeholder="https://github.com/user/repo", + value=FinetuneState.github_repo_url, + on_change=FinetuneState.set_github_repo_url, + width="100%", + ), + spacing="1", + ), + rx.vstack( + _label("GitHub Token (needs repo scope)"), + rx.input( + type="password", + placeholder="ghp_xxxxxxxxxxxxx", + value=FinetuneState.github_token, + on_change=FinetuneState.set_github_token, + width="100%", + ), + spacing="1", + ), + columns="2", + spacing="3", + width="100%", ), rx.button( rx.cond( - FinetuneState.chat_loading, - rx.hstack(rx.spinner(size="2"), rx.text("Generating..."), spacing="2"), - rx.text("Generate"), + FinetuneState.github_push_status == "pushing", + rx.hstack(rx.spinner(size="2"), rx.text("Pushing..."), spacing="2"), + rx.text("Push to GitHub"), ), - on_click=FinetuneState.send_test_chat, - disabled=FinetuneState.chat_loading | (FinetuneState.chat_input == ""), + on_click=FinetuneState.push_to_github, + disabled=FinetuneState.github_push_status == "pushing", color_scheme="blue", - variant="soft", size="2", ), rx.cond( - FinetuneState.chat_response != "", - rx.box( - rx.text( - FinetuneState.chat_response, - font_size="0.85rem", - color=c("text_primary"), - white_space="pre-wrap", - ), - padding="12px", - background=c("bg_input"), - border_radius="6px", - border="1px solid", - border_color=c("border"), - width="100%", - ), - rx.cond( - FinetuneState.chat_error != "", - rx.text(FinetuneState.chat_error, font_size="0.78rem", color=c("error")), - rx.fragment(), + FinetuneState.github_push_status == "done", + rx.callout( + f"Pushed to {FinetuneState.github_repo_url}", + color_scheme="green", + size="1", ), + rx.fragment(), ), - spacing="2", - width="100%", - ), + spacing="3", + ) ), - columns="2", - spacing="4", - width="100%", + rx.fragment(), ), - # Start another - rx.box(height="8px"), + # Quick actions (always visible) rx.hstack( rx.button( - "Fine-tune another model", - on_click=FinetuneState.go_to_step(1), + rx.hstack(rx.icon("download", size=14), rx.text("Download adapter"), spacing="2"), + on_click=FinetuneState.download_adapter, + color_scheme="blue", variant="soft", - color_scheme="gray", - size="3", + size="2", ), - justify="start", - width="100%", + rx.cond( + FinetuneState.deploy_merged, + rx.button( + rx.cond( + FinetuneState.merge_status == "merging", + rx.hstack(rx.spinner(size="2"), rx.text("Merging..."), spacing="2"), + rx.hstack( + rx.icon("layers", size=14), rx.text("Merge & download"), spacing="2" + ), + ), + on_click=FinetuneState.start_merge, + disabled=FinetuneState.merge_status == "merging", + color_scheme="blue", + variant="soft", + size="2", + ), + rx.fragment(), + ), + spacing="3", + wrap="wrap", + ), + # Deploy log + rx.cond( + FinetuneState.deploy_log != "", + _card( + rx.vstack( + rx.text( + "Activity log", + font_size="0.78rem", + font_weight="600", + color=c("text_secondary"), + ), + rx.box( + rx.text( + FinetuneState.deploy_log, + font_size="0.76rem", + color=c("text_secondary"), + font_family="monospace", + white_space="pre-wrap", + ), + background=c("bg_input"), + border_radius="8px", + padding="12px", + max_height="200px", + overflow_y="auto", + width="100%", + ), + spacing="2", + ) + ), + rx.fragment(), + ), + # Start over + rx.box(height="8px"), + rx.button( + "Start a new fine-tune →", + on_click=rx.redirect("/finetune"), + color_scheme="gray", + variant="soft", + size="2", ), spacing="4", width="100%", @@ -1029,7 +1874,7 @@ def _step5() -> rx.Component: ) -# ── Page root ──────────────────────────────────────────────────── +# ── Page root ───────────────────────────────────────────────────── def finetune_page() -> rx.Component: return rx.box( rx.vstack( @@ -1039,15 +1884,16 @@ def finetune_page() -> rx.Component: rx.text( "Fine-tune a model", font_size="1.25rem", - font_weight="600", + font_weight="700", color=c("text_primary"), ), - spacing="3", + spacing="2", align="center", - margin_bottom="8px", + margin_bottom="24px", ), + # Progress bar _progress_bar(), - # Step body + # Step content rx.match( FinetuneState.current_step, (1, _step1()), @@ -1055,15 +1901,16 @@ def finetune_page() -> rx.Component: (3, _step3()), (4, _step4()), (5, _step5()), - rx.text("Unknown step", color=c("text_muted")), + (6, _step6()), + (7, _step7()), + rx.text("Invalid step", color=c("text_muted")), ), spacing="0", width="100%", - max_width="900px", + max_width="760px", align_items="flex-start", + padding="32px 24px", ), - padding="40px", - min_height="100vh", - background=c("bg_primary"), - on_mount=FinetuneState.load_existing_datasets, + width="100%", + on_mount=ExperimentState.load_runs, ) diff --git a/app/pages/results.py b/app/pages/results.py index 2f7e5d2..063e324 100644 --- a/app/pages/results.py +++ b/app/pages/results.py @@ -1,34 +1,5 @@ import reflex as rx -from app.state.job_state import JobState - def results_page() -> rx.Component: - return rx.container( - rx.vstack( - rx.heading("Training Completed", size="8"), - rx.text("Your LoRA adapter is ready."), - rx.card( - rx.vstack( - rx.text("Output Path:"), - rx.code(JobState.output_path), - rx.hstack( - rx.button("Download Adapter", color_scheme="blue"), - rx.button("Merge and Export Full Model", color_scheme="blue"), - spacing="4", - ), - ), - padding="2em", - width="100%", - ), - rx.button( - "Train Another Model", - on_click=rx.redirect("/"), - color_scheme="gray", - margin_top="2em", - ), - spacing="4", - padding="2em", - align_items="center", - ) - ) + return rx.box(on_mount=rx.redirect("/finetune")) diff --git a/app/pages/training.py b/app/pages/training.py index b00a743..2e909a7 100644 --- a/app/pages/training.py +++ b/app/pages/training.py @@ -1,36 +1,5 @@ import reflex as rx -from app.components.loss_chart import loss_chart -from app.state.job_state import JobState - def training_page() -> rx.Component: - return rx.container( - rx.vstack( - rx.heading("Training in progress", size="6"), - rx.badge( - JobState.status, - color_scheme=rx.cond( - JobState.status == "done", - "green", - rx.cond(JobState.status == "failed", "red", "yellow"), - ), - ), - loss_chart(), - rx.cond( - JobState.status == "done", - rx.button( - "View results", - on_click=rx.redirect("/results"), - color_scheme="green", - ), - ), - rx.cond( - JobState.status == "failed", - rx.callout(JobState.error_msg, color_scheme="red"), - ), - spacing="4", - padding="2em", - align_items="center", - ) - ) + return rx.box(on_mount=rx.redirect("/finetune")) diff --git a/app/state/experiment_state.py b/app/state/experiment_state.py new file mode 100644 index 0000000..01ea554 --- /dev/null +++ b/app/state/experiment_state.py @@ -0,0 +1,170 @@ +"""Experiment tracking — persists all fine-tuning runs across sessions in SQLite.""" + +from __future__ import annotations + +import json +import os +import sqlite3 +from typing import Any + +import reflex as rx +from pydantic import BaseModel + +_PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +DB_PATH = os.getenv("EXPERIMENT_DB", os.path.join(_PROJECT_ROOT, "storage", "experiments.db")) + + +def _get_conn() -> sqlite3.Connection: + os.makedirs(os.path.dirname(DB_PATH), exist_ok=True) + conn = sqlite3.connect(DB_PATH) + conn.row_factory = sqlite3.Row + return conn + + +def _init_db(): + with _get_conn() as conn: + conn.execute(""" + CREATE TABLE IF NOT EXISTS runs ( + id TEXT PRIMARY KEY, + name TEXT, + model_id TEXT, + model_source TEXT, + technique TEXT, + epochs INTEGER, + learning_rate TEXT, + lora_r INTEGER, + batch_size INTEGER, + dataset_name TEXT, + user_intent TEXT, + final_loss REAL, + perplexity REAL, + started_at TEXT, + finished_at TEXT, + status TEXT, + output_path TEXT, + loss_history TEXT + ) + """) + + +_init_db() + + +class ExperimentRun(BaseModel): + id: str = "" + name: str = "" + model_id: str = "" + model_source: str = "hub" + technique: str = "qlora" + epochs: int = 3 + learning_rate: str = "2e-4" + lora_r: int = 16 + batch_size: int = 4 + dataset_name: str = "" + user_intent: str = "" + final_loss: float = 0.0 + perplexity: float = 0.0 + started_at: str = "" + finished_at: str = "" + status: str = "unknown" + output_path: str = "" + + +class ExperimentState(rx.State): + runs: list[ExperimentRun] = [] + selected_run_ids: list[str] = [] + is_loading: bool = False + + @rx.var + def selected_runs(self) -> list[ExperimentRun]: + ids = set(self.selected_run_ids) + return [r for r in self.runs if r.id in ids] + + @rx.var + def completed_runs(self) -> list[ExperimentRun]: + return [r for r in self.runs if r.status == "done"] + + @rx.event + def load_runs(self): + try: + with _get_conn() as conn: + rows = conn.execute("SELECT * FROM runs ORDER BY started_at DESC").fetchall() + self.runs = [ + ExperimentRun( + id=r["id"], + name=r["name"] or "", + model_id=r["model_id"] or "", + model_source=r["model_source"] or "hub", + technique=r["technique"] or "qlora", + epochs=r["epochs"] or 3, + learning_rate=r["learning_rate"] or "2e-4", + lora_r=r["lora_r"] or 16, + batch_size=r["batch_size"] or 4, + dataset_name=r["dataset_name"] or "", + user_intent=r["user_intent"] or "", + final_loss=r["final_loss"] or 0.0, + perplexity=r["perplexity"] or 0.0, + started_at=r["started_at"] or "", + finished_at=r["finished_at"] or "", + status=r["status"] or "unknown", + output_path=r["output_path"] or "", + ) + for r in rows + ] + except Exception: + self.runs = [] + + @rx.event + def toggle_run_selection(self, run_id: str): + if run_id in self.selected_run_ids: + self.selected_run_ids = [i for i in self.selected_run_ids if i != run_id] + else: + self.selected_run_ids = [*self.selected_run_ids, run_id] + + @rx.event + def delete_run(self, run_id: str): + try: + with _get_conn() as conn: + conn.execute("DELETE FROM runs WHERE id = ?", (run_id,)) + self.runs = [r for r in self.runs if r.id != run_id] + self.selected_run_ids = [i for i in self.selected_run_ids if i != run_id] + except Exception: + pass + + +def save_experiment_run(run_data: dict[str, Any]): + """Called from FinetuneState._save_experiment_record() — writes to SQLite.""" + try: + _init_db() + with _get_conn() as conn: + conn.execute( + """ + INSERT OR REPLACE INTO runs + (id, name, model_id, model_source, technique, epochs, learning_rate, + lora_r, batch_size, dataset_name, user_intent, final_loss, perplexity, + started_at, finished_at, status, output_path, loss_history) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + ( + run_data.get("id", ""), + run_data.get("name", ""), + run_data.get("model_id", ""), + run_data.get("model_source", "hub"), + run_data.get("technique", "qlora"), + run_data.get("epochs", 3), + run_data.get("learning_rate", "2e-4"), + run_data.get("lora_r", 16), + run_data.get("batch_size", 4), + run_data.get("dataset_name", ""), + run_data.get("user_intent", ""), + run_data.get("final_loss", 0.0), + run_data.get("perplexity", 0.0), + run_data.get("started_at", ""), + run_data.get("finished_at", ""), + run_data.get("status", "unknown"), + run_data.get("output_path", ""), + json.dumps(run_data.get("loss_history", [])), + ), + ) + except Exception: + pass diff --git a/app/state/finetune_state.py b/app/state/finetune_state.py index 0913b31..e673b2d 100644 --- a/app/state/finetune_state.py +++ b/app/state/finetune_state.py @@ -1,106 +1,373 @@ -"""Wizard state for the /finetune dedicated flow.""" +"""Wizard state for the /finetune dedicated flow — single source of truth.""" from __future__ import annotations import json import os +import uuid +from datetime import datetime, timezone from typing import Any import httpx import reflex as rx -from app.state.job_state import JobState +from app.state.experiment_state import ExperimentState, save_experiment_run DATASET_DIR = os.getenv("DATASET_DIR", "./storage/datasets") +REDIS_URL = os.getenv("REDIS_URL", "redis://localhost:6379/0") +API_BASE = os.getenv("API_BASE", "http://localhost:8000") class FinetuneState(rx.State): # ── Step tracking ───────────────────────────────────────────── - current_step: int = 1 # 1–5 + current_step: int = 1 # 1–7 - # ── Step 1: Model + Technique ───────────────────────────────── + # ── Step 1: Model source ────────────────────────────────────── + model_source: str = "hub" # "hub" | "local" | "custom_string" selected_model_id: str = "" selected_model_name: str = "" + custom_model_str: str = "" + local_model_path: str = "" + model_url_error: str = "" + is_validating_model: bool = False + hf_token: str = "" # for gated models selected_technique: str = "qlora" # "qlora" | "lora" - # ── Step 2: Dataset ─────────────────────────────────────────── + # ── Step 2: Intent ──────────────────────────────────────────── + user_intent: str = "" + + # ── Step 3: Data ────────────────────────────────────────────── + data_source: str = "upload" # "upload" | "hub_dataset" | "generate" dataset_path: str = "" dataset_filename: str = "" dataset_preview: list[dict[str, Any]] = [] dataset_error: str = "" is_uploading: bool = False existing_datasets: list[str] = [] - - # ── Step 3: Hyperparameters ─────────────────────────────────── + # Hub dataset + hub_dataset_id: str = "" + hub_dataset_split: str = "train" + hub_dataset_instruction_col: str = "instruction" + hub_dataset_output_col: str = "output" + hub_dataset_preview: list[dict[str, Any]] = [] + hub_dataset_columns: list[str] = [] + is_loading_hub_preview: bool = False + hub_preview_error: str = "" + # Synthetic generation + is_generating: bool = False + generation_method: str = "self_instruct" # "self_instruct" | "few_shot" | "template" + generation_n: int = 50 + generation_status: str = "" + generated_samples: list[dict[str, Any]] = [] + generation_diversity_score: float = 0.0 + seed_examples: list[dict[str, Any]] = [] + + # ── Step 4: Configure ───────────────────────────────────────── + ui_mode: str = "simple" # "simple" | "advanced" lora_r: int = 16 lora_alpha: int = 32 + lora_dropout: float = 0.05 epochs: int = 3 learning_rate: str = "2e-4" batch_size: int = 4 max_seq_length: int = 512 + gradient_accumulation_steps: int = 4 + warmup_ratio: float = 0.03 + lr_scheduler: str = "cosine" + bf16: bool = False + experiment_name: str = "" - # ── Step 4: Training ────────────────────────────────────────── + # ── Step 5: Training dashboard ──────────────────────────────── job_id: str = "" is_starting: bool = False start_error: str = "" training_start_time: str = "" - - # ── Step 5: Results ─────────────────────────────────────────── - hf_token_input: str = "" - hf_repo_name: str = "" - push_status: str = "idle" # idle | pushing | done | error - push_error: str = "" - push_repo_url: str = "" + training_status: str = "idle" # idle | running | done | failed + current_epoch: float = 0.0 + total_steps: int = 0 + elapsed_seconds: int = 0 + gpu_memory_used_gb: float = 0.0 + ai_commentary: str = "" + output_path: str = "" + error_msg: str = "" + loss_history: list[dict[str, Any]] = [] + epoch_log: list[dict[str, Any]] = [] # one entry per completed epoch + + # Experiment tracking + experiment_id: str = "" + + # ── Step 6: Results ─────────────────────────────────────────── eval_perplexity: float = 0.0 + eval_bleu: float = 0.0 eval_status: str = "idle" # idle | running | done | error | not_ready + test_chat_history: list[dict[str, Any]] = [] chat_input: str = "" - chat_response: str = "" chat_loading: bool = False chat_error: str = "" + # ── Step 7: Deploy ──────────────────────────────────────────── + deploy_adapter: bool = True + deploy_merged: bool = False + deploy_hub: bool = False + deploy_gguf: bool = False + deploy_github: bool = False + hf_token_input: str = "" + hf_repo_name: str = "" + push_status: str = "idle" + push_error: str = "" + push_repo_url: str = "" + gguf_quantization: str = "Q4_K_M" + github_repo_url: str = "" + github_token: str = "" + merge_status: str = "idle" + deploy_log: str = "" + gguf_status: str = "idle" + github_push_status: str = "idle" + # ── Computed vars ───────────────────────────────────────────── @rx.var - def can_go_to_dataset(self) -> bool: - return bool(self.selected_model_id) + def can_go_to_intent(self) -> bool: + return bool(self.effective_model_id) + + @rx.var + def can_go_to_data(self) -> bool: + return bool(self.user_intent) @rx.var def can_go_to_configure(self) -> bool: - return bool(self.dataset_path) and not bool(self.dataset_error) + has_data = ( + ( + self.data_source == "upload" + and bool(self.dataset_path) + and not bool(self.dataset_error) + ) + or (self.data_source == "hub_dataset" and bool(self.hub_dataset_id)) + or (self.data_source == "generate" and bool(self.dataset_path)) + ) + return has_data @rx.var def can_start_training(self) -> bool: - return self.can_go_to_configure and bool(self.selected_model_id) + return self.can_go_to_configure and bool(self.effective_model_id) + + @rx.var + def effective_model_id(self) -> str: + if self.model_source == "hub" and self.selected_model_id: + return self.selected_model_id + if self.model_source == "local" and self.local_model_path: + return self.local_model_path + if self.model_source == "custom_string" and self.custom_model_str: + return self.custom_model_str + return "" + + @rx.var + def effective_model_name(self) -> str: + if self.model_source == "hub": + return self.selected_model_name or self.selected_model_id + if self.model_source == "local": + return os.path.basename(self.local_model_path) if self.local_model_path else "" + return self.custom_model_str @rx.var def technique_label(self) -> str: return "QLoRA" if self.selected_technique == "qlora" else "LoRA" @rx.var - def technique_description(self) -> str: - if self.selected_technique == "qlora": - return "Trains a small adapter in compressed mode. Works on 12 GB+ GPU. Recommended." - return "Trains a small adapter in float16. Needs ~16 GB GPU for 7B models." + def elapsed_time_display(self) -> str: + s = self.elapsed_seconds + m, sec = divmod(s, 60) + h, m = divmod(m, 60) + if h: + return f"{h}h {m}m" + if m: + return f"{m}m {sec}s" + return f"{sec}s" + + @rx.var + def current_total_steps_display(self) -> str: + current = len(self.loss_history) + total = self.total_steps + if total: + return f"{current} / {total}" + return str(current) + + @rx.var + def gpu_memory_display(self) -> str: + if self.gpu_memory_used_gb == 0.0: + return "—" + return f"{self.gpu_memory_used_gb:.1f} GB" + + @rx.var + def epoch_progress_pct(self) -> float: + if self.epochs == 0: + return 0.0 + return min(100.0, round((self.current_epoch / self.epochs) * 100, 1)) + + @rx.var + def current_epoch_display(self) -> str: + return f"{min(int(self.current_epoch) + 1, self.epochs)} / {self.epochs}" + + @rx.var + def dataset_name(self) -> str: + if self.data_source == "hub_dataset": + return self.hub_dataset_id + if self.dataset_filename: + return self.dataset_filename + return "Unknown dataset" + + @rx.var + def config_summary_lr(self) -> str: + lr_map = {"1e-4": "Slow & careful", "2e-4": "Balanced", "5e-4": "Fast learning"} + return lr_map.get(self.learning_rate, self.learning_rate) # ── Step 1 events ───────────────────────────────────────────── @rx.event def select_model(self, model_id: str, model_name: str): self.selected_model_id = model_id self.selected_model_name = model_name + self.model_source = "hub" + self.custom_model_str = "" + self.model_url_error = "" @rx.event def select_technique(self, technique: str): self.selected_technique = technique + @rx.event + def set_model_source(self, source: str): + self.model_source = source + self.model_url_error = "" + + @rx.event + def set_custom_model_str(self, value: str): + self.custom_model_str = value + self.model_url_error = "" + + @rx.event + def set_hf_token(self, value: str): + self.hf_token = value + + @rx.event(background=True) + async def validate_and_select_custom_model(self): + model_str = self.custom_model_str.strip() + if not model_str: + return + async with self: + self.is_validating_model = True + self.model_url_error = "" + + try: + async with httpx.AsyncClient(timeout=20.0) as client: + resp = await client.post( + f"{API_BASE}/api/models/validate", + json={"model_id": model_str, "hf_token": self.hf_token}, + ) + data = resp.json() + if data.get("valid"): + async with self: + self.selected_model_id = model_str + self.selected_model_name = data.get("model_type", model_str) + self.is_validating_model = False + else: + async with self: + self.model_url_error = data.get("error", "Model not found or inaccessible.") + self.is_validating_model = False + except Exception as exc: + async with self: + self.model_url_error = f"Validation failed: {exc}" + self.is_validating_model = False + + async def handle_local_model_upload(self, files: list[rx.UploadFile]): + self.is_validating_model = True + self.model_url_error = "" + yield + + if not files: + self.is_validating_model = False + return + + file = files[0] + data = await file.read() + dest_dir = os.path.join("./storage/models", uuid.uuid4().hex) + os.makedirs(dest_dir, exist_ok=True) + safe_name = os.path.basename(file.filename) + dest_path = os.path.join(dest_dir, safe_name) + with open(dest_path, "wb") as f: + f.write(data) + + if safe_name.endswith(".zip"): + import zipfile + + with zipfile.ZipFile(dest_path) as archive: + archive.extractall(dest_dir) + self.local_model_path = dest_dir + else: + self.local_model_path = dest_path + self.model_source = "local" + self.is_validating_model = False + # ── Step 2 events ───────────────────────────────────────────── + @rx.event + def set_user_intent(self, value: str): + self.user_intent = value + + # ── Step 3 events ───────────────────────────────────────────── + @rx.event + def set_data_source(self, source: str): + self.data_source = source + + @rx.event + def set_hub_dataset_id(self, dataset_id: str): + self.hub_dataset_id = dataset_id + self.data_source = "hub_dataset" + + @rx.event + def set_hub_instruction_col(self, value: str): + self.hub_dataset_instruction_col = value + + @rx.event + def set_hub_output_col(self, value: str): + self.hub_dataset_output_col = value + + @rx.event(background=True) + async def load_hub_dataset_preview(self): + if not self.hub_dataset_id: + return + async with self: + self.is_loading_hub_preview = True + self.hub_preview_error = "" + + try: + async with httpx.AsyncClient(timeout=30.0) as client: + resp = await client.get(f"{API_BASE}/api/datasets/{self.hub_dataset_id}/preview") + if resp.status_code == 200: + data = resp.json() + async with self: + self.hub_dataset_columns = data.get("columns", []) + self.hub_dataset_preview = data.get("rows", []) + self.is_loading_hub_preview = False + # Auto-detect instruction/output columns + cols = data.get("columns", []) + if "instruction" in cols: + self.hub_dataset_instruction_col = "instruction" + if "output" in cols: + self.hub_dataset_output_col = "output" + else: + async with self: + self.hub_preview_error = resp.json().get("detail", "Failed to load preview") + self.is_loading_hub_preview = False + except Exception as exc: + async with self: + self.hub_preview_error = str(exc) + self.is_loading_hub_preview = False + @rx.event def load_existing_datasets(self): if not os.path.exists(DATASET_DIR): self.existing_datasets = [] return self.existing_datasets = [ - f - for f in os.listdir(DATASET_DIR) - if os.path.isfile(os.path.join(DATASET_DIR, f)) + f for f in os.listdir(DATASET_DIR) if os.path.isfile(os.path.join(DATASET_DIR, f)) ] @rx.event @@ -122,9 +389,8 @@ async def handle_dataset_upload(self, files: list[rx.UploadFile]): file = files[0] data = await file.read() - os.makedirs(DATASET_DIR, exist_ok=True) - out_path = os.path.join(DATASET_DIR, file.filename) + out_path = os.path.join(DATASET_DIR, os.path.basename(file.filename)) with open(out_path, "wb") as f: f.write(data) @@ -132,17 +398,21 @@ async def handle_dataset_upload(self, files: list[rx.UploadFile]): self.dataset_filename = file.filename self.is_uploading = False - # Refresh existing list yield FinetuneState.load_existing_datasets() self._validate_dataset_at(out_path) def _validate_dataset_at(self, path: str): - """Read up to 10 rows, check for required columns, populate preview.""" import pandas as pd try: if path.endswith(".csv"): df = pd.read_csv(path, nrows=10) + elif path.endswith(".json") and not path.endswith(".jsonl"): + import json as _json + + with open(path) as fh: + raw = _json.load(fh) + df = pd.DataFrame(raw if isinstance(raw, list) else [raw]) else: rows = [] with open(path) as fh: @@ -155,26 +425,86 @@ def _validate_dataset_at(self, path: str): break df = pd.DataFrame(rows) - required = {"instruction", "output"} - missing = required - set(df.columns) - if missing: - self.dataset_error = ( - f"Missing columns: {', '.join(sorted(missing))}. " - "File must contain 'instruction' and 'output' fields." - ) + # Accept any two columns — user can remap, but prefer instruction/output + if len(df.columns) < 2: + self.dataset_error = "Dataset must have at least 2 columns." self.dataset_preview = [] - else: - self.dataset_error = "" - self.dataset_preview = ( - df[["instruction", "output"]] - .head(5) - .fillna("") - .to_dict("records") - ) + return + + # Use instruction/output if present, else first two columns + inst_col = "instruction" if "instruction" in df.columns else df.columns[0] + out_col = "output" if "output" in df.columns else df.columns[1] + + self.dataset_error = "" + self.dataset_preview = ( + df[[inst_col, out_col]] + .head(5) + .fillna("") + .rename(columns={inst_col: "instruction", out_col: "output"}) + .to_dict("records") + ) except Exception as exc: self.dataset_error = f"Could not read file: {exc}" self.dataset_preview = [] + @rx.event + def set_generation_method(self, method: str): + self.generation_method = method + + @rx.event + def set_generation_n(self, value: str): + try: + self.generation_n = int(value) + except ValueError: + pass + + @rx.event(background=True) + async def generate_starter_dataset(self): + async with self: + self.is_generating = True + self.generation_status = "Generating data..." + self.generated_samples = [] + + try: + async with httpx.AsyncClient(timeout=120.0) as client: + resp = await client.post( + f"{API_BASE}/api/datasets/generate", + json={ + "user_intent": self.user_intent, + "method": self.generation_method, + "n_samples": self.generation_n, + "seed_examples": self.seed_examples, + "hf_token": self.hf_token, + }, + ) + if resp.status_code == 200: + data = resp.json() + samples = data.get("samples", []) + stats = data.get("stats", {}) + async with self: + self.dataset_path = data.get("dataset_path", "") + self.dataset_filename = os.path.basename(self.dataset_path) + self.generated_samples = samples[:5] # preview + self.dataset_preview = samples[:5] + n = stats.get("final_count", len(samples)) + div = stats.get("diversity_score", 0) + self.generation_diversity_score = div + self.generation_status = f"Generated {n} examples" + ( + f" · diversity {div:.2f}" if div else "" + ) + self.is_generating = False + self.data_source = "generate" + else: + async with self: + self.generation_status = ( + f"Generation failed: {resp.json().get('detail', 'Unknown error')}" + ) + self.is_generating = False + except Exception as exc: + async with self: + self.generation_status = f"Generation failed: {exc}" + self.is_generating = False + # ── Navigation ──────────────────────────────────────────────── @rx.event def go_to_step(self, step: int): @@ -182,14 +512,18 @@ def go_to_step(self, step: int): @rx.event def next_step(self): - self.current_step += 1 + self.current_step = min(7, self.current_step + 1) @rx.event def prev_step(self): if self.current_step > 1: self.current_step -= 1 - # ── Step 3 setters ──────────────────────────────────────────── + # ── Step 4 setters ──────────────────────────────────────────── + @rx.event + def set_ui_mode(self, mode: str): + self.ui_mode = mode + @rx.event def set_lora_r(self, value: list[float]): self.lora_r = int(value[0]) @@ -198,10 +532,14 @@ def set_lora_r(self, value: list[float]): def set_lora_alpha(self, value: list[float]): self.lora_alpha = int(value[0]) + @rx.event + def set_lora_dropout(self, value: list[float]): + self.lora_dropout = round(value[0], 2) + @rx.event def set_epochs(self, value: str): try: - self.epochs = max(1, min(20, int(value))) + self.epochs = max(1, min(50, int(value))) except ValueError: pass @@ -211,162 +549,311 @@ def set_learning_rate(self, value: str): @rx.event def set_batch_size(self, value: str): - self.batch_size = int(value) + try: + self.batch_size = int(value) + except ValueError: + pass @rx.event def set_max_seq_length(self, value: str): - self.max_seq_length = int(value) + try: + self.max_seq_length = int(value) + except ValueError: + pass - # ── Step 4: Start training ──────────────────────────────────── @rx.event - def start_training(self): + def set_gradient_accumulation_steps(self, value: str): + try: + self.gradient_accumulation_steps = int(value) + except ValueError: + pass + + @rx.event + def set_warmup_ratio(self, value: list[float]): + self.warmup_ratio = round(value[0], 2) + + @rx.event + def set_lr_scheduler(self, value: str): + self.lr_scheduler = value + + @rx.event + def set_bf16(self, value: bool): + self.bf16 = value + + @rx.event + def set_experiment_name(self, value: str): + self.experiment_name = value + + # ── Step 5: Start training ──────────────────────────────────── + @rx.event(background=True) + async def start_training(self): if not self.can_start_training: return - import uuid - from datetime import datetime + exp_id = str(uuid.uuid4()) + exp_name = ( + self.experiment_name + or f"{self.effective_model_name}-{datetime.now().strftime('%m%d-%H%M')}" + ) - job_id = str(uuid.uuid4()) - self.job_id = job_id - self.is_starting = True - self.start_error = "" - self.training_start_time = datetime.utcnow().isoformat() + async with self: + self.is_starting = True + self.start_error = "" + self.experiment_id = exp_id + self.experiment_name = exp_name + self.training_start_time = datetime.now(timezone.utc).isoformat() + self.loss_history = [] + self.epoch_log = [] + self.ai_commentary = "" + self.training_status = "idle" use_4bit = self.selected_technique == "qlora" - model_cfg = { - "model_name": self.selected_model_id, + payload = { + "model_id": self.effective_model_id, + "model_source": self.model_source, + "local_model_path": self.local_model_path, + "hf_token": self.hf_token, + "dataset_path": self.dataset_path if self.data_source != "hub_dataset" else "", + "hub_dataset_id": self.hub_dataset_id if self.data_source == "hub_dataset" else "", + "hub_dataset_split": self.hub_dataset_split, + "instruction_col": self.hub_dataset_instruction_col, + "output_col": self.hub_dataset_output_col, + "technique": self.selected_technique, "use_4bit": use_4bit, - "use_8bit": False, - "trust_remote_code": False, - "max_seq_length": self.max_seq_length, - } - lora_cfg = { - "r": self.lora_r, + "lora_rank": self.lora_r, "lora_alpha": self.lora_alpha, - "lora_dropout": 0.05, - "bias": "none", - "task_type": "CAUSAL_LM", - "target_modules": ["q_proj", "v_proj"], - } - train_cfg = { - "output_dir": os.getenv("OUTPUT_DIR", "./outputs"), - "num_train_epochs": self.epochs, - "per_device_train_batch_size": self.batch_size, - "gradient_accumulation_steps": 4, + "lora_dropout": self.lora_dropout, "learning_rate": float(self.learning_rate), - "fp16": True, - "bf16": False, - "logging_steps": 1, - "save_steps": 100, - "warmup_ratio": 0.03, - "lr_scheduler_type": "cosine", - "optim": "paged_adamw_32bit", - "max_grad_norm": 0.3, + "epochs": self.epochs, + "batch_size": self.batch_size, + "max_seq_length": self.max_seq_length, + "gradient_accumulation_steps": self.gradient_accumulation_steps, + "warmup_ratio": self.warmup_ratio, + "lr_scheduler_type": self.lr_scheduler, + "bf16": self.bf16, + "user_intent": self.user_intent, + "experiment_name": exp_name, + "experiment_id": exp_id, } try: - from workers.train_task import run_finetune - - run_finetune.delay( - job_id=job_id, - model_cfg=model_cfg, - lora_cfg=lora_cfg, - train_cfg=train_cfg, - dataset_path=self.dataset_path, - ) + async with httpx.AsyncClient(timeout=30.0) as client: + resp = await client.post(f"{API_BASE}/api/jobs", json=payload) + if resp.status_code in (200, 201): + job_id = resp.json()["job_id"] + async with self: + self.job_id = job_id + self.is_starting = False + self.current_step = 5 + self.training_status = "running" + await self._poll_job_loop(job_id) + else: + async with self: + self.start_error = resp.json().get("detail", "Failed to start training job") + self.is_starting = False except Exception as exc: - self.start_error = str(exc) - self.is_starting = False - return + async with self: + self.start_error = str(exc) + self.is_starting = False - self.is_starting = False - self.current_step = 4 - return JobState.poll_job(job_id) + async def _poll_job_loop(self, job_id: str): + import redis.asyncio as aioredis - # ── Step 5: Post-training actions ───────────────────────────── - @rx.event - def download_adapter(self): - return rx.redirect(f"/api/jobs/{self.job_id}/download") + r = aioredis.from_url(REDIS_URL) + pubsub = r.pubsub() + await pubsub.subscribe(f"job:{job_id}:progress") - @rx.event(background=True) - async def push_to_hub(self): - async with self: - self.push_status = "pushing" - self.push_error = "" + prev_epoch = 0.0 + epoch_start_loss: float | None = None + + async for message in pubsub.listen(): + if message["type"] != "message": + continue + data = json.loads(message["data"]) + + current_loss = data.get("loss", 0) + current_epoch = data.get("epoch", 0) + + async with self: + self.loss_history.append( + { + "step": data.get("step", 0), + "loss": current_loss, + "epoch": current_epoch, + "learning_rate": data.get("learning_rate", 0), + "eval_loss": data.get("eval_loss"), + } + ) + self.current_epoch = current_epoch + self.total_steps = data.get("total_steps", 0) + self.elapsed_seconds = data.get("elapsed_seconds", 0) + self.gpu_memory_used_gb = data.get("gpu_memory_used_gb", 0.0) + + # Detect epoch boundary and log a summary + if int(current_epoch) > int(prev_epoch): + if epoch_start_loss is not None and self.loss_history: + drop_pct = round( + (epoch_start_loss - current_loss) / max(epoch_start_loss, 1e-9) * 100, 1 + ) + async with self: + self.epoch_log.append( + { + "epoch": int(prev_epoch) + 1, + "loss_start": round(epoch_start_loss, 4), + "loss_end": round(current_loss, 4), + "drop_pct": drop_pct, + "elapsed_seconds": self.elapsed_seconds, + } + ) + # Refresh AI commentary + await self._refresh_commentary(current_loss, drop_pct, int(current_epoch)) + epoch_start_loss = current_loss + elif epoch_start_loss is None: + epoch_start_loss = current_loss + + prev_epoch = current_epoch + + if data.get("status") in ("done", "failed"): + async with self: + self.training_status = data.get("status", "done") + break + # After loop: fetch final status + output_path from the REST API try: - async with httpx.AsyncClient(timeout=120.0) as client: + async with httpx.AsyncClient(timeout=10.0) as client: + resp = await client.get(f"{API_BASE}/api/jobs/{job_id}") + if resp.status_code == 200: + sdata = resp.json() + async with self: + self.output_path = sdata.get("output_path", "") + self.error_msg = sdata.get("error", "") + self.training_status = sdata.get("status", self.training_status) + except Exception: + pass + + await pubsub.unsubscribe() + await r.aclose() + + # Persist experiment record + await self._save_experiment_record() + + # Auto-advance to Results step if training succeeded + if self.training_status == "done": + async with self: + self.current_step = 6 + # Trigger eval + await self._auto_eval() + + async def _refresh_commentary(self, current_loss: float, drop_pct: float, epoch: int): + try: + async with httpx.AsyncClient(timeout=10.0) as client: resp = await client.post( - f"http://localhost:8000/api/jobs/{self.job_id}/push_hub", + f"{API_BASE}/api/jobs/{self.job_id}/commentary", json={ - "repo_name": self.hf_repo_name, - "hf_token": self.hf_token_input, + "epoch": epoch, + "total_epochs": self.epochs, + "loss_drop_pct": drop_pct, + "current_loss": current_loss, + "intent": self.user_intent, }, ) if resp.status_code == 200: - data = resp.json() async with self: - self.push_status = "done" - self.push_repo_url = data.get("repo_url", "") - else: - async with self: - self.push_status = "error" - self.push_error = resp.json().get("detail", "Push failed") - except Exception as exc: - async with self: - self.push_status = "error" - self.push_error = str(exc) - - @rx.event(background=True) - async def run_eval(self): - async with self: - self.eval_status = "running" + self.ai_commentary = resp.json().get("commentary", "") + except Exception: + pass - import asyncio + async def _auto_eval(self): + for _ in range(30): + import asyncio - for _ in range(60): # Poll for up to 60 seconds try: async with httpx.AsyncClient(timeout=10.0) as client: - resp = await client.get( - f"http://localhost:8000/api/jobs/{self.job_id}/eval" - ) + resp = await client.get(f"{API_BASE}/api/jobs/{self.job_id}/eval") data = resp.json() if data.get("status") == "done": - ppl = data.get("perplexity") async with self: self.eval_status = "done" + ppl = data.get("perplexity") self.eval_perplexity = float(ppl) if ppl is not None else 0.0 return - elif data.get("status") == "not_ready": - await asyncio.sleep(2) - else: - break except Exception: - await asyncio.sleep(2) + pass + await asyncio.sleep(3) async with self: self.eval_status = "not_ready" + async def _save_experiment_record(self): + final_loss = self.loss_history[-1]["loss"] if self.loss_history else 0.0 + save_experiment_run( + { + "id": self.experiment_id, + "name": self.experiment_name, + "model_id": self.effective_model_id, + "model_source": self.model_source, + "technique": self.selected_technique, + "epochs": self.epochs, + "learning_rate": self.learning_rate, + "lora_r": self.lora_r, + "batch_size": self.batch_size, + "dataset_name": self.dataset_name, + "user_intent": self.user_intent, + "final_loss": final_loss, + "perplexity": self.eval_perplexity, + "started_at": self.training_start_time, + "finished_at": datetime.now(timezone.utc).isoformat(), + "status": self.training_status, + "output_path": self.output_path, + "loss_history": self.loss_history, + } + ) + async with self: + pass + return ExperimentState.load_runs() + + # ── Step 6: Results ─────────────────────────────────────────── + @rx.event(background=True) + async def run_eval(self): + async with self: + self.eval_status = "running" + await self._auto_eval() + + @rx.event + def set_chat_input(self, value: str): + self.chat_input = value + @rx.event(background=True) async def send_test_chat(self): - prompt = self.chat_input - if not prompt.strip(): + prompt = self.chat_input.strip() + if not prompt: return + system = self.user_intent or "" + full_prompt = f"[System: {system}]\n\n{prompt}" if system else prompt + async with self: self.chat_loading = True - self.chat_response = "" self.chat_error = "" + self.test_chat_history = [ + *self.test_chat_history, + {"role": "user", "content": prompt}, + ] try: async with httpx.AsyncClient(timeout=180.0) as client: resp = await client.post( - f"http://localhost:8000/api/jobs/{self.job_id}/infer", - json={"prompt": prompt, "max_new_tokens": 200, "temperature": 0.7}, + f"{API_BASE}/api/jobs/{self.job_id}/infer", + json={"prompt": full_prompt, "max_new_tokens": 300, "temperature": 0.7}, ) if resp.status_code == 200: + response = resp.json().get("response", "") async with self: - self.chat_response = resp.json().get("response", "") + self.test_chat_history = [ + *self.test_chat_history, + {"role": "assistant", "content": response}, + ] + self.chat_input = "" self.chat_loading = False else: async with self: @@ -377,9 +864,19 @@ async def send_test_chat(self): self.chat_error = str(exc) self.chat_loading = False + # ── Step 7: Deploy ──────────────────────────────────────────── @rx.event - def set_chat_input(self, value: str): - self.chat_input = value + def toggle_deploy_target(self, target: str): + targets = { + "adapter": "deploy_adapter", + "merged": "deploy_merged", + "hub": "deploy_hub", + "gguf": "deploy_gguf", + "github": "deploy_github", + } + if target in targets: + attr = targets[target] + setattr(self, attr, not getattr(self, attr)) @rx.event def set_hf_repo_name(self, value: str): @@ -388,3 +885,117 @@ def set_hf_repo_name(self, value: str): @rx.event def set_hf_token_input(self, value: str): self.hf_token_input = value + + @rx.event + def set_gguf_quantization(self, value: str): + self.gguf_quantization = value + + @rx.event + def set_github_repo_url(self, value: str): + self.github_repo_url = value + + @rx.event + def set_github_token(self, value: str): + self.github_token = value + + @rx.event + def download_adapter(self): + return rx.redirect(f"{API_BASE}/api/jobs/{self.job_id}/download") + + @rx.event(background=True) + async def push_to_hub(self): + async with self: + self.push_status = "pushing" + self.push_error = "" + + try: + async with httpx.AsyncClient(timeout=180.0) as client: + resp = await client.post( + f"{API_BASE}/api/jobs/{self.job_id}/push_hub", + json={"repo_name": self.hf_repo_name, "hf_token": self.hf_token_input}, + ) + if resp.status_code == 200: + async with self: + self.push_status = "done" + self.push_repo_url = resp.json().get("repo_url", "") + else: + async with self: + self.push_status = "error" + self.push_error = resp.json().get("detail", "Push failed") + except Exception as exc: + async with self: + self.push_status = "error" + self.push_error = str(exc) + + @rx.event(background=True) + async def start_merge(self): + async with self: + self.merge_status = "merging" + self.deploy_log = "Starting model merge..." + + try: + async with httpx.AsyncClient(timeout=30.0) as client: + resp = await client.post( + f"{API_BASE}/api/jobs/{self.job_id}/merge", + json={"hf_token": self.hf_token_input}, + ) + if resp.status_code in (200, 202): + async with self: + self.deploy_log += "\nMerge job submitted. This may take 5–15 minutes." + else: + async with self: + self.merge_status = "error" + self.deploy_log += f"\nMerge failed: {resp.json().get('detail', 'Unknown')}" + except Exception as exc: + async with self: + self.merge_status = "error" + self.deploy_log += f"\nMerge error: {exc}" + + @rx.event(background=True) + async def start_gguf_export(self): + async with self: + self.gguf_status = "exporting" + self.deploy_log += "\nStarting GGUF export..." + + try: + async with httpx.AsyncClient(timeout=30.0) as client: + resp = await client.post( + f"{API_BASE}/api/jobs/{self.job_id}/export-gguf", + json={"quant_type": self.gguf_quantization}, + ) + if resp.status_code in (200, 202): + async with self: + self.deploy_log += "\nGGUF export job submitted." + else: + async with self: + self.gguf_status = "error" + self.deploy_log += f"\nGGUF export failed: {resp.json().get('detail', '')}" + except Exception as exc: + async with self: + self.gguf_status = "error" + self.deploy_log += f"\nGGUF export error: {exc}" + + @rx.event(background=True) + async def push_to_github(self): + async with self: + self.github_push_status = "pushing" + self.deploy_log += "\nPushing adapter to GitHub..." + + try: + async with httpx.AsyncClient(timeout=60.0) as client: + resp = await client.post( + f"{API_BASE}/api/jobs/{self.job_id}/push-github", + json={"repo_url": self.github_repo_url, "github_token": self.github_token}, + ) + if resp.status_code == 200: + async with self: + self.github_push_status = "done" + self.deploy_log += f"\nPushed to {self.github_repo_url}" + else: + async with self: + self.github_push_status = "error" + self.deploy_log += f"\nGitHub push failed: {resp.json().get('detail', '')}" + except Exception as exc: + async with self: + self.github_push_status = "error" + self.deploy_log += f"\nGitHub push error: {exc}" diff --git a/app/state/job_state.py b/app/state/job_state.py index ce40c15..072912a 100644 --- a/app/state/job_state.py +++ b/app/state/job_state.py @@ -1,58 +1,14 @@ -import asyncio -import json -import os -from typing import Any +""" +Deprecated — all job state now lives in FinetuneState. +This stub exists only to prevent import errors in legacy pages. +""" -import redis import reflex as rx -REDIS_URL = os.getenv("REDIS_URL", "redis://localhost:6379/0") - class JobState(rx.State): job_id: str = "" - status: str = "idle" # idle | running | done | failed - loss_history: list[dict[str, Any]] = [] + status: str = "idle" + loss_history: list = [] output_path: str = "" error_msg: str = "" - - @rx.event(background=True) - async def poll_job(self, job_id: str): - """ - Subscribe to Redis pub/sub channel for live loss updates. - Updates state after every logged step. - """ - async with self: - self.job_id = job_id - self.status = "running" - self.loss_history = [] - - r = redis.from_url(REDIS_URL) - pubsub = r.pubsub() - pubsub.subscribe(f"job:{job_id}:loss") - status_key = f"job:{job_id}:status" - - while True: - msg = pubsub.get_message(ignore_subscribe_messages=True) - if msg: - data = json.loads(msg["data"]) - async with self: - self.loss_history.append( - { - "step": data["step"], - "loss": data["loss"], - "epoch": data["epoch"], - } - ) - # Check if job finished - status_raw = r.get(status_key) - if status_raw: - status_data = json.loads(status_raw) - if status_data["status"] in ("done", "failed"): - async with self: - self.status = status_data["status"] - self.output_path = status_data.get("output_path", "") - self.error_msg = status_data.get("error", "") - break - - await asyncio.sleep(1) diff --git a/app/state/model_state.py b/app/state/model_state.py index 5dd4bde..b42ee7c 100644 --- a/app/state/model_state.py +++ b/app/state/model_state.py @@ -1,26 +1,10 @@ -import os -import uuid - import reflex as rx -def _get_run_finetune(): - from workers.train_task import run_finetune - - return run_finetune - - -from app.state.job_state import JobState # noqa: E402 - - class ModelState(rx.State): model_name: str = "mistralai/Mistral-7B-v0.1" - - # LoRA parameters lora_r: list[int] = [16] lora_alpha: list[int] = [32] - - # Training parameters epochs: int = 3 learning_rate: str = "2e-4" dataset_path: str = "" @@ -48,56 +32,3 @@ def set_epochs(self, value: str): @rx.event def set_learning_rate(self, value: str): self.learning_rate = value - - @rx.event - def start_training(self): - if not self.dataset_path: - # Need to handle no dataset selected - return rx.window_alert("Please upload and select a dataset first.") - - job_id = str(uuid.uuid4()) - - # Prepare configs to match our backend structure - model_cfg = { - "model_name": self.model_name, - "use_4bit": True, - "use_8bit": False, - "trust_remote_code": False, - "max_seq_length": 512, - } - - lora_cfg = { - "r": self.lora_r[0], - "lora_alpha": self.lora_alpha[0], - "lora_dropout": 0.05, - "bias": "none", - "task_type": "CAUSAL_LM", - "target_modules": ["q_proj", "v_proj"], - } - - train_cfg = { - "output_dir": os.getenv("OUTPUT_DIR", "./outputs"), - "num_train_epochs": self.epochs, - "per_device_train_batch_size": 4, - "gradient_accumulation_steps": 4, - "learning_rate": float(self.learning_rate), - "fp16": True, - "bf16": False, - "logging_steps": 1, # Set lower for demonstration of UI streaming - "save_steps": 100, - "warmup_ratio": 0.03, - "lr_scheduler_type": "cosine", - "optim": "paged_adamw_32bit", - "max_grad_norm": 0.3, - } - - # Kick off Celery task - _get_run_finetune().delay( - job_id=job_id, - model_cfg=model_cfg, - lora_cfg=lora_cfg, - train_cfg=train_cfg, - dataset_path=self.dataset_path, - ) - - return [JobState.poll_job(job_id), rx.redirect("/training")] diff --git a/hf_spaces/entrypoint.sh b/hf_spaces/entrypoint.sh new file mode 100644 index 0000000..0ed8d03 --- /dev/null +++ b/hf_spaces/entrypoint.sh @@ -0,0 +1,27 @@ +#!/bin/bash +# TuneOS HF Spaces entrypoint — starts Redis, Celery worker, FastAPI, Reflex UI, and nginx. +set -e + +echo "[tuneos] Starting Redis..." +redis-server --daemonize yes --loglevel notice + +until redis-cli ping | grep -q PONG; do + echo "[tuneos] Waiting for Redis..." + sleep 1 +done +echo "[tuneos] Redis is up." + +echo "[tuneos] Starting Celery worker..." +celery -A workers.celery_app worker --loglevel=info --concurrency=1 -Q celery & + +echo "[tuneos] Starting Reflex app (UI + API)..." +reflex run --env prod --backend-port 8000 & + +echo "[tuneos] Waiting for Reflex to be ready..." +until curl -sf http://127.0.0.1:8000/api/health > /dev/null 2>&1; do + sleep 2 +done +echo "[tuneos] Reflex is up." + +echo "[tuneos] Starting nginx on port 7860..." +nginx -g "daemon off;" diff --git a/hf_spaces/nginx.conf b/hf_spaces/nginx.conf new file mode 100644 index 0000000..cc1799c --- /dev/null +++ b/hf_spaces/nginx.conf @@ -0,0 +1,22 @@ +server { + listen 7860; + + location / { + proxy_pass http://127.0.0.1:3000; + proxy_http_version 1.1; + proxy_set_header Upgrade $http_upgrade; + proxy_set_header Connection "upgrade"; + proxy_set_header Host $host; + proxy_set_header X-Real-IP $remote_addr; + proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; + proxy_read_timeout 86400; + } + + location /api/ { + proxy_pass http://127.0.0.1:8000; + proxy_http_version 1.1; + proxy_set_header Host $host; + proxy_set_header X-Real-IP $remote_addr; + proxy_read_timeout 120; + } +} diff --git a/poetry.lock b/poetry.lock index 9a0acec..61885e7 100644 --- a/poetry.lock +++ b/poetry.lock @@ -635,12 +635,12 @@ version = "0.4.6" description = "Cross-platform colored terminal text." optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7" -groups = ["main", "desktop"] -markers = "platform_system == \"Windows\"" +groups = ["main", "desktop", "dev"] files = [ {file = "colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6"}, {file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"}, ] +markers = {main = "platform_system == \"Windows\"", desktop = "platform_system == \"Windows\"", dev = "sys_platform == \"win32\""} [[package]] name = "cuda-bindings" @@ -843,7 +843,7 @@ version = "1.3.1" description = "Backport of PEP 654 (exception groups)" optional = false python-versions = ">=3.7" -groups = ["main", "desktop"] +groups = ["main", "desktop", "dev"] markers = "python_version == \"3.10\"" files = [ {file = "exceptiongroup-1.3.1-py3-none-any.whl", hash = "sha256:a7a39a3bd276781e98394987d3a5701d0c4edffb633bb7a5144577f82c773598"}, @@ -1326,6 +1326,18 @@ files = [ [package.extras] all = ["mypy (>=1.11.2)", "pytest (>=8.3.2)", "ruff (>=0.6.2)"] +[[package]] +name = "iniconfig" +version = "2.3.0" +description = "brain-dead simple config-ini parsing" +optional = false +python-versions = ">=3.10" +groups = ["dev"] +files = [ + {file = "iniconfig-2.3.0-py3-none-any.whl", hash = "sha256:f631c04d2c48c52b84d0d0549c99ff3859c98df65b3101406327ecc7d53fbf12"}, + {file = "iniconfig-2.3.0.tar.gz", hash = "sha256:c76315c77db068650d49c5b56314774a7804df16fee4402c1f19d6d15d8c4730"}, +] + [[package]] name = "jinja2" version = "3.1.6" @@ -2224,7 +2236,7 @@ version = "26.2" description = "Core utilities for Python packages" optional = false python-versions = ">=3.8" -groups = ["main", "desktop"] +groups = ["main", "desktop", "dev"] files = [ {file = "packaging-26.2-py3-none-any.whl", hash = "sha256:5fc45236b9446107ff2415ce77c807cee2862cb6fac22b8a73826d0693b0980e"}, {file = "packaging-26.2.tar.gz", hash = "sha256:ff452ff5a3e828ce110190feff1178bb1f2ea2281fa2075aadb987c2fb221661"}, @@ -2475,6 +2487,22 @@ files = [ {file = "platformdirs-4.9.6.tar.gz", hash = "sha256:3bfa75b0ad0db84096ae777218481852c0ebc6c727b3168c1b9e0118e458cf0a"}, ] +[[package]] +name = "pluggy" +version = "1.6.0" +description = "plugin and hook calling mechanisms for python" +optional = false +python-versions = ">=3.9" +groups = ["dev"] +files = [ + {file = "pluggy-1.6.0-py3-none-any.whl", hash = "sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746"}, + {file = "pluggy-1.6.0.tar.gz", hash = "sha256:7dcc130b76258d33b90f61b658791dede3486c3e6bfb003ee5c9bfb396dd22f3"}, +] + +[package.extras] +dev = ["pre-commit", "tox"] +testing = ["coverage", "pytest", "pytest-benchmark"] + [[package]] name = "prompt-toolkit" version = "3.0.52" @@ -2877,7 +2905,7 @@ version = "2.20.0" description = "Pygments is a syntax highlighting package written in Python." optional = false python-versions = ">=3.9" -groups = ["main"] +groups = ["main", "dev"] files = [ {file = "pygments-2.20.0-py3-none-any.whl", hash = "sha256:81a9e26dd42fd28a23a2d169d86d7ac03b46e2f8b59ed4698fb4785f946d0176"}, {file = "pygments-2.20.0.tar.gz", hash = "sha256:6757cd03768053ff99f3039c1a36d6c0aa0b263438fcab17520b30a303a82b5f"}, @@ -3045,6 +3073,30 @@ files = [ {file = "pyqt6_webengine_qt6-6.11.1-py3-none-win_arm64.whl", hash = "sha256:21aaa6c7c9a91076936baa4a1c02dd5a0cbd4c75238d5fd6a216736f654cf89a"}, ] +[[package]] +name = "pytest" +version = "9.0.3" +description = "pytest: simple powerful testing with Python" +optional = false +python-versions = ">=3.10" +groups = ["dev"] +files = [ + {file = "pytest-9.0.3-py3-none-any.whl", hash = "sha256:2c5efc453d45394fdd706ade797c0a81091eccd1d6e4bccfcd476e2b8e0ab5d9"}, + {file = "pytest-9.0.3.tar.gz", hash = "sha256:b86ada508af81d19edeb213c681b1d48246c1a91d304c6c81a427674c17eb91c"}, +] + +[package.dependencies] +colorama = {version = ">=0.4", markers = "sys_platform == \"win32\""} +exceptiongroup = {version = ">=1", markers = "python_version < \"3.11\""} +iniconfig = ">=1.0.1" +packaging = ">=22" +pluggy = ">=1.5,<2" +pygments = ">=2.7.2" +tomli = {version = ">=1", markers = "python_version < \"3.11\""} + +[package.extras] +dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests", "setuptools", "xmlschema"] + [[package]] name = "python-dateutil" version = "2.9.0.post0" @@ -3702,6 +3754,34 @@ pygments = ">=2.13.0,<3.0.0" [package.extras] jupyter = ["ipywidgets (>=7.5.1,<9)"] +[[package]] +name = "ruff" +version = "0.15.15" +description = "An extremely fast Python linter and code formatter, written in Rust." +optional = false +python-versions = ">=3.7" +groups = ["dev"] +files = [ + {file = "ruff-0.15.15-py3-none-linux_armv6l.whl", hash = "sha256:cf93e5388f412e1b108b1f8b34a6e036b70fe8aff89393befad96fe48670311b"}, + {file = "ruff-0.15.15-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:ac5a646d1f6a7dadd5d50842dae2c1f9862ac887ef5d1b1375e02def791fde6e"}, + {file = "ruff-0.15.15-py3-none-macosx_11_0_arm64.whl", hash = "sha256:77d955a431430c66f72dd94e379ad38a16daea3d25094872ac4edf9e797be530"}, + {file = "ruff-0.15.15-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7614ee79c69788cf6cedd568069ade9cecc22a1ad20494efe8d0c9ebb4b622d4"}, + {file = "ruff-0.15.15-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:3cdb1679e06a1f6b47bc384714ae96f6e2fb65ca441eb78c43d2ca554176ce1f"}, + {file = "ruff-0.15.15-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2728b93d7b23a603ea2c0ac6eb73d760bd38ec9de35f35fb41e18f7a3fee7622"}, + {file = "ruff-0.15.15-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:be582fcc0db438902c7792b08d6ddf6c9b9e21addaa10092c2c741cfb09e5a45"}, + {file = "ruff-0.15.15-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7aa77465b8ecaf1a27bea098d696f7fed5e1eccbd10b321b682d6de586ae5627"}, + {file = "ruff-0.15.15-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:48decfa11d740de4889de623be1463308346312f2409a56e24aa280c86162dc4"}, + {file = "ruff-0.15.15-py3-none-manylinux_2_31_riscv64.whl", hash = "sha256:a5015088452ca0081387063649ec67f06d3d1d6b8b936a1f836b5e9657ecd48c"}, + {file = "ruff-0.15.15-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:f5294aab6356c81600fcdea3a62bb1b924dfd5e91767c12318d3f68f86af57cd"}, + {file = "ruff-0.15.15-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:db5bd4d802415cca656dc1616070b725952d6ae95eb5d4831e49fbd94a38f75f"}, + {file = "ruff-0.15.15-py3-none-musllinux_1_2_i686.whl", hash = "sha256:587a6278ed42059191c1a466e490bd7930fb50bd2e255398bc29616c895a61cb"}, + {file = "ruff-0.15.15-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:df0c1c084f5f4be9812f61518a45c440d3c30d69ce4bf6c5270e66d38338f02a"}, + {file = "ruff-0.15.15-py3-none-win32.whl", hash = "sha256:29428ea79694afbe756d45fd59b36f22b6b020dc0443cf7de0173046236964b9"}, + {file = "ruff-0.15.15-py3-none-win_amd64.whl", hash = "sha256:8df0323902e15e24bc4bf246da830573d3cf3352bd0b9a164eab335d111ff4a4"}, + {file = "ruff-0.15.15-py3-none-win_arm64.whl", hash = "sha256:3c8ceca6792f38196b8f589bc92eccd03eef286602da92e5dc05cc42ef6441b7"}, + {file = "ruff-0.15.15.tar.gz", hash = "sha256:b8dff018130b46d8e5bf0f926ef6b60cf871d6d5ae45fc9334e09632daa741d6"}, +] + [[package]] name = "safetensors" version = "0.7.0" @@ -3892,6 +3972,64 @@ dev = ["tokenizers[testing]"] docs = ["setuptools-rust", "sphinx", "sphinx-rtd-theme"] testing = ["datasets", "numpy", "pytest", "pytest-asyncio", "requests", "ruff", "ty"] +[[package]] +name = "tomli" +version = "2.4.1" +description = "A lil' TOML parser" +optional = false +python-versions = ">=3.8" +groups = ["dev"] +markers = "python_version == \"3.10\"" +files = [ + {file = "tomli-2.4.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:f8f0fc26ec2cc2b965b7a3b87cd19c5c6b8c5e5f436b984e85f486d652285c30"}, + {file = "tomli-2.4.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:4ab97e64ccda8756376892c53a72bd1f964e519c77236368527f758fbc36a53a"}, + {file = "tomli-2.4.1-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:96481a5786729fd470164b47cdb3e0e58062a496f455ee41b4403be77cb5a076"}, + {file = "tomli-2.4.1-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:5a881ab208c0baf688221f8cecc5401bd291d67e38a1ac884d6736cbcd8247e9"}, + {file = "tomli-2.4.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:47149d5bd38761ac8be13a84864bf0b7b70bc051806bc3669ab1cbc56216b23c"}, + {file = "tomli-2.4.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:ec9bfaf3ad2df51ace80688143a6a4ebc09a248f6ff781a9945e51937008fcbc"}, + {file = "tomli-2.4.1-cp311-cp311-win32.whl", hash = "sha256:ff2983983d34813c1aeb0fa89091e76c3a22889ee83ab27c5eeb45100560c049"}, + {file = "tomli-2.4.1-cp311-cp311-win_amd64.whl", hash = "sha256:5ee18d9ebdb417e384b58fe414e8d6af9f4e7a0ae761519fb50f721de398dd4e"}, + {file = "tomli-2.4.1-cp311-cp311-win_arm64.whl", hash = "sha256:c2541745709bad0264b7d4705ad453b76ccd191e64aa6f0fc66b69a293a45ece"}, + {file = "tomli-2.4.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:c742f741d58a28940ce01d58f0ab2ea3ced8b12402f162f4d534dfe18ba1cd6a"}, + {file = "tomli-2.4.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:7f86fd587c4ed9dd76f318225e7d9b29cfc5a9d43de44e5754db8d1128487085"}, + {file = "tomli-2.4.1-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:ff18e6a727ee0ab0388507b89d1bc6a22b138d1e2fa56d1ad494586d61d2eae9"}, + {file = "tomli-2.4.1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:136443dbd7e1dee43c68ac2694fde36b2849865fa258d39bf822c10e8068eac5"}, + {file = "tomli-2.4.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:5e262d41726bc187e69af7825504c933b6794dc3fbd5945e41a79bb14c31f585"}, + {file = "tomli-2.4.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:5cb41aa38891e073ee49d55fbc7839cfdb2bc0e600add13874d048c94aadddd1"}, + {file = "tomli-2.4.1-cp312-cp312-win32.whl", hash = "sha256:da25dc3563bff5965356133435b757a795a17b17d01dbc0f42fb32447ddfd917"}, + {file = "tomli-2.4.1-cp312-cp312-win_amd64.whl", hash = "sha256:52c8ef851d9a240f11a88c003eacb03c31fc1c9c4ec64a99a0f922b93874fda9"}, + {file = "tomli-2.4.1-cp312-cp312-win_arm64.whl", hash = "sha256:f758f1b9299d059cc3f6546ae2af89670cb1c4d48ea29c3cacc4fe7de3058257"}, + {file = "tomli-2.4.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:36d2bd2ad5fb9eaddba5226aa02c8ec3fa4f192631e347b3ed28186d43be6b54"}, + {file = "tomli-2.4.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:eb0dc4e38e6a1fd579e5d50369aa2e10acfc9cace504579b2faabb478e76941a"}, + {file = "tomli-2.4.1-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:c7f2c7f2b9ca6bdeef8f0fa897f8e05085923eb091721675170254cbc5b02897"}, + {file = "tomli-2.4.1-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:f3c6818a1a86dd6dca7ddcaaf76947d5ba31aecc28cb1b67009a5877c9a64f3f"}, + {file = "tomli-2.4.1-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:d312ef37c91508b0ab2cee7da26ec0b3ed2f03ce12bd87a588d771ae15dcf82d"}, + {file = "tomli-2.4.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:51529d40e3ca50046d7606fa99ce3956a617f9b36380da3b7f0dd3dd28e68cb5"}, + {file = "tomli-2.4.1-cp313-cp313-win32.whl", hash = "sha256:2190f2e9dd7508d2a90ded5ed369255980a1bcdd58e52f7fe24b8162bf9fedbd"}, + {file = "tomli-2.4.1-cp313-cp313-win_amd64.whl", hash = "sha256:8d65a2fbf9d2f8352685bc1364177ee3923d6baf5e7f43ea4959d7d8bc326a36"}, + {file = "tomli-2.4.1-cp313-cp313-win_arm64.whl", hash = "sha256:4b605484e43cdc43f0954ddae319fb75f04cc10dd80d830540060ee7cd0243cd"}, + {file = "tomli-2.4.1-cp314-cp314-macosx_10_15_x86_64.whl", hash = "sha256:fd0409a3653af6c147209d267a0e4243f0ae46b011aa978b1080359fddc9b6cf"}, + {file = "tomli-2.4.1-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:a120733b01c45e9a0c34aeef92bf0cf1d56cfe81ed9d47d562f9ed591a9828ac"}, + {file = "tomli-2.4.1-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:559db847dc486944896521f68d8190be1c9e719fced785720d2216fe7022b662"}, + {file = "tomli-2.4.1-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:01f520d4f53ef97964a240a035ec2a869fe1a37dde002b57ebc4417a27ccd853"}, + {file = "tomli-2.4.1-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:7f94b27a62cfad8496c8d2513e1a222dd446f095fca8987fceef261225538a15"}, + {file = "tomli-2.4.1-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:ede3e6487c5ef5d28634ba3f31f989030ad6af71edfb0055cbbd14189ff240ba"}, + {file = "tomli-2.4.1-cp314-cp314-win32.whl", hash = "sha256:3d48a93ee1c9b79c04bb38772ee1b64dcf18ff43085896ea460ca8dec96f35f6"}, + {file = "tomli-2.4.1-cp314-cp314-win_amd64.whl", hash = "sha256:88dceee75c2c63af144e456745e10101eb67361050196b0b6af5d717254dddf7"}, + {file = "tomli-2.4.1-cp314-cp314-win_arm64.whl", hash = "sha256:b8c198f8c1805dc42708689ed6864951fd2494f924149d3e4bce7710f8eb5232"}, + {file = "tomli-2.4.1-cp314-cp314t-macosx_10_15_x86_64.whl", hash = "sha256:d4d8fe59808a54658fcc0160ecfb1b30f9089906c50b23bcb4c69eddc19ec2b4"}, + {file = "tomli-2.4.1-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:7008df2e7655c495dd12d2a4ad038ff878d4ca4b81fccaf82b714e07eae4402c"}, + {file = "tomli-2.4.1-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:1d8591993e228b0c930c4bb0db464bdad97b3289fb981255d6c9a41aedc84b2d"}, + {file = "tomli-2.4.1-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:734e20b57ba95624ecf1841e72b53f6e186355e216e5412de414e3c51e5e3c41"}, + {file = "tomli-2.4.1-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:8a650c2dbafa08d42e51ba0b62740dae4ecb9338eefa093aa5c78ceb546fcd5c"}, + {file = "tomli-2.4.1-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:504aa796fe0569bb43171066009ead363de03675276d2d121ac1a4572397870f"}, + {file = "tomli-2.4.1-cp314-cp314t-win32.whl", hash = "sha256:b1d22e6e9387bf4739fbe23bfa80e93f6b0373a7f1b96c6227c32bef95a4d7a8"}, + {file = "tomli-2.4.1-cp314-cp314t-win_amd64.whl", hash = "sha256:2c1c351919aca02858f740c6d33adea0c5deea37f9ecca1cc1ef9e884a619d26"}, + {file = "tomli-2.4.1-cp314-cp314t-win_arm64.whl", hash = "sha256:eab21f45c7f66c13f2a9e0e1535309cee140182a9cdae1e041d02e47291e8396"}, + {file = "tomli-2.4.1-py3-none-any.whl", hash = "sha256:0d85819802132122da43cb86656f8d1f8c6587d54ae7dcaf30e90533028b49fe"}, + {file = "tomli-2.4.1.tar.gz", hash = "sha256:7c7e1a961a0b2f2472c1ac5b69affa0ae1132c39adcb67aba98568702b9cc23f"}, +] + [[package]] name = "torch" version = "2.12.0" @@ -4117,11 +4255,12 @@ version = "4.15.0" description = "Backported and Experimental Type Hints for Python 3.9+" optional = false python-versions = ">=3.9" -groups = ["main", "desktop"] +groups = ["main", "desktop", "dev"] files = [ {file = "typing_extensions-4.15.0-py3-none-any.whl", hash = "sha256:f0fa19c6845758ab08074a0cfa8b7aecb71c999ca73d62883bc25cc018c4e548"}, {file = "typing_extensions-4.15.0.tar.gz", hash = "sha256:0cea48d173cc12fa28ecabc3b837ea3cf6f38c6d1136f85cbaaf598984861466"}, ] +markers = {dev = "python_version == \"3.10\""} [[package]] name = "typing-inspection" @@ -4787,4 +4926,4 @@ propcache = ">=0.2.1" [metadata] lock-version = "2.1" python-versions = ">=3.10,<3.15" -content-hash = "5f90f3b418db02d61a97335c516940014206537887f6633c5bfe68570f979606" +content-hash = "e1ff53b7edd628496f045c9123f342b3f2bfe21f85197be659c448d458dfe5e4" diff --git a/rxconfig.py b/rxconfig.py index 6db70dc..b72ccf7 100644 --- a/rxconfig.py +++ b/rxconfig.py @@ -1,5 +1,7 @@ import reflex as rx +from reflex_base.plugins.sitemap import SitemapPlugin config = rx.Config( app_name="app", + disable_plugins=[SitemapPlugin], ) diff --git a/tests/test_api.py b/tests/test_api.py index b3ba04f..a36d38d 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -127,9 +127,10 @@ def test_create_job_missing_model_id_fails(): assert resp.status_code == 422 -def test_create_job_missing_dataset_path_fails(): +def test_create_job_without_dataset_path_succeeds(): + # dataset_path is optional — hub dataset jobs omit it resp = client.post("/jobs", json={"model_id": "google/gemma-2b"}) - assert resp.status_code == 422 + assert resp.status_code == 201 def test_get_job_status_returns_200(): diff --git a/tests/test_evaluate.py b/tests/test_evaluate.py index be04054..1d295fd 100644 --- a/tests/test_evaluate.py +++ b/tests/test_evaluate.py @@ -8,8 +8,12 @@ import pytest -# Mock heavy ML deps before importing the module -sys.modules.setdefault("torch", MagicMock()) +# Mock heavy ML deps before importing the module. +# torch sub-modules must be registered individually so Python's import +# machinery doesn't try to traverse the MagicMock as a real package. +_torch_mock = MagicMock() +for _mod in ["torch", "torch.utils", "torch.utils.data", "torch.nn", "torch.cuda"]: + sys.modules.setdefault(_mod, _torch_mock) sys.modules.setdefault("transformers", MagicMock()) sys.modules.setdefault("evaluate", MagicMock()) diff --git a/tests/test_workers.py b/tests/test_workers.py index d2cf200..76586b2 100644 --- a/tests/test_workers.py +++ b/tests/test_workers.py @@ -144,7 +144,7 @@ def _base_configs(self): def test_sets_running_status_on_start(self): model_cfg, lora_cfg, train_cfg = self._base_configs() mock_redis = MagicMock() - mock_finetune = MagicMock(return_value="/outputs/job1") + mock_finetune = MagicMock(return_value=("/outputs/job1", MagicMock(), MagicMock())) with ( patch("workers.train_task.redis.from_url", return_value=mock_redis), @@ -161,7 +161,7 @@ def test_sets_running_status_on_start(self): def test_sets_done_status_on_success(self): model_cfg, lora_cfg, train_cfg = self._base_configs() mock_redis = MagicMock() - mock_finetune = MagicMock(return_value="/outputs/job2") + mock_finetune = MagicMock(return_value=("/outputs/job2", MagicMock(), MagicMock())) with ( patch("workers.train_task.redis.from_url", return_value=mock_redis), diff --git a/trainer/callbacks.py b/trainer/callbacks.py index cb960cb..853f437 100644 --- a/trainer/callbacks.py +++ b/trainer/callbacks.py @@ -1,7 +1,9 @@ import json import os +import time import redis +import torch from transformers import TrainerCallback REDIS_URL = os.getenv("REDIS_URL", "redis://localhost:6379/0") @@ -9,21 +11,47 @@ class RedisLossCallback(TrainerCallback): """ - Publishes training loss to Redis channel 'job::loss' - after every logging step so the frontend can stream it live. + Publishes training progress to Redis channel 'job::progress' + after every logging step so the frontend can stream live metrics. """ def __init__(self, job_id: str): self.job_id = job_id self.redis = redis.from_url(REDIS_URL) - self.channel = f"job:{job_id}:loss" + self.channel = f"job:{job_id}:progress" + self._start_time: float = 0.0 + + def on_train_begin(self, args, state, control, **kwargs): + self._start_time = time.time() def on_log(self, args, state, control, logs=None, **kwargs): - if logs and "loss" in logs: - payload = { - "step": state.global_step, - "loss": round(logs["loss"], 4), - "epoch": round(state.epoch, 2) if state.epoch else 0, - "learning_rate": logs.get("learning_rate", 0), - } - self.redis.publish(self.channel, json.dumps(payload)) + if not logs or "loss" not in logs: + return + + gpu_mem = 0.0 + try: + if torch.cuda.is_available(): + gpu_mem = round(torch.cuda.memory_allocated() / 1e9, 2) + except Exception: + pass + + payload = { + "step": state.global_step, + "loss": round(logs["loss"], 4), + "epoch": round(state.epoch, 2) if state.epoch else 0, + "learning_rate": logs.get("learning_rate", 0), + "eval_loss": logs.get("eval_loss"), + "total_steps": state.max_steps or 0, + "elapsed_seconds": int(time.time() - self._start_time), + "gpu_memory_used_gb": gpu_mem, + "status": "running", + } + self.redis.publish(self.channel, json.dumps(payload)) + + def on_train_end(self, args, state, control, **kwargs): + payload = { + "status": "done", + "step": state.global_step, + "elapsed_seconds": int(time.time() - self._start_time), + } + self.redis.publish(self.channel, json.dumps(payload)) diff --git a/trainer/config.py b/trainer/config.py index 5f0ffec..9428b63 100644 --- a/trainer/config.py +++ b/trainer/config.py @@ -8,6 +8,9 @@ class ModelConfig: use_8bit: bool = False trust_remote_code: bool = False max_seq_length: int = 512 + hf_token: str = "" + local_model_path: str = "" + model_source: str = "hub" # "hub" | "local" | "custom_string" @dataclass diff --git a/trainer/dataset.py b/trainer/dataset.py index dbb6133..5fc5d7a 100644 --- a/trainer/dataset.py +++ b/trainer/dataset.py @@ -9,10 +9,12 @@ {output}""" -def format_prompt(row: dict) -> str: +def format_prompt( + row: dict, instruction_col: str = "instruction", output_col: str = "output" +) -> str: return PROMPT_TEMPLATE.format( - instruction=row.get("instruction", ""), - output=row.get("output", ""), + instruction=row.get(instruction_col, row.get("instruction", "")), + output=row.get(output_col, row.get("output", "")), ) @@ -20,26 +22,46 @@ def load_and_tokenize( file_path: str, tokenizer: PreTrainedTokenizer, max_seq_length: int = 512, + hub_dataset_id: str = "", + hub_split: str = "train", + instruction_col: str = "instruction", + output_col: str = "output", ) -> Dataset: """ - Load JSONL or CSV, format as instruction prompts, tokenize. + Load from a local file or HF Hub dataset, apply column mapping, + format as instruction prompts, and tokenize. """ - if file_path.endswith(".csv"): + if hub_dataset_id: + raw = load_dataset(hub_dataset_id, split=hub_split, trust_remote_code=False) + elif file_path.endswith(".csv"): df = pd.read_csv(file_path) raw = Dataset.from_pandas(df) + elif file_path.endswith(".json") and not file_path.endswith(".jsonl"): + import json + + with open(file_path) as f: + data = json.load(f) + raw = Dataset.from_list(data if isinstance(data, list) else [data]) else: raw = load_dataset("json", data_files=file_path, split="train") - def tokenize(batch): - prompts = [format_prompt(row) for row in batch] # adjust if batched - return tokenizer( - prompts, - truncation=True, - max_length=max_seq_length, - padding="max_length", + # Validate columns exist before renaming + if instruction_col not in raw.column_names: + raise ValueError( + f"Instruction column '{instruction_col}' not found. " + f"Available columns: {raw.column_names}" + ) + if output_col not in raw.column_names: + raise ValueError( + f"Output column '{output_col}' not found. Available columns: {raw.column_names}" ) - # Map with formatted prompts first + # Normalise column names so format_prompt always sees "instruction" / "output" + if instruction_col != "instruction": + raw = raw.rename_column(instruction_col, "instruction") + if output_col != "output" and output_col in raw.column_names: + raw = raw.rename_column(output_col, "output") + raw = raw.map(lambda x: {"text": format_prompt(x)}) tokenized = raw.map( lambda x: tokenizer( diff --git a/trainer/finetune.py b/trainer/finetune.py index 1674b0c..75e0875 100644 --- a/trainer/finetune.py +++ b/trainer/finetune.py @@ -16,20 +16,32 @@ def finetune( train_cfg: TrainingConfig, dataset_path: str, job_id: str, + hub_dataset_id: str = "", + hub_split: str = "train", + instruction_col: str = "instruction", + output_col: str = "output", ) -> str: """ Full fine-tuning pipeline: - 1. Load QLoRA model - 2. Load + tokenize dataset + 1. Load model (any source: HF Hub, local, custom string) + 2. Load + tokenize dataset (local file or HF Hub dataset) 3. Train with SFTTrainer 4. Save adapter weights - Returns path to saved adapter. + Returns (output_path, model, tokenizer). """ - # 1. Prepare QLoRA model + # 1. Prepare model model, tokenizer = prepare_qlora_model(model_cfg, lora_cfg) # 2. Load dataset - dataset = load_and_tokenize(dataset_path, tokenizer, model_cfg.max_seq_length) + dataset = load_and_tokenize( + dataset_path, + tokenizer, + model_cfg.max_seq_length, + hub_dataset_id=hub_dataset_id, + hub_split=hub_split, + instruction_col=instruction_col, + output_col=output_col, + ) # 3. Training arguments output_path = os.path.join(train_cfg.output_dir, job_id) diff --git a/trainer/loader.py b/trainer/loader.py index 07c92e5..d2d2b5d 100644 --- a/trainer/loader.py +++ b/trainer/loader.py @@ -1,3 +1,5 @@ +import os + import torch from transformers import ( AutoModelForCausalLM, @@ -8,36 +10,51 @@ from trainer.config import ModelConfig +def _resolve_model_path(cfg: ModelConfig) -> str: + """Return the model identifier to pass to from_pretrained.""" + if cfg.model_source == "local" and cfg.local_model_path: + return cfg.local_model_path + return cfg.model_name + + def load_model_and_tokenizer(cfg: ModelConfig): """ - Load model with optional 4-bit or 8-bit quantization. + Load any Transformers-compatible model with optional 4-bit/8-bit quantization. + Supports HF Hub IDs, local paths, and any string from_pretrained accepts. Returns (model, tokenizer). """ - bnb_config = None + model_path = _resolve_model_path(cfg) + token = cfg.hf_token or os.getenv("HF_TOKEN") or None + local_only = cfg.model_source == "local" and os.path.exists(model_path) + bnb_config = None if cfg.use_4bit: bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.float16, - bnb_4bit_use_double_quant=True, # QLoRA double quantization + bnb_4bit_use_double_quant=True, ) elif cfg.use_8bit: bnb_config = BitsAndBytesConfig(load_in_8bit=True) model = AutoModelForCausalLM.from_pretrained( - cfg.model_name, + model_path, quantization_config=bnb_config, device_map="auto", trust_remote_code=cfg.trust_remote_code, torch_dtype=torch.float16, + token=token, + local_files_only=local_only, ) model.config.use_cache = False model.config.pretraining_tp = 1 tokenizer = AutoTokenizer.from_pretrained( - cfg.model_name, + model_path, trust_remote_code=cfg.trust_remote_code, + token=token, + local_files_only=local_only, ) tokenizer.pad_token = tokenizer.eos_token tokenizer.padding_side = "right" diff --git a/trainer/merge.py b/trainer/merge.py new file mode 100644 index 0000000..2fa2ffc --- /dev/null +++ b/trainer/merge.py @@ -0,0 +1,99 @@ +"""Merge a LoRA adapter into its base model and optionally export as GGUF.""" + +import os +import shutil + +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer + + +def merge_adapter( + base_model_id: str, + adapter_path: str, + output_path: str, + hf_token: str = "", +) -> str: + """ + Load base model in full precision, apply the LoRA adapter, merge weights, + and save the resulting standalone model to output_path. + + Returns output_path on success. + """ + from peft import PeftModel + + token = hf_token or os.getenv("HF_TOKEN") or None + is_local = os.path.exists(base_model_id) + + model = AutoModelForCausalLM.from_pretrained( + base_model_id, + torch_dtype=torch.float16, + device_map="auto", + token=token, + local_files_only=is_local, + ) + model = PeftModel.from_pretrained(model, adapter_path) + model = model.merge_and_unload() + + os.makedirs(output_path, exist_ok=True) + model.save_pretrained(output_path) + + tokenizer = AutoTokenizer.from_pretrained( + base_model_id, + token=token, + local_files_only=is_local, + ) + tokenizer.save_pretrained(output_path) + + return output_path + + +def export_gguf( + merged_model_path: str, + output_dir: str, + quant_type: str = "Q4_K_M", +) -> str: + """ + Convert a merged safetensors model to GGUF using llama.cpp's convert script. + Requires llama-cpp-python or the llama.cpp binary to be installed. + + Returns path to the .gguf file on success, raises RuntimeError otherwise. + """ + import subprocess + import sys + + # Try to find llama.cpp convert script via llama-cpp-python package + try: + import llama_cpp + + convert_script = os.path.join(os.path.dirname(llama_cpp.__file__), "convert_hf_to_gguf.py") + except ImportError: + convert_script = shutil.which("convert_hf_to_gguf.py") or "" + + if not convert_script or not os.path.exists(convert_script): + raise RuntimeError( + "llama.cpp convert_hf_to_gguf.py not found. " + "Install llama-cpp-python or add llama.cpp to PATH." + ) + + os.makedirs(output_dir, exist_ok=True) + gguf_path = os.path.join(output_dir, f"model-{quant_type.lower()}.gguf") + + # Step 1: Convert to f16 GGUF + f16_path = os.path.join(output_dir, "model-f16.gguf") + subprocess.run( + [sys.executable, convert_script, merged_model_path, "--outfile", f16_path], + check=True, + ) + + # Step 2: Quantize (requires llama-quantize binary) + quantize_bin = shutil.which("llama-quantize") or shutil.which("quantize") + if quantize_bin: + subprocess.run( + [quantize_bin, f16_path, gguf_path, quant_type], + check=True, + ) + os.remove(f16_path) + return gguf_path + else: + # Return f16 if quantize binary not available + return f16_path diff --git a/workers/merge_task.py b/workers/merge_task.py new file mode 100644 index 0000000..9e3bf5b --- /dev/null +++ b/workers/merge_task.py @@ -0,0 +1,172 @@ +"""Celery tasks for post-training operations: merge, GGUF export, GitHub push.""" + +import json +import os + +import redis + +from workers.celery_app import celery_app + +REDIS_URL = os.getenv("REDIS_URL", "redis://localhost:6379/0") + +try: + import spaces + + _gpu_decorator = spaces.GPU +except ImportError: + _gpu_decorator = lambda fn: fn # noqa: E731 + + +def _publish_deploy_log(r: redis.Redis, job_id: str, message: str): + r.publish(f"job:{job_id}:deploy", json.dumps({"message": message})) + + +@_gpu_decorator +def _run_merge_impl( + job_id: str, base_model_id: str, adapter_path: str, output_path: str, hf_token: str = "" +): + from trainer.merge import merge_adapter + + return merge_adapter(base_model_id, adapter_path, output_path, hf_token) + + +@celery_app.task(bind=True, name="workers.merge_task.merge_adapter") +def merge_adapter_task( + self, job_id: str, base_model_id: str, adapter_path: str, hf_token: str = "" +): + r = redis.from_url(REDIS_URL) + merged_key = f"job:{job_id}:merged" + output_path = os.path.join(os.getenv("OUTPUT_DIR", "./outputs"), job_id, "merged") + + try: + _publish_deploy_log(r, job_id, "Starting model merge — this may take several minutes...") + merged_path = _run_merge_impl(job_id, base_model_id, adapter_path, output_path, hf_token) + r.set(merged_key, json.dumps({"status": "done", "merged_path": merged_path})) + _publish_deploy_log(r, job_id, f"Merge complete. Saved to {merged_path}") + return merged_path + except Exception as e: + r.set(merged_key, json.dumps({"status": "failed", "error": str(e)})) + _publish_deploy_log(r, job_id, f"Merge failed: {e}") + raise + + +@celery_app.task(bind=True, name="workers.merge_task.export_gguf") +def export_gguf_task(self, job_id: str, merged_model_path: str, quant_type: str = "Q4_K_M"): + r = redis.from_url(REDIS_URL) + gguf_key = f"job:{job_id}:gguf" + output_dir = os.path.join(os.getenv("OUTPUT_DIR", "./outputs"), job_id, "gguf") + + try: + _publish_deploy_log(r, job_id, f"Exporting GGUF with quantization {quant_type}...") + from trainer.merge import export_gguf + + gguf_path = export_gguf(merged_model_path, output_dir, quant_type) + r.set(gguf_key, json.dumps({"status": "done", "gguf_path": gguf_path})) + _publish_deploy_log(r, job_id, f"GGUF export complete: {os.path.basename(gguf_path)}") + return gguf_path + except Exception as e: + r.set(gguf_key, json.dumps({"status": "failed", "error": str(e)})) + _publish_deploy_log(r, job_id, f"GGUF export failed: {e}") + raise + + +@celery_app.task(bind=True, name="workers.merge_task.push_github") +def push_github_task( + self, + job_id: str, + adapter_path: str, + repo_url: str, + github_token: str, + commit_message: str = "Add fine-tuned LoRA adapter", +): + import subprocess + import tempfile + + r = redis.from_url(REDIS_URL) + + # Only allow GitHub HTTPS remotes — reject arbitrary hosts + if not repo_url.startswith("https://github.com/"): + raise ValueError(f"Only https://github.com/ remotes are supported, got: {repo_url}") + + try: + _publish_deploy_log(r, job_id, "Pushing adapter to GitHub...") + + # Provide the token via a credential helper so it never appears in + # command args, process listings, or error output. + import stat + import textwrap + + with tempfile.TemporaryDirectory() as tmp: + # Write a one-shot credential helper script + helper_path = os.path.join(tmp, "git-credential-tuneos") + helper_script = textwrap.dedent(f"""\ + #!/bin/sh + echo username=x-token + echo password={github_token} + """) + with open(helper_path, "w") as fh: + fh.write(helper_script) + os.chmod(helper_path, stat.S_IRWXU) + + clone_env = { + **os.environ, + "GIT_ASKPASS": helper_path, + "GIT_TERMINAL_PROMPT": "0", + } + + repo_dir = os.path.join(tmp, "repo") + subprocess.run( + ["git", "clone", repo_url, repo_dir], check=True, capture_output=True, env=clone_env + ) + + # Copy adapter files into the cloned repo + import shutil + + dest = os.path.join(repo_dir, "adapter") + shutil.copytree(adapter_path, dest, dirs_exist_ok=True) + + # Set up LFS for large files + subprocess.run( + ["git", "-C", repo_dir, "lfs", "install"], check=True, capture_output=True + ) + subprocess.run( + ["git", "-C", repo_dir, "lfs", "track", "*.safetensors"], + check=True, + capture_output=True, + ) + subprocess.run( + ["git", "-C", repo_dir, "add", ".gitattributes"], check=True, capture_output=True + ) + subprocess.run( + ["git", "-C", repo_dir, "add", "adapter/"], check=True, capture_output=True + ) + subprocess.run( + ["git", "-C", repo_dir, "commit", "-m", commit_message], + check=True, + capture_output=True, + env={ + **clone_env, + "GIT_AUTHOR_NAME": "TuneOS", + "GIT_AUTHOR_EMAIL": "tuneos@bot.local", + "GIT_COMMITTER_NAME": "TuneOS", + "GIT_COMMITTER_EMAIL": "tuneos@bot.local", + }, + ) + subprocess.run( + ["git", "-C", repo_dir, "push"], + check=True, + capture_output=True, + env=clone_env, + ) + + _publish_deploy_log(r, job_id, f"Pushed adapter to {repo_url}") + except subprocess.CalledProcessError as e: + # Strip any token that might appear in error output before logging + raw_msg = e.stderr.decode() if e.stderr else str(e) + safe_msg = raw_msg.replace(github_token, "***") if github_token else raw_msg + _publish_deploy_log(r, job_id, f"GitHub push failed: {safe_msg}") + raise RuntimeError(safe_msg) from e + except Exception as e: + safe_e = str(e).replace(github_token, "***") if github_token else str(e) + _publish_deploy_log(r, job_id, f"GitHub push failed: {safe_e}") + raise diff --git a/workers/train_task.py b/workers/train_task.py index ce0889b..67cbccf 100644 --- a/workers/train_task.py +++ b/workers/train_task.py @@ -10,9 +10,26 @@ REDIS_URL = os.getenv("REDIS_URL", "redis://localhost:6379/0") +try: + import spaces + _gpu_decorator = spaces.GPU +except ImportError: + _gpu_decorator = lambda fn: fn # noqa: E731 + + +@_gpu_decorator def _run_finetune_impl( - task_self, job_id: str, model_cfg: dict, lora_cfg: dict, train_cfg: dict, dataset_path: str + task_self, + job_id: str, + model_cfg: dict, + lora_cfg: dict, + train_cfg: dict, + dataset_path: str, + hub_dataset_id: str = "", + hub_split: str = "train", + instruction_col: str = "instruction", + output_col: str = "output", ): """Core logic, separated so it can be unit-tested without a live Celery broker.""" r = redis.from_url(REDIS_URL) @@ -21,12 +38,17 @@ def _run_finetune_impl( try: r.set(status_key, json.dumps({"status": "running", "job_id": job_id})) + cfg = ModelConfig(**model_cfg) output_path, model, tokenizer = finetune( - model_cfg=ModelConfig(**model_cfg), + model_cfg=cfg, lora_cfg=LoraConfig(**lora_cfg), train_cfg=TrainingConfig(**train_cfg), dataset_path=dataset_path, job_id=job_id, + hub_dataset_id=hub_dataset_id, + hub_split=hub_split, + instruction_col=instruction_col, + output_col=output_col, ) # Evaluate on a 20% random sample of the training data @@ -35,7 +57,13 @@ def _run_finetune_impl( from trainer.evaluate import evaluate_model full_dataset = load_and_tokenize( - dataset_path, tokenizer, ModelConfig(**model_cfg).max_seq_length + dataset_path, + tokenizer, + cfg.max_seq_length, + hub_dataset_id=hub_dataset_id, + hub_split=hub_split, + instruction_col=instruction_col, + output_col=output_col, ) n_eval = max(1, int(0.2 * len(full_dataset))) eval_sample = full_dataset.shuffle(seed=42).select(range(n_eval)) @@ -72,8 +100,28 @@ def _run_finetune_impl( raise -@celery_app.task(bind=True, name="workers.train_task.run_finetune") +@celery_app.task(bind=True, name="workers.train_task.run_finetune", time_limit=7200) def run_finetune( - self, job_id: str, model_cfg: dict, lora_cfg: dict, train_cfg: dict, dataset_path: str + self, + job_id: str, + model_cfg: dict, + lora_cfg: dict, + train_cfg: dict, + dataset_path: str, + hub_dataset_id: str = "", + hub_split: str = "train", + instruction_col: str = "instruction", + output_col: str = "output", ): - return _run_finetune_impl(self, job_id, model_cfg, lora_cfg, train_cfg, dataset_path) + return _run_finetune_impl( + self, + job_id, + model_cfg, + lora_cfg, + train_cfg, + dataset_path, + hub_dataset_id=hub_dataset_id, + hub_split=hub_split, + instruction_col=instruction_col, + output_col=output_col, + )