Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 27 additions & 1 deletion misaki/en.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import re
import spacy
import unicodedata
from transformers import BartForConditionalGeneration
import torch

def merge_tokens(tokens: List[MToken], unk: Optional[str] = None) -> MToken:
stress = {tk._.stress for tk in tokens if tk._.stress is not None}
Expand Down Expand Up @@ -492,6 +494,30 @@ def __call__(self, tk, ctx):
# return apply_stress(self.append_currency(ps, tk._.currency), tk._.stress), rating
return None, None

class FallbackNetwork:
def __init__(self, british):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model = BartForConditionalGeneration.from_pretrained(
"PeterReid/graphemes_to_phonemes_en_" + ("gb" if british else "us"))
self.model.to(self.device)
self.model.eval()
self.grapheme_to_token = {g: i for i, g in enumerate(self.model.config.grapheme_chars)}
self.token_to_phoneme = {i: p for i, p in enumerate(self.model.config.phoneme_chars)}

def graphemes_to_tokens(self, graphemes):
return [1] + [self.grapheme_to_token.get(g, 3) for g in graphemes] + [2]

def tokens_to_phonemes(self, tokens):
return "".join([self.token_to_phoneme.get(t, '') for t in tokens if t > 3])

def __call__(self, input_token):
input_ids = torch.tensor([self.graphemes_to_tokens(input_token.text)], device = self.device)

with torch.no_grad():
generated_ids = self.model.generate(input_ids = input_ids)
output_text = self.tokens_to_phonemes(generated_ids[0].tolist())
return (output_text, 1)

class G2P:
def __init__(self, version=None, trf=False, british=False, fallback=None, unk='❓'):
self.version = version
Expand All @@ -502,7 +528,7 @@ def __init__(self, version=None, trf=False, british=False, fallback=None, unk='
components = ['transformer' if trf else 'tok2vec', 'tagger']
self.nlp = spacy.load(name, enable=components)
self.lexicon = Lexicon(british)
self.fallback = fallback if fallback else None
self.fallback = fallback if fallback else FallbackNetwork(british)
self.unk = unk

@staticmethod
Expand Down
8 changes: 5 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,21 @@ classifiers = [
"License :: OSI Approved :: Apache Software License",
"Operating System :: OS Independent",
]
requires-python = ">=3.8, <3.13"

requires-python = ">=3.9, <3.14"
dependencies = [
"addict",
"pip>=25.0.1",
# ^ Spacy needs pip within Python, uv lacks pip by default
"regex",
]

[project.optional-dependencies]
en = ["num2words", "spacy", "spacy-curated-transformers", "phonemizer-fork", "espeakng-loader"]
en = ["num2words", "spacy", "spacy-curated-transformers", "phonemizer-fork", "espeakng-loader", "torch", "transformers"]
ja = ["fugashi", "jaconv", "mojimoji", "unidic", "pyopenjtalk"]
ko = ["jamo", "nltk"]
zh = ["jieba", "ordered-set", "pypinyin", "cn2an", "pypinyin-dict"]
vi = ["num2words", "spacy", "spacy-curated-transformers", "underthesea"]
he = ["mishkal-hebrew>=0.3.2"]

[build-system]
requires = ["hatchling"]
Expand Down
Loading