Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 22 additions & 1 deletion src/heretic/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,12 +73,33 @@ def __init__(self, settings: Settings):
print()
print(f"Loading model [bold]{settings.model}[/]...")

self.tokenizer = AutoTokenizer.from_pretrained(
# Load configuration dictionary to verify model type.
# This prevents tokenizers configured with incorrect classes in upstream
# config metadata from generating corrupted/space-stripped tokens.
config_dict, _ = PretrainedConfig.get_config_dict(
settings.model,
trust_remote_code=settings.trust_remote_code,
**self.revision_kwargs,
)

tokenizer_kwargs = {
"trust_remote_code": settings.trust_remote_code,
**self.revision_kwargs,
}

if config_dict.get("model_type") == "qwen2":
from transformers import Qwen2TokenizerFast # ty:ignore[unresolved-import]

self.tokenizer = Qwen2TokenizerFast.from_pretrained(
settings.model,
**tokenizer_kwargs,
Comment on lines +90 to +95

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Checking only for "qwen2" might miss other Qwen2-family model types (such as "qwen2_moe", "qwen2_vl", or "qwen2_audio") which also use the Qwen2 tokenizer and could suffer from similar upstream configuration issues. Checking if the model_type starts with "qwen2" is more robust and future-proof.

Suggested change
if config_dict.get("model_type") == "qwen2":
from transformers import Qwen2TokenizerFast # ty:ignore[unresolved-import]
self.tokenizer = Qwen2TokenizerFast.from_pretrained(
settings.model,
**tokenizer_kwargs,
model_type = config_dict.get("model_type")
if isinstance(model_type, str) and model_type.startswith("qwen2"):
from transformers import Qwen2TokenizerFast
self.tokenizer = Qwen2TokenizerFast.from_pretrained(
settings.model,
**tokenizer_kwargs,
)

)
else:
self.tokenizer = AutoTokenizer.from_pretrained(
settings.model,
**tokenizer_kwargs,
)

# Multimodal models have a processor we'll want to save.
self.processor = None
if get_model_class(settings.model) == AutoModelForImageTextToText:
Expand Down