Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions fastembed/common/model_description.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,4 +49,5 @@ class SparseModelDescription(BaseModelDescription):
class PoolingType(str, Enum):
CLS = "CLS"
MEAN = "MEAN"
LAST_TOKEN = "LAST_TOKEN"
DISABLED = "DISABLED"
9 changes: 8 additions & 1 deletion fastembed/common/model_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,14 +409,21 @@ def download_model(cls, model: T, cache_dir: str, retries: int = 3, **kwargs: An
try:
cache_kwargs = deepcopy(kwargs)
cache_kwargs["local_files_only"] = True
return Path(
cached_dir = Path(
cls.download_files_from_huggingface(
hf_source,
cache_dir=cache_dir,
extra_patterns=extra_patterns,
**cache_kwargs,
)
)
# Verify all required files exist in cache before returning
missing = [p for p in extra_patterns if not (cached_dir / p).exists()]
if missing:
raise FileNotFoundError(
f"Cached snapshot missing files: {missing}"
)
return cached_dir
except Exception:
pass
finally:
Expand Down
19 changes: 15 additions & 4 deletions fastembed/common/preprocessor_utils.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@
import json
import logging
from typing import Any
from pathlib import Path

from tokenizers import AddedToken, Tokenizer

from fastembed.image.transform.operators import Compose

logger = logging.getLogger(__name__)


def load_special_tokens(model_dir: Path) -> dict[str, Any]:
tokens_map_path = model_dir / "special_tokens_map.json"
if not tokens_map_path.exists():
raise ValueError(f"Could not find special_tokens_map.json in {model_dir}")
return {}

with open(str(tokens_map_path)) as tokens_map_file:
tokens_map = json.load(tokens_map_file)
Expand Down Expand Up @@ -51,9 +54,17 @@ def load_tokenizer(model_dir: Path) -> tuple[Tokenizer, dict[str, int]]:
tokenizer = Tokenizer.from_file(str(tokenizer_path))
tokenizer.enable_truncation(max_length=max_context)
if not tokenizer.padding:
tokenizer.enable_padding(
pad_id=config.get("pad_token_id", 0), pad_token=tokenizer_config["pad_token"]
)
pad_token_id = config.get("pad_token_id")
if pad_token_id is None:
logger.warning(
"pad_token_id not found in config.json for %s, defaulting to 0",
model_dir.name,
)
pad_token_id = 0
pad_token = tokenizer_config.get("pad_token", "")
if isinstance(pad_token, dict):
pad_token = pad_token.get("content", "")
tokenizer.enable_padding(pad_id=pad_token_id, pad_token=pad_token)

for token in tokens_map.values():
if isinstance(token, str):
Expand Down
22 changes: 22 additions & 0 deletions fastembed/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,28 @@
T = TypeVar("T")


def last_token_pool(input_array: NumpyArray, attention_mask: NDArray[np.int64]) -> NumpyArray:
"""Extract embedding from the last non-padding token position.

Qwen3-Embedding uses last-token pooling (NOT CLS/mean pooling).
Handles both left-padding and right-padding.

Args:
input_array: Model output, shape (batch_size, seq_len, hidden_dim).
attention_mask: Attention mask, shape (batch_size, seq_len).

Returns:
Pooled embeddings, shape (batch_size, hidden_dim).
"""
left_padding = bool(attention_mask[:, -1].sum() == attention_mask.shape[0])
if left_padding:
return input_array[:, -1]

sequence_lengths = attention_mask.sum(axis=1).astype(np.int64) - 1
batch_size = input_array.shape[0]
return input_array[np.arange(batch_size), sequence_lengths]


def normalize(input_array: NumpyArray, p: int = 2, dim: int = 1, eps: float = 1e-12) -> NumpyArray:
# Calculate the Lp norm along the specified dimension
norm = np.linalg.norm(input_array, ord=p, axis=dim, keepdims=True)
Expand Down
225 changes: 225 additions & 0 deletions fastembed/rerank/cross_encoder/qwen3_cross_encoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,225 @@
"""Qwen3 reranker using causal LM with yes/no logit scoring.

Unlike traditional cross-encoder rerankers (which concatenate query+document
as a pair, feed through a BERT-class model, and read a relevance head), the
Qwen3 reranker:

1. Formats input as a **chat template** with system/user/assistant turns.
2. Runs a **causal language model** (Qwen3ForCausalLM).
3. Extracts the **last-token logits** for the "yes" and "no" tokens.
4. Applies **softmax** to obtain the relevance probability.

This means the ONNX model output has shape ``(batch, seq_len, vocab_size)``
instead of the typical ``(batch, num_labels)`` from cross-encoders.
"""

from typing import Any

import numpy as np

from fastembed.common.model_description import BaseModelDescription, ModelSource
from fastembed.common.onnx_model import OnnxOutputContext
from fastembed.common.types import NumpyArray
from fastembed.rerank.cross_encoder.onnx_text_cross_encoder import (
OnnxTextCrossEncoder,
TextCrossEncoderWorker,
)
from fastembed.rerank.cross_encoder.onnx_text_model import TextRerankerWorker

# ---------------------------------------------------------------------------
# Qwen3 reranker constants
# ---------------------------------------------------------------------------
# Token IDs in the Qwen3 tokenizer vocabulary
TOKEN_YES_ID = 9693
TOKEN_NO_ID = 2132

SYSTEM_PROMPT = (
"Judge whether the Document meets the requirements based on the Query "
'and the Instruct provided. Note that the answer can only be "yes" or "no".'
)

DEFAULT_INSTRUCTION = (
"Given a query and a document, judge whether the document is relevant to the query."
)

RERANK_TEMPLATE = (
"<|im_start|>system\n{system}<|im_end|>\n"
"<|im_start|>user\n<Instruct>: {instruction}\n"
"<Query>: {query}\n<Document>: {document}<|im_end|>\n"
"<|im_start|>assistant\n<think>\n\n</think>\n\n"
)

# ---------------------------------------------------------------------------
# Model registry
# ---------------------------------------------------------------------------
supported_qwen3_reranker_models: list[BaseModelDescription] = [
BaseModelDescription(
model="Qwen/Qwen3-Reranker-0.6B",
description=(
"Qwen3 reranker (0.6B) using causal LM yes/no scoring. "
"INT8 dynamic quantized. Multilingual, 40960 input tokens, "
"instruction-aware, 2025 year."
),
license="apache-2.0",
size_in_GB=0.57,
sources=ModelSource(hf="n24q02m/Qwen3-Reranker-0.6B-ONNX"),
model_file="onnx/model_quantized.onnx",
),
BaseModelDescription(
model="Qwen/Qwen3-Reranker-0.6B-Q4F16",
description=(
"Qwen3 reranker (0.6B) using causal LM yes/no scoring. "
"INT4 weights + FP16 activations (Q4F16). Multilingual, "
"40960 input tokens, instruction-aware, 2025 year."
),
license="apache-2.0",
size_in_GB=0.57,
sources=ModelSource(hf="n24q02m/Qwen3-Reranker-0.6B-ONNX"),
model_file="onnx/model_q4f16.onnx",
),
]


# ---------------------------------------------------------------------------
# Qwen3 reranker implementation
# ---------------------------------------------------------------------------
class Qwen3CrossEncoder(OnnxTextCrossEncoder):
"""Qwen3 Reranker using causal LM with yes/no logit scoring.

Usage::

from fastembed import TextCrossEncoder

reranker = TextCrossEncoder("Qwen/Qwen3-Reranker-0.6B")
scores = list(reranker.rerank("What is AI?", ["doc1", "doc2"]))

# Custom instruction
scores = list(reranker.rerank(
"What is AI?",
["doc1", "doc2"],
instruction="Judge document relevance for code search.",
))
"""

@classmethod
def _list_supported_models(cls) -> list[BaseModelDescription]:
"""Return the list of supported Qwen3 reranker models."""
return supported_qwen3_reranker_models

# ------------------------------------------------------------------
# Chat template formatting
# ------------------------------------------------------------------
@staticmethod
def _format_rerank_input(
query: str,
document: str,
instruction: str = DEFAULT_INSTRUCTION,
) -> str:
"""Build the chat-template string for a single query-document pair."""
return RERANK_TEMPLATE.format(
system=SYSTEM_PROMPT,
instruction=instruction,
query=query,
document=document,
)

# ------------------------------------------------------------------
# Yes/No logit scoring
# ------------------------------------------------------------------
@staticmethod
def _compute_yes_no_scores(model_output: NumpyArray) -> NumpyArray:
"""Extract yes/no logits from causal LM output and compute scores.

Args:
model_output: Raw model output, shape ``(batch, seq_len, vocab_size)``.

Returns:
Relevance scores (P(yes)), shape ``(batch,)``.
"""
# Last token logits for each sample
last_logits: NumpyArray = model_output[:, -1, :] # (batch, vocab_size)

# Stack [no, yes] logits
yes_no_logits = np.stack(
[last_logits[:, TOKEN_NO_ID], last_logits[:, TOKEN_YES_ID]], axis=1
) # (batch, 2)

# Numerically stable softmax
max_logits = np.max(yes_no_logits, axis=1, keepdims=True)
exp_logits = np.exp(yes_no_logits - max_logits)
probs = exp_logits / np.sum(exp_logits, axis=1, keepdims=True)

return probs[:, 1] # P(yes)

# ------------------------------------------------------------------
# Override ONNX inference to use chat-template + CausalLM scoring
# ------------------------------------------------------------------
def onnx_embed(self, query: str, documents: list[str], **kwargs: Any) -> OnnxOutputContext:
"""Score query-document pairs using the Qwen3 chat template."""
instruction = kwargs.pop("instruction", DEFAULT_INSTRUCTION)
texts = [self._format_rerank_input(query, doc, instruction) for doc in documents]
return self._onnx_embed_texts(texts, **kwargs)

def onnx_embed_pairs(self, pairs: list[tuple[str, str]], **kwargs: Any) -> OnnxOutputContext:
"""Score pre-formed (query, document) pairs."""
instruction = kwargs.pop("instruction", DEFAULT_INSTRUCTION)
texts = [self._format_rerank_input(query, doc, instruction) for query, doc in pairs]
return self._onnx_embed_texts(texts, **kwargs)

def _onnx_embed_texts(self, texts: list[str], **kwargs: Any) -> OnnxOutputContext:
"""Tokenise and run model one text at a time (static batch=1 ONNX graph),
then concatenate the yes/no scores."""
assert self.tokenizer is not None, "Tokenizer not loaded. Call load_onnx_model() first."

input_names: set[str] = {node.name for node in self.model.get_inputs()} # type: ignore[union-attr]
all_scores: list[NumpyArray] = []
for text in texts:
tokenized = self.tokenizer.encode_batch([text])
onnx_input: dict[str, NumpyArray] = {
"input_ids": np.array([tokenized[0].ids], dtype=np.int64),
}
if "attention_mask" in input_names:
onnx_input["attention_mask"] = np.array(
[tokenized[0].attention_mask], dtype=np.int64
)
if "token_type_ids" in input_names:
onnx_input["token_type_ids"] = np.zeros_like(
onnx_input["input_ids"], dtype=np.int64
)

onnx_input = self._preprocess_onnx_input(onnx_input, **kwargs)
outputs = self.model.run(self.ONNX_OUTPUT_NAMES, onnx_input) # type: ignore[union-attr]
model_output = outputs[0]
if model_output.dtype == np.float16:
model_output = model_output.astype(np.float32)
scores = self._compute_yes_no_scores(model_output)
all_scores.append(scores)

concatenated = np.concatenate(all_scores).astype(np.float32)
return OnnxOutputContext(model_output=concatenated)

# ------------------------------------------------------------------
# Worker
# ------------------------------------------------------------------
@classmethod
def _get_worker_class(cls) -> type[TextRerankerWorker]:
"""Return the worker class for parallel processing."""
return Qwen3CrossEncoderWorker


class Qwen3CrossEncoderWorker(TextCrossEncoderWorker):
"""Worker for parallel Qwen3 reranker inference."""

def init_embedding(
self,
model_name: str,
cache_dir: str,
**kwargs: Any,
) -> OnnxTextCrossEncoder:
"""Initialise a Qwen3CrossEncoder instance for the worker."""
return Qwen3CrossEncoder(
model_name=model_name,
cache_dir=cache_dir,
threads=1,
**kwargs,
)
2 changes: 2 additions & 0 deletions fastembed/rerank/cross_encoder/text_cross_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from fastembed.common.types import Device
from fastembed.rerank.cross_encoder.onnx_text_cross_encoder import OnnxTextCrossEncoder
from fastembed.rerank.cross_encoder.custom_text_cross_encoder import CustomTextCrossEncoder
from fastembed.rerank.cross_encoder.qwen3_cross_encoder import Qwen3CrossEncoder

from fastembed.rerank.cross_encoder.text_cross_encoder_base import TextCrossEncoderBase
from fastembed.common.model_description import (
Expand All @@ -16,6 +17,7 @@
class TextCrossEncoder(TextCrossEncoderBase):
CROSS_ENCODER_REGISTRY: list[Type[TextCrossEncoderBase]] = [
OnnxTextCrossEncoder,
Qwen3CrossEncoder,
CustomTextCrossEncoder,
]

Expand Down
5 changes: 4 additions & 1 deletion fastembed/text/onnx_text_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,11 @@ def onnx_embed(
onnx_input = self._preprocess_onnx_input(onnx_input, **kwargs)

model_output = self.model.run(self.ONNX_OUTPUT_NAMES, onnx_input) # type: ignore[union-attr]
embeddings = model_output[0]
if embeddings.dtype == np.float16:
embeddings = embeddings.astype(np.float32)
return OnnxOutputContext(
model_output=model_output[0],
model_output=embeddings,
attention_mask=onnx_input.get("attention_mask", attention_mask),
input_ids=onnx_input.get("input_ids", input_ids),
)
Expand Down
Loading