diff --git a/API/dependencies/di.py b/API/dependencies/di.py index 5641e25..0a018f2 100644 --- a/API/dependencies/di.py +++ b/API/dependencies/di.py @@ -18,6 +18,7 @@ from services.expense_service import ExpenseService, IExpenseService from services.group_log_service import GroupLogService, IGroupLogService from services.group_service import GroupService, IGroupService +from services.receipt_service import IReceiptService, ReceiptService from services.user_group_service import IUserGroupService, UserGroupService from services.user_service import IUserService, UserService from sqlalchemy.orm import Session @@ -84,4 +85,7 @@ def get_expense_payment_service( return ExpensePaymentService(repo, expense_repository, group_repository, user_repository) def get_category_service(repo: ICategoryRepository = Depends(get_category_repository)) -> ICategoryService: - return CategoryService(repo) \ No newline at end of file + return CategoryService(repo) + +def get_receipt_service(category_repository: ICategoryRepository = Depends(get_category_repository)) -> IReceiptService: + return ReceiptService(category_repository) \ No newline at end of file diff --git a/API/main.py b/API/main.py index 97bd864..3095dd0 100644 --- a/API/main.py +++ b/API/main.py @@ -8,6 +8,9 @@ from routes.expense_routes import router as expense_router from routes.group_log_routes import router as group_log_router from routes.group_routes import router as group_router + +# ReceiptService specific +from routes.receipt_routes import router as receipt_router from routes.user_routes import router as user_router app = FastAPI(title="GitPushForce API") @@ -30,6 +33,7 @@ app.include_router(category_router, prefix="/categories", tags=["Categories"]) app.include_router(expense_payment_router, prefix="/expenses_payments") app.include_router(group_log_router, prefix="/group_logs") +app.include_router(receipt_router, prefix="/receipt") @app.get("/") def root(): diff --git a/API/repositories/category_repository.py b/API/repositories/category_repository.py index 95b1260..41ac10d 100644 --- a/API/repositories/category_repository.py +++ b/API/repositories/category_repository.py @@ -44,6 +44,12 @@ def get_all(self, sort_by: str, order: str) -> List[Category]: statement = select(Category).order_by(sort_order) return list(self.db.scalars(statement)) + def get_by_user(self, user_id: int, sort_by: str, order: str) -> List[Category]: + sort_column = getattr(Category, sort_by, Category.title) + sort_order = asc(sort_column) if order == "asc" else desc(sort_column) + statement = select(Category).where(Category.user_id == user_id).order_by(sort_order) + return list(self.db.scalars(statement)) + def get_by_id(self, category_id: int) -> Category: statement = select(Category).where(Category.id == category_id) return self.db.scalars(statement).first() diff --git a/API/requirements.txt b/API/requirements.txt index 3808645..8653696 100644 Binary files a/API/requirements.txt and b/API/requirements.txt differ diff --git a/API/routes/category_routes.py b/API/routes/category_routes.py index 5d8ade0..7bc5b07 100644 --- a/API/routes/category_routes.py +++ b/API/routes/category_routes.py @@ -28,6 +28,10 @@ def get_all_categories( ): return category_service.get_all_categories(sort_by, order) +@router.get("/{user_id}") +def get_user_categories(user_id: int = Depends(get_current_user_id), category_service: ICategoryService = Depends(get_category_service), sort_by: str = Query("title"), order: str = Query("asc", regex="^(asc|desc)$")): + return category_service.get_user_categories(user_id, sort_by, order) + @router.put("/{category_id}") def update_category(category_id: int, category_in: CategoryUpdate, requester_id: int = Depends(get_current_user_id), category_service: ICategoryService = Depends(get_category_service)): return category_service.update_category(category_id, category_in, requester_id) diff --git a/API/routes/receipt_routes.py b/API/routes/receipt_routes.py new file mode 100644 index 0000000..5a90df0 --- /dev/null +++ b/API/routes/receipt_routes.py @@ -0,0 +1,16 @@ +from dependencies.di import get_receipt_service +from fastapi import APIRouter, Depends, File, Request, UploadFile +from services.receipt_service import IReceiptService +from utils.helpers.jwt_utils import JwtUtils + +router = APIRouter(prefix="/receipt", tags=["Receipt"]) + +def get_current_user_id(request: Request) -> int: + """ + Returns the authenticated user id. + """ + return JwtUtils.auth_wrapper(request) + +@router.post("/process-receipt") +def process_receipt(image: UploadFile = File(...), user_id: int = Depends(get_current_user_id), receipt_service: IReceiptService = Depends(get_receipt_service)): + return receipt_service.process_receipt_photo(image, user_id) diff --git a/API/services/category_service.py b/API/services/category_service.py index 55b8d50..aa4eb65 100644 --- a/API/services/category_service.py +++ b/API/services/category_service.py @@ -74,6 +74,16 @@ def get_all_categories(self, sort_by: str, order: str) -> APIResponse: data=categories_response ) + def get_user_categories(self, user_id: int, sort_by: str, order: str) -> APIResponse: + self.logger.info(f"Fetching categories for user with id {user_id}") + + categories = self.repository.get_by_user(user_id, sort_by, order) + categories_response = [CategoryResponse.model_validate(category) for category in categories] + return APIResponse( + success=True, + data=categories_response + ) + def update_category(self, category_id: int, data: CategoryUpdate, requester_id: int) -> APIResponse: self.logger.info(f"Updating category with id {category_id}") diff --git a/API/services/receipt_service.py b/API/services/receipt_service.py new file mode 100644 index 0000000..b4fe0ff --- /dev/null +++ b/API/services/receipt_service.py @@ -0,0 +1,166 @@ +import io +import json +import os +import re +import time +from abc import ABC, abstractmethod +from typing import List + +from dotenv import load_dotenv +from fastapi import HTTPException, UploadFile +from google import genai +from google.genai import types +from PIL import Image +from repositories.category_repository import ICategoryRepository + +load_dotenv() + +class IReceiptService(ABC): + @abstractmethod + def extract_json_from_response(self, text: str) -> str: ... + + @abstractmethod + def generate_prompt(self, categories: dict[str, List[str]]) -> str: ... + + @abstractmethod + def process_receipt_photo(self, image: UploadFile, user_id: int): ... + + +class ReceiptService: + def __init__(self, category_repository: ICategoryRepository): + self.category_repository = category_repository + self.max_retries = 3 + self.delay = 2 + self.API_KEY = os.getenv("API_KEY") + self.client = genai.Client(api_key=self.API_KEY) + self.SYSTEM_CONFIG = types.GenerateContentConfig( + system_instruction=(""" + You are a receipt-processing assistant. + Your task is to analyze a photo of a receipt and extract only the purchased items, then categorize each item into one of the categories provided in the user prompt. + The categories will be provided in this format: category (a list of relevant keywords for this category), category ... + A keyword may be one word or a sentence describing the category. + You must output a single valid JSON object with the following structure: + { + "items": [ + { + "name": string, + "quantity": number, + "price": number, + "category": string, + "keywords": List(string) + } + ], + "total": number + } + + Rules: + - Output ONLY JSON. No explanations. No commentary. No text outside the JSON. + - Output STRICT valid JSON (double quotes, no trailing commas, correct types). + - Do not include any fields other than those defined in the JSON schema. + - Prices, totals, and quantities must be numeric values only (no currency symbols). + - If the receipt does not specify quantity, use 1. + - If the receipt breaks an item into multiple lines, merge them into one coherent item. + - If there are multiple totals (e.g., subtotal, total with tax), always choose the total that includes taxes. + - The semantic meaning of the category always takes precedence over keyword matches. Keywords help, but only when the category meaning aligns with the actual type of the item. + - If an item matches one of the provided categories, add it in the response, along with its keywords; Do NOT generate any additional keywords for this category and make sure to include all of the provided keywords. + - If an item clearly does not match any provided category, you may create one new category, but: + - Name it concisely (1–2 words). + - Make it a general category that could reasonably include similar items, avoiding overly specific or niche categories. + - Generate 5 relevant keywords for the category to include in the response. + - Only create a new category if absolutely necessary; prefer mapping items to broader existing categories whenever possible. + - If a field is missing or ambiguous, deduce it cautiously from surrounding information. + - If there is no receipt in the provided image, return an empty JSON. + """ + ) + ) + + def extract_json_from_response(self, text: str) -> str: + fenced = re.search(r"```(?:json)?\s*(.*?)```", text, re.DOTALL | re.IGNORECASE) + if fenced: + return fenced.group(1).strip() + return text.strip() + + def _validate_image(self, image_bytes: bytes): + try: + Image.open(io.BytesIO(image_bytes)) + except Exception: + raise ValueError("The provided image is not valid.") + + def _validate_receipt_response(self, response_json: str): + try: + data = json.loads(response_json) + except json.JSONDecodeError: + raise ValueError("Response JSON is not valid.") + if not data: + raise ValueError("The provided image is not a valid receipt image.") + + required_top = {"items", "total"} + if not required_top.issubset(data.keys()): + raise ValueError("Response JSON is missing 'items' or 'total' fields.") + extra_top = set(data.keys()) - required_top + if extra_top: + raise ValueError(f"Unexpected top-level fields: {extra_top}") + + required_item_fields = {"name", "quantity", "price", "category", "keywords"} + for item in data["items"]: + missing = required_item_fields - set(item.keys()) + if missing: + raise ValueError(f"Item is missing required fields: {missing}") + extra = set(item.keys()) - required_item_fields + if extra: + raise ValueError(f"Item contains unexpected fields: {extra}") + + def generate_prompt(self, categories: dict[str, List[str]]) -> str: + if categories: + parts = [] + for category, keywords in categories.items(): + keyword_str = ", ".join(keywords) + parts.append(f"{category} ({keyword_str})") + joined = ", ".join(parts) + return f"Analyze the receipt image and categorize each purchased item into one of these categories: {joined}" + else: + return "Analyze the receipt image and categorize each purchased item into one category" + + def _load_user_categories(self, user_id: int) -> dict[str, List[str]]: + categories = self.category_repository.get_by_user( + user_id=user_id, + sort_by="title", + order="asc", + ) + result: dict[str, List[str]] = {} + + for category in categories: + keywords = category.keywords or [] + result[category.title] = keywords + + return result + + def process_receipt_photo(self, image: UploadFile, user_id: int): + image_bytes = image.file.read() + self._validate_image(image_bytes) + categories = self._load_user_categories(user_id) + prompt = self.generate_prompt(categories) + + for attempt in range(1, self.max_retries + 1): + response = self.client.models.generate_content( + model="gemini-2.5-flash", + contents=[ + types.Part.from_bytes( + data=image_bytes, + mime_type="image/jpeg", + ), + prompt, + ], + config=self.SYSTEM_CONFIG, + ) + response_json = self.extract_json_from_response(response.text) + try: + self._validate_receipt_response(response_json) + return json.loads(response_json) + except ValueError as e: + error_msg = str(e) + if "not a valid receipt image" in error_msg: + raise HTTPException(status_code=400, detail=error_msg) + if attempt >= self.max_retries: + raise HTTPException(status_code=500, detail=f"Failed after {self.max_retries} attempts: {error_msg}") + time.sleep(self.delay) diff --git a/ReceiptService/main.py b/ReceiptService/main.py index 6941673..78b563c 100644 --- a/ReceiptService/main.py +++ b/ReceiptService/main.py @@ -17,6 +17,8 @@ system_instruction=(""" You are a receipt-processing assistant. Your task is to analyze a photo of a receipt and extract only the purchased items, then categorize each item into one of the categories provided in the user prompt. + The categories will be provided in this format: category (a list of relevant keywords for this category), category ... + A keyword may be one word or a sentence describing the category. You must output a single valid JSON object with the following structure: { "items": [ @@ -24,7 +26,8 @@ "name": string, "quantity": number, "price": number, - "category": string + "category": string, + "keywords": List(string) } ], "total": number @@ -33,13 +36,17 @@ Rules: - Output ONLY JSON. No explanations. No commentary. No text outside the JSON. - Output STRICT valid JSON (double quotes, no trailing commas, correct types). + - Do not include any fields other than those defined in the JSON schema. - Prices, totals, and quantities must be numeric values only (no currency symbols). - If the receipt does not specify quantity, use 1. - If the receipt breaks an item into multiple lines, merge them into one coherent item. - If there are multiple totals (e.g., subtotal, total with tax), always choose the total that includes taxes. + - The semantic meaning of the category always takes precedence over keyword matches. Keywords help, but only when the category meaning aligns with the actual type of the item. + - If an item matches one of the provided categories, add it in the response, along with its keywords; Do NOT generate any additional keywords for this category and make sure to include all of the provided keywords. - If an item clearly does not match any provided category, you may create one new category, but: - Name it concisely (1–2 words). - Make it a general category that could reasonably include similar items, avoiding overly specific or niche categories. + - Generate 5 relevant keywords for the category to include in the response. - Only create a new category if absolutely necessary; prefer mapping items to broader existing categories whenever possible. - If a field is missing or ambiguous, deduce it cautiously from surrounding information. - If there is no receipt in the provided image, return an empty JSON. @@ -67,37 +74,43 @@ def _validate_receipt_response(response_json: str): data = json.loads(response_json) except json.JSONDecodeError: raise ValueError("Response JSON is not valid.") - if "items" not in data or "total" not in data: + if not data: + raise ValueError("The provided image is not a valid receipt image.") + + required_top = {"items", "total"} + if not required_top.issubset(data.keys()): raise ValueError("Response JSON is missing 'items' or 'total' fields.") + extra_top = set(data.keys()) - required_top + if extra_top: + raise ValueError(f"Unexpected top-level fields: {extra_top}") + + required_item_fields = {"name", "quantity", "price", "category", "keywords"} for item in data["items"]: - if "name" not in item: - raise ValueError("Response JSON is missing 'name' field.") - if "quantity" not in item: - raise ValueError("Response JSON is missing 'quantity' field.") - if "price" not in item: - raise ValueError("Response JSON is missing 'price' field.") - if "category" not in item: - raise ValueError("Response JSON is missing 'category' field.") + missing = required_item_fields - set(item.keys()) + if missing: + raise ValueError(f"Item is missing required fields: {missing}") + extra = set(item.keys()) - required_item_fields + if extra: + raise ValueError(f"Item contains unexpected fields: {extra}") -def generate_prompt(categories: List[str]) -> str: +def generate_prompt(categories: dict[str, List[str]]) -> str: if categories: - category_titles = ", ".join([c for c in categories]) - prompt = ( - f"Analyze the receipt image and categorize each purchased item into one of these categories: " - f"{category_titles}." - ) + parts = [] + for category, keywords in categories.items(): + keyword_str = ", ".join(keywords) + parts.append(f"{category} ({keyword_str})") + joined = ", ".join(parts) + return f"Analyze the receipt image and categorize each purchased item into one of these categories: {joined}" else: - prompt = "Analyze the receipt image and categorize each purchased item into one category" - return prompt + return "Analyze the receipt image and categorize each purchased item into one category" -def process_receipt_photo(image_bytes: bytes, categories: List[str], max_retries=3, delay=2): +def process_receipt_photo(image_bytes: bytes, categories: dict[str, List[str]], max_retries=3, delay=2): _validate_image(image_bytes) prompt = generate_prompt(categories) for attempt in range(1, max_retries + 1): - print(f"Attempt {attempt}...") response = client.models.generate_content( model="gemini-2.5-flash", contents=[ @@ -110,16 +123,13 @@ def process_receipt_photo(image_bytes: bytes, categories: List[str], max_retries config=SYSTEM_CONFIG, ) response_json = extract_json_from_response(response.text) - if response_json == "{}": - raise ValueError("The provided image is not a valid receipt image.") try: _validate_receipt_response(response_json) return response_json except ValueError as e: - print(f"Validation failed: {e}") - if attempt < max_retries: - print(f"Retrying in {delay} seconds...") - time.sleep(delay) - else: - raise RuntimeError(f"Failed after {max_retries} attempts: {e}") - + error_msg = str(e) + if "not a valid receipt image" in error_msg: + raise ValueError(error_msg) + if attempt >= max_retries: + raise RuntimeError(f"Failed after {max_retries} attempts: {error_msg}") + time.sleep(delay)