diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000..e69de29 diff --git a/model-lab/app/main.py b/model-lab/app/main.py index bb63904..61815f0 100644 --- a/model-lab/app/main.py +++ b/model-lab/app/main.py @@ -6,15 +6,21 @@ from pathlib import Path from fastapi.middleware.cors import CORSMiddleware from dotenv import load_dotenv + load_dotenv() app = FastAPI(title="Face Sentiment API") -# path to the model +# Path to the model (fixed typo: onxx_models → onnx_models) BASE_DIR = Path(__file__).resolve().parent.parent -MODEL_PATH = BASE_DIR / "models"/ "onxx_models" / "emotion-ferplus-8.onnx" +MODEL_PATH = BASE_DIR / "models" / "onnx_models" / "emotion-ferplus-8.onnx" +if not MODEL_PATH.exists(): + raise FileNotFoundError(f"Model file not found: {MODEL_PATH}") +# Validate CORS origins properly ALLOWED_ORIGINS = os.getenv("CORS_ALLOWED_ORIGINS", "http://localhost:3000").split(",") +if not ALLOWED_ORIGINS or ALLOWED_ORIGINS == [""]: + raise RuntimeError("No allowed origins configured in CORS_ALLOWED_ORIGINS") app.add_middleware( CORSMiddleware, @@ -23,16 +29,31 @@ allow_headers=["Content-Type"], ) +# ✅ Pre-load both model options at startup to avoid race conditions +emotion_models = { + 1: Model(model_path=MODEL_PATH, model_option=1), # HuggingFace ViT + 2: Model(model_path=MODEL_PATH, model_option=2) # ONNX with CV2 +} + @app.post("/predict") async def predict(file: UploadFile = File(...), model_option: int = Form(1)): if not file.content_type.startswith("image/"): raise HTTPException(status_code=400, detail="File must be an image") + file_bytes = await file.read() try: image = Image.open(io.BytesIO(file_bytes)) except (OSError, ValueError) as err: raise HTTPException(status_code=400, detail="Invalid or corrupted image file") from err - emotion_model = Model(model_path=MODEL_PATH, model_option=model_option) - emotion, prob = emotion_model.predict(pil_image=image) - - return {"emotion": emotion, "probabilities": prob} \ No newline at end of file + + # ✅ Validate model_option before using + if model_option not in emotion_models: + raise HTTPException(status_code=400, detail=f"Invalid model_option: {model_option}") + + try: + # ✅ Select the correct preloaded model without mutating shared state + emotion, prob = emotion_models[model_option].predict(pil_image=image) + except Exception as e: + raise HTTPException(status_code=500, detail=f"Model inference failed: {e}") from e + + return {"emotion": emotion, "probabilities": prob}