diff --git a/misaki/data/gb_gold.json b/misaki/data/gb_gold.json index 8c7de92..f34f4f2 100644 --- a/misaki/data/gb_gold.json +++ b/misaki/data/gb_gold.json @@ -66252,7 +66252,7 @@ "DEFAULT": "ɹˈiːd", "VBD": "ɹˈɛd", "VBN": "ɹˈɛd", - "VBP": "ɹˈɛd" + "VBP": "ɹˈiːd" }, "read's": "ɹˈiːdz", "read-in": "ɹˈiːdɪn", @@ -68045,7 +68045,7 @@ "NOUN": "ɹˈiːɹiːd", "VBD": "ɹiːɹˈɛd", "VBN": "ɹiːɹˈɛd", - "VBP": "ɹiːɹˈɛd" + "VBP": "ɹiːɹˈiːd" }, "reread's": "ɹˈiːɹiːdz", "rereading": "ɹiːɹˈiːdɪŋ", @@ -89069,7 +89069,7 @@ "DEFAULT": "wˈuːnd", "VBD": "wˈWnd", "VBN": "wˈWnd", - "VBP": "wˈWnd" + "VBP": "wˈuːnd" }, "wound's": "wˈuːndz", "wounded": "wˈuːndɪd", diff --git a/misaki/data/us_gold.json b/misaki/data/us_gold.json index 8ef4bdc..381933e 100644 --- a/misaki/data/us_gold.json +++ b/misaki/data/us_gold.json @@ -68102,7 +68102,7 @@ "DEFAULT": "ɹˈid", "VBD": "ɹˈɛd", "VBN": "ɹˈɛd", - "VBP": "ɹˈɛd" + "VBP": "ɹˈid" }, "read's": "ɹˈidz", "read-in": "ɹˈidˌɪn", @@ -69997,7 +69997,7 @@ "NOUN": "ɹˈiɹid", "VBD": "ɹiɹˈɛd", "VBN": "ɹiɹˈɛd", - "VBP": "ɹiɹˈɛd" + "VBP": "ɹˌiɹˈid" }, "reread's": "ɹˈiɹidz", "rereading": "ɹˌiɹˈidɪŋ", @@ -91908,7 +91908,7 @@ "DEFAULT": "wˈund", "VBD": "wˈWnd", "VBN": "wˈWnd", - "VBP": "wˈWnd" + "VBP": "wˈund" }, "wound's": "wˈundz", "wounded": "wˈundᵻd", diff --git a/misaki/en.py b/misaki/en.py index 222c170..4602f40 100644 --- a/misaki/en.py +++ b/misaki/en.py @@ -1,4 +1,5 @@ from . import data +import os from .token import MToken from dataclasses import dataclass, replace from num2words import num2words @@ -25,7 +26,7 @@ def merge_tokens(tokens: List[MToken], unk: Optional[str] = None) -> MToken: phonemes += ' ' phonemes += unk if tk.phonemes is None else tk.phonemes return MToken( - text=''.join(tk.text + tk.whitespace for tk in tokens[:-1]) + tokens[-1].text, + text=(''.join(tk.text + tk.whitespace for tk in tokens[:-1]) + tokens[-1].text).strip(), tag=max(tokens, key=lambda tk: sum(1 if c == c.lower() else 2 for c in tk.text)).tag, whitespace=tokens[-1].whitespace, phonemes=phonemes, @@ -495,10 +496,19 @@ def __call__(self, tk, ctx): return None, None class FallbackNetwork: - def __init__(self, british): + def __init__(self, british, local_files_only=None): self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + # Respect offline env vars if not explicitly set + if local_files_only is None: + local_files_only = ( + os.environ.get("TRANSFORMERS_OFFLINE", "0") == "1" or + os.environ.get("HF_HUB_OFFLINE", "0") == "1" + ) + self.model = BartForConditionalGeneration.from_pretrained( - "PeterReid/graphemes_to_phonemes_en_" + ("gb" if british else "us")) + "PeterReid/graphemes_to_phonemes_en_" + ("gb" if british else "us"), + local_files_only=local_files_only) self.model.to(self.device) self.model.eval() self.grapheme_to_token = {g: i for i, g in enumerate(self.model.config.grapheme_chars)} @@ -519,7 +529,7 @@ def __call__(self, input_token): return (output_text, 1) class G2P: - def __init__(self, version=None, trf=False, british=False, fallback=None, unk='❓'): + def __init__(self, version=None, trf=False, british=False, fallback=None, unk='❓', local_files_only=None): self.version = version self.british = british name = f"en_core_web_{'trf' if trf else 'sm'}" @@ -528,7 +538,7 @@ def __init__(self, version=None, trf=False, british=False, fallback=None, unk=' components = ['transformer' if trf else 'tok2vec', 'tagger'] self.nlp = spacy.load(name, enable=components) self.lexicon = Lexicon(british) - self.fallback = fallback if fallback else FallbackNetwork(british) + self.fallback = fallback if fallback else FallbackNetwork(british, local_files_only=local_files_only) self.unk = unk @staticmethod diff --git a/pyproject.toml b/pyproject.toml index e0ee5aa..30be697 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,7 +24,7 @@ dependencies = [ ] [project.optional-dependencies] -en = ["num2words", "spacy", "spacy-curated-transformers", "phonemizer-fork", "espeakng-loader", "torch", "transformers"] +en = ["num2words", "spacy<4", "spacy-curated-transformers", "phonemizer-fork", "espeakng-loader", "torch", "transformers"] ja = ["fugashi", "jaconv", "mojimoji", "unidic", "pyopenjtalk"] ko = ["jamo", "nltk"] zh = ["jieba", "ordered-set", "pypinyin", "cn2an", "pypinyin-dict"]