Skip to content

Add a transformer-based default English fallback#74

Merged
hexgrad merged 1 commit into
hexgrad:mainfrom
PeterReid:main
Apr 14, 2025
Merged

Add a transformer-based default English fallback#74
hexgrad merged 1 commit into
hexgrad:mainfrom
PeterReid:main

Conversation

@PeterReid
Copy link
Copy Markdown
Contributor

I trained some BartForConditionalGeneration models from the transformers package to serve as a default fallback. This potentially drops the need for espeak for out-of-domain words. I used en_gold and gb_gold as the basis for the training data.

The main tweak I did to it was because regular plural versions of words do not seem to appear in it. To fix that, I scraped pluralized versions of words from wiki.train.raw and got their pronunciations the same way misaki does when it sees a plural.

Here are some sample lines from the poem Jabberwocky, and the G2P that it did with the new fallback.

"Beware the Jubjub bird, and shun
The frumious Bandersnatch!"

frumious -> fɹˈuːmɪəs
Jubjub -> ʤˈʌbʤʌb

twas_brillig_en.mp4

"Twas brillig, and the slithy toves
Did gyre and gimble in the wabe;"

wabe -> wˈAb
gimble -> ɡˈɪmbᵊl
toves -> tˈOvz
brillig -> bɹˈɪlɪɡ

jubjub_gb.mp4

Here's my training code, which is not exactly cleaned, but at least explains what I did.

import os
import json
from datasets import Dataset
from transformers import (
    BartForConditionalGeneration,
    BartConfig,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer,
    DataCollatorForSeq2Seq,
    PreTrainedTokenizer,
)
import numpy as np
import sys
import torch
import re

british = True
if british:
    dataset_path = "/home/peter/phonemizer/gb_gold.json"
    model_dir = "./en_gb_model"
else:
    dataset_path = "/home/peter/phonemizer/us_gold.json" 
    model_dir = "./en_us_model"
# -------------------------------
# 1. Define the training data
# -------------------------------
# TRAINING_DATA should be a list of (source, target) pairs.
# For example:
def load_dataset(path):
    with open(path) as f: 
        j = json.loads(f.read())
    dataset = []
    for english, phonemes in j.items():
        if type(phonemes) == str:
            dataset.append((english, phonemes))
    return dataset

def get_plural_phonemes(word, phoneme_lookup):
    if len(word) < 3 or not word.endswith('s'):
        return None
    if not word.endswith('ss'):
        stem = word[:-1]
    elif (word.endswith("'s") or (len(word) > 4 and word.endswith('es') and not word.endswith('ies'))):
        stem = word[:-2]
    elif len(word) > 4 and word.endswith('ies'):
        stem = word[:-3] + 'y'
    else:
        return None
    stem_phonemes = phoneme_lookup.get(stem, None)
    if not stem_phonemes: return None

    if stem[-1] in 'ptkfθ':
        return stem_phonemes + 's'
    elif stem[-1] in 'szʃʒʧʤ':
        return stem_phonemes + ('ɪ' if british else 'ᵻ') + 'z'
    return stem_phonemes + 'z'

# The dataset does not have regular plural versions of words, but the model does need
# to be able to handle that. To get examples of plural words that are actually used,
# we ingest a large text file and check if each word is a plural of a known-good word.
# If so, we can derive is pronounciation by following a simple rule.
# (There are some plural versions in the dataset, but they are mostly the unusual ones.)
def augment_dataset(dataset, text_file):
    phoneme_lookup = {graphemes: phonemes for (graphemes, phonemes) in dataset}
    added = set()
    extra_dataset = []
    with open(text_file) as f:
        for line in f.readlines():
            words = re.findall(r'\b\w+\b', line)
            for word in words:
                word = word.lower()
                if word in phoneme_lookup or word in added: continue

                plural_phonemes = get_plural_phonemes(word, phoneme_lookup)
                if plural_phonemes is not None:
                    extra_dataset.append((word, plural_phonemes))
                    added.add(word)
    return extra_dataset

def vocab_for_data(dataset):
    grapheme_chars = set()
    phoneme_chars = set()
    for graphemes, phonemes in dataset:
        grapheme_chars.update(list(graphemes))
        phoneme_chars.update(list(phonemes))
    print(grapheme_chars)
    print(phoneme_chars)
    # To have an clear char -> token id mapping, the characters that occur
    # as both graphemes and phonemes will be identically ordered at the
    # start for both.
    chars_in_common = set(grapheme_chars).intersection(set(phoneme_chars))
    only_grapheme = set(grapheme_chars).difference(chars_in_common)
    only_phoneme = set(phoneme_chars).difference(chars_in_common)

    grapheme_chars = sorted(list(chars_in_common)) + sorted(list(only_grapheme))
    phoneme_chars = sorted(list(chars_in_common)) + sorted(list(only_phoneme))
    return grapheme_chars, phoneme_chars

# -------------------------------
# 2. Build a character-level tokenizer
# -------------------------------
class CharTokenizer(PreTrainedTokenizer):
    """
    A very simple character-level tokenizer.
    """
    def __init__(self, grapheme_chars, phoneme_chars):
        self.bos_token = "<s>"
        self.eos_token = "</s>"
        self.pad_token = "<pad>"
        self.unk_token = "<unk>"

        # Reserve first few indices for special tokens.
        # We will use:
        #   <pad>: 0, <s>: 1, </s>: 2, <unk>: 3.
        special_vocab = [self.pad_token, self.bos_token, self.eos_token, self.unk_token]
        grapheme_vocab_list = special_vocab + grapheme_chars
        phoneme_vocab_list = special_vocab + phoneme_chars
        print(grapheme_vocab_list)
        print(phoneme_vocab_list)
        
        vocab = {token: idx for idx, token in enumerate(grapheme_vocab_list)}
        vocab.update({token: idx for idx, token in enumerate(phoneme_vocab_list)})
        self.vocab = vocab
        # Build an inverse mapping from id to token.
        self.ids_to_tokens = {i: t for t, i in self.vocab.items()}
        self.ids_to_phonemes = {i: t for i, t in enumerate(phoneme_vocab_list)}
        self.ids_to_graphemes = {i: t for i, t in enumerate(grapheme_vocab_list)}
        super().__init__()
    
    def _tokenize(self, text):
        # Split text into characters.
        return list(text)
    
    def _convert_token_to_id(self, token):
        return self.vocab.get(token, self.vocab.get("<unk>"))
    
    def convert_tokens_to_string(self, tokens):
        # Rejoin tokens as a string.
        return "".join(tokens)
   
    def decode_phonemes(self, tokens):
        return "".join([self.ids_to_phonemes.get(token, "<unk>") for token in tokens.tolist() if token > 3])

    def myencode(self, text):
        return [1] + [self.vocab.get(c, 3) for c in text] + [2]

    def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
        # For a single sentence, add bos at beginning and eos at end.
        if token_ids_1 is None:
            return [self.bos_token_id] + token_ids_0 + [self.eos_token_id]
        # For sequence pair, you can define your own method.
        return [self.bos_token_id] + token_ids_0 + [self.eos_token_id] + token_ids_1 + [self.eos_token_id]
    
    @property
    def vocab_size(self):
        return 4 + max(len(self.ids_to_phonemes), len(self.ids_to_graphemes))
    def get_vocab(self):
        return self.vocab
    def save_vocabulary(self, save_directory, filename_prefix=None):
        return ()


# -------------------------------
# 3. Encode the training examples
# -------------------------------
def encode_example(example):
    # Each example is a tuple: (input_text, target_text)
    source, target = example
    # Encode input and target with special tokens
    source_ids = tokenizer.myencode(source)
    target_ids = tokenizer.myencode(target)
    return {"input_ids": source_ids, "labels": target_ids}

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class Phonemizer:
    def __init__(self):
        self.model = BartForConditionalGeneration.from_pretrained(model_dir)
        # grapheme_chars and phoneme_chars start with "____" for the convience of the runtime version of this.
        self.tokenizer = CharTokenizer(list(self.model.config.grapheme_chars.lstrip('_')), list(self.model.config.phoneme_chars.lstrip('_')))
        total_params = sum(p.numel() for p in self.model.parameters())
        print(f"Total parameters: {total_params:,}")
        self.model.to(device)
        self.model.eval()

    def run(self, input_text):
        print("Running on", input_text)
        input_ids = torch.tensor([self.tokenizer.myencode(input_text)], device = device)

        # === 4. Generate output ===
        with torch.no_grad():
            generated_ids = self.model.generate(input_ids = input_ids)

        output_text = self.tokenizer.decode_phonemes(generated_ids[0])
        return output_text
# -------------------------------
# 5. Prepare training arguments and trainer
# -------------------------------
mode = sys.argv[1]
if mode == 'train':
    training_data = load_dataset(dataset_path)
    unstemmed_dataset = augment_dataset(training_data, '../wiki.train.raw')

    training_data += unstemmed_dataset
    grapheme_chars, phoneme_chars = vocab_for_data(training_data)
    print(grapheme_chars, phoneme_chars)
    print(training_data[-20:])
    tokenizer = CharTokenizer(grapheme_chars, phoneme_chars)
    # Tokenize all training samples
    encoded_data = [encode_example(pair) for pair in training_data]
    dataset = Dataset.from_list(encoded_data)

    # -------------------------------
    # 4. Define the model
    # -------------------------------
    # We use BartForConditionalGeneration from scratch.
    # Set a small configuration appropriate for the task.
    vocab_size = tokenizer.vocab_size
    config = BartConfig(
        vocab_size=vocab_size,
        d_model=128,                # hidden size
        encoder_layers=1,           # number of encoder layers
        decoder_layers=1,           # number of decoder layers
        encoder_attention_heads=1,  # attention heads in encoder
        decoder_attention_heads=1,  # attention heads in decoder
        decoder_start_token_id=tokenizer.bos_token_id,
        max_position_embeddings = 64,
        bos_token_id=tokenizer.bos_token_id,
        eos_token_id=tokenizer.eos_token_id,
        pad_token_id=tokenizer.pad_token_id,
    )
    config.grapheme_chars = "____" + "".join(grapheme_chars)
    config.phoneme_chars = "____" + "".join(phoneme_chars)
    model = BartForConditionalGeneration(config)

    training_args = Seq2SeqTrainingArguments(
        output_dir=model_dir,
        num_train_epochs=100,
        per_device_train_batch_size=32,
        learning_rate=5e-4,
        logging_steps=10,
        #evaluation_strategy="no",  # change to "steps" and add eval_dataset for evaluation
        predict_with_generate=True,
    )

    data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

    trainer = Seq2SeqTrainer(
        model=model,
        args=training_args,
        train_dataset=dataset,
        tokenizer=tokenizer,
        data_collator=data_collator,
    )

    # -------------------------------
    # 6. Start training!
    # -------------------------------
    trainer.train()

    # Optionally, save the final model and tokenizer
    model.save_pretrained(model_dir)
    
elif mode == 'run':
    # === 2. Load the model ===
    p = Phonemizer()

    # === 3. Prepare input ===
    input_text = sys.argv[2]
    output_text = p.run(input_text)
    print(f"Phonemes for '{input_text}':", output_text)

@PeterReid
Copy link
Copy Markdown
Contributor Author

Something seems wrong with the en_us one I uploaded to huggingface... I will try training that again soon.

@hexgrad hexgrad mentioned this pull request Apr 13, 2025
@hexgrad
Copy link
Copy Markdown
Owner

hexgrad commented Apr 13, 2025

Thanks for the PR! @PeterReid

A fallback model has long been on the bucket list, and I am impressed with the size being <10 MB each.

The dictionary files are periodically patched when errors are found. I have pushed the latest bump to #75 (still a draft PR, may need verification) and also uploaded the latest dictionaries just now to https://huggingface.co/datasets/hexgrad/misaki so if you are retraining, consider grabbing the latest dictionaries off HF.

Edit: Consider also using the silver dictionaries, either for training, validation, or both?

@PeterReid
Copy link
Copy Markdown
Contributor Author

PeterReid commented Apr 14, 2025

Thanks for the advice and I'm glad this seems like a right direction to you!

I have published updated models to huggingface. I updated to the newer dictionaries from #75 and used 90% of the silver as more training data, and 10% as eval. After a lot of messing with the training parameters, I've gotten the model to 3MB and performing better than it did before (based on eval loss and its reading of that poem).

jabberwocky_us.mp4
jabberwocky_gb.mp4

@PeterReid PeterReid closed this Apr 14, 2025
@PeterReid PeterReid reopened this Apr 14, 2025
@hexgrad hexgrad merged commit 232a5e3 into hexgrad:main Apr 14, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants