Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
151 changes: 151 additions & 0 deletions app/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,29 @@

from __future__ import annotations

import io
import json
import os
import platform
import subprocess
import uuid
import zipfile

from fastapi import FastAPI, HTTPException
from fastapi.responses import StreamingResponse
from pydantic import BaseModel, Field

# ── Router ───────────────────────────────────────────────────────
app_api = FastAPI(title="TuneOS API")

# ── Constants ────────────────────────────────────────────────────
_VERSION = "0.1.0"
REDIS_URL = os.getenv("REDIS_URL", "redis://localhost:6379/0")
OUTPUT_DIR = os.getenv("OUTPUT_DIR", "./outputs")

# Process-level inference cache: job_id -> (model, tokenizer)
# Max 1 entry — evicted when a different job_id is requested.
_INFER_CACHE: dict = {}

_SUPPORTED_MODELS: list[dict[str, str]] = [
{
Expand Down Expand Up @@ -246,3 +257,143 @@ async def cancel_job(job_id: str) -> JobStatus:
except Exception:
pass # Best-effort cancellation
return JobStatus(job_id=job_id, status="cancelled")


# ── Post-training endpoints ───────────────────────────────────────

@app_api.get("/jobs/{job_id}/download")
async def download_adapter(job_id: str):
"""Zip and stream the adapter weights directory."""
adapter_dir = os.path.join(OUTPUT_DIR, job_id)
if not os.path.isdir(adapter_dir):
raise HTTPException(status_code=404, detail="Adapter not found")

buf = io.BytesIO()
with zipfile.ZipFile(buf, "w", zipfile.ZIP_DEFLATED) as zf:
for fname in os.listdir(adapter_dir):
full = os.path.join(adapter_dir, fname)
if os.path.isfile(full):
zf.write(full, fname)
buf.seek(0)

return StreamingResponse(
buf,
media_type="application/zip",
headers={"Content-Disposition": f"attachment; filename=adapter_{job_id[:8]}.zip"},
)


class PushHubRequest(BaseModel):
repo_name: str
hf_token: str = ""


@app_api.post("/jobs/{job_id}/push_hub")
async def push_to_hub(job_id: str, req: PushHubRequest):
"""Push the adapter to Hugging Face Hub."""
token = req.hf_token or os.getenv("HF_TOKEN", "")
if not token:
raise HTTPException(status_code=400, detail="HF token required")

adapter_dir = os.path.join(OUTPUT_DIR, job_id)
if not os.path.isdir(adapter_dir):
raise HTTPException(status_code=404, detail="Adapter not found")

try:
from huggingface_hub import HfApi

api = HfApi(token=token)
api.create_repo(repo_id=req.repo_name, repo_type="model", exist_ok=True, private=True)
api.upload_folder(folder_path=adapter_dir, repo_id=req.repo_name, repo_type="model")
except Exception as exc:
raise HTTPException(status_code=500, detail=str(exc)) from exc

return {"status": "pushed", "repo_url": f"https://huggingface.co/{req.repo_name}"}


@app_api.get("/jobs/{job_id}/eval")
async def get_eval(job_id: str):
"""Read evaluation metrics written by the training worker."""
try:
import redis as _redis

r = _redis.from_url(REDIS_URL)
raw = r.get(f"job:{job_id}:eval")
if not raw:
return {"status": "not_ready", "perplexity": None, "bleu": None}
data = json.loads(raw)
return {"status": "done", **data}
except Exception as exc:
raise HTTPException(status_code=500, detail=str(exc)) from exc


class InferRequest(BaseModel):
prompt: str
max_new_tokens: int = 200
temperature: float = 0.7


@app_api.post("/jobs/{job_id}/infer")
async def infer(job_id: str, req: InferRequest):
"""Run inference using the fine-tuned adapter (loaded lazily, cached in-process)."""
import torch

state = _get_job_status_from_redis(job_id)
if state.get("status") != "done":
raise HTTPException(status_code=400, detail="Job not complete")

adapter_dir = state.get("output_path") or os.path.join(OUTPUT_DIR, job_id)
if not os.path.isdir(adapter_dir):
raise HTTPException(status_code=404, detail="Adapter directory not found")

global _INFER_CACHE

if job_id not in _INFER_CACHE:
# Evict any previously cached model to free VRAM
_INFER_CACHE.clear()

config_path = os.path.join(adapter_dir, "adapter_config.json")
if not os.path.exists(config_path):
raise HTTPException(status_code=404, detail="adapter_config.json not found")

with open(config_path) as f:
adapter_cfg = json.load(f)
base_model_name = adapter_cfg.get("base_model_name_or_path", "")
if not base_model_name:
raise HTTPException(status_code=500, detail="base_model_name_or_path missing in adapter_config.json")

try:
from peft import PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer

base_model = AutoModelForCausalLM.from_pretrained(
base_model_name,
device_map="auto",
torch_dtype=torch.float16,
)
tokenizer = AutoTokenizer.from_pretrained(base_model_name)
model = PeftModel.from_pretrained(base_model, adapter_dir)
model.eval()
_INFER_CACHE[job_id] = (model, tokenizer)
except Exception as exc:
raise HTTPException(status_code=500, detail=f"Failed to load model: {exc}") from exc

model, tokenizer = _INFER_CACHE[job_id]

try:
inputs = tokenizer(req.prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
out = model.generate(
**inputs,
max_new_tokens=req.max_new_tokens,
temperature=req.temperature,
do_sample=True,
pad_token_id=tokenizer.eos_token_id,
)
response = tokenizer.decode(
out[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True
)
except Exception as exc:
raise HTTPException(status_code=500, detail=f"Inference failed: {exc}") from exc

return {"response": response}
2 changes: 2 additions & 0 deletions app/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from app.components.layout import two_panel_layout
from app.pages.configure import configure_page
from app.pages.datasets import datasets_page
from app.pages.finetune import finetune_page
from app.pages.results import results_page
from app.pages.training import training_page
from app.pages.upload import upload_page
Expand Down Expand Up @@ -60,6 +61,7 @@ def index() -> rx.Component:
app.add_page(training_page, route="/training", title="Training — TuneOS")
app.add_page(results_page, route="/results", title="Results — TuneOS")
app.add_page(datasets_page, route="/datasets", title="Datasets — TuneOS")
app.add_page(finetune_page, route="/finetune", title="Fine-tune — TuneOS")

# Mount REST API endpoints. Imported here, after page registration, to avoid
# a circular import between the Reflex app module and the API router.
Expand Down
4 changes: 2 additions & 2 deletions app/components/sidebar.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ def _expanded_sidebar() -> rx.Component:
active=AppState.current_view == "datasets",
on_click=AppState.set_view("datasets"),
),
_nav_item("flask-conical", "Techniques"),
_nav_item("flask-conical", "Fine-tune", on_click=rx.redirect("/finetune")),
spacing="1",
width="100%",
padding_x="8px",
Expand Down Expand Up @@ -309,7 +309,7 @@ def _collapsed_sidebar() -> rx.Component:
active=AppState.current_view == "datasets",
on_click=AppState.set_view("datasets"),
),
_collapsed_icon_btn("flask-conical"),
_collapsed_icon_btn("flask-conical", on_click=rx.redirect("/finetune")),
rx.spacer(),
_collapsed_icon_btn("settings"),
spacing="4",
Expand Down
Loading
Loading