Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 69 additions & 0 deletions fastembed/text/builtin_sentence_embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
from typing import Any, Iterable, Type


from fastembed.common.types import NumpyArray
from fastembed.common.onnx_model import OnnxOutputContext
from fastembed.text.onnx_embedding import OnnxTextEmbedding, OnnxTextEmbeddingWorker
from fastembed.common.model_description import DenseModelDescription, ModelSource


supported_builtin_sentence_embedding_models: list[DenseModelDescription] = [
DenseModelDescription(
model="google/embeddinggemma-300m",
dim=768,
description=(
"Text embeddings, Unimodal (text), multilingual, 2048 input tokens truncation, "
"Prefixes for queries/documents: `task: search result | query: {content}` for query, "
"`title: {title | 'none'} | text: {content}` for documents, 2025 year."
),
license="apache-2.0",
size_in_GB=1.24,
sources=ModelSource(
hf="onnx-community/embeddinggemma-300m-ONNX",
),
model_file="onnx/model.onnx",
additional_files=["onnx/model.onnx_data"],
),
]


class BuiltinSentenceEmbedding(OnnxTextEmbedding):
"""Builtin Sentence Embedding uses built-in pooling and normalization of underlying onnx models"""

@classmethod
def _get_worker_class(cls) -> Type[OnnxTextEmbeddingWorker]:
return BuiltinSentenceEmbeddingWorker

@classmethod
def _list_supported_models(cls) -> list[DenseModelDescription]:
"""Lists the supported models.

Returns:
list[DenseModelDescription]: A list of DenseModelDescription objects containing the model information.
"""
return supported_builtin_sentence_embedding_models

def _post_process_onnx_output(
self, output: OnnxOutputContext, **kwargs: Any
) -> Iterable[NumpyArray]:
return output.model_output

def _run_model(
self, onnx_input: dict[str, Any], onnx_output_names: list[str] | None = None
) -> NumpyArray:
return self.model.run(onnx_output_names, onnx_input)[1] # type: ignore[union-attr]
Comment on lines +51 to +54
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Hardcoded output index [1] is brittle and lacks validation.

The method assumes the ONNX model always has at least two outputs with sentence_embedding at index 1. If the model's output structure changes or differs across versions, this will raise an IndexError without a clear diagnostic message.

Consider adding validation or a more descriptive error:

🛡️ Proposed fix with validation
 def _run_model(
     self, onnx_input: dict[str, Any], onnx_output_names: list[str] | None = None
 ) -> NumpyArray:
-    return self.model.run(onnx_output_names, onnx_input)[1]  # type: ignore[union-attr]
+    result = self.model.run(onnx_output_names, onnx_input)  # type: ignore[union-attr]
+    if len(result) < 2:
+        raise ValueError(
+            f"Expected at least 2 ONNX outputs (last_hidden_state, sentence_embedding), "
+            f"but got {len(result)}. Ensure the model has built-in pooling."
+        )
+    return result[1]
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@fastembed/text/builtin_sentence_embedding.py` around lines 51 - 54, The
_run_model currently returns self.model.run(...)[1] which assumes a second
output exists and contains sentence embeddings; instead validate the outputs and
pick the correct tensor: if onnx_output_names is provided, use its index or name
to extract the desired output; otherwise inspect the actual output names
returned by self.model (or the model metadata) to locate "sentence_embedding"
(or fall back to the last/only output if that's the intended behavior), and
raise a clear ValueError indicating available outputs when the expected output
name/index is missing; update the _run_model implementation to perform these
checks before accessing an index and include the symbols _run_model, model.run,
and "sentence_embedding" in your logic and error message.



class BuiltinSentenceEmbeddingWorker(OnnxTextEmbeddingWorker):
def init_embedding(
self,
model_name: str,
cache_dir: str,
**kwargs: Any,
) -> OnnxTextEmbedding:
return BuiltinSentenceEmbedding(
model_name=model_name,
cache_dir=cache_dir,
threads=1,
**kwargs,
)
11 changes: 9 additions & 2 deletions fastembed/text/onnx_text_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,14 +92,21 @@ def onnx_embed(
[np.zeros(len(e), dtype=np.int64) for e in input_ids], dtype=np.int64
)
onnx_input = self._preprocess_onnx_input(onnx_input, **kwargs)
model_output = self._run_model(
onnx_input=onnx_input, onnx_output_names=self.ONNX_OUTPUT_NAMES
)

model_output = self.model.run(self.ONNX_OUTPUT_NAMES, onnx_input) # type: ignore[union-attr]
return OnnxOutputContext(
model_output=model_output[0],
model_output=model_output,
attention_mask=onnx_input.get("attention_mask", attention_mask),
input_ids=onnx_input.get("input_ids", input_ids),
)

def _run_model(
self, onnx_input: dict[str, Any], onnx_output_names: list[str] | None = None
) -> NumpyArray:
return self.model.run(onnx_output_names, onnx_input)[0] # type: ignore[union-attr]

def _embed_documents(
self,
model_name: str,
Expand Down
2 changes: 2 additions & 0 deletions fastembed/text/text_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from fastembed.text.pooled_normalized_embedding import PooledNormalizedEmbedding
from fastembed.text.pooled_embedding import PooledEmbedding
from fastembed.text.multitask_embedding import JinaEmbeddingV3
from fastembed.text.builtin_sentence_embedding import BuiltinSentenceEmbedding
from fastembed.text.onnx_embedding import OnnxTextEmbedding
from fastembed.text.text_embedding_base import TextEmbeddingBase
from fastembed.common.model_description import DenseModelDescription, ModelSource, PoolingType
Expand All @@ -20,6 +21,7 @@ class TextEmbedding(TextEmbeddingBase):
PooledNormalizedEmbedding,
PooledEmbedding,
JinaEmbeddingV3,
BuiltinSentenceEmbedding,
CustomTextEmbedding,
]

Expand Down
52 changes: 52 additions & 0 deletions tests/test_text_onnx_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,22 @@
"Qdrant/clip-ViT-B-32-text": np.array([0.0083, 0.0103, -0.0138, 0.0199, -0.0069]),
"thenlper/gte-base": np.array([0.0038, 0.0355, 0.0181, 0.0092, 0.0654]),
"jinaai/jina-clip-v1": np.array([-0.0862, -0.0101, -0.0056, 0.0375, -0.0472]),
"google/embeddinggemma-300m": np.array(
[-0.08181356, 0.0214127, 0.05120273, -0.03690156, -0.0254504]
),
}


DOC_PREFIXES = {
"google/embeddinggemma-300m": "title: none | text: ",
}
QUERY_PREFIXES = {
"google/embeddinggemma-300m": "task: search result | query: ",
}
CANONICAL_QUERY_VECTOR_VALUES = {
"google/embeddinggemma-300m": np.array(
[-0.22990295, 0.03311195, 0.04290345, -0.03558498, -0.01399477]
)
}

MULTI_TASK_MODELS = ["jinaai/jina-embeddings-v3"]
Expand Down Expand Up @@ -119,6 +135,9 @@ def test_embedding(model_cache, model_name: str) -> None:

with model_cache(model_desc.model) as model:
docs = ["hello world", "flag embedding"]
if model_desc.model in DOC_PREFIXES:
docs = [DOC_PREFIXES[model_desc.model] + doc for doc in docs]

embeddings = list(model.embed(docs))
embeddings = np.stack(embeddings, axis=0)
assert embeddings.shape == (2, dim)
Expand All @@ -129,6 +148,39 @@ def test_embedding(model_cache, model_name: str) -> None:
), model_desc.model


def test_query_embedding(model_cache) -> None:
is_ci = os.getenv("CI")
is_mac = platform.system() == "Darwin"
is_manual = os.getenv("GITHUB_EVENT_NAME") == "workflow_dispatch"

for model_desc in TextEmbedding._list_supported_models():
if model_desc.model in MULTI_TASK_MODELS or (
is_mac and model_desc.model == "nomic-ai/nomic-embed-text-v1.5-Q"
):
continue

if model_desc.model not in CANONICAL_QUERY_VECTOR_VALUES:
continue

if not should_test_model(model_desc, "", is_ci, is_manual):
continue

dim = model_desc.dim
with model_cache(model_desc.model) as model:
queries = ["hello world", "flag embedding"]
if model_desc.model in QUERY_PREFIXES:
queries = [QUERY_PREFIXES[model_desc.model] + query for query in queries]

embeddings = list(model.query_embed(queries))
embeddings = np.stack(embeddings, axis=0)
assert embeddings.shape == (2, dim)

canonical_vector = CANONICAL_QUERY_VECTOR_VALUES[model_desc.model]
assert np.allclose(
embeddings[0, : canonical_vector.shape[0]], canonical_vector, atol=1e-3
), model_desc.model


@pytest.mark.parametrize("n_dims,model_name", [(384, "BAAI/bge-small-en-v1.5")])
def test_batch_embedding(model_cache, n_dims: int, model_name: str) -> None:
with model_cache(model_name) as model:
Expand Down
Loading