Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 69 additions & 0 deletions fastembed/text/builtin_sentence_embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
from typing import Any, Iterable, Type


from fastembed.common.types import NumpyArray
from fastembed.common.onnx_model import OnnxOutputContext
from fastembed.text.onnx_embedding import OnnxTextEmbedding, OnnxTextEmbeddingWorker
from fastembed.common.model_description import DenseModelDescription, ModelSource


supported_builtin_sentence_embedding_models: list[DenseModelDescription] = [
DenseModelDescription(
model="google/embeddinggemma-300m",
dim=768,
description=(
"Text embeddings, Unimodal (text), multilingual, 2048 input tokens truncation, "
"Prefixes for queries/documents: `task: search result | query: {content}` for query, "
"`title: {title | 'none'} | text: {content}` for documents, 2025 year."
),
license="apache-2.0",
size_in_GB=1.24,
sources=ModelSource(
hf="onnx-community/embeddinggemma-300m-ONNX",
),
model_file="onnx/model.onnx",
additional_files=["onnx/model.onnx_data"],
),
]


class BuiltinSentenceEmbedding(OnnxTextEmbedding):
"""Builtin Sentence Embedding uses built-in pooling and normalization of underlying onnx models"""

@classmethod
def _get_worker_class(cls) -> Type[OnnxTextEmbeddingWorker]:
return BuiltinSentenceEmbeddingWorker

@classmethod
def _list_supported_models(cls) -> list[DenseModelDescription]:
"""Lists the supported models.

Returns:
list[DenseModelDescription]: A list of DenseModelDescription objects containing the model information.
"""
return supported_builtin_sentence_embedding_models

def _post_process_onnx_output(
self, output: OnnxOutputContext, **kwargs: Any
) -> Iterable[NumpyArray]:
return output.model_output

def _run_model(
self, onnx_input: dict[str, Any], onnx_output_names: list[str] | None = None
) -> NumpyArray:
return self.model.run(onnx_output_names, onnx_input)[1] # type: ignore[union-attr]


class BuiltinSentenceEmbeddingWorker(OnnxTextEmbeddingWorker):
def init_embedding(
self,
model_name: str,
cache_dir: str,
**kwargs: Any,
) -> OnnxTextEmbedding:
return BuiltinSentenceEmbedding(
model_name=model_name,
cache_dir=cache_dir,
threads=1,
**kwargs,
)
11 changes: 9 additions & 2 deletions fastembed/text/onnx_text_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,14 +92,21 @@ def onnx_embed(
[np.zeros(len(e), dtype=np.int64) for e in input_ids], dtype=np.int64
)
onnx_input = self._preprocess_onnx_input(onnx_input, **kwargs)
model_output = self._run_model(
onnx_input=onnx_input, onnx_output_names=self.ONNX_OUTPUT_NAMES
)

model_output = self.model.run(self.ONNX_OUTPUT_NAMES, onnx_input) # type: ignore[union-attr]
return OnnxOutputContext(
model_output=model_output[0],
model_output=model_output,
attention_mask=onnx_input.get("attention_mask", attention_mask),
input_ids=onnx_input.get("input_ids", input_ids),
)

def _run_model(
self, onnx_input: dict[str, Any], onnx_output_names: list[str] | None = None
) -> NumpyArray:
return self.model.run(onnx_output_names, onnx_input)[0] # type: ignore[union-attr]

def _embed_documents(
self,
model_name: str,
Expand Down
2 changes: 2 additions & 0 deletions fastembed/text/text_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from fastembed.text.pooled_normalized_embedding import PooledNormalizedEmbedding
from fastembed.text.pooled_embedding import PooledEmbedding
from fastembed.text.multitask_embedding import JinaEmbeddingV3
from fastembed.text.builtin_sentence_embedding import BuiltinSentenceEmbedding
from fastembed.text.onnx_embedding import OnnxTextEmbedding
from fastembed.text.text_embedding_base import TextEmbeddingBase
from fastembed.common.model_description import DenseModelDescription, ModelSource, PoolingType
Expand All @@ -20,6 +21,7 @@ class TextEmbedding(TextEmbeddingBase):
PooledNormalizedEmbedding,
PooledEmbedding,
JinaEmbeddingV3,
BuiltinSentenceEmbedding,
CustomTextEmbedding,
]

Expand Down
52 changes: 52 additions & 0 deletions tests/test_text_onnx_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,22 @@
"Qdrant/clip-ViT-B-32-text": np.array([0.0083, 0.0103, -0.0138, 0.0199, -0.0069]),
"thenlper/gte-base": np.array([0.0038, 0.0355, 0.0181, 0.0092, 0.0654]),
"jinaai/jina-clip-v1": np.array([-0.0862, -0.0101, -0.0056, 0.0375, -0.0472]),
"google/embeddinggemma-300m": np.array(
[-0.08181356, 0.0214127, 0.05120273, -0.03690156, -0.0254504]
),
}


DOC_PREFIXES = {
"google/embeddinggemma-300m": "title: none | text: ",
}
QUERY_PREFIXES = {
"google/embeddinggemma-300m": "task: search result | query: ",
}
CANONICAL_QUERY_VECTOR_VALUES = {
"google/embeddinggemma-300m": np.array(
[-0.22990295, 0.03311195, 0.04290345, -0.03558498, -0.01399477]
)
}

MULTI_TASK_MODELS = ["jinaai/jina-embeddings-v3"]
Expand Down Expand Up @@ -119,6 +135,9 @@ def test_embedding(model_cache, model_name: str) -> None:

with model_cache(model_desc.model) as model:
docs = ["hello world", "flag embedding"]
if model_desc.model in DOC_PREFIXES:
docs = [DOC_PREFIXES[model_desc.model] + doc for doc in docs]

embeddings = list(model.embed(docs))
embeddings = np.stack(embeddings, axis=0)
assert embeddings.shape == (2, dim)
Expand All @@ -129,6 +148,39 @@ def test_embedding(model_cache, model_name: str) -> None:
), model_desc.model


def test_query_embedding(model_cache) -> None:
is_ci = os.getenv("CI")
is_mac = platform.system() == "Darwin"
is_manual = os.getenv("GITHUB_EVENT_NAME") == "workflow_dispatch"

for model_desc in TextEmbedding._list_supported_models():
if model_desc.model in MULTI_TASK_MODELS or (
is_mac and model_desc.model == "nomic-ai/nomic-embed-text-v1.5-Q"
):
continue

if model_desc.model not in CANONICAL_QUERY_VECTOR_VALUES:
continue

if not should_test_model(model_desc, "", is_ci, is_manual):
continue

dim = model_desc.dim
with model_cache(model_desc.model) as model:
queries = ["hello world", "flag embedding"]
if model_desc.model in QUERY_PREFIXES:
queries = [QUERY_PREFIXES[model_desc.model] + query for query in queries]

embeddings = list(model.query_embed(queries))
embeddings = np.stack(embeddings, axis=0)
assert embeddings.shape == (2, dim)

canonical_vector = CANONICAL_QUERY_VECTOR_VALUES[model_desc.model]
assert np.allclose(
embeddings[0, : canonical_vector.shape[0]], canonical_vector, atol=1e-3
), model_desc.model


@pytest.mark.parametrize("n_dims,model_name", [(384, "BAAI/bge-small-en-v1.5")])
def test_batch_embedding(model_cache, n_dims: int, model_name: str) -> None:
with model_cache(model_name) as model:
Expand Down
Loading