From 13ea42939b0b59288e80fa1aaacd1e6c13787fe7 Mon Sep 17 00:00:00 2001 From: aliamerj Date: Thu, 19 Feb 2026 22:16:50 +0300 Subject: [PATCH] fix rag engine --- rag-engine/src/common/utils.py | 66 +++- .../chunking_embedding/chunk_document.py | 242 ++++++------ .../layers/chunking_embedding/embedding.py | 36 +- .../src/layers/chunking_embedding/models.py | 6 +- .../layers/data_extractor/extractor/pdf.py | 48 ++- .../src/layers/data_extractor/models.py | 18 +- rag-engine/src/layers/qdrant_store/store.py | 46 ++- .../layers/structure_analyzer/analyzer/pdf.py | 64 +-- .../src/layers/structure_analyzer/models.py | 10 +- rag-engine/src/main.py | 2 + rag-engine/src/query/__init__.py | 0 rag-engine/src/query/controller.py | 35 ++ rag-engine/src/query/model.py | 24 ++ rag-engine/src/query/service.py | 365 ++++++++++++++++++ rag-engine/src/store/controllers/pdf.py | 4 +- rag-engine/src/store/routers.py | 6 +- rag-engine/src/store/services/pdf.py | 5 +- 17 files changed, 748 insertions(+), 229 deletions(-) create mode 100644 rag-engine/src/query/__init__.py create mode 100644 rag-engine/src/query/controller.py create mode 100644 rag-engine/src/query/model.py create mode 100644 rag-engine/src/query/service.py diff --git a/rag-engine/src/common/utils.py b/rag-engine/src/common/utils.py index e5f5130..3228733 100644 --- a/rag-engine/src/common/utils.py +++ b/rag-engine/src/common/utils.py @@ -1,9 +1,20 @@ import os from pathlib import Path as FilePath -from fastembed import TextEmbedding +from fastembed import SparseTextEmbedding, TextEmbedding from fastembed.common.model_description import ModelSource, PoolingType from qdrant_client import QdrantClient, models -from qdrant_client.models import Distance, FieldCondition, MatchValue, VectorParams +from qdrant_client.conversions.common_types import SparseVectorParams +from fastembed.rerank.cross_encoder import TextCrossEncoder +import json +from fastapi import HTTPException, status +from qdrant_client.models import Optional +from qdrant_client.models import ( + Distance, + FieldCondition, + MatchValue, + VectorParams, +) + CACHE_DIR = FilePath("./models_cache") CACHE_DIR.mkdir(exist_ok=True) @@ -18,10 +29,18 @@ dim=VECTOR_SIZE, model_file="onnx/model.onnx", ) -embedding_model = TextEmbedding( +dense_embedding = TextEmbedding( model_name="intfloat/multilingual-e5-small", cache_dir=str(CACHE_DIR), ) +sparse_embedding = SparseTextEmbedding( + model_name="prithivida/Splade_PP_en_v1", + cache_dir=str(CACHE_DIR), +) +reranker = TextCrossEncoder( + model_name="Xenova/ms-marco-MiniLM-L-12-v2", + cache_dir=str(CACHE_DIR), +) qclient = QdrantClient( url=os.getenv("QDRANT_DB_URL"), @@ -31,20 +50,25 @@ if COLLECTION_NAME not in [c.name for c in qclient.get_collections().collections]: qclient.create_collection( collection_name=COLLECTION_NAME, - vectors_config=VectorParams( - size=VECTOR_SIZE, - distance=Distance.COSINE, - ), + vectors_config={ + "text-dense": VectorParams( + size=VECTOR_SIZE, + distance=Distance.COSINE, + ), + }, + sparse_vectors_config={ + "text-sparse": SparseVectorParams(index=models.SparseIndexParams()), + }, ) qclient.create_payload_index( - collection_name=COLLECTION_NAME, - field_name="_file_hash", - field_schema=models.PayloadSchemaType.KEYWORD, + collection_name=COLLECTION_NAME, + field_name="_file_hash", + field_schema=models.PayloadSchemaType.KEYWORD, ) qclient.create_payload_index( - collection_name=COLLECTION_NAME, - field_name="_user_id", - field_schema=models.PayloadSchemaType.KEYWORD, + collection_name=COLLECTION_NAME, + field_name="_user_id", + field_schema=models.PayloadSchemaType.KEYWORD, ) @@ -67,3 +91,19 @@ def document_exists(user_id: str, file_hash: str) -> bool: ) return len(results[0]) > 0 + + + + + +def parse_metadata(metadata_str: Optional[str]) -> dict: + """Parse JSON metadata string, return empty dict if None.""" + if metadata_str: + try: + return json.loads(metadata_str) + except json.JSONDecodeError: + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + detail="Invalid JSON in metadata field", + ) + return {} diff --git a/rag-engine/src/layers/chunking_embedding/chunk_document.py b/rag-engine/src/layers/chunking_embedding/chunk_document.py index ce0c875..58cb63a 100644 --- a/rag-engine/src/layers/chunking_embedding/chunk_document.py +++ b/rag-engine/src/layers/chunking_embedding/chunk_document.py @@ -1,21 +1,28 @@ -import json from typing import List import uuid import tiktoken from src.layers.chunking_embedding.models import Chunk -from src.layers.structure_analyzer.models import Section, StructuredDocument +from src.layers.data_extractor.models import ImagePage, TablePage +from src.layers.structure_analyzer.models import Paragraph, Section, StructuredDocument _encoder = tiktoken.get_encoding("cl100k_base") +_token_cache = {} + def count_tokens(text: str) -> int: - return len(_encoder.encode(text)) + if text in _token_cache: + return _token_cache[text] + + val = len(_encoder.encode(text)) + _token_cache[text] = val + return val def chunk_document( structured_document: StructuredDocument, metadata: dict, - max_tokens: int = 400, + max_tokens: int = 450, min_tokens: int = 80, ) -> List[Chunk]: @@ -32,7 +39,6 @@ def chunk_document( section_path=["Preamble"], level=0, max_tokens=max_tokens, - min_tokens=min_tokens, metadata=metadata, ) ) @@ -50,9 +56,8 @@ def chunk_document( ) # ---- FINAL CLEANUP ---- - chunks = _deduplicate_chunks(chunks) - if len(chunks) == 0: - raise ValueError("No text found in your pdf!, make sure it is not image") + chunks = _merge_small_chunks(chunks, min_tokens, max_tokens) + chunks = _deduplicate_chunks_atttach_index(chunks) return chunks @@ -63,7 +68,6 @@ def _chunk_paragraphs( section_path: List[str], level: int, max_tokens: int, - min_tokens: int, metadata: dict, ) -> List[Chunk]: @@ -121,74 +125,95 @@ def _chunk_paragraphs( ) ) - return _merge_small_chunks(chunks, metadata, min_tokens, max_tokens) + return chunks -def _process_section( - section: Section, - parent_path: List[str], - max_tokens: int, - min_tokens: int, - metadata: dict, -) -> List[Chunk]: +def _process_section(section, parent_path, max_tokens, min_tokens, metadata): path = parent_path + [section.title] - chunks: List[Chunk] = [] + chunks = [] - # ---- TEXT CHUNKS ---- - if section.paragraphs: - chunks.extend( - _chunk_paragraphs( - paragraphs=section.paragraphs, - section_title=section.title, - section_path=path, - level=section.level, - max_tokens=max_tokens, - min_tokens=min_tokens, - metadata=metadata, - ) - ) + paragraph_buffer = [] - # ---- TABLE CHUNKS ---- - if section.tables: - chunks.extend( - _build_table_chunks_from_section( - section=section, - section_path=path, - metadata=metadata, + def flush_paragraph_buffer(): + nonlocal paragraph_buffer, chunks + + if paragraph_buffer: + chunks.extend( + _chunk_paragraphs( + paragraph_buffer, + section.title, + path, + section.level, + max_tokens, + metadata, + ) ) - ) + paragraph_buffer = [] - # ---- CHILD SECTIONS ---- - for child in section.children: - chunks.extend( - _process_section( - child, - parent_path=path, - max_tokens=max_tokens, - min_tokens=min_tokens, - metadata=metadata, + for item in section.content_stream: + if isinstance(item, Paragraph): + paragraph_buffer.append(item) + + elif isinstance(item, TablePage): + flush_paragraph_buffer() + chunks.extend(_build_table_chunk(item, section, path, metadata)) + + elif isinstance(item, Section): + flush_paragraph_buffer() + chunks.extend( + _process_section(item, path, max_tokens, min_tokens, metadata) ) + + elif isinstance(item, ImagePage): + flush_paragraph_buffer() + continue + + return chunks + + +def _build_table_chunk(table, section, section_path, metadata): + + table_metadata = metadata.copy() + table_metadata["_content_type"] = "table" + + headers = table.data[0] + rows = table.data[1:] + + max_rows_per_chunk = 20 + + chunks = [] + + for i in range(0, len(rows), max_rows_per_chunk): + group = rows[i : i + max_rows_per_chunk] + + # Natural language version + text_lines = [] + for row in group: + row_text = ", ".join(f"{headers[j]}: {row[j]}" for j in range(len(headers))) + text_lines.append(row_text) + + text = "Table:\n" + "\n".join(text_lines) + + chunk = Chunk( + id=str(uuid.uuid4()), + text=text, + token_count=count_tokens(text), + section_title=section.title, + section_path=section_path, + level=section.level, + page_start=table.page_number, + page_end=table.page_number, + metadata={ + **table_metadata, + "_table_headers": headers, + "_row_start": i, + "_row_end": i + len(group) - 1, + "_table_json": group, + }, ) - if ( - not section.paragraphs - and not section.tables - and not section.children - and section.title.strip() - ): - if not _is_pure_category_title(section.title): - chunks.append( - _build_chunk( - text=section.title.strip(), - section_title=section.title, - section_path=path, - level=section.level, - page_start=section.page_number, - page_end=section.page_number, - metadata=metadata, - ) - ) + chunks.append(chunk) return chunks @@ -213,13 +238,11 @@ def _build_chunk( page_start=page_start, page_end=page_end, metadata=metadata, - embedding=None, ) def _merge_small_chunks( chunks: List[Chunk], - metadata: dict, min_tokens: int, max_tokens: int, ) -> List[Chunk]: @@ -227,39 +250,43 @@ def _merge_small_chunks( if not chunks: return [] - merged = [] - buffer = chunks[0] + merged: List[Chunk] = [] - for chunk in chunks[1:]: - # If buffer too small, try merging - if buffer.token_count < min_tokens: - combined_text = buffer.text + "\n" + chunk.text + for chunk in chunks: + if not merged: + merged.append(chunk) + continue + + prev = merged[-1] + + # Merge if either side is small + if prev.token_count < min_tokens or chunk.token_count < min_tokens: + combined_text = prev.text + "\n" + chunk.text combined_tokens = count_tokens(combined_text) - # Only merge if we stay under max_tokens if combined_tokens <= max_tokens: - buffer = _build_chunk( + merged[-1] = _build_chunk( combined_text, - buffer.section_title, - buffer.section_path, - buffer.level, - buffer.page_start, + prev.section_title, + # combine section paths + prev.section_path + + [p for p in chunk.section_path if p not in prev.section_path], + prev.level, + prev.page_start, chunk.page_end, - metadata, + prev.metadata, ) continue - # Otherwise flush buffer - merged.append(buffer) - buffer = chunk + merged.append(chunk) - merged.append(buffer) return merged -def _deduplicate_chunks(chunks: List[Chunk]) -> List[Chunk]: +def _deduplicate_chunks_atttach_index(chunks: List[Chunk]) -> List[Chunk]: seen = set() unique = [] + index = 0 for chunk in chunks: normalized = chunk.text.strip() @@ -268,47 +295,8 @@ def _deduplicate_chunks(chunks: List[Chunk]) -> List[Chunk]: continue seen.add(normalized) + chunk.chunk_index = index + index += 1 unique.append(chunk) return unique - - -def _build_table_chunks_from_section( - section: Section, - section_path: List[str], - metadata: dict, -) -> List[Chunk]: - - chunks = [] - - for table in section.tables: - table_metadata = metadata.copy() - table_metadata["_content_type"] = "table" - table_json = json.dumps(table, ensure_ascii=False) - - chunks.append( - Chunk( - id=str(uuid.uuid4()), - text=table_json, - token_count=count_tokens(table_json), - section_title=section.title, - section_path=section_path, - level=section.level, - page_start=section.page_number, - page_end=section.page_number, - metadata=table_metadata, - embedding=None, - ) - ) - - return chunks - - -def _is_pure_category_title(title: str) -> bool: - clean = title.strip() - - # If fully uppercase and short → likely category - if clean.isupper() and len(clean.split()) <= 3: - return True - - return False diff --git a/rag-engine/src/layers/chunking_embedding/embedding.py b/rag-engine/src/layers/chunking_embedding/embedding.py index 0eaa73d..f3b1f9e 100644 --- a/rag-engine/src/layers/chunking_embedding/embedding.py +++ b/rag-engine/src/layers/chunking_embedding/embedding.py @@ -1,13 +1,39 @@ +from concurrent.futures import ThreadPoolExecutor +import os from typing import List +from qdrant_client.models import models from src.layers.chunking_embedding.models import Chunk -from src.common.utils import embedding_model +from src.common.utils import dense_embedding, sparse_embedding +_executor = ThreadPoolExecutor(max_workers=os.cpu_count() or 4) + def embed_chunks(chunks: List[Chunk], batch_size: int = 64) -> List[Chunk]: + for i in range(0, len(chunks), batch_size): + batch = chunks[i : i + batch_size] - texts = [c.text for c in batch] - vectors = list(embedding_model.embed(texts)) - for chunk, vector in zip(batch, vectors): - chunk.embedding = vector.tolist() + texts = [f"passage: {c.text.strip()}" for c in batch] + + def dense_task(): + return list(dense_embedding.embed(texts)) + + def sparse_task(): + return list(sparse_embedding.embed(texts)) + + future_dense = _executor.submit(dense_task) + future_sparse = _executor.submit(sparse_task) + + dense_vectors = future_dense.result() + sparse_vectors = future_sparse.result() + + for chunk, dv, sv in zip(batch, dense_vectors, sparse_vectors): + + chunk.dense_vectors = dv.tolist() + + chunk.sparse_vectors = models.SparseVector( + indices=sv.indices.tolist(), + values=sv.values.tolist(), + ) + return chunks diff --git a/rag-engine/src/layers/chunking_embedding/models.py b/rag-engine/src/layers/chunking_embedding/models.py index a565925..9e70f61 100644 --- a/rag-engine/src/layers/chunking_embedding/models.py +++ b/rag-engine/src/layers/chunking_embedding/models.py @@ -1,6 +1,8 @@ from pydantic import BaseModel from typing import List, Dict, Any +from qdrant_client.models import models + # ------------------------- # Output Model # ------------------------- @@ -17,6 +19,8 @@ class Chunk(BaseModel): page_start: int | None page_end: int | None - embedding: Any + chunk_index: int = 0 + dense_vectors: List[float] | None = None + sparse_vectors: models.SparseVector | None = None metadata: Dict[str, Any] = {} diff --git a/rag-engine/src/layers/data_extractor/extractor/pdf.py b/rag-engine/src/layers/data_extractor/extractor/pdf.py index d417dd6..0b838f4 100644 --- a/rag-engine/src/layers/data_extractor/extractor/pdf.py +++ b/rag-engine/src/layers/data_extractor/extractor/pdf.py @@ -4,7 +4,7 @@ import uuid import pdfplumber -from src.layers.data_extractor.models import ImagePage, Line, Page, Word +from src.layers.data_extractor.models import ImagePage, Line, Page, TablePage, Word # =============================== @@ -28,14 +28,9 @@ def extract_data(pdf_bytes: bytes) -> tuple[list[Page], dict]: metadata["_file_metadata"] = pdf_doc.metadata for page_number, page in enumerate(pdf_doc.pages, start=1): - tables_output = _extract_tables(page) - table_bboxes = [ - _expand_bbox(table.bbox, padding=TABLE_PADDING) - for table in page.find_tables() - ] - + tables_output = _extract_tables(page, page_number) words = _extract_words(page) - words = _filter_table_words(words, table_bboxes) + words = _filter_table_words(words, tables_output) lines_output = _group_words_into_lines(words) @@ -139,6 +134,7 @@ def _group_words_into_lines(words: List[Word]) -> List[Line]: x1 = max(w.x1 for w in cluster) top = min(w.top for w in cluster) + bottom = max(w.bottom for w in cluster) lines_output.append( Line( @@ -149,6 +145,7 @@ def _group_words_into_lines(words: List[Word]) -> List[Line]: is_bold=is_bold, x0=x0, x1=x1, + bottom=bottom, ) ) @@ -158,17 +155,31 @@ def _group_words_into_lines(words: List[Word]) -> List[Line]: return lines_output -def _extract_tables(page): - - tables_output = [] - +def _extract_tables(page, page_number): + tables_output: list[TablePage] = [] tables = page.find_tables() for table in tables: data = table.extract() - if data and any(any(cell for cell in row) for row in data): - tables_output.append(data) + if not data or not any(any(cell for cell in row) for row in data): + continue + + bbox = table.bbox + x0, top, x1, bottom = bbox + + tables_output.append( + TablePage( + id=str(uuid.uuid4()), + bbox=bbox, + data=data, + top=top, + x0=x0, + x1=x1, + bottom=bottom, + page_number=page_number, + ) + ) return tables_output @@ -221,15 +232,10 @@ def _fix_merged_words(text: str) -> str: return re.sub(r"([a-z])([A-Z])", r"\1 \2", text) -def _expand_bbox(bbox, padding=1.0): - x0, top, x1, bottom = bbox - return (x0 - padding, top - padding, x1 + padding, bottom + padding) - - -def _filter_table_words(words: list[Word], table_bboxes: list[tuple]) -> list[Word]: +def _filter_table_words(words: list[Word], tables: list[TablePage]) -> list[Word]: filtered = [] for word in words: - if not any(_is_inside_bbox(word, bbox) for bbox in table_bboxes): + if not any(_is_inside_bbox(word, table.bbox) for table in tables): filtered.append(word) return filtered diff --git a/rag-engine/src/layers/data_extractor/models.py b/rag-engine/src/layers/data_extractor/models.py index 040353b..316022a 100644 --- a/rag-engine/src/layers/data_extractor/models.py +++ b/rag-engine/src/layers/data_extractor/models.py @@ -17,8 +17,9 @@ class Line(BaseModel): top: float avg_size: float is_bold: bool - x0: float # new - x1: float # new + x0: float + x1: float + bottom: float class ImagePage(BaseModel): @@ -31,11 +32,22 @@ class ImagePage(BaseModel): height: float | None +class TablePage(BaseModel): + id: str + bbox: tuple[float, float, float, float] + data: list[list[str | None]] + top: float + x0: float + x1: float + bottom: float + page_number: int + + class Page(BaseModel): page_number: int text: str lines: list[Line] - tables: list[list[list[str | None]]] + tables: list[TablePage] images: list[ImagePage] width: float | None height: float | None diff --git a/rag-engine/src/layers/qdrant_store/store.py b/rag-engine/src/layers/qdrant_store/store.py index e1ee82f..d285c77 100644 --- a/rag-engine/src/layers/qdrant_store/store.py +++ b/rag-engine/src/layers/qdrant_store/store.py @@ -1,33 +1,49 @@ from typing import List -from qdrant_client.conversions.common_types import PointStruct, Points +from qdrant_client.conversions.common_types import PointStruct from src.layers.chunking_embedding.models import Chunk from src.common.utils import VECTOR_SIZE, qclient, COLLECTION_NAME def store_chunks(chunks: List[Chunk], batch_size: int = 64) -> None: + for i in range(0, len(chunks), batch_size): batch = chunks[i : i + batch_size] - points:Points = [] + points: List[PointStruct] = [] + for chunk in batch: - if chunk.embedding is None: + # Correct validation + if chunk.dense_vectors is None or chunk.sparse_vectors is None: continue + payload = { - "text": chunk.text, - "token_count": chunk.token_count, - "section_title": chunk.section_title, - "section_path": chunk.section_path, - "level": chunk.level, - "page_start": chunk.page_start, - "page_end": chunk.page_end, + "_text": chunk.text, + "_chunk_index": chunk.chunk_index, + "_token_count": chunk.token_count, + "_section_title": chunk.section_title, + "_section_path": chunk.section_path, + "_level": chunk.level, + "_page_start": chunk.page_start, + "_page_end": chunk.page_end, **chunk.metadata, } + points.append( PointStruct( id=chunk.id, - vector=chunk.embedding, + vector={ + "text-dense": chunk.dense_vectors, + "text-sparse": chunk.sparse_vectors, + }, payload=payload, - ), + ) + ) + + # Correct validation + assert len(chunk.dense_vectors) == VECTOR_SIZE + + if points: + qclient.upsert( + collection_name=COLLECTION_NAME, + points=points, + wait=False ) - assert isinstance(chunk.embedding, list) - assert len(chunk.embedding) == VECTOR_SIZE - qclient.upsert(collection_name=COLLECTION_NAME, points=points) diff --git a/rag-engine/src/layers/structure_analyzer/analyzer/pdf.py b/rag-engine/src/layers/structure_analyzer/analyzer/pdf.py index 1efc336..6a2db6c 100644 --- a/rag-engine/src/layers/structure_analyzer/analyzer/pdf.py +++ b/rag-engine/src/layers/structure_analyzer/analyzer/pdf.py @@ -2,7 +2,7 @@ import uuid from typing import List -from src.layers.data_extractor.models import Line, Page +from src.layers.data_extractor.models import Line, Page, TablePage from src.layers.structure_analyzer.models import Paragraph, Section, StructuredDocument @@ -26,13 +26,16 @@ def analyze_layout(pages: List[Page]) -> StructuredDocument: # ---- detect columns ---- columns = _cluster_columns(page_lines) + for column_lines in columns: - blocks = _build_blocks(column_lines) - - for block in blocks: - if _is_garbage_block(block): + text_blocks = _build_blocks(column_lines) + layout_stream = _merge_layout_blocks(text_blocks, page.tables) + for kind, item in layout_stream: + if kind == "table": + if stack: + stack[-1].content_stream.append(item.table) continue - + block = item heading_level, confidence = _detect_heading(block, font_tiers) # ------------------------------- @@ -45,6 +48,7 @@ def analyze_layout(pages: List[Page]) -> StructuredDocument: level=heading_level, page_number=page.page_number, confidence=confidence, + content_stream=[] ) while stack and stack[-1].level >= heading_level: @@ -52,6 +56,7 @@ def analyze_layout(pages: List[Page]) -> StructuredDocument: if stack: stack[-1].children.append(section) + stack[-1].content_stream.append(section) else: document.sections.append(section) @@ -67,15 +72,10 @@ def analyze_layout(pages: List[Page]) -> StructuredDocument: ) if stack: - stack[-1].paragraphs.append(paragraph) + stack[-1].content_stream.append(paragraph) else: document.preamble.append(paragraph) - # ---- attach assets ---- - if stack: - stack[-1].tables.extend(page.tables) - stack[-1].images.extend(page.images) - return document @@ -199,24 +199,6 @@ def _detect_heading(block, font_tiers): return 0, round(score, 3) -def _is_garbage_block(block): - - text = block.text.strip() - - if not text: - return True - - # pure symbols - if len(text) <= 2 and not text.isalpha(): - return True - - # extremely tiny font - if block.avg_size < 5: - return True - - return False - - def _clean_title(text: str) -> str: return re.sub(r"^\d+(\.\d+)*\s*", "", text).strip() @@ -290,3 +272,25 @@ def _looks_like_toc_block(block) -> bool: return True return False + + +class TableBlock: + def __init__(self, table: TablePage): + self.table = table + self.top = table.top + self.x0 = table.x0 + +def _merge_layout_blocks( + text_blocks: List[Block], + tables: List[TablePage] +): + + layout_items = [] + + for b in text_blocks: + layout_items.append(("text", b)) + + for t in tables: + layout_items.append(("table", TableBlock(t))) + + return sorted(layout_items, key=lambda x: (round(x[1].top, 1), x[1].x0)) diff --git a/rag-engine/src/layers/structure_analyzer/models.py b/rag-engine/src/layers/structure_analyzer/models.py index df001ac..83bd0c8 100644 --- a/rag-engine/src/layers/structure_analyzer/models.py +++ b/rag-engine/src/layers/structure_analyzer/models.py @@ -1,7 +1,7 @@ +from __future__ import annotations from pydantic import BaseModel, Field from typing import List - -from src.layers.data_extractor.models import ImagePage +from src.layers.data_extractor.models import ImagePage, TablePage class Paragraph(BaseModel): @@ -14,12 +14,8 @@ class Section(BaseModel): title: str level: int page_number: int - - paragraphs: List[Paragraph] = Field(default_factory=list) + content_stream: List[Paragraph | TablePage | ImagePage | Section] children: List["Section"] = Field(default_factory=list) - - tables: list[list[list[str | None]]] = Field(default_factory=list) - images: List[ImagePage] = Field(default_factory=list) confidence: float diff --git a/rag-engine/src/main.py b/rag-engine/src/main.py index 7bf0a21..f91a71a 100644 --- a/rag-engine/src/main.py +++ b/rag-engine/src/main.py @@ -1,6 +1,7 @@ from dotenv import load_dotenv from fastapi import FastAPI from src.store.routers import store_upload_router, store_url_router +from src.query.controller import query_router from .logging import configure_logging, LogLevels from pathlib import Path @@ -12,3 +13,4 @@ app = FastAPI() app.include_router(store_upload_router) app.include_router(store_url_router) +app.include_router(query_router) diff --git a/rag-engine/src/query/__init__.py b/rag-engine/src/query/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/rag-engine/src/query/controller.py b/rag-engine/src/query/controller.py new file mode 100644 index 0000000..32113e6 --- /dev/null +++ b/rag-engine/src/query/controller.py @@ -0,0 +1,35 @@ +from typing import List +from fastapi import APIRouter, Form, HTTPException, status +from src.query.model import QueryResponse +from qdrant_client.models import Optional + +import logging +from src.common.utils import parse_metadata +from src.query.service import query + +query_router = APIRouter(tags=["Query"]) + + +@query_router.post( + "/query", + summary="query chunk", + response_model=List[QueryResponse], + status_code=status.HTTP_200_OK, +) +def chunk_query( + queries: list[str] = Form(..., description="query list"), + metadata: Optional[str] = Form(..., description="Metadata for chunks (JSON)"), + top_result: int | None = Form(None, description="top results"), +): + meta = parse_metadata(metadata) + logging.info(f"parse Metadata: {len(meta)}") + user_id = meta.get("_user_id") + if user_id is None: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Missing '_user_id' in metadata", + ) + final_top_k = top_result if top_result else 10 + chunks = query(queries, meta, final_top_k) + logging.info(f"Query sucscesfully :{len(chunks)}") + return chunks diff --git a/rag-engine/src/query/model.py b/rag-engine/src/query/model.py new file mode 100644 index 0000000..aeebfd1 --- /dev/null +++ b/rag-engine/src/query/model.py @@ -0,0 +1,24 @@ +from pydantic import BaseModel +from qdrant_client.models import Any, Dict + + +class Hit(BaseModel): + id: str + score: float + payload: Dict[str, Any] + + +class Reference(BaseModel): + file: str + section: str + pages: list[int] + + +class QueryResponse(BaseModel): + text: str + score: float + token_count: int + content_type: str + metadata: Dict + table_json: list[list[str | None]] + reference: Reference diff --git a/rag-engine/src/query/service.py b/rag-engine/src/query/service.py new file mode 100644 index 0000000..49106d0 --- /dev/null +++ b/rag-engine/src/query/service.py @@ -0,0 +1,365 @@ +import os +import numpy as np +import math +from typing import List, Dict, Optional, Tuple +from collections import defaultdict +from concurrent.futures import ThreadPoolExecutor +from functools import lru_cache +import logging + +from src.query.model import QueryResponse, Reference +from qdrant_client.http import models +from qdrant_client.models import ( + Filter, + FieldCondition, + MatchValue, + Condition, +) + +from src.common.utils import ( + COLLECTION_NAME, + dense_embedding, + sparse_embedding, + reranker, + qclient, +) +from src.query.model import Hit + + +_executor = ThreadPoolExecutor(max_workers=os.cpu_count() or 4) + +_CANDIDATE_POOL = 50 +_RERANK_BATCH_SIZE = 64 + +_MAX_CONTEXT_TOKENS = 1400 + +_MIN_RERANK_SCORE = 0.15 +_RELATIVE_RERANK_DROP = 0.30 +_NEIGHBOR_WINDOW = 2 + + +def query( + queries: List[str], + metadata: Optional[Dict], + final_top_k: int, +) -> List[QueryResponse]: + + expanded_queries = _expand_queries(queries) + logging.info(f"expanded queries and clean up to {len(expanded_queries)}") + + meta_filter = _build_metadata_filter(metadata) + logging.info("Build filtering..") + + query_embeddings = _embed_queries(expanded_queries) + logging.info(f"query embeded : {len(query_embeddings)}") + + hits = _hybrid_query(query_embeddings, meta_filter, _CANDIDATE_POOL) + logging.info(f"hybrid query first hits : {len(hits)}") + + hits = _normalize_scores(hits) + logging.info(f"hits after normalized scores : {len(hits)}") + + hits = _rerank(expanded_queries[0], hits, math.floor(final_top_k / 2) + 1) + logging.info( + f"hits after reranking : {len(hits)} with final_top: {math.floor(final_top_k / 2) + 1}" + ) + + neighbors = _expand_neighbors(hits) + logging.info(f"getting neighbors of the hits : {len(neighbors)}") + + hits.extend(neighbors) + logging.info(f"add all neighbors to hits , Total Hits: {len(neighbors)}") + + hits = _rerank(expanded_queries[0], hits, final_top_k) + logging.info( + f"make second reranking for hits : {len(hits)} with final_top: {final_top_k}" + ) + + return _pack_context(hits) + + +def _expand_queries(queries: List[str]) -> List[str]: + + expanded = set() + + for q in queries: + q = q.strip() + expanded.add(q) + expanded.add(q.replace("-", " ")) + expanded.add(q.replace("_", " ")) + expanded.add(" ".join(q.split())) + + return list(expanded) + + +@lru_cache(maxsize=512) +def _embed_single_query(query: str): + + prefixed = f"query: {query}" + + def dense_task(): + return list(dense_embedding.embed([prefixed]))[0] + + def sparse_task(): + return list(sparse_embedding.embed([prefixed]))[0] + + fd = _executor.submit(dense_task) + fs = _executor.submit(sparse_task) + + dense = fd.result() + sparse = fs.result() + + return { + "dense": dense.tolist(), + "sparse": models.SparseVector( + indices=sparse.indices.tolist(), + values=sparse.values.tolist(), + ), + } + + +def _embed_queries(queries: List[str]): + return [_embed_single_query(q) for q in queries] + + +def _build_metadata_filter(meta: Optional[Dict]) -> Optional[Filter]: + + if not meta: + return None + + cond: List[Condition] = [] + + for k, v in meta.items(): + if v is None: + continue + + cond.append( + FieldCondition( + key=k, + match=MatchValue(value=v), + ) + ) + + return Filter(must=cond) if cond else None + + +def _hybrid_query( + embeddings: List[Dict], meta_filter: Optional[Filter], limit: int +) -> List[Hit]: + + results: List[Hit] = [] + + for emb in embeddings: + res = qclient.query_points( + collection_name=COLLECTION_NAME, + prefetch=[ + models.Prefetch( + query=emb["dense"], + using="text-dense", + filter=meta_filter, + limit=limit, + ), + models.Prefetch( + query=emb["sparse"], + using="text-sparse", + filter=meta_filter, + limit=limit, + ), + ], + query=models.FusionQuery(fusion=models.Fusion.RRF), + limit=limit, + ) + + for p in res.points: + results.append( + Hit( + id=str(p.id), + score=float(p.score), + payload=p.payload or {}, + ) + ) + + return results + + +def _normalize_scores(hits: List[Hit]) -> List[Hit]: + + if not hits: + return hits + + scores = np.array([h.score for h in hits]) + mn, mx = scores.min(), scores.max() + + if mx == mn: + return hits + + normalized = (scores - mn) / (mx - mn) + for h, norm_score in zip(hits, normalized): + h.score = float(norm_score) + + return hits + + +def _rerank(query: str, hits: List[Hit], top_k: int) -> List[Hit]: + + texts = [h.payload.get("_text", "") for h in hits] + + score_list: List[float] = [] + + for i in range(0, len(texts), _RERANK_BATCH_SIZE): + batch = texts[i : i + _RERANK_BATCH_SIZE] + score_list.extend(list(reranker.rerank(query, batch))) + + if not score_list: + return [] + + scores_np = np.array(score_list, dtype=float) + + mn, mx = scores_np.min(), scores_np.max() + if mx > mn: + scores_np = (scores_np - mn) / (mx - mn) + + scored = sorted(zip(hits, scores_np.tolist()), key=lambda x: x[1], reverse=True) + + return _adaptive_rerank_cutoff(scored, top_k) + + +def _adaptive_rerank_cutoff(scored: List[Tuple[Hit, float]], top_k: int) -> List[Hit]: + + if not scored: + return [] + + best = max(scored[0][1], 1e-6) + + selected: List[Hit] = [] + + for hit, score in scored: + if score < _MIN_RERANK_SCORE: + continue + + if score / best < _RELATIVE_RERANK_DROP: + break + + hit.score = float(score) + selected.append(hit) + + if len(selected) >= top_k: + break + + if not selected: + selected.append(scored[0][0]) + + return selected + + +# ============================================================ +# CONTEXT PACKING +# ============================================================ +def _pack_context(hits: List[Hit]) -> List[QueryResponse]: + + if not hits: + return [] + + # --- Preserve document reading order --- + hits = sorted( + hits, + key=lambda h: ( + h.payload.get("_file_hash"), + h.payload.get("_chunk_index", 0), + ), + ) + + chunks: List[QueryResponse] = [] + file_counts = defaultdict(int) + + tokens_used = 0 + + for hit in hits: + payload = hit.payload + file_hash = payload.get("_file_hash") + + text = payload.get("_text", "") + token_count = payload.get("_token_count", len(text.split())) + + if tokens_used + token_count > _MAX_CONTEXT_TOKENS: + break + + file_counts[file_hash] += 1 + tokens_used += token_count + chunks.append( + QueryResponse( + text=text, + score=float(hit.score), + token_count=token_count, + content_type=payload.get("_content_type", "text"), + metadata=payload.get("_file_metadata", {}), + table_json=payload.get("_table_json", []), + reference=Reference( + file=payload.get("_source_file", ""), + section=payload.get("_section_title", ""), + pages=[ + payload.get("_page_start", 0), + payload.get("_page_end", 0), + ], + ), + ) + ) + + # Return highest score first + return sorted(chunks, key=lambda x: x.score, reverse=True) + + +def _expand_neighbors(hits: List[Hit]) -> List[Hit]: + + expanded: List[Hit] = [] + seen_ids = {h.id for h in hits} + + for hit in hits: + fh = hit.payload.get("_file_hash") + idx = hit.payload.get("_chunk_index") + + if fh is None or idx is None: + continue + + neighbors = _fetch_neighbors(fh, idx) + + for n in neighbors: + nid = str(n.id) + + if nid in seen_ids: + continue + + expanded.append( + Hit( + id=nid, + score=hit.score * 0.85, + payload=n.payload or {}, + ) + ) + + return expanded + + +def _fetch_neighbors(file_hash, idx): + + result, _ = qclient.scroll( + collection_name=COLLECTION_NAME, + scroll_filter=models.Filter( + must=[ + models.FieldCondition( + key="_file_hash", + match=models.MatchValue(value=file_hash), + ), + models.FieldCondition( + key="_chunk_index", + range=models.Range( + gte=idx - _NEIGHBOR_WINDOW, + lte=idx + _NEIGHBOR_WINDOW, + ), + ), + ] + ), + limit=_NEIGHBOR_WINDOW * 2 + 1, + ) + + return result diff --git a/rag-engine/src/store/controllers/pdf.py b/rag-engine/src/store/controllers/pdf.py index 52863d8..356834d 100644 --- a/rag-engine/src/store/controllers/pdf.py +++ b/rag-engine/src/store/controllers/pdf.py @@ -4,7 +4,7 @@ from qdrant_client.models import Optional import requests from src.common.utils import document_exists -from src.store.controllers.utils import parse_metadata +from src.common.utils import parse_metadata from src.store.services import pdf from urllib.parse import urlparse @@ -33,7 +33,7 @@ async def upload( -async def with_url( +def with_url( url: str = Form(..., description="Link to fetch"), metadata: Optional[str] = Form(..., description="Metadata for chunks (JSON)"), ): diff --git a/rag-engine/src/store/routers.py b/rag-engine/src/store/routers.py index a01638c..d594cac 100644 --- a/rag-engine/src/store/routers.py +++ b/rag-engine/src/store/routers.py @@ -17,7 +17,7 @@ async def store_pdf_upload( upload: UploadFile = File(..., description="The file to upload"), metadata: Optional[str] = Form(..., description="Metadata for chunks (JSON)"), ): - return pdf.upload(upload, metadata) + return await pdf.upload(upload, metadata) @store_url_router.post( @@ -26,8 +26,8 @@ async def store_pdf_upload( response_model=StoreResponse, status_code=status.HTTP_200_OK, ) -async def store_pdf_with_url( - url: str = Form(..., description="Link to fetch"), +def store_pdf_with_url( + url: str = Form(..., description="Link to fetch"), metadata: Optional[str] = Form(..., description="Metadata for chunks (JSON)"), ): return pdf.with_url(url, metadata) diff --git a/rag-engine/src/store/services/pdf.py b/rag-engine/src/store/services/pdf.py index d4fba99..0cf18b9 100644 --- a/rag-engine/src/store/services/pdf.py +++ b/rag-engine/src/store/services/pdf.py @@ -19,11 +19,12 @@ def _handleFile(file_bytes: bytes, metadata: dict): chunks = chunk_document( structured_document, metadata | extractor_meta, - max_tokens=400, + max_tokens=450, + min_tokens=80, ) logging.info(f"chunked pdf to : {len(chunks)} chunks") chunks = embed_chunks(chunks) - logging.info(f"embedding chunks: {len(chunks[0].embedding)}") + logging.info("embedding chunks") store_chunks(chunks) logging.info("stored chunked") return makeResponse(metadata | extractor_meta, chunks)