Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions eland/ml/pytorch/_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Licensed to Elasticsearch B.V. under one or more contributor
# license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright
# ownership. Elasticsearch B.V. licenses this file to you under
# the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

"""
Utility functions for PyTorch module that can be shared across modules
without causing circular imports.
"""

from typing import Any, Set, Tuple, Union


def _is_tokenizer_type(
tokenizer: Any, tokenizer_class_names: Union[str, Tuple[str, ...], Set[str]]
) -> bool:
"""
Check if tokenizer is one of the specified types by class name.
Works even if tokenizer classes are not directly importable.

Args:
tokenizer: The tokenizer instance to check
tokenizer_class_names: String or tuple of strings with class names

Returns:
bool: True if tokenizer matches any of the specified types
"""
if isinstance(tokenizer_class_names, str):
tokenizer_class_names = (tokenizer_class_names,)
tokenizer_class_name = tokenizer.__class__.__name__
return tokenizer_class_name in tokenizer_class_names
159 changes: 107 additions & 52 deletions eland/ml/pytorch/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
from torch import Tensor
from torch.profiler import profile # type: ignore
from transformers import (
BertTokenizer,
PretrainedConfig,
PreTrainedModel,
PreTrainedTokenizer,
Expand Down Expand Up @@ -104,23 +103,41 @@
"text_similarity": TextSimilarityInferenceOptions,
}
SUPPORTED_TASK_TYPES_NAMES = ", ".join(sorted(SUPPORTED_TASK_TYPES))

# Try to import tokenizer classes, but don't fail if they don't exist
_SUPPORTED_TOKENIZER_NAMES = {
"BertTokenizer",
"BertJapaneseTokenizer",
"MPNetTokenizer",
"DPRContextEncoderTokenizer",
"DPRQuestionEncoderTokenizer",
"DistilBertTokenizer",
"ElectraTokenizer",
"MobileBertTokenizer",
"RetriBertTokenizer", # May be deprecated but models still work
"RobertaTokenizer",
"BartTokenizer",
"SqueezeBertTokenizer", # May be deprecated but models still work
"XLMRobertaTokenizer",
"DebertaV2Tokenizer",
}

# Try to build tuple of classes for backward compatibility where possible
_SUPPORTED_TOKENIZER_CLASSES = []
for name in _SUPPORTED_TOKENIZER_NAMES:
tokenizer_class = getattr(transformers, name, None)
if tokenizer_class is not None:
_SUPPORTED_TOKENIZER_CLASSES.append(tokenizer_class)

SUPPORTED_TOKENIZERS = (
transformers.BertTokenizer,
transformers.BertJapaneseTokenizer,
transformers.MPNetTokenizer,
transformers.DPRContextEncoderTokenizer,
transformers.DPRQuestionEncoderTokenizer,
transformers.DistilBertTokenizer,
transformers.ElectraTokenizer,
transformers.MobileBertTokenizer,
transformers.RetriBertTokenizer,
transformers.RobertaTokenizer,
transformers.BartTokenizer,
transformers.SqueezeBertTokenizer,
transformers.XLMRobertaTokenizer,
transformers.DebertaV2Tokenizer,
tuple(_SUPPORTED_TOKENIZER_CLASSES)
if _SUPPORTED_TOKENIZER_CLASSES
else (PreTrainedTokenizer,)
)
SUPPORTED_TOKENIZERS_NAMES = ", ".join(sorted([str(x) for x in SUPPORTED_TOKENIZERS]))
SUPPORTED_TOKENIZERS_NAMES = ", ".join(sorted(_SUPPORTED_TOKENIZER_NAMES))


from eland.ml.pytorch._utils import _is_tokenizer_type # noqa: E402

TracedModelTypes = Union[
torch.nn.Module,
Expand Down Expand Up @@ -205,18 +222,18 @@ def _compatible_inputs(self) -> tuple[Tensor, ...]:
inputs["token_type_ids"] = torch.zeros(
inputs["input_ids"].size(1), dtype=torch.long
)
if isinstance(
if _is_tokenizer_type(
self._tokenizer,
(
transformers.BartTokenizer,
transformers.MPNetTokenizer,
transformers.RobertaTokenizer,
transformers.XLMRobertaTokenizer,
"BartTokenizer",
"MPNetTokenizer",
"RobertaTokenizer",
"XLMRobertaTokenizer",
),
):
return (inputs["input_ids"], inputs["attention_mask"])

if isinstance(self._tokenizer, transformers.DebertaV2Tokenizer):
if _is_tokenizer_type(self._tokenizer, "DebertaV2Tokenizer"):
return (
inputs["input_ids"],
inputs["attention_mask"],
Expand Down Expand Up @@ -399,9 +416,9 @@ def __init__(
)

# check for a supported tokenizer
if not isinstance(self._tokenizer, SUPPORTED_TOKENIZERS):
if not _is_tokenizer_type(self._tokenizer, _SUPPORTED_TOKENIZER_NAMES):
raise TypeError(
f"Tokenizer type {self._tokenizer} not supported, must be one of: {SUPPORTED_TOKENIZERS_NAMES}"
f"Tokenizer type {self._tokenizer.__class__.__name__} not supported, must be one of: {SUPPORTED_TOKENIZERS_NAMES}"
)

self._traceable_model = self._create_traceable_model()
Expand All @@ -423,10 +440,43 @@ def _load_vocab(self) -> dict[str, list[str]]:
]
vocab_obj["merges"] = merges

if isinstance(self._tokenizer, transformers.DebertaV2Tokenizer):
sp_model = self._tokenizer._tokenizer.spm
else:
sp_model = getattr(self._tokenizer, "sp_model", None)
# Try to extract scores from sentencepiece model or fast tokenizer backend
scores = self._extract_vocab_scores(vocabulary)
if scores:
vocab_obj["scores"] = scores

return vocab_obj

def _extract_vocab_scores(self, vocabulary: list[str]) -> list[float] | None:
"""
Extract vocabulary scores from the tokenizer.

For slow tokenizers: uses sp_model.get_score()
For fast tokenizers: extracts from the tokenizers library JSON backend
"""
# First, try to get scores from fast tokenizer backend (transformers v5)
if getattr(self._tokenizer, "is_fast", False):
backend = getattr(self._tokenizer, "_tokenizer", None)
if backend is not None:
try:
import json

tok_json = json.loads(backend.to_str())
model_info = tok_json.get("model", {})
# Unigram models store vocab as [[token, score], ...]
if model_info.get("type") == "Unigram":
backend_vocab = model_info.get("vocab", [])
if backend_vocab:
# Build a token->score map from backend vocab
score_map = {token: score for token, score in backend_vocab}
# Return scores in the same order as vocabulary
scores = [score_map.get(token, 0.0) for token in vocabulary]
return scores
except Exception:
pass # Fall through to sp_model approach

# Try legacy sp_model approach for slow tokenizers
sp_model = getattr(self._tokenizer, "sp_model", None)

if sp_model:
id_correction = getattr(self._tokenizer, "fairseq_offset", 0)
Expand All @@ -438,34 +488,32 @@ def _load_vocab(self) -> dict[str, list[str]]:
scores.append(sp_model.get_score(token_id - id_correction))
except IndexError:
scores.append(0.0)
pass
if len(scores) > 0:
vocab_obj["scores"] = scores
return vocab_obj
if scores:
return scores

return None

def _create_tokenization_config(self) -> NlpTokenizationConfig:
if self._max_model_input_size:
_max_sequence_length = self._max_model_input_size
else:
_max_sequence_length = self._find_max_sequence_length()

if isinstance(self._tokenizer, transformers.MPNetTokenizer):
if _is_tokenizer_type(self._tokenizer, "MPNetTokenizer"):
return NlpMPNetTokenizationConfig(
do_lower_case=getattr(self._tokenizer, "do_lower_case", None),
max_sequence_length=_max_sequence_length,
)
elif isinstance(
self._tokenizer, (transformers.RobertaTokenizer, transformers.BartTokenizer)
):
elif _is_tokenizer_type(self._tokenizer, ("RobertaTokenizer", "BartTokenizer")):
return NlpRobertaTokenizationConfig(
add_prefix_space=getattr(self._tokenizer, "add_prefix_space", None),
max_sequence_length=_max_sequence_length,
)
elif isinstance(self._tokenizer, transformers.XLMRobertaTokenizer):
elif _is_tokenizer_type(self._tokenizer, "XLMRobertaTokenizer"):
return NlpXLMRobertaTokenizationConfig(
max_sequence_length=_max_sequence_length
)
elif isinstance(self._tokenizer, transformers.DebertaV2Tokenizer):
elif _is_tokenizer_type(self._tokenizer, "DebertaV2Tokenizer"):
return NlpDebertaV2TokenizationConfig(
max_sequence_length=_max_sequence_length,
do_lower_case=getattr(self._tokenizer, "do_lower_case", None),
Expand Down Expand Up @@ -509,7 +557,7 @@ def _find_max_sequence_length(self) -> int:
if max_len is not None and max_len < REASONABLE_MAX_LENGTH:
return int(max_len)

if isinstance(self._tokenizer, BertTokenizer):
if _is_tokenizer_type(self._tokenizer, "BertTokenizer"):
return 512

raise UnknownModelInputSizeError("Cannot determine model max input length")
Expand Down Expand Up @@ -692,18 +740,25 @@ def _make_inputs_compatible(
inputs["token_type_ids"] = torch.zeros(
inputs["input_ids"].size(1), dtype=torch.long
)
if isinstance(
if _is_tokenizer_type(
self._tokenizer,
(
transformers.BartTokenizer,
transformers.MPNetTokenizer,
transformers.RobertaTokenizer,
transformers.XLMRobertaTokenizer,
"BartTokenizer",
"MPNetTokenizer",
"RobertaTokenizer",
"XLMRobertaTokenizer",
),
):
del inputs["token_type_ids"]
return (inputs["input_ids"], inputs["attention_mask"])

if _is_tokenizer_type(self._tokenizer, "DebertaV2Tokenizer"):
return (
inputs["input_ids"],
inputs["attention_mask"],
inputs["token_type_ids"],
)

position_ids = torch.arange(inputs["input_ids"].size(1), dtype=torch.long)
inputs["position_ids"] = position_ids
return (
Expand All @@ -716,7 +771,7 @@ def _make_inputs_compatible(
def _create_traceable_model(self) -> _TransformerTraceableModel:
if self._task_type == "auto":
model = transformers.AutoModel.from_pretrained(
self._model_id, token=self._access_token, torchscript=True
self._model_id, token=self._access_token
)
maybe_task_type = task_type_from_model_config(model.config)
if maybe_task_type is None:
Expand All @@ -728,28 +783,28 @@ def _create_traceable_model(self) -> _TransformerTraceableModel:

if self._task_type == "text_expansion":
model = transformers.AutoModelForMaskedLM.from_pretrained(
self._model_id, token=self._access_token, torchscript=True
self._model_id, token=self._access_token
)
model = _DistilBertWrapper.try_wrapping(model)
return _TraceableTextExpansionModel(self._tokenizer, model)

if self._task_type == "fill_mask":
model = transformers.AutoModelForMaskedLM.from_pretrained(
self._model_id, token=self._access_token, torchscript=True
self._model_id, token=self._access_token
)
model = _DistilBertWrapper.try_wrapping(model)
return _TraceableFillMaskModel(self._tokenizer, model)

elif self._task_type == "ner":
model = transformers.AutoModelForTokenClassification.from_pretrained(
self._model_id, token=self._access_token, torchscript=True
self._model_id, token=self._access_token
)
model = _DistilBertWrapper.try_wrapping(model)
return _TraceableNerModel(self._tokenizer, model)

elif self._task_type == "text_classification":
model = transformers.AutoModelForSequenceClassification.from_pretrained(
self._model_id, token=self._access_token, torchscript=True
self._model_id, token=self._access_token
)
model = _DistilBertWrapper.try_wrapping(model)
return _TraceableTextClassificationModel(self._tokenizer, model)
Expand All @@ -766,7 +821,7 @@ def _create_traceable_model(self) -> _TransformerTraceableModel:

elif self._task_type == "zero_shot_classification":
model = transformers.AutoModelForSequenceClassification.from_pretrained(
self._model_id, token=self._access_token, torchscript=True
self._model_id, token=self._access_token
)
model = _DistilBertWrapper.try_wrapping(model)
return _TraceableZeroShotClassificationModel(self._tokenizer, model)
Expand All @@ -779,14 +834,14 @@ def _create_traceable_model(self) -> _TransformerTraceableModel:

elif self._task_type == "text_similarity":
model = transformers.AutoModelForSequenceClassification.from_pretrained(
self._model_id, token=self._access_token, torchscript=True
self._model_id, token=self._access_token
)
model = _DistilBertWrapper.try_wrapping(model)
return _TraceableTextSimilarityModel(self._tokenizer, model)

elif self._task_type == "pass_through":
model = transformers.AutoModel.from_pretrained(
self._model_id, token=self._access_token, torchscript=True
self._model_id, token=self._access_token
)
return _TraceablePassThroughModel(self._tokenizer, model)

Expand Down
Loading