diff --git a/.dockerignore b/.dockerignore index e550c7d..de1cb9e 100644 --- a/.dockerignore +++ b/.dockerignore @@ -7,3 +7,14 @@ __pycache__/ .idea/ .vscode/ *.log + +hf_cache +.venv +.git +__pycache__ +**/__pycache__ +**/.pytest_cache +**/.mypy_cache +infrastructure/terraform/.terraform +*.tfstate +*.tfstate.* \ No newline at end of file diff --git a/.gitignore b/.gitignore index 625f901..756b158 100644 --- a/.gitignore +++ b/.gitignore @@ -17,3 +17,4 @@ service-account.json get-pip.py .DS_Store hf_cache/ +infrastructure/deploy.ps1 diff --git a/app/api/endpoints/media_endpoints.py b/app/api/endpoints/media_endpoints.py new file mode 100644 index 0000000..cdce18a --- /dev/null +++ b/app/api/endpoints/media_endpoints.py @@ -0,0 +1,16 @@ +from fastapi import APIRouter, Depends, File, UploadFile + +from app.api.dependencies import get_current_user +from app.models.domain.user import User +from app.services.openfake_service import verify_media_with_openfake + +router = APIRouter(prefix="/media") + + +@router.post("/verify") +async def verify_media( + file: UploadFile = File(...), + # Only logged-in users can call this endpoint + current_user: User = Depends(get_current_user), +): + return await verify_media_with_openfake(file) diff --git a/app/api/router.py b/app/api/router.py index 4774254..e5fdf84 100644 --- a/app/api/router.py +++ b/app/api/router.py @@ -1,18 +1,20 @@ from fastapi import APIRouter + from app.api.endpoints import ( - user_endpoints, - claim_endpoints, analysis_endpoints, - source_endpoints, - search_endpoints, - feedback_endpoints, + claim_conversation_endpoints, + claim_endpoints, conversation_endpoints, - message_endpoints, + discussion_endpoints, domain_endpoints, + feedback_endpoints, health_endpoints, - claim_conversation_endpoints, - discussion_endpoints, + media_endpoints, + message_endpoints, post_endpoints, + search_endpoints, + source_endpoints, + user_endpoints, ) router = APIRouter() @@ -30,3 +32,4 @@ router.include_router(post_endpoints.router, tags=["posts"]) router.include_router(claim_conversation_endpoints.router, tags=["claim-conversations"]) router.include_router(health_endpoints.router, tags=["health"]) +router.include_router(media_endpoints.router, tags=["media"]) diff --git a/app/core/config.py b/app/core/config.py index 38a0d7c..b51ac98 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -1,8 +1,9 @@ +import logging import os +from functools import lru_cache from typing import Optional + from pydantic_settings import BaseSettings # type: ignore -from functools import lru_cache -import logging logger = logging.getLogger(__name__) @@ -40,6 +41,10 @@ class Settings(BaseSettings): DEBUG: bool = False + # For OpenFake + OPENFAKE_API_URL: str = "https://complexdatalab-openfakedemo.hf.space/api/predict" + MEDIA_VERIFICATION_ENABLED: bool = True + def __init__(self, **kwargs): super().__init__(**kwargs) if self.DEBUG: diff --git a/app/core/llm/prompts.py b/app/core/llm/prompts.py index 9b62df1..d437b38 100644 --- a/app/core/llm/prompts.py +++ b/app/core/llm/prompts.py @@ -45,6 +45,8 @@ class AnalysisPrompt: "2. Ensure all special characters in the analysis text are properly escaped\n" "3. The analysis field should be a single line with newlines represented as \\n\n" "4. Do not include any control characters\n" + Before responding, check that your answer is valid JSON and matches this exact structure: + {"veracity_score":NUMBER,"analysis":"TEXT"} """ @@ -62,6 +64,8 @@ class AnalysisPrompt: "2. Ensure all special characters in the analysis text are properly escaped\n" "3. The analysis field should be a single line with newlines represented as \\n\n" "4. Do not include any control characters\n" + Before responding, check that your answer is valid JSON and matches this exact structure: + {"veracity_score":NUMBER,"analysis":"TEXT"} """ @@ -82,7 +86,8 @@ class AnalysisPrompt: "2. Assurez-vous que tous les caractères spéciaux dans le texte d'analyse sont correctement retranscrits\n" "3. Le champ "analysis" doit être une seule ligne avec des nouvelles lignes représentées par \\n\n" "4. N'ajoutez aucune entité de caractère\n" - + Avant de répondre, vérifie que ta réponse est un JSON valide et qu’elle respecte exactement cette structure : + {"veracity_score":NUMBER,"analysis":"TEXT"} """ GET_CONFIDENCE = """ diff --git a/app/core/text_safety.py b/app/core/text_safety.py new file mode 100644 index 0000000..0747a48 --- /dev/null +++ b/app/core/text_safety.py @@ -0,0 +1,88 @@ +import json +import logging +import re +from typing import Any + +logger = logging.getLogger(__name__) + + +def clean_unicode_text(value: Any) -> str: + """ + Make text safe for JSON responses and PostgreSQL text fields. + """ + if value is None: + return "" + + text = value if isinstance(value, str) else str(value) + + text = text.replace("\x00", "") + text = text.replace("\x1a", "") + + return text.encode("utf-8", errors="replace").decode("utf-8", errors="replace") + + +def extract_json_candidate(raw_text: str) -> str: + text = clean_unicode_text(raw_text).strip() + + text = re.sub(r"^```(?:json)?\s*", "", text) + text = re.sub(r"\s*```$", "", text) + + start = text.find("{") + end = text.rfind("}") + + if start != -1 and end > start: + return text[start : end + 1] + + return text + + +def parse_analysis_response(raw_text: str) -> tuple[int, str]: + candidate = extract_json_candidate(raw_text) + + try: + data = json.loads(candidate) + score = int(float(data.get("veracity_score", 0))) + analysis = clean_unicode_text(data.get("analysis", "No analysis provided")) + return max(0, min(100, score)), analysis + + except (json.JSONDecodeError, TypeError, ValueError) as e: + logger.warning( + "Malformed analysis JSON from LLM. Falling back to regex extraction. " "Error=%s Raw=%r", + e, + raw_text[:2000], + ) + + score_match = re.search( + r'"veracity_score"\s*:\s*([0-9]+(?:\.\d+)?)', + candidate, + ) + score = int(float(score_match.group(1))) if score_match else 0 + score = max(0, min(100, score)) + + analysis_match = re.search( + r'"analysis"\s*:\s*(.*)$', + candidate, + flags=re.DOTALL, + ) + + if analysis_match: + analysis = analysis_match.group(1).strip() + analysis = analysis.rstrip("}").rstrip(",").strip() + + if analysis.startswith('"'): + analysis = analysis[1:] + + if analysis.endswith('"'): + analysis = analysis[:-1] + + analysis = analysis.replace('\\"', '"') + analysis = analysis.replace("\\n", "\n") + else: + analysis = candidate + + analysis = clean_unicode_text(analysis) + + if not analysis: + analysis = "Analysis was generated, but the response format was malformed." + + return score, analysis diff --git a/app/services/analysis_orchestrator.py b/app/services/analysis_orchestrator.py index 4d4bd8c..2b8c91c 100644 --- a/app/services/analysis_orchestrator.py +++ b/app/services/analysis_orchestrator.py @@ -1,33 +1,33 @@ -import logging -from typing import AsyncGenerator, Dict, Any, List, Optional, NamedTuple -from uuid import UUID, uuid4 -from datetime import UTC, datetime import json +import logging +import math import re from copy import deepcopy -import math +from datetime import UTC, datetime +from typing import Any, AsyncGenerator, Dict, List, NamedTuple, Optional +from uuid import UUID, uuid4 from app.core.exceptions import NotAuthorizedException, NotFoundException, ValidationError from app.core.llm.interfaces import LLMProvider +from app.core.llm.messages import Message as LLMMessage +from app.core.llm.prompts import AnalysisPrompt +from app.core.text_safety import clean_unicode_text, parse_analysis_response from app.models.database.models import AnalysisStatus, ClaimStatus, ConversationStatus, MessageSenderType -from app.models.domain.claim import Claim from app.models.domain.analysis import Analysis, LogProbsData -from app.models.domain.search import Search -from app.models.domain.message import Message -from app.core.llm.messages import Message as LLMMessage -from app.models.domain.conversation import Conversation +from app.models.domain.claim import Claim from app.models.domain.claim_conversation import ClaimConversation +from app.models.domain.conversation import Conversation +from app.models.domain.message import Message +from app.models.domain.search import Search +from app.repositories.implementations.analysis_repository import AnalysisRepository from app.repositories.implementations.claim_conversation_repository import ClaimConversationRepository from app.repositories.implementations.claim_repository import ClaimRepository -from app.repositories.implementations.analysis_repository import AnalysisRepository -from app.repositories.implementations.message_repository import MessageRepository from app.repositories.implementations.conversation_repository import ConversationRepository -from app.repositories.implementations.source_repository import SourceRepository +from app.repositories.implementations.message_repository import MessageRepository from app.repositories.implementations.search_repository import SearchRepository +from app.repositories.implementations.source_repository import SourceRepository from app.services.interfaces.web_search_service import WebSearchServiceInterface -from app.core.llm.prompts import AnalysisPrompt - logger = logging.getLogger(__name__) console_handler = logging.StreamHandler() @@ -218,38 +218,42 @@ async def _generate_analysis( # logger.warning(f"length {len(log_probs)}, {log_probs}") try: + # OLD VERSION KEPT FOR REFERENCE: # Clean the text before parsing # fmt: off - cleaned_text = ( - full_text.strip() - .replace("\r", "") # Remove carriage returns - .replace("\x00", "") # Remove null bytes - .replace("\x1a", "") # Remove SUB characters - .replace("\n", "") - .replace("\\\'", "'") - .replace("\t", "") - ) + # cleaned_text = ( + # full_text.strip() + # .replace("\r", "") # Remove carriage returns + # .replace("\x00", "") # Remove null bytes + # .replace("\x1a", "") # Remove SUB characters + # .replace("\n", "") + # .replace("\\\'", "'") + # .replace("\t", "") + # ) # fmt: on # Try to find the JSON object if there's additional text - try: - start_idx = cleaned_text.find("{") - end_idx = cleaned_text.rindex("}") + 1 - if start_idx != -1 and end_idx != -1: - cleaned_text = cleaned_text[start_idx:end_idx] - except ValueError: - pass + # try: + # start_idx = cleaned_text.find("{") + # end_idx = cleaned_text.rindex("}") + 1 + # if start_idx != -1 and end_idx != -1: + # cleaned_text = cleaned_text[start_idx:end_idx] + # except ValueError: + # pass - response_data = json.loads(cleaned_text) + # response_data = json.loads(cleaned_text) - logger.debug(response_data) + # logger.debug(response_data) - veracity_score = int(response_data.get("veracity_score", 0)) - analysis_content = str(response_data.get("analysis", "No analysis provided")) + # veracity_score = int(response_data.get("veracity_score", 0)) + # analysis_content = str(response_data.get("analysis", "No analysis provided")) - veracity_score = max(0, min(100, veracity_score)) + # veracity_score = max(0, min(100, veracity_score)) + veracity_score, analysis_content = parse_analysis_response(full_text) current_analysis.veracity_score = float(veracity_score) / 100 - current_analysis.analysis_text = analysis_content + # current_analysis.analysis_text = analysis_content + + current_analysis.analysis_text = clean_unicode_text(analysis_content) current_analysis.status = AnalysisStatus.completed.value current_analysis.updated_at = datetime.now(UTC) @@ -281,8 +285,8 @@ async def _generate_analysis( import re try: - veracity_match = re.search(r'"veracity_score"\s*:\s*([0-9]+(?:\.\d+)?)', cleaned_text) - analysis_match = re.search(r'"analysis"\s*:\s*"((?:[^"\\]|\\.)*)"', cleaned_text, re.DOTALL) + veracity_match = re.search(r'"veracity_score"\s*:\s*([0-9]+(?:\.\d+)?)', full_text) + analysis_match = re.search(r'"analysis"\s*:\s*"((?:[^"\\]|\\.)*)"', full_text, re.DOTALL) logger.info("Successfully found regex matches") if veracity_match and analysis_match: veracity_score = int(veracity_match.group(1)) @@ -314,17 +318,38 @@ async def _generate_analysis( yield {"type": "error", "content": f"Error parsing analysis response: {str(e)}"} raise + # except Exception as e: + # logger.error(f"Error processing analysis: {str(e)}") + # current_analysis.status = AnalysisStatus.failed.value + # await self._analysis_repo.update(current_analysis) + # yield {"type": "error", "content": f"Error creating analysis: {str(e)}"} + # raise + except Exception as e: - logger.error(f"Error processing analysis: {str(e)}") + logger.error( + "Error processing analysis response: %s\nFull text: %r", + str(e), + full_text[:2000], + exc_info=True, + ) + current_analysis.status = AnalysisStatus.failed.value + current_analysis.analysis_text = clean_unicode_text(full_text) await self._analysis_repo.update(current_analysis) - yield {"type": "error", "content": f"Error creating analysis: {str(e)}"} - raise + + yield { + "type": "error", + "content": "Sorry, we couldn’t complete the analysis. Please try again.", + } + return except Exception as e: logger.error(f"Error in _generate_analysis: {str(e)}", exc_info=True) - yield {"type": "error", "content": str(e)} - raise + yield { + "type": "error", + "content": "Sorry, we couldn’t complete the analysis. Please try again.", + } + return async def initialize_claim_conversation( self, diff --git a/app/services/openfake_service.py b/app/services/openfake_service.py new file mode 100644 index 0000000..079f45c --- /dev/null +++ b/app/services/openfake_service.py @@ -0,0 +1,74 @@ +import os + +import httpx +from fastapi import HTTPException, UploadFile + +OPENFAKE_API_URL = os.getenv( + "OPENFAKE_API_URL", + "https://complexdatalab-openfakedemo.hf.space/api/predict", +) + +ALLOWED_TYPES = { + "image/jpeg", + "image/png", + "image/webp", + "image/gif", + "video/mp4", + "video/quicktime", + "video/webm", + "video/x-matroska", +} + + +async def verify_media_with_openfake(file: UploadFile) -> dict: + if file.content_type not in ALLOWED_TYPES: + raise HTTPException( + status_code=400, + detail=f"Unsupported file type: {file.content_type}", + ) + + await file.seek(0) + + # This forwards the uploaded file object to OpenFake + files = { + "file": ( + file.filename, + file.file, + file.content_type, + ) + } + + async with httpx.AsyncClient(timeout=120) as client: + response = await client.post( + OPENFAKE_API_URL, + files=files, + ) + + if response.status_code != 200: + raise HTTPException( + status_code=502, + detail="OpenFake detector failed", + ) + + data = response.json() + + p_fake = float(data.get("p_fake", 0)) + reliability = float(data.get("reliability", 1 - p_fake)) + + if p_fake >= 0.75: + verdict = "Likely fake" + elif p_fake >= 0.45: + verdict = "Uncertain" + else: + verdict = "Likely real" + + return { + "media_type": data.get("media_type"), + "p_fake": p_fake, + "reliability": reliability, + "reliability_score": round(reliability * 100), + "verdict": verdict, + "n_frames": data.get("n_frames"), + "frame_probs": data.get("frame_probs"), + "explanation": (f"The detector estimates a {round(p_fake * 100)}% probability " f"that this media is fake."), + } diff --git a/infrastructure/terraform/.terraform.lock.hcl b/infrastructure/terraform/.terraform.lock.hcl index 537dde3..589b02e 100644 --- a/infrastructure/terraform/.terraform.lock.hcl +++ b/infrastructure/terraform/.terraform.lock.hcl @@ -6,6 +6,7 @@ provider "registry.terraform.io/hashicorp/google" { constraints = "~> 4.0" hashes = [ "h1:ZVDZuhYSIWhCkSuDkwFeSIJjn0/DcCxak2W/cHW4OQQ=", + "h1:sld/eTvevl/Af3upWX1TesnLLCCUMBQlczxo5lPzA48=", "zh:17d60a6a6c1741cf1e09ac6731433a30950285eac88236e623ab4cbf23832ca3", "zh:1c70254c016439dbb75cab646b4beace6ceeff117c75d81f2cc27d41c312f752", "zh:35e2aa2cc7ac84ce55e05bb4de7b461b169d3582e56d3262e249ff09d64fe008", @@ -26,6 +27,7 @@ provider "registry.terraform.io/hashicorp/kubernetes" { constraints = "~> 2.0" hashes = [ "h1:3j4XBR5UWQA7xXaiEnzZp0bHbcwOhWetHYKTWIrUTI0=", + "h1:yfV3jmsFAvnByjJEsL4DjjQmGViS+MMcBeZWnDH3mPo=", "zh:0e715d7fb13a8ad569a5fdc937b488590633f6942e986196fdb17cd7b8f7720e", "zh:495fc23acfe508ed981e60af9a3758218b0967993065e10a297fdbc210874974", "zh:4b930a8619910ef528bc90dae739cb4236b9b76ce41367281e3bc3cf586101c7",