Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
166 changes: 135 additions & 31 deletions rag-engine/src/layers/chunking/chunk_document.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import json
from typing import List
import uuid
from src.layers.chunking.models import Chunk
import tiktoken

from src.layers.structure_analyzer.models import Section, StructuredDocument

_encoder = tiktoken.get_encoding("cl100k_base")


Expand All @@ -11,36 +14,45 @@ def count_tokens(text: str) -> int:


def chunk_document(
structured_document,
structured_document: StructuredDocument,
metadata: dict,
max_tokens: int = 400,
min_tokens: int = 80,
) -> List[Chunk]:

chunks: List[Chunk] = []

# handle preamble
# ---- PREAMBLE ----
if structured_document.preamble:
preamble_text = "\n".join(p.text for p in structured_document.preamble)
if not _looks_like_toc(preamble_text):
if preamble_text.strip():
chunks.extend(
_chunk_paragraphs(
paragraphs=structured_document.preamble,
section_title="Preamble",
section_path=["Preamble"],
level=0,
max_tokens=max_tokens,
min_tokens=min_tokens,
metadata=metadata,
)
)

# handle sections
# ---- SECTIONS ----
for section in structured_document.sections:
chunks.extend(
_process_section(
section, parent_path=[], max_tokens=max_tokens, metadata=metadata
section,
parent_path=[],
max_tokens=max_tokens,
min_tokens=min_tokens,
metadata=metadata,
)
)

# ---- FINAL CLEANUP ----
chunks = _deduplicate_chunks(chunks)

return chunks


Expand All @@ -50,6 +62,7 @@ def _chunk_paragraphs(
section_path: List[str],
level: int,
max_tokens: int,
min_tokens: int,
metadata: dict,
) -> List[Chunk]:

Expand Down Expand Up @@ -107,20 +120,21 @@ def _chunk_paragraphs(
)
)

return _merge_small_chunks(chunks, metadata)
return _merge_small_chunks(chunks, metadata, min_tokens, max_tokens)


def _process_section(
section,
section: Section,
parent_path: List[str],
max_tokens: int,
min_tokens: int,
metadata: dict,
) -> List[Chunk]:

path = parent_path + [section.title]
chunks: List[Chunk] = []

# chunk this section's paragraphs
# ---- TEXT CHUNKS ----
if section.paragraphs:
chunks.extend(
_chunk_paragraphs(
Expand All @@ -129,18 +143,52 @@ def _process_section(
section_path=path,
level=section.level,
max_tokens=max_tokens,
min_tokens=min_tokens,
metadata=metadata,
)
)

# ---- TABLE CHUNKS ----
if section.tables:
chunks.extend(
_build_table_chunks_from_section(
section=section,
section_path=path,
metadata=metadata,
)
)

# recursively process children
# ---- CHILD SECTIONS ----
for child in section.children:
chunks.extend(
_process_section(
child, parent_path=path, max_tokens=max_tokens, metadata=metadata
child,
parent_path=path,
max_tokens=max_tokens,
min_tokens=min_tokens,
metadata=metadata,
)
)

if (
not section.paragraphs
and not section.tables
and not section.children
and section.title.strip()
):
if not _is_pure_category_title(section.title):
chunks.append(
_build_chunk(
text=section.title.strip(),
section_title=section.title,
section_path=path,
level=section.level,
page_start=section.page_number,
page_end=section.page_number,
metadata=metadata,
)
)

return chunks


Expand Down Expand Up @@ -168,40 +216,96 @@ def _build_chunk(


def _merge_small_chunks(
chunks: List[Chunk], metadata: dict, min_tokens: int = 80
chunks: List[Chunk],
metadata: dict,
min_tokens: int,
max_tokens: int,
) -> List[Chunk]:

if not chunks:
return chunks
return []

merged = []
buffer = chunks[0]

for chunk in chunks[1:]:
# If buffer too small, try merging
if buffer.token_count < min_tokens:
combined_text = buffer.text + "\n" + chunk.text
buffer = _build_chunk(
combined_text,
buffer.section_title,
buffer.section_path,
buffer.level,
buffer.page_start,
chunk.page_end,
metadata=metadata,
)
else:
merged.append(buffer)
buffer = chunk
combined_tokens = count_tokens(combined_text)

# Only merge if we stay under max_tokens
if combined_tokens <= max_tokens:
buffer = _build_chunk(
combined_text,
buffer.section_title,
buffer.section_path,
buffer.level,
buffer.page_start,
chunk.page_end,
metadata,
)
continue

# Otherwise flush buffer
merged.append(buffer)
buffer = chunk

merged.append(buffer)
return merged


def _looks_like_toc(text: str) -> bool:
lines = text.split("\n")
digit_lines = sum(1 for lin in lines if lin.strip().split()[-1].isdigit())
dotted_lines = sum(1 for lin in lines if "..." in lin or ". ." in lin)
def _deduplicate_chunks(chunks: List[Chunk]) -> List[Chunk]:
seen = set()
unique = []

for chunk in chunks:
normalized = chunk.text.strip()

if normalized in seen:
continue

seen.add(normalized)
unique.append(chunk)

return unique


def _build_table_chunks_from_section(
section: Section,
section_path: List[str],
metadata: dict,
) -> List[Chunk]:

chunks = []

for table in section.tables:
table_metadata = metadata.copy()
table_metadata["_content_type"] = "table"
table_json = json.dumps(table, ensure_ascii=False)

chunks.append(
Chunk(
id=str(uuid.uuid4()),
text=table_json,
token_count=count_tokens(table_json),
section_title=section.title,
section_path=section_path,
level=section.level,
page_start=section.page_number,
page_end=section.page_number,
metadata=table_metadata,
)
)

return chunks


def _is_pure_category_title(title: str) -> bool:
clean = title.strip()

if len(lines) == 0:
return False
# If fully uppercase and short → likely category
if clean.isupper() and len(clean.split()) <= 3:
return True

return (digit_lines / len(lines)) > 0.4 or (dotted_lines / len(lines)) > 0.3
return False
Loading