Skip to content

Commit e14914a

Browse files
committed
refactor: rename builtin pooling normalized embedding to builtin sentence embedding
1 parent 8f04f57 commit e14914a

2 files changed

Lines changed: 11 additions & 10 deletions

File tree

fastembed/text/builtin_pooling_normalized_embedding.py renamed to fastembed/text/builtin_sentence_embedding.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,11 @@
33

44
from fastembed.common.types import NumpyArray
55
from fastembed.common.onnx_model import OnnxOutputContext
6-
from fastembed.common.utils import normalize
76
from fastembed.text.onnx_embedding import OnnxTextEmbedding, OnnxTextEmbeddingWorker
87
from fastembed.common.model_description import DenseModelDescription, ModelSource
98

109

11-
supported_builtin_pooling_normalized_models: list[DenseModelDescription] = [
10+
supported_builtin_sentence_embedding_models: list[DenseModelDescription] = [
1211
DenseModelDescription(
1312
model="google/embeddinggemma-300m",
1413
dim=768,
@@ -28,10 +27,12 @@
2827
]
2928

3029

31-
class BuiltinPoolingNormalizedEmbedding(OnnxTextEmbedding):
30+
class BuiltinSentenceEmbedding(OnnxTextEmbedding):
31+
"""Builtin Sentence Embedding uses built-in pooling and normalization of underlying onnx models"""
32+
3233
@classmethod
3334
def _get_worker_class(cls) -> Type[OnnxTextEmbeddingWorker]:
34-
return BuiltinPoolingNormalizedEmbeddingWorker
35+
return BuiltinSentenceEmbeddingWorker
3536

3637
@classmethod
3738
def _list_supported_models(cls) -> list[DenseModelDescription]:
@@ -40,27 +41,27 @@ def _list_supported_models(cls) -> list[DenseModelDescription]:
4041
Returns:
4142
list[DenseModelDescription]: A list of DenseModelDescription objects containing the model information.
4243
"""
43-
return supported_builtin_pooling_normalized_models
44+
return supported_builtin_sentence_embedding_models
4445

4546
def _post_process_onnx_output(
4647
self, output: OnnxOutputContext, **kwargs: Any
4748
) -> Iterable[NumpyArray]:
48-
return normalize(output.model_output)
49+
return output.model_output
4950

5051
def _run_model(
5152
self, onnx_input: dict[str, Any], onnx_output_names: list[str] | None = None
5253
) -> NumpyArray:
5354
return self.model.run(onnx_output_names, onnx_input)[1] # type: ignore[union-attr]
5455

5556

56-
class BuiltinPoolingNormalizedEmbeddingWorker(OnnxTextEmbeddingWorker):
57+
class BuiltinSentenceEmbeddingWorker(OnnxTextEmbeddingWorker):
5758
def init_embedding(
5859
self,
5960
model_name: str,
6061
cache_dir: str,
6162
**kwargs: Any,
6263
) -> OnnxTextEmbedding:
63-
return BuiltinPoolingNormalizedEmbedding(
64+
return BuiltinSentenceEmbedding(
6465
model_name=model_name,
6566
cache_dir=cache_dir,
6667
threads=1,

fastembed/text/text_embedding.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from fastembed.text.pooled_normalized_embedding import PooledNormalizedEmbedding
99
from fastembed.text.pooled_embedding import PooledEmbedding
1010
from fastembed.text.multitask_embedding import JinaEmbeddingV3
11-
from fastembed.text.builtin_pooling_normalized_embedding import BuiltinPoolingNormalizedEmbedding
11+
from fastembed.text.builtin_sentence_embedding import BuiltinSentenceEmbedding
1212
from fastembed.text.onnx_embedding import OnnxTextEmbedding
1313
from fastembed.text.text_embedding_base import TextEmbeddingBase
1414
from fastembed.common.model_description import DenseModelDescription, ModelSource, PoolingType
@@ -21,7 +21,7 @@ class TextEmbedding(TextEmbeddingBase):
2121
PooledNormalizedEmbedding,
2222
PooledEmbedding,
2323
JinaEmbeddingV3,
24-
BuiltinPoolingNormalizedEmbedding,
24+
BuiltinSentenceEmbedding,
2525
CustomTextEmbedding,
2626
]
2727

0 commit comments

Comments
 (0)