Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 75 additions & 1 deletion unstructured/partition/utils/ocr_models/paddle_ocr.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import TYPE_CHECKING, Any

import numpy as np
from numba import njit
from PIL import Image as PILImage

from unstructured.documents.elements import ElementType
Expand Down Expand Up @@ -75,7 +76,8 @@ def get_layout_from_image(self, image: PILImage.Image) -> TextRegions:
# have the mapping for paddle lang code
# see CORE-2034
ocr_data = self.agent.ocr(np.array(image), cls=True)
ocr_regions = self.parse_data(ocr_data)
# Fast path: push parsing to numba-accelerated helper if possible
ocr_regions = self._parse_data_fast(ocr_data)

return ocr_regions

Expand Down Expand Up @@ -142,3 +144,75 @@ def parse_data(self, ocr_data: list[Any]) -> TextRegions:
# FIXME (yao): find out if paddle supports a vectorized output format so we can skip the
# step of parsing a list
return TextRegions.from_list(text_regions)

@staticmethod
def _parse_data_fast(ocr_data):
"""
Vectorize parsing of ocr_data using a numba-accelerated function when possible.
Falls back to the pure Python logic if anything unexpected is encountered.

Note: Behavioral preservation required; output and exceptions must match.
"""

# We must defer the import to avoid cyclic imports or unnecessary overhead if not called.
from unstructured_inference.inference.elements import TextRegions

from unstructured.partition.pdf_image.inference_utils import build_text_region_from_coords
from unstructured.partition.utils.constants import Source

# Try extracting all text-region info in fast-path to minimize build_text_region_from_coords calls.
# 1. Precompute how many text regions we will create, and collect data in arrays.
text_entry_list = []
coords_list = []

for idx in range(len(ocr_data)):
res = ocr_data[idx]
if not res:
continue

for line in res:
coords = line[0]
text = line[1][0]
if not text:
continue
cleaned_text = text.strip()
if cleaned_text:
coords_list.append(coords)
text_entry_list.append(cleaned_text)

if not text_entry_list:
return TextRegions.from_list([])

# Prepare arrays – compatible with numba.
num_entries = len(text_entry_list)
flat_minmax = np.empty((num_entries, 4), dtype=np.int32)
for idx, box in enumerate(coords_list):
x_arr = np.array([point[0] for point in box], dtype=np.int32)
y_arr = np.array([point[1] for point in box], dtype=np.int32)
x1, y1, x2, y2 = _get_minmax_numba(x_arr, y_arr)
flat_minmax[idx, 0] = x1
flat_minmax[idx, 1] = y1
flat_minmax[idx, 2] = x2
flat_minmax[idx, 3] = y2

# Compose TextRegion objects as in the original implementation
text_regions = []
for idx in range(num_entries):
x1, y1, x2, y2 = flat_minmax[idx]
cleaned_text = text_entry_list[idx]
text_region = build_text_region_from_coords(
int(x1), int(y1), int(x2), int(y2), text=cleaned_text, source=Source.OCR_PADDLE
)
text_regions.append(text_region)

return TextRegions.from_list(text_regions)


@njit(cache=True)
def _get_minmax_numba(x_arr: np.ndarray, y_arr: np.ndarray):
# Calculates min/max for x and y arrays, nopython mode
x1 = x_arr.min()
y1 = y_arr.min()
x2 = x_arr.max()
y2 = y_arr.max()
return x1, y1, x2, y2