Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 85 additions & 0 deletions tests/test_bugfix_generate_flush_optional.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import importlib
import importlib.util
import os
import unittest
from unittest import mock


class GenerateFlushBugfixOptionalTests(unittest.TestCase):
@classmethod
def setUpClass(cls):
enabled = os.getenv("VOXMLX_ENABLE_MLX_RUNTIME_TESTS", "").strip().lower() in {
"1",
"true",
"yes",
"on",
}
if not enabled:
raise unittest.SkipTest(
"Set VOXMLX_ENABLE_MLX_RUNTIME_TESTS=1 to run MLX runtime optional tests"
)
if importlib.util.find_spec("mlx") is None:
raise unittest.SkipTest("mlx is required for runtime optional tests")

def test_generate_flushes_last_pending_non_eos_token(self):
import mlx.core as mx
import numpy as np

generate_module = importlib.import_module("voxmlx.generate")

class _FakeLanguageModel:
def __init__(self, hidden_dim: int, vocab_size: int):
self.layers = [object()]
self._hidden_dim = hidden_dim
self._vocab_size = vocab_size

def embed(self, input_ids):
bsz, seq = input_ids.shape
return mx.zeros((bsz, seq, self._hidden_dim), dtype=mx.float32)

class _FakeModel:
def __init__(self):
self.hidden_dim = 8
self.vocab_size = 128
self.n_audio = 4 # prefix=2, loop runs 2 steps
self._decode_tokens = [11, 12, 13] # prefill + 2 steps
self._decode_calls = 0
self.language_model = _FakeLanguageModel(self.hidden_dim, self.vocab_size)

def time_embedding(self, t):
return mx.zeros((1, self.hidden_dim), dtype=mx.float32)

def encode(self, mel):
del mel
return mx.zeros((self.n_audio, self.hidden_dim), dtype=mx.float32)

def decode(self, embeddings, t_cond, mask=None, cache=None):
del t_cond, mask, cache
token = self._decode_tokens[self._decode_calls]
self._decode_calls += 1
seq = embeddings.shape[1]
logits = np.full((1, seq, self.vocab_size), -1000.0, dtype=np.float32)
logits[:, :, token] = 1000.0
return mx.array(logits)

model = _FakeModel()
with mock.patch.object(generate_module, "load_audio", return_value=np.zeros(1, dtype=np.float32)), mock.patch.object(
generate_module, "pad_audio", side_effect=lambda x: x
), mock.patch.object(
generate_module, "log_mel_spectrogram", return_value=mx.zeros((128, 8), dtype=mx.float32)
):
out = generate_module.generate(
model=model,
audio_path="dummy.wav",
prompt_tokens=[1, 2],
n_delay_tokens=0,
temperature=0.0,
eos_token_id=99,
sliding_window=16,
)

self.assertEqual(out, [11, 12, 13])


if __name__ == "__main__":
unittest.main()
7 changes: 4 additions & 3 deletions voxmlx/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,9 @@ def step(token, pos):

y = next_y

# Check the last token
if output_tokens and output_tokens[-1] == eos_token_id:
output_tokens = output_tokens[:-1]
# Flush last pending token when loop ends without EOS.
token_id = y.item()
if token_id != eos_token_id:
output_tokens.append(token_id)

return output_tokens