diff --git a/.gitignore b/.gitignore index 8f5aeb6..4c3bbf5 100644 --- a/.gitignore +++ b/.gitignore @@ -84,3 +84,9 @@ scripts/release_e2e/stage6_perf/last_run.json # Repo-local model-snapshot dir (RMLX_O_MODELS_ROOT fallback). Drop snapshots here. /models/ + +# Local audio test fixtures (real recordings + reference transcripts). The +# long-form regression test (crates/rmlx-audio/tests/transcribe.rs) scans this +# dir for `*.{m4a,wav,mp3,…}` paired with a sibling `*.transcript.vtt`. Drop +# your own audio + VTT here; nothing in it is tracked. +/crates/rmlx-audio/tests/fixtures/ diff --git a/Cargo.lock b/Cargo.lock index 5d0cc88..06eff8c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1947,6 +1947,7 @@ dependencies = [ "dhat", "libc", "parking_lot", + "rmlx-audio", "rmlx-core", "rmlx-kv-quant", "rmlx-kv-ssd", @@ -2472,8 +2473,10 @@ checksum = "1758d6c853020a7244de03cc3e0185eaea3f58715122422dd3cc7452e6d4c16a" dependencies = [ "lazy_static", "symphonia-bundle-mp3", + "symphonia-codec-aac", "symphonia-codec-pcm", "symphonia-core", + "symphonia-format-isomp4", "symphonia-format-riff", "symphonia-metadata", ] @@ -2489,6 +2492,18 @@ dependencies = [ "symphonia-core", ] +[[package]] +name = "symphonia-codec-aac" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f1979c515a76371b186aad2feff5f23e21cbec775bf95de08bf1e3af92a2ad76" +dependencies = [ + "lazy_static", + "log", + "symphonia-common", + "symphonia-core", +] + [[package]] name = "symphonia-codec-pcm" version = "0.6.0" @@ -2499,6 +2514,17 @@ dependencies = [ "symphonia-core", ] +[[package]] +name = "symphonia-common" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8257891ffa7f05e02b58f4761e2abf7e5278c8744fd59e981559e050f86eef55" +dependencies = [ + "log", + "symphonia-core", + "symphonia-metadata", +] + [[package]] name = "symphonia-core" version = "0.6.0" @@ -2513,6 +2539,18 @@ dependencies = [ "smallvec", ] +[[package]] +name = "symphonia-format-isomp4" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2d179a01305b3505940135a9f0180d6ef4b487912748fe97554756f120fbd05e" +dependencies = [ + "log", + "symphonia-common", + "symphonia-core", + "symphonia-metadata", +] + [[package]] name = "symphonia-format-riff" version = "0.6.0" diff --git a/Cargo.toml b/Cargo.toml index 2d30a38..23874c3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -123,14 +123,17 @@ parking_lot = "0.12" # Default features fine (SIMD acceleration on aarch64 is always-on in 6.x). rustfft = "6.4" -# Pure-Rust audio decode for WAV + MP3. -# wav — RIFF/WAVE container format reader (symphonia-format-riff). -# pcm — PCM codec required to decode the audio inside WAV files -# (symphonia-codec-pcm; NOT included by wav alone). -# mp3 — MPEG Layer 3 codec (symphonia-bundle-mp3/mp3). -# default-features=false drops AIFF/OGG/FLAC/AAC/MKV/ADPCM from the default -# registry so only the three features we need are compiled in. -symphonia = { version = "0.6", default-features = false, features = ["wav", "pcm", "mp3"] } +# Pure-Rust audio decode for WAV + MP3 + AAC/MP4 (`.m4a`). +# wav — RIFF/WAVE container format reader (symphonia-format-riff). +# pcm — PCM codec required to decode the audio inside WAV files +# (symphonia-codec-pcm; NOT included by wav alone). +# mp3 — MPEG Layer 3 codec (symphonia-bundle-mp3/mp3). +# isomp4 — ISO-BMFF / MP4 container reader (needed for `.m4a`, `.mp4`). +# aac — AAC codec (the audio inside most `.m4a` recordings). +# These are all features of the existing `symphonia` dep — no new crate. They +# let `rmlx transcribe` / the audio endpoint ingest real meeting recordings +# (`.m4a`) directly. default-features=false still drops AIFF/OGG/FLAC/MKV/ADPCM. +symphonia = { version = "0.6", default-features = false, features = ["wav", "pcm", "mp3", "isomp4", "aac"] } # Testing tempfile = "3" diff --git a/crates/rmlx-audio/src/lib.rs b/crates/rmlx-audio/src/lib.rs index 16f4c35..96b92f7 100644 --- a/crates/rmlx-audio/src/lib.rs +++ b/crates/rmlx-audio/src/lib.rs @@ -53,10 +53,14 @@ pub mod mel; pub mod npz; pub mod tokenizer; +pub mod transcribe; pub mod tts; pub mod vad; pub mod wav; pub mod whisper; +pub use transcribe::{ + render, resample_to_16k, OutputFormat, Segment, TranscribeOptions, Transcriber, Transcription, +}; pub use vad::{voiced_segments, SileroVad, VadState}; pub use wav::{WavDecoder, WavEncoder}; diff --git a/crates/rmlx-audio/src/npz.rs b/crates/rmlx-audio/src/npz.rs index 1009d7a..3019af4 100644 --- a/crates/rmlx-audio/src/npz.rs +++ b/crates/rmlx-audio/src/npz.rs @@ -604,9 +604,11 @@ pub fn parse_npy_array(name: &str, data: &[u8]) -> Result { let raw = &data[header_end..]; let n_elems: usize = shape.iter().product(); let elem_bytes = match dtype { + Dtype::U8 => 1, Dtype::F16 => 2, - Dtype::F32 => 4, - _ => { + Dtype::F32 | Dtype::U32 | Dtype::I32 => 4, + Dtype::Bf16 => { + // extract_npy_dtype never yields Bf16, but the match must be total. return Err(npy_err(&format!("unsupported dtype {dtype:?}"))); } }; @@ -625,6 +627,12 @@ pub fn parse_npy_array(name: &str, data: &[u8]) -> Result { // ── NPY header field extractors (pub for tests) ─────────────────────────────── /// Extract the numpy dtype descriptor from a `.npy` header string. +/// +/// Whisper `weights.npz` carries float weights (`f2`/`f4`) plus a small +/// `alignment_heads` mask. That mask is stored as a NumPy boolean (`b1`) or +/// small-int array; older parsers hard-errored on it with +/// "cannot parse dtype". We map booleans / 1-byte ints to `U8` and 4-byte ints +/// to `U32`/`I32` so the mask loads instead of aborting the whole archive. pub fn extract_npy_dtype(header: &str) -> Option { let start = header.find("'descr'")?; let rest = &header[start + 7..]; @@ -635,6 +643,10 @@ pub fn extract_npy_dtype(header: &str) -> Option { match s { "f2" => Some(Dtype::F16), "f4" => Some(Dtype::F32), + // boolean + 1-byte ints → U8 (alignment_heads mask). + "b1" | "u1" | "i1" => Some(Dtype::U8), + "u4" => Some(Dtype::U32), + "i4" => Some(Dtype::I32), _ => None, } } diff --git a/crates/rmlx-audio/src/npz_tests.rs b/crates/rmlx-audio/src/npz_tests.rs index e99939f..3146feb 100644 --- a/crates/rmlx-audio/src/npz_tests.rs +++ b/crates/rmlx-audio/src/npz_tests.rs @@ -17,6 +17,19 @@ fn npy_dtype_f4() { assert_eq!(extract_npy_dtype(hdr), Some(Dtype::F32)); } +#[test] +fn npy_dtype_bool_and_ints() { + // alignment_heads ships as a boolean (`b1`) mask — must parse, not error. + let hdr_bool = "{'descr': '|b1', 'fortran_order': False, 'shape': (20, 32), }"; + assert_eq!(extract_npy_dtype(hdr_bool), Some(Dtype::U8)); + let hdr_u1 = "{'descr': '|u1', 'fortran_order': False, 'shape': (10,), }"; + assert_eq!(extract_npy_dtype(hdr_u1), Some(Dtype::U8)); + let hdr_i4 = "{'descr': '`=50259 … `<|yue|>`=50358), +//! one more than v1/v2 — this shifts every special after the language block up +//! by one vs the v2 layout. Getting these wrong made the decoder emit the wrong +//! task token and treat `<|notimestamps|>` as the timestamp-begin sentinel, +//! producing empty / garbage transcripts. //! //! | Token | ID | //! |---|---| //! | `<|endoftext|>` (eot) | 50 257 | //! | `<|startoftranscript|>` (sot) | 50 258 | //! | `<|en|>` | 50 259 | -//! | `<|translate|>` | 50 358 | -//! | `<|transcribe|>` | 50 359 | -//! | `<|nospeech|>` | 50 362 | -//! | `<|notimestamps|>` | 50 363 | -//! | `<|0.00|>` (timestamp_begin) | 50 364 | +//! | `<|yue|>` (last language) | 50 358 | +//! | `<|translate|>` | 50 359 | +//! | `<|transcribe|>` | 50 360 | +//! | `<|startoflm|>` | 50 361 | +//! | `<|startofprev|>` | 50 362 | +//! | `<|nospeech|>` | 50 363 | +//! | `<|notimestamps|>` | 50 364 | +//! | `<|0.00|>` (timestamp_begin) | 50 365 | //! -//! Total vocabulary: 51 866 tokens (50 257 base GPT-2 + 1 609 added). +//! Total vocabulary: 51 866 tokens (50 257 base GPT-2 + 1 609 added; the last +//! timestamp `<|30.00|>` is 51 865). //! //! ## Tokenizer loading //! @@ -49,16 +58,22 @@ pub const TOK_EOT: u32 = 50_257; pub const TOK_SOT: u32 = 50_258; /// English language token (`<|en|>`). pub const TOK_EN: u32 = 50_259; +/// Last language token (`<|yue|>`). large-v3 has 100 languages (50259..=50358). +pub const TOK_LANG_LAST: u32 = 50_358; /// Translate task token (`<|translate|>`). -pub const TOK_TRANSLATE: u32 = 50_358; +pub const TOK_TRANSLATE: u32 = 50_359; /// Transcribe task token (`<|transcribe|>`). -pub const TOK_TRANSCRIBE: u32 = 50_359; +pub const TOK_TRANSCRIBE: u32 = 50_360; +/// Start-of-LM token (`<|startoflm|>`). +pub const TOK_SOT_LM: u32 = 50_361; +/// Start-of-previous-context token (`<|startofprev|>`), for prompt conditioning. +pub const TOK_SOT_PREV: u32 = 50_362; /// No-speech token (`<|nospeech|>`). -pub const TOK_NOSPEECH: u32 = 50_362; +pub const TOK_NOSPEECH: u32 = 50_363; /// No-timestamps token (`<|notimestamps|>`). -pub const TOK_NO_TIMESTAMPS: u32 = 50_363; -/// First timestamp token `<|0.00|>`. -pub const TOK_TIMESTAMP_BEGIN: u32 = 50_364; +pub const TOK_NO_TIMESTAMPS: u32 = 50_364; +/// First timestamp token `<|0.00|>`. Timestamps run 50365..=51865 in 0.02 s steps. +pub const TOK_TIMESTAMP_BEGIN: u32 = 50_365; /// Language token IDs for the 99 supported languages. /// @@ -228,6 +243,70 @@ impl WhisperTokenizer { Ok(enc.get_ids().to_vec()) } + /// Derive the Whisper non-speech / special suppression token set from the + /// loaded tokenizer. + /// + /// This is a faithful port of openai-whisper's `Tokenizer.non_speech_tokens` + /// plus `_get_suppress_tokens`: the non-speech set is computed by BPE-encoding + /// a fixed list of punctuation / symbol strings and keeping the ids that + /// encode to a single token (general — no hardcoded magic id list). On top of + /// that we always suppress the structural specials (`SOT`, `TRANSCRIBE`, + /// `TRANSLATE`, `NOSPEECH`) so they can never be emitted as content. + /// + /// The returned set is sorted+deduplicated and intended to be applied as a + /// logit mask at every decode step. + pub fn suppress_tokens(&self) -> Vec { + // openai-whisper symbol list (`whisper/tokenizer.py::non_speech_tokens`). + // Each char of the first group is a standalone symbol; the second group + // is space-separated multi-char symbols. + const SYMBOL_CHARS: &str = "\"#()*+/:;<=>@[\\]^_`{|}~「」『』"; + const SYMBOL_WORDS: &str = + "<< >> <<< >>> -- --- -( -[ (' (\" (( )) ((( ))) [[ ]] {{ }} ♪♪ ♪♪♪"; + // "Miscellaneous" musical symbols: each is forced in even when it + // BPE-splits into multiple tokens (matches the reference). + const MISC: &str = "♩♪♫♬♭♮♯"; + + let mut set: std::collections::BTreeSet = std::collections::BTreeSet::new(); + + // Reference seeds the set with the first token of " -" and " '". + for seed in [" -", " '"] { + if let Ok(ids) = self.encode(seed) { + if let Some(&first) = ids.first() { + set.insert(first); + } + } + } + + let misc: std::collections::BTreeSet = MISC.chars().collect(); + let symbols: Vec = SYMBOL_CHARS + .chars() + .map(|c| c.to_string()) + .chain(SYMBOL_WORDS.split(' ').map(str::to_owned)) + .chain(MISC.chars().map(|c| c.to_string())) + .collect(); + + for symbol in &symbols { + let is_misc = symbol.chars().count() == 1 + && symbol.chars().next().is_some_and(|c| misc.contains(&c)); + for variant in [symbol.clone(), format!(" {symbol}")] { + if let Ok(ids) = self.encode(&variant) { + if (ids.len() == 1 || is_misc) && !ids.is_empty() { + if let Some(&first) = ids.first() { + set.insert(first); + } + } + } + } + } + + // Structural specials that must never appear in content. + for t in [TOK_SOT, TOK_TRANSCRIBE, TOK_TRANSLATE, TOK_NOSPEECH] { + set.insert(t); + } + + set.into_iter().collect() + } + /// Decode token IDs back to text, skipping tokens ≥ `TOK_EOT` (specials). pub fn decode(&self, tokens: &[u32]) -> Result { // Filter timestamp and special tokens before decode. diff --git a/crates/rmlx-audio/src/tokenizer_tests.rs b/crates/rmlx-audio/src/tokenizer_tests.rs index 35adce7..ba0680d 100644 --- a/crates/rmlx-audio/src/tokenizer_tests.rs +++ b/crates/rmlx-audio/src/tokenizer_tests.rs @@ -15,8 +15,8 @@ use super::{ /// `WhisperTokenizerFast.from_pretrained` and calling `tokenizer.encode(f"<|{code}|>")`. /// Full table coverage is constrained to languages that are both (a) in the Whisper /// vocabulary and (b) listed in the `language_token` compile-time match. -/// When `RMLX_TEST_MODEL_WHISPER` is set, integration tests in `crates/rmlx-audio/tests/` -/// can verify against a live tokenizer.json. +/// Integration tests in `crates/rmlx-audio/tests/transcribe.rs` verify against a +/// live `tokenizer.json` when the snapshot is present under `RMLX_O_MODELS_ROOT`. #[test] fn language_token_ids() { // Verified against WhisperTokenizerFast.from_pretrained output. @@ -37,14 +37,18 @@ fn language_token_ids() { } /// Special token constant sanity check. +/// +/// large-v3 has 100 language slots (`<|en|>`=50259 … `<|yue|>`=50358), so every +/// special after the language block is shifted up by one vs the v1/v2 layout. +/// Verified against the shipped `tokenizer.json` `added_tokens` table. #[test] fn special_token_constants() { assert_eq!(TOK_EOT, 50_257); assert_eq!(TOK_SOT, 50_258); assert_eq!(TOK_EN, 50_259); - assert_eq!(TOK_TRANSLATE, 50_358); - assert_eq!(TOK_TRANSCRIBE, 50_359); - assert_eq!(TOK_NO_TIMESTAMPS, 50_363); + assert_eq!(TOK_TRANSLATE, 50_359); + assert_eq!(TOK_TRANSCRIBE, 50_360); + assert_eq!(TOK_NO_TIMESTAMPS, 50_364); } /// SOT sequence structure. diff --git a/crates/rmlx-audio/src/transcribe.rs b/crates/rmlx-audio/src/transcribe.rs new file mode 100644 index 0000000..c122199 --- /dev/null +++ b/crates/rmlx-audio/src/transcribe.rs @@ -0,0 +1,593 @@ +// LOC-exempt: the long-form transcription engine (window seek loop + timestamp +// segmentation + previous-text conditioning + subtitle formatting) is one +// cohesive sequential pipeline; splitting the seek loop from the formatters +// would scatter the timestamp arithmetic without cohesion gain. + +//! Long-form Whisper transcription engine. +//! +//! This is the single transcription core shared by the HTTP +//! `POST /v1/audio/transcriptions` route and the `rmlx transcribe` CLI. It +//! replaces the old "first 30 s only" behaviour with a sliding-window seek +//! loop modelled on `openai-whisper` / `mlx_whisper` `transcribe()`: +//! +//! 1. Decode the input container to 16 kHz mono f32 (caller's responsibility; +//! see [`crate::wav::WavDecoder`] + [`resample_to_16k`]). +//! 2. Walk the audio in 30 s windows. For each window, run the decoder in +//! **timestamp mode** with the full [`crate::whisper::DecodeFilters`] chain. +//! 3. Parse the emitted timestamp tokens into segments with real cumulative +//! times, advance the seek position by the last consumed timestamp, and feed +//! the previous window's text back as a prompt (`<|startofprev|>`). +//! 4. Emit multi-segment output (`vtt` / `srt` / `json` / `txt`). +//! +//! Determinism: temperature is fixed at 0 (greedy argmax), so the same audio +//! produces byte-identical output across runs. + +use std::sync::Arc; + +use rmlx_mlx::Device; +use tracing::{debug, info}; + +use crate::mel::{MelExtractor, N_SAMPLES, SAMPLE_RATE}; +use crate::tokenizer::{WhisperTask, WhisperTokenizer, TOK_EOT, TOK_SOT_PREV, TOK_TIMESTAMP_BEGIN}; +use crate::whisper::{DecodeFilters, WhisperError, WhisperModel}; + +/// Seconds of audio represented by one timestamp-token step (Whisper uses 0.02 s). +const TIME_PRECISION: f32 = 0.02; + +/// Max length of the previous-text prompt fed back as `<|startofprev|>` context, +/// derived from the decoder context length at runtime (no fixed literal). +/// +/// Mirrors openai-whisper's `prompt[-(n_text_ctx // 2 - 1):]`: half the context +/// minus one, so the prompt leaves room for the SOT_PREV marker, the SOT prefix, +/// and the per-window generation budget without overrunning `n_text_ctx`. +#[must_use] +fn previous_text_cap(n_text_ctx: usize) -> usize { + (n_text_ctx / 2).saturating_sub(1) +} + +/// Per-window decoder generation budget, derived from `n_text_ctx` at runtime. +/// +/// openai-whisper uses `sample_len = n_text_ctx // 2`, but the hard ceiling is +/// that the decoder position must stay `< n_text_ctx`: the positional-embedding +/// slice `[offset, offset+seq)` would otherwise run off the `[n_text_ctx, n_state]` +/// table and abort the transcription. `offset` starts at `prefix_len` and grows by +/// one per generated token, so the largest row requested is +/// `prefix_len + generated - 1`. Bounding `generated <= n_text_ctx - prefix_len` +/// keeps that `< n_text_ctx`. Returns 0 when the prefix already fills the context. +#[must_use] +fn window_token_budget(n_text_ctx: usize, prefix_len: usize) -> usize { + let headroom = n_text_ctx.saturating_sub(prefix_len); + (n_text_ctx / 2).min(headroom) +} + +/// One transcribed segment with real wall-clock times (seconds from start). +#[derive(Debug, Clone)] +#[allow(clippy::exhaustive_structs, reason = "stable public segment shape")] +pub struct Segment { + /// Segment start time in seconds from the beginning of the audio. + pub start: f32, + /// Segment end time in seconds. + pub end: f32, + /// Decoded text (specials / timestamps stripped, trimmed). + pub text: String, +} + +/// Full transcription result. +#[derive(Debug, Clone)] +#[allow(clippy::exhaustive_structs, reason = "stable public result shape")] +pub struct Transcription { + /// Concatenated full text. + pub text: String, + /// Per-segment breakdown with timestamps. + pub segments: Vec, + /// Resolved language (BCP-47 code or `lang_tok=N` when auto-detected). + pub language: String, + /// Total audio duration in seconds. + pub duration: f32, +} + +/// Options for a transcription run. +#[derive(Debug, Clone)] +#[allow(clippy::exhaustive_structs, reason = "small, stable options bag")] +pub struct TranscribeOptions { + /// Language: a BCP-47 code (`"en"`, `"fr"`, …) or `"auto"` for detection. + pub language: String, + /// Transcribe (same language) or translate (force English). + pub task: WhisperTask, + /// Sampling temperature. 0 = deterministic greedy (the only supported path). + pub temperature: f32, + /// Feed the previous window's text back as a decoder prompt. + pub condition_on_previous_text: bool, +} + +impl Default for TranscribeOptions { + fn default() -> Self { + Self { + language: "auto".to_owned(), + task: WhisperTask::Transcribe, + temperature: 0.0, + condition_on_previous_text: true, + } + } +} + +/// Long-form Whisper transcriber. Construct once, reuse across requests. +pub struct Transcriber { + model: Arc, + tokenizer: Arc, + extractor: MelExtractor, + /// Tokenizer-derived non-speech / special suppression set. + suppress: Vec, + /// SuppressBlank ids (EOT + blank-space token). + blank_ids: Vec, +} + +impl std::fmt::Debug for Transcriber { + /// Print config dims + suppression-set sizes; the model/tokenizer/extractor + /// are opaque (large MLX buffers), so they are summarised, not dumped. + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Transcriber") + .field("n_vocab", &self.model.cfg.n_vocab) + .field("n_text_ctx", &self.model.cfg.n_text_ctx) + .field("n_mels", &self.model.cfg.n_mels) + .field("suppress_len", &self.suppress.len()) + .field("blank_ids_len", &self.blank_ids.len()) + .finish_non_exhaustive() + } +} + +impl Transcriber { + /// Build a transcriber from a loaded model + tokenizer. + pub fn new( + model: Arc, + tokenizer: Arc, + ) -> Result { + let extractor = MelExtractor::new(model.cfg.n_mels) + .map_err(|e| WhisperError::Mlx(format!("mel extractor: {e}")))?; + let suppress = tokenizer.suppress_tokens(); + // SuppressBlank: EOT + whatever " " encodes to (general, from the tokenizer). + let mut blank_ids = vec![TOK_EOT]; + if let Ok(space) = tokenizer.encode(" ") { + blank_ids.extend(space); + } + Ok(Self { + model, + tokenizer, + extractor, + suppress, + blank_ids, + }) + } + + /// Transcribe a full 16 kHz mono f32 waveform of any length. + #[allow( + clippy::too_many_lines, + reason = "the seek loop + segment accumulation is one cohesive long-form pass" + )] + pub fn transcribe( + &self, + samples: &[f32], + opts: &TranscribeOptions, + device: Device, + ) -> Result { + let total_samples = samples.len(); + let duration = total_samples as f32 / SAMPLE_RATE as f32; + + // Resolve language once on the first window (auto-detect needs an encode). + let mut resolved_lang: Option<(u32, String)> = None; + + let mut segments: Vec = Vec::new(); + let mut prompt_tokens: Vec = Vec::new(); + let mut seek: usize = 0; // sample offset of the current window start + + let filters = DecodeFilters::new(self.suppress.clone(), self.blank_ids.clone(), true); + + while seek < total_samples { + let window_end = (seek + N_SAMPLES).min(total_samples); + #[allow( + clippy::indexing_slicing, + reason = "seek < total_samples (loop guard) and window_end = min(seek+N, total_samples), so seek..window_end is always in bounds" + )] + let window = &samples[seek..window_end]; + let window_start_time = seek as f32 / SAMPLE_RATE as f32; + // Real audio length of this (un-padded) window, in seconds. + let window_dur = (window_end - seek) as f32 / SAMPLE_RATE as f32; + + // mel + encode (mel pads to 30 s internally). + let mel_frames = self + .extractor + .extract(window) + .map_err(|e| WhisperError::Mlx(format!("mel: {e}")))?; + let encoder_out = self.model.encode_mel(&mel_frames, device)?; + + // Resolve language on first window. + if resolved_lang.is_none() { + let (lang_tok, lang_str) = if opts.language == "auto" { + let t = self + .model + .detect_language(&encoder_out, device) + .unwrap_or(crate::tokenizer::TOK_EN); + (t, format!("lang_tok={t}")) + } else { + ( + crate::tokenizer::language_token(&opts.language), + opts.language.clone(), + ) + }; + debug!(lang_tok, "language resolved"); + resolved_lang = Some((lang_tok, lang_str)); + } + let lang_tok = resolved_lang + .as_ref() + .map_or(crate::tokenizer::TOK_EN, |(t, _)| *t); + + // Build the SOT sequence (timestamp mode → no <|notimestamps|>), with + // optional previous-text prompt. + let sot = self + .tokenizer + .sot_sequence_from_tok(lang_tok, opts.task, true); + let mut full_prefix: Vec = Vec::new(); + if opts.condition_on_previous_text && !prompt_tokens.is_empty() { + full_prefix.push(TOK_SOT_PREV); + full_prefix.extend(prompt_tokens.iter().copied()); + } + full_prefix.extend(sot.iter().copied()); + + // Per-window generation budget, derived from `n_text_ctx` at runtime + // (no fixed literal); bounded so the decoder position stays `< n_text_ctx`. + // greedy_decode additionally refuses any positional row `>= n_text_ctx` + // as a belt-and-suspenders guard. + let max_tokens = window_token_budget(self.model.cfg.n_text_ctx, full_prefix.len()); + + let tokens = match self.model.greedy_decode( + &encoder_out, + &full_prefix, + max_tokens, + opts.temperature, + &filters, + device, + ) { + Ok(t) => t, + Err(WhisperError::Silence) => { + // Whole window is silence — skip ahead a full window. + seek += N_SAMPLES; + continue; + } + Err(e) => return Err(e), + }; + + // Parse timestamp tokens into segments within this window. + let (window_segments, consumed_time) = + self.split_window(&tokens, window_start_time, window_dur)?; + + // Advance seek by the time we actually consumed. + let advance_secs = consumed_time.clamp(0.0, window_dur); + let advance_samples = (advance_secs * SAMPLE_RATE as f32) as usize; + // Guarantee forward progress even when no timestamp was emitted. + let advance_samples = advance_samples.max(1).min(window_end - seek); + // If we consumed essentially nothing but the window is full-size, + // jump a whole window to avoid stalling. + let advance_samples = + if advance_secs <= TIME_PRECISION && window_end - seek >= N_SAMPLES { + N_SAMPLES + } else { + advance_samples + }; + + debug!( + seek, + window_start_time, + window_dur, + n_tokens = tokens.len(), + n_segments = window_segments.len(), + consumed_time, + advance_secs, + advance_samples, + "long-form window done" + ); + + // Update the prompt for the next window from this window's text. + if opts.condition_on_previous_text { + let window_text: String = window_segments + .iter() + .map(|s| s.text.as_str()) + .collect::>() + .join(" "); + if !window_text.trim().is_empty() { + prompt_tokens = self + .tokenizer + .encode(window_text.trim()) + .unwrap_or_default(); + // Cap the previous-text prompt the way openai-whisper does: + // the prefix never overruns n_text_ctx once the SOT_PREV marker + // + SOT prefix + generation budget are added. + let cap = previous_text_cap(self.model.cfg.n_text_ctx); + if prompt_tokens.len() > cap { + let start = prompt_tokens.len() - cap; + prompt_tokens = prompt_tokens.split_off(start); + } + } + } + + segments.extend(window_segments); + seek += advance_samples; + } + + let text = segments + .iter() + .map(|s| s.text.as_str()) + .collect::>() + .join(" ") + .split_whitespace() + .collect::>() + .join(" "); + + let language = resolved_lang.map_or_else(|| "en".to_owned(), |(_, s)| s); + + info!( + n_segments = segments.len(), + duration, "long-form transcription complete" + ); + + Ok(Transcription { + text, + segments, + language, + duration, + }) + } + + /// Split one window's token stream into timestamped segments. + /// + /// Returns `(segments, consumed_time_secs)` where `consumed_time_secs` is the + /// last timestamp boundary used to advance the seek (relative to the window + /// start, 0..30 s). + #[allow( + clippy::unnecessary_wraps, + reason = "returns Result for symmetry with tokenizer.decode error paths" + )] + fn split_window( + &self, + tokens: &[u32], + window_start_time: f32, + window_dur: f32, + ) -> Result<(Vec, f32), WhisperError> { + let mut segments: Vec = Vec::new(); + let mut last_ts_time: Option = None; // relative seconds within window + let mut cur_text: Vec = Vec::new(); + let mut seg_start: Option = None; + let mut consumed_time = 0.0_f32; + + let ts_to_secs = |tok: u32| -> f32 { (tok - TOK_TIMESTAMP_BEGIN) as f32 * TIME_PRECISION }; + + // A segment whose opening timestamp sits past the real (un-padded) audio + // length is in the 30 s zero-pad tail — Whisper hallucinates filler there + // ("you", "thank you", "♪"). Drop those. A small tolerance absorbs the + // 0.02 s timestamp granularity. + let speech_limit = window_dur + 0.5; + let in_speech = |start: f32| start <= speech_limit; + + for &tok in tokens { + if tok == TOK_EOT { + break; + } + if tok >= TOK_TIMESTAMP_BEGIN { + let t = ts_to_secs(tok); + match (seg_start, last_ts_time) { + (None, _) => { + // Opening timestamp of a segment. + seg_start = Some(t); + } + (Some(start), _) => { + // Closing timestamp — flush the accumulated text. + let text = self.tokenizer.decode(&cur_text).unwrap_or_default(); + let trimmed = text.trim(); + if !trimmed.is_empty() && in_speech(start) { + segments.push(Segment { + start: window_start_time + start, + end: window_start_time + t, + text: trimmed.to_owned(), + }); + } + cur_text.clear(); + seg_start = None; + } + } + last_ts_time = Some(t); + consumed_time = t; + } else { + cur_text.push(tok); + } + } + + // Flush a dangling open segment (text with an opening timestamp but no + // closing one — happens when the window cuts mid-utterance). + if let Some(start) = seg_start { + let text = self.tokenizer.decode(&cur_text).unwrap_or_default(); + let trimmed = text.trim(); + if !trimmed.is_empty() && in_speech(start) { + let end = last_ts_time.unwrap_or(window_dur).max(start); + segments.push(Segment { + start: window_start_time + start, + end: window_start_time + end, + text: trimmed.to_owned(), + }); + } + } else if segments.is_empty() && !cur_text.is_empty() { + // No timestamps at all — emit the whole window as one segment. + let text = self.tokenizer.decode(&cur_text).unwrap_or_default(); + let trimmed = text.trim(); + if !trimmed.is_empty() { + segments.push(Segment { + start: window_start_time, + end: window_start_time + window_dur, + text: trimmed.to_owned(), + }); + } + } + + // If no usable timestamp boundary was found, consume the whole window. + if consumed_time <= TIME_PRECISION { + consumed_time = window_dur; + } + + Ok((segments, consumed_time)) + } +} + +// ── Output formatters ─────────────────────────────────────────────────────── + +/// Output format for a transcription. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[allow( + clippy::exhaustive_enums, + reason = "the supported format set is closed" +)] +pub enum OutputFormat { + /// Plain concatenated text. + Txt, + /// `{"text", "language", "duration", "segments":[…]}` JSON. + Json, + /// SubRip subtitles. + Srt, + /// WebVTT subtitles. + Vtt, +} + +impl OutputFormat { + /// Parse a format string (`txt|json|srt|vtt`), returning an error message on + /// unrecognised input. + pub fn parse(s: &str) -> Result { + match s { + "txt" | "text" => Ok(Self::Txt), + "json" => Ok(Self::Json), + "srt" => Ok(Self::Srt), + "vtt" => Ok(Self::Vtt), + other => Err(format!( + "unsupported format '{other}'; must be one of: txt, json, srt, vtt" + )), + } + } +} + +/// Render a transcription to the requested format string. +#[must_use] +pub fn render(t: &Transcription, fmt: OutputFormat) -> String { + match fmt { + OutputFormat::Txt => t.text.clone(), + OutputFormat::Json => render_json(t), + OutputFormat::Srt => render_srt(&t.segments), + OutputFormat::Vtt => render_vtt(&t.segments), + } +} + +fn render_json(t: &Transcription) -> String { + let segs: Vec = t + .segments + .iter() + .enumerate() + .map(|(i, s)| { + serde_json::json!({ + "id": i, + "start": s.start, + "end": s.end, + "text": s.text, + }) + }) + .collect(); + serde_json::json!({ + "text": t.text, + "language": t.language, + "duration": t.duration, + "segments": segs, + }) + .to_string() +} + +/// Format seconds as `HH:MM:SS,mmm` (SRT) or `HH:MM:SS.mmm` (VTT). +fn fmt_time(secs: f32, comma: bool) -> String { + let total_ms = (secs.max(0.0) * 1000.0).round() as u64; + let ms = total_ms % 1000; + let total_s = total_ms / 1000; + let s = total_s % 60; + let m = (total_s / 60) % 60; + let h = total_s / 3600; + let sep = if comma { ',' } else { '.' }; + format!("{h:02}:{m:02}:{s:02}{sep}{ms:03}") +} + +fn render_srt(segments: &[Segment]) -> String { + use std::fmt::Write as _; + let mut out = String::new(); + for (i, s) in segments.iter().enumerate() { + let _ = write!( + out, + "{}\n{} --> {}\n{}\n\n", + i + 1, + fmt_time(s.start, true), + fmt_time(s.end, true), + s.text + ); + } + out +} + +fn render_vtt(segments: &[Segment]) -> String { + use std::fmt::Write as _; + let mut out = String::from("WEBVTT\n\n"); + for s in segments { + let _ = write!( + out, + "{} --> {}\n{}\n\n", + fmt_time(s.start, false), + fmt_time(s.end, false), + s.text + ); + } + out +} + +// ── 48 kHz / stereo → 16 kHz mono resampler ───────────────────────────────── + +/// Resample mono f32 `samples` from `src_rate` Hz to 16 kHz using linear +/// interpolation. +/// +/// Whisper is forgiving of resampler quality (its mel front-end is robust), so a +/// linear resampler is sufficient and avoids a new dependency. Stereo downmix is +/// already handled by [`crate::wav::WavDecoder`] (channel average → mono). +#[must_use] +pub fn resample_to_16k(samples: &[f32], src_rate: u32) -> Vec { + if src_rate == SAMPLE_RATE || samples.is_empty() { + return samples.to_vec(); + } + let ratio = f64::from(SAMPLE_RATE) / f64::from(src_rate); + let out_len = ((samples.len() as f64) * ratio).round() as usize; + let mut out = Vec::with_capacity(out_len); + let last = samples.len() - 1; + for i in 0..out_len { + // Position in the source signal. + let src_pos = i as f64 / ratio; + let idx = src_pos.floor() as usize; + if idx >= last { + out.push(*samples.last().unwrap_or(&0.0)); + continue; + } + let frac = (src_pos - idx as f64) as f32; + #[allow( + clippy::indexing_slicing, + reason = "idx < last guaranteed by the branch above; idx+1 <= last" + )] + let a = samples[idx]; + #[allow( + clippy::indexing_slicing, + reason = "idx + 1 <= last (idx < last) is in bounds" + )] + let b = samples[idx + 1]; + out.push(a + (b - a) * frac); + } + out +} + +#[cfg(test)] +#[path = "transcribe_tests.rs"] +mod tests; diff --git a/crates/rmlx-audio/src/transcribe_tests.rs b/crates/rmlx-audio/src/transcribe_tests.rs new file mode 100644 index 0000000..7570dd6 --- /dev/null +++ b/crates/rmlx-audio/src/transcribe_tests.rs @@ -0,0 +1,152 @@ +//! Unit tests for the long-form transcription engine helpers (no model needed). + +use super::{ + fmt_time, previous_text_cap, render_srt, render_vtt, resample_to_16k, window_token_budget, + OutputFormat, Segment, +}; + +/// Tokens prepended by the seek loop before the previous-text prompt: the +/// `<|startofprev|>` marker (1) plus the timestamp-mode SOT prefix +/// `[<|sot|>, , ]` (3). +const SOT_PREV_MARKER: usize = 1; +const SOT_PREFIX_LEN: usize = 3; + +/// Regression for the positional-embedding overflow: a window with a FULL +/// previous-text prompt + the largest generation budget must never request a +/// decoder position `>= n_text_ctx`, otherwise the positional-embedding slice +/// `[offset, offset+seq)` runs off the `[n_text_ctx, n_state]` table and aborts +/// the whole transcription. Model-free: it exercises the runtime bound formulas +/// directly for every realistic `n_text_ctx`. +#[test] +fn decode_budget_never_overruns_positional_table() { + // large-v3 is 448; tiny is 448 too (all Whisper variants share n_text_ctx=448), + // but sweep a range of realistic context lengths to keep the bound general. + // Values below ~8 are degenerate (the fixed 4-token SOT_PREV+SOT prefix alone + // exceeds them) and no real Whisper config uses them. + for &n_text_ctx in &[8usize, 16, 64, 256, 447, 448, 449, 1024] { + // Worst case: the previous-text prompt is filled to its cap. + let prompt_cap = previous_text_cap(n_text_ctx); + let prefix_len = SOT_PREV_MARKER + prompt_cap + SOT_PREFIX_LEN; + + // The prefix (SOT_PREV + capped prompt + SOT prefix) must itself fit so the + // prefill rows `[0, prefix_len)` stay on the positional table. + assert!( + prefix_len <= n_text_ctx, + "n_text_ctx={n_text_ctx}: prefix_len={prefix_len} overruns the table on prefill" + ); + + // The decoder offset starts at `prefix_len` (prefill) and grows by one per + // generated token. The single largest positional row ever requested is + // `prefix_len + generated - 1`. + let max_tokens = window_token_budget(n_text_ctx, prefix_len); + let max_offset_row = prefix_len + max_tokens; // prefill rows [0,prefix_len) + max_tokens steps + + assert!( + max_offset_row <= n_text_ctx, + "n_text_ctx={n_text_ctx}: prefix_len={prefix_len} + max_tokens={max_tokens} \ + = {max_offset_row} exceeds n_text_ctx — positional-embedding overflow" + ); + // The largest *row index* requested is `max_offset_row - 1` and must be a + // valid index into the `[n_text_ctx, n_state]` table (i.e. `< n_text_ctx`), + // whenever any token is generated. + if max_tokens > 0 { + assert!( + max_offset_row - 1 < n_text_ctx, + "n_text_ctx={n_text_ctx}: last positional row {} is out of bounds", + max_offset_row - 1 + ); + } + } +} + +/// Even without a previous-text prompt (prefix = SOT prefix only), the generation +/// budget must keep the decoder position in bounds. +#[test] +fn decode_budget_no_prompt_in_bounds() { + for &n_text_ctx in &[4usize, 448, 1024] { + let prefix_len = SOT_PREFIX_LEN; // no SOT_PREV / prompt on the first window + let max_tokens = window_token_budget(n_text_ctx, prefix_len); + assert!(prefix_len + max_tokens <= n_text_ctx, "ctx={n_text_ctx}"); + } +} + +/// The budget collapses to zero (rather than underflowing) when the prefix already +/// fills the context — the seek loop then emits nothing for that window instead of +/// crashing. +#[test] +fn decode_budget_saturates_when_prefix_full() { + assert_eq!(window_token_budget(448, 448), 0); + assert_eq!(window_token_budget(448, 1000), 0); + assert_eq!(previous_text_cap(0), 0); + assert_eq!(previous_text_cap(1), 0); + assert_eq!(previous_text_cap(448), 223); +} + +#[test] +fn output_format_parse() { + assert_eq!(OutputFormat::parse("txt").unwrap(), OutputFormat::Txt); + assert_eq!(OutputFormat::parse("text").unwrap(), OutputFormat::Txt); + assert_eq!(OutputFormat::parse("json").unwrap(), OutputFormat::Json); + assert_eq!(OutputFormat::parse("srt").unwrap(), OutputFormat::Srt); + assert_eq!(OutputFormat::parse("vtt").unwrap(), OutputFormat::Vtt); + assert!(OutputFormat::parse("flac").is_err()); +} + +#[test] +fn time_formatting_srt_vtt() { + // 1h 2m 3.456s. + let secs = 3600.0 + 120.0 + 3.456; + assert_eq!(fmt_time(secs, true), "01:02:03,456"); + assert_eq!(fmt_time(secs, false), "01:02:03.456"); + // Zero. + assert_eq!(fmt_time(0.0, false), "00:00:00.000"); + // Negative clamps to zero. + assert_eq!(fmt_time(-5.0, false), "00:00:00.000"); +} + +#[test] +fn srt_vtt_multi_segment() { + let segs = vec![ + Segment { + start: 0.0, + end: 2.5, + text: "hello world".to_owned(), + }, + Segment { + start: 2.5, + end: 5.0, + text: "second line".to_owned(), + }, + ]; + let srt = render_srt(&segs); + assert!(srt.contains("1\n00:00:00,000 --> 00:00:02,500\nhello world")); + assert!(srt.contains("2\n00:00:02,500 --> 00:00:05,000\nsecond line")); + + let vtt = render_vtt(&segs); + assert!(vtt.starts_with("WEBVTT\n\n")); + assert!(vtt.contains("00:00:00.000 --> 00:00:02.500\nhello world")); + assert!(vtt.contains("00:00:02.500 --> 00:00:05.000\nsecond line")); +} + +#[test] +fn resample_identity_when_already_16k() { + let s = vec![0.1_f32, 0.2, 0.3, 0.4]; + assert_eq!(resample_to_16k(&s, 16_000), s); +} + +#[test] +fn resample_downsamples_48k_to_16k() { + // 48k -> 16k should produce ~1/3 the samples. + let s: Vec = (0..4800).map(|i| i as f32 / 4800.0).collect(); + let out = resample_to_16k(&s, 48_000); + // 4800 * 16000/48000 = 1600. + assert!((out.len() as i64 - 1600).abs() <= 1, "len = {}", out.len()); + // Monotone ramp preserved roughly. + assert!(out.first().copied().unwrap_or(1.0) < 0.05); + assert!(out.last().copied().unwrap_or(0.0) > 0.9); +} + +#[test] +fn resample_empty() { + assert!(resample_to_16k(&[], 48_000).is_empty()); +} diff --git a/crates/rmlx-audio/src/whisper.rs b/crates/rmlx-audio/src/whisper.rs index 5d5bd2b..38d796a 100644 --- a/crates/rmlx-audio/src/whisper.rs +++ b/crates/rmlx-audio/src/whisper.rs @@ -51,14 +51,16 @@ use std::path::Path; use rmlx_mlx::{ - add, argmax, concatenate, conv1d, divide, gelu, matmul, multiply, scalar_f32, softmax, - softmax_precise, sqrt, subtract, sum_axis_keepdims, Array, Device, + add, argmax, concatenate, conv1d, divide, gelu, matmul, multiply, scalar_f32, softmax_precise, + sqrt, subtract, sum_axis_keepdims, Array, Device, }; use thiserror::Error; use tracing::{debug, info, instrument, warn}; use crate::npz::WeightMap; -use crate::tokenizer::{TOK_EOT, TOK_NOSPEECH, TOK_TIMESTAMP_BEGIN}; +use crate::tokenizer::{ + TOK_EN, TOK_EOT, TOK_LANG_LAST, TOK_NOSPEECH, TOK_NO_TIMESTAMPS, TOK_TIMESTAMP_BEGIN, +}; // ── Errors ──────────────────────────────────────────────────────────────────── @@ -705,8 +707,9 @@ impl WhisperModel { /// Detect the spoken language from encoder output. /// - /// Runs a single SOT decoder step and returns the argmax over the 99 - /// language tokens (50259–50357). Falls back to English (50259) on error. + /// Runs a single SOT decoder step and returns the argmax over the 100 + /// language tokens (`<|en|>` … `<|yue|>`, ids `TOK_EN ..= TOK_LANG_LAST`). + /// Falls back to English (`TOK_EN`) on error. /// /// Call after `encode_mel`. The returned token id can be passed directly to /// `WhisperTokenizer::sot_sequence_from_lang_tok`. @@ -723,8 +726,8 @@ impl WhisperModel { .decoder .forward(&sot_arr, encoder_out, 0, &[], &[], device)?; // logits: [1, 1, vocab] → slice language range → argmax - let lang_start: i32 = 50_259; - let lang_end: i32 = 50_358; // exclusive — 99 language tokens + let lang_start: i32 = TOK_EN as i32; + let lang_end: i32 = TOK_LANG_LAST as i32 + 1; // exclusive — 100 language tokens let lang_logits = logits.slice(&[0, 0, lang_start], &[1, 1, lang_end], &[1, 1, 1], device)?; let best = argmax( @@ -738,7 +741,7 @@ impl WhisperModel { .to_bytes() .map_err(|e| WhisperError::Mlx(e.to_string()))?; let Some(b4) = bytes.get(..4) else { - return Ok(50_259_u32); // fallback to English + return Ok(TOK_EN); // fallback to English }; let idx = i32::from_le_bytes(b4.try_into().unwrap_or([0u8; 4])); Ok((lang_start + idx) as u32) @@ -775,6 +778,11 @@ impl WhisperModel { /// `sot_sequence`: initial tokens (SOT + lang + task + no_timestamps). /// `max_tokens`: cap on generated tokens. /// `temperature`: 0 = greedy argmax; > 0 = temperature-scaled softmax + argmax. + /// `filters`: per-step logit suppression set (`DecodeFilters`). + /// + /// Returns the full token sequence **including** any timestamp tokens — the + /// caller (long-form chunker) needs them for segmentation. Special / non-speech + /// content tokens are masked out by `filters`, never emitted. #[allow( clippy::explicit_counter_loop, reason = "offset tracks KV-cache position starting from sot_len; not a pure iteration counter" @@ -789,11 +797,14 @@ impl WhisperModel { sot_sequence: &[u32], max_tokens: usize, temperature: f32, + filters: &DecodeFilters, device: Device, ) -> Result, WhisperError> { let mut self_kvs: Vec<(Array, Array)> = Vec::new(); let mut cross_kvs: Vec<(Array, Array)> = Vec::new(); - let mut output_tokens: Vec = Vec::new(); + // Sampled tokens (the suffix after the SOT prefix) — these are what the + // timestamp rules and the caller operate on. + let mut sampled: Vec = Vec::new(); // Prefill the SOT sequence. let sot_i32: Vec = sot_sequence.iter().map(|&t| t as i32).collect(); @@ -816,40 +827,43 @@ impl WhisperModel { )?; let mut offset = sot_len; - // SuppressBlank: at the very first text-generation position (right after the - // SOT sequence), suppress EOT and the blank-space token. This matches Python's - // SuppressBlank logit filter in mlx_whisper/openai-whisper and prevents the - // model from immediately halting on short audio where EOT has the highest raw - // logit (often the case for 2–3 s clips with no leading silence). - let suppressed_logits = suppress_eot_at_prefill(&last_logits, self.cfg.n_vocab, device)?; - let next_tok = sample_next(&suppressed_logits, temperature, device)?; + // First sampled position: apply SuppressBlank (in addition to the standing + // suppression set + timestamp rules). + let mut logit_vec = logits_to_f32(&last_logits, self.cfg.n_vocab, device)?; + filters.apply(&mut logit_vec, &sampled, true); + let mut next_tok = argmax_f32(&logit_vec, temperature); debug!( next_tok, tok_eot = TOK_EOT, tok_nospeech = TOK_NOSPEECH, tok_ts_begin = TOK_TIMESTAMP_BEGIN, - "prefill first token (after suppress-blank)" + "prefill first token (after filters)" ); if next_tok == TOK_NOSPEECH { return Err(WhisperError::Silence); } - // Mirror the in-loop guard: don't push timestamp tokens from prefill. - if next_tok >= TOK_TIMESTAMP_BEGIN { - debug!( - next_tok, - "prefill produced timestamp after suppress-blank; returning empty transcription" - ); - return Ok(output_tokens); // empty - } - output_tokens.push(next_tok); + sampled.push(next_tok); - for _ in 1..max_tokens { - let last = *output_tokens.last().unwrap_or(&TOK_EOT); - if last == TOK_EOT || last >= TOK_TIMESTAMP_BEGIN { + while sampled.len() < max_tokens { + if next_tok == TOK_EOT { + break; + } + // Belt-and-suspenders: never request a positional-embedding row + // `>= n_text_ctx`. The next decode step would slice the positional + // table at `[offset, offset+1)`; once `offset == n_text_ctx` that row + // is off the `[n_text_ctx, n_state]` table and `forward` would abort + // the whole transcription. The caller already bounds `max_tokens`, but + // this guard makes the decode loop self-contained regardless of caller. + if offset >= self.cfg.n_text_ctx { + debug!( + offset, + n_text_ctx = self.cfg.n_text_ctx, + "stopping decode: positional-embedding ceiling reached" + ); break; } - let tok_i32 = [last as i32]; + let tok_i32 = [next_tok as i32]; let tok_arr = Array::from_i32_slice(&tok_i32, &[1, 1]) .map_err(|e| WhisperError::Mlx(e.to_string()))?; @@ -865,108 +879,217 @@ impl WhisperModel { cross_kvs = new_cross; offset += 1; - let tok = sample_next(&step_logits, temperature, device)?; - output_tokens.push(tok); + let mut lv = logits_to_f32(&step_logits, self.cfg.n_vocab, device)?; + filters.apply(&mut lv, &sampled, false); + next_tok = argmax_f32(&lv, temperature); + sampled.push(next_tok); } - while output_tokens.last() == Some(&TOK_EOT) { - output_tokens.pop(); + // Drop trailing EOT — it is a stop marker, not content. + while sampled.last() == Some(&TOK_EOT) { + sampled.pop(); } - debug!(n_tokens = output_tokens.len(), "greedy decode done"); - Ok(output_tokens) + debug!(n_tokens = sampled.len(), "greedy decode done"); + Ok(sampled) } } -/// Argmax (or temperature-scaled softmax + argmax) on the last token's logits. -fn sample_next(logits: &Array, temperature: f32, device: Device) -> Result { - let flat = if temperature > 0.0 { - let scaled = divide(logits, &scalar_f32(temperature), device)?; - softmax(&scaled.reshape(&[-1], device)?, -1, device)? - } else { - logits.reshape(&[-1], device)? - }; - let idx = argmax(&flat, 0, device)?; - // Materialise the scalar array to extract the value. - // Use eval() (synchronous) — async_eval does not guarantee the data pointer - // is ready when to_bytes() accesses it immediately after scheduling. - idx.eval().map_err(WhisperError::from)?; - let bytes = idx.to_bytes().map_err(WhisperError::from)?; - if bytes.len() < 4 { - return Err(WhisperError::Mlx("argmax returned empty bytes".to_owned())); - } - // MLX argmax returns uint32; decode as i32 (matching existing project pattern) - // then widen to u32. - #[allow(clippy::indexing_slicing, reason = "bounds checked: bytes.len() >= 4")] - Ok(i32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]) as u32) +/// Per-step logit suppression for Whisper decode. +/// +/// Holds the standing suppression set (non-speech / special tokens derived from +/// the tokenizer) plus a `timestamps` mode flag. Applied to a host-side f32 logit +/// vector at **every** decode step — this is the openai-whisper / `mlx_whisper` +/// `LogitFilter` chain (`SuppressBlank` + `SuppressTokens` + timestamp handling), +/// the absence of which made rMLX decode halt or emit garbage. +#[derive(Debug, Clone)] +pub struct DecodeFilters { + /// Standing suppression set (always masked). + suppress: Vec, + /// `true` => timestamp mode (emit timestamp tokens, apply pairing rules). + /// `false` => `no_timestamps` mode (suppress all timestamp tokens). + timestamps: bool, + /// First sampled token of the SOT prefix that begins generation (always 0 + /// here because `sampled` is the post-prefix suffix). + blank_suppress: Vec, } -/// Suppress EOT (and blank-space) at the first text-generation step. -/// -/// Matches Python's `SuppressBlank` logit filter: when sampling the very first -/// output token after the SOT sequence, EOT (`<|endoftext|>`) and the blank-space -/// token would be spuriously predicted on short audio. Set their logit to -/// `-1e9` so argmax picks a real text token instead. -/// -/// The logits tensor has shape `[1, 1, n_vocab]` (F16 or F32). We materialise, -/// patch the EOT byte offset, and reconstruct — done once per decode sequence. -#[allow( - clippy::too_many_lines, - reason = "suppress_eot_at_prefill is a single linear operation; splitting adds no clarity" -)] -fn suppress_eot_at_prefill( - logits: &Array, - n_vocab: usize, - device: Device, -) -> Result { - use rmlx_mlx::Dtype; +impl DecodeFilters { + /// Build the filter chain. + /// + /// `suppress` is the tokenizer-derived non-speech/special set + /// (`WhisperTokenizer::suppress_tokens`). `blank_ids` are the SuppressBlank + /// ids (EOT + the blank-space token, `tokenizer.encode(" ")`). `timestamps` + /// selects timestamp vs `no_timestamps` mode. + #[must_use] + pub fn new(suppress: Vec, blank_ids: Vec, timestamps: bool) -> Self { + Self { + suppress, + timestamps, + blank_suppress: blank_ids, + } + } - let flat = logits.reshape(&[-1], device)?; - flat.eval().map_err(WhisperError::from)?; + /// Apply the filter chain in-place to a host f32 logit vector of length + /// `n_vocab`. + /// + /// - `sampled`: tokens sampled so far (post-SOT-prefix). + /// - `first_step`: `true` only for the very first sampled position + /// (enables SuppressBlank). + #[allow( + clippy::indexing_slicing, + reason = "all token ids are bounds-checked against logits.len() before indexing" + )] + pub fn apply(&self, logits: &mut [f32], sampled: &[u32], first_step: bool) { + let n = logits.len(); + let mask = |logits: &mut [f32], tok: u32| { + let t = tok as usize; + if t < n { + logits[t] = f32::NEG_INFINITY; + } + }; - let dtype = flat.dtype(); - let mut bytes = flat.to_bytes().map_err(WhisperError::from)?; + // SuppressBlank — first sampled position only. + if first_step { + for &t in &self.blank_suppress { + mask(logits, t); + } + } - // Suppress EOT. Blank-space token for Whisper is typically token id 220 - // (" " as a standalone piece in the BPE vocabulary); suppress it too, matching - // Python's SuppressBlank which sets `mask[tokenizer.encode(" ") + [tokenizer.eot]] = -inf`. - let suppress: &[u32] = &[TOK_EOT, 220]; + // SuppressTokens — non-speech / special set, every step. + for &t in &self.suppress { + mask(logits, t); + } - for &tok in suppress { - if tok as usize >= n_vocab { - continue; + let ts_begin = TOK_TIMESTAMP_BEGIN as usize; + if !self.timestamps { + // no_timestamps mode: suppress every timestamp token. + for l in logits.iter_mut().skip(ts_begin) { + *l = f32::NEG_INFINITY; + } + return; } - match dtype { - Dtype::F32 => { - let byte_off = tok as usize * 4; - if byte_off + 4 <= bytes.len() { - let val: f32 = -1e9; - #[allow( - clippy::indexing_slicing, - reason = "byte_off + 4 <= bytes.len() checked above" - )] - bytes[byte_off..byte_off + 4].copy_from_slice(&val.to_le_bytes()); + + // Timestamp mode: faithful port of openai-whisper `ApplyTimestampRules`. + // <|notimestamps|> can never be sampled here. + mask(logits, TOK_NO_TIMESTAMPS); + + let len = sampled.len(); + let last_was_ts = sampled.last().is_some_and(|&t| t as usize >= ts_begin); + // NOTE: `penultimate_was_timestamp` is TRUE when fewer than 2 tokens have + // been sampled (`len(seq) < 2`) — this is the reference semantics. Getting + // this wrong (treating len<2 as false) strands the decoder into EOT right + // after the opening `<|0.00|>` and yields empty transcripts. + let penult_was_ts = len < 2 + || sampled + .get(len.wrapping_sub(2)) + .is_some_and(|&t| t as usize >= ts_begin); + + if last_was_ts { + if penult_was_ts { + // Timestamps must pair → the next token has to be NON-timestamp. + for l in logits.iter_mut().skip(ts_begin) { + *l = f32::NEG_INFINITY; + } + } else { + // A single open timestamp → the next token cannot be normal text + // (only a closing timestamp or EOT). Mask everything below EOT. + let eot = TOK_EOT as usize; + for l in logits.iter_mut().take(eot.min(n)) { + *l = f32::NEG_INFINITY; } } - Dtype::F16 => { - let byte_off = tok as usize * 2; - if byte_off + 2 <= bytes.len() { - // -inf in float16 = 0xFC00 (sign=1, exp=11111, mantissa=0). - let neg_inf_f16: u16 = 0xFC00_u16; - #[allow( - clippy::indexing_slicing, - reason = "byte_off + 2 <= bytes.len() checked above" - )] - bytes[byte_off..byte_off + 2].copy_from_slice(&neg_inf_f16.to_le_bytes()); + } + + // Monotonic timestamps: forbid timestamps smaller than the last one, and + // force nonzero segment length. + if let Some(&last_ts) = sampled.iter().rev().find(|&&t| t as usize >= ts_begin) { + let last_ts = last_ts as usize; + let upper = if last_was_ts && !penult_was_ts { + last_ts // allow re-emitting the same close timestamp + } else { + last_ts + 1 + }; + let upper = upper.min(n); + for (t, l) in logits.iter_mut().enumerate().take(upper) { + if t >= ts_begin { + *l = f32::NEG_INFINITY; } } - // Whisper weights are F16; F32 is used in tests. Bfloat16 / integer types - // do not appear in the Whisper logit tensor; skip them silently. - Dtype::Bf16 | Dtype::U8 | Dtype::U32 | Dtype::I32 => {} + } + + // Beginning-of-sequence: the first sampled token must be a timestamp. + if sampled.is_empty() { + for l in logits.iter_mut().take(ts_begin.min(n)) { + *l = f32::NEG_INFINITY; + } + } + + // Tie-break: if the cumulative probability mass over all timestamp tokens + // exceeds the best single text-token probability, suppress text entirely + // (force a timestamp). Mirrors the `log_softmax` comparison in the + // reference; computed on the post-mask logits. + let ts_logsumexp = logsumexp(&logits[ts_begin.min(n)..]); + let max_text = logits + .iter() + .take(ts_begin.min(n)) + .copied() + .fold(f32::NEG_INFINITY, f32::max); + if ts_logsumexp > max_text { + for l in logits.iter_mut().take(ts_begin.min(n)) { + *l = f32::NEG_INFINITY; + } } } +} - Array::from_bytes(&bytes, &[n_vocab as i32], dtype) - .map_err(|e| WhisperError::Mlx(e.to_string())) +/// Numerically-stable log-sum-exp over a slice (ignores `-inf` entries). +fn logsumexp(xs: &[f32]) -> f32 { + let max = xs.iter().copied().fold(f32::NEG_INFINITY, f32::max); + if !max.is_finite() { + return f32::NEG_INFINITY; + } + let sum: f32 = xs.iter().map(|&x| (x - max).exp()).sum(); + max + sum.ln() +} + +/// Materialise a `[1, 1, n_vocab]` (or `[1, n_vocab]`) logit tensor (F16/F32/Bf16) +/// into a host-side f32 vector of length `n_vocab`. +fn logits_to_f32(logits: &Array, n_vocab: usize, device: Device) -> Result, WhisperError> { + use rmlx_mlx::Dtype; + let flat = logits.reshape(&[-1], device)?.astype(Dtype::F32, device)?; + flat.eval().map_err(WhisperError::from)?; + let bytes = flat.to_bytes().map_err(WhisperError::from)?; + let mut out = Vec::with_capacity(n_vocab); + #[allow( + clippy::indexing_slicing, + reason = "chunks_exact(4) yields exactly 4 bytes; take(n_vocab) bounds the count" + )] + for chunk in bytes.chunks_exact(4).take(n_vocab) { + out.push(f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]])); + } + if out.len() < n_vocab { + out.resize(n_vocab, f32::NEG_INFINITY); + } + Ok(out) +} + +/// Greedy argmax over a host f32 logit vector. +/// +/// `temperature` is accepted for parity with the API; at temp == 0 (the only +/// deterministic mode rMLX serves) this is a plain argmax. Temperature scaling +/// changes the softmax distribution but not the argmax, so for greedy decode the +/// result is identical — we keep the param to avoid an API break and document +/// that sampling (temp > 0) is not yet a stochastic path. +fn argmax_f32(logits: &[f32], _temperature: f32) -> u32 { + let mut best_idx = 0usize; + let mut best_val = f32::NEG_INFINITY; + for (i, &v) in logits.iter().enumerate() { + if v > best_val { + best_val = v; + best_idx = i; + } + } + best_idx as u32 } // ── Tests ───────────────────────────────────────────────────────────────────── diff --git a/crates/rmlx-audio/src/whisper_tests.rs b/crates/rmlx-audio/src/whisper_tests.rs index f4944fb..8b47b68 100644 --- a/crates/rmlx-audio/src/whisper_tests.rs +++ b/crates/rmlx-audio/src/whisper_tests.rs @@ -1,13 +1,83 @@ //! Whisper model unit tests. //! -//! Full model tests (load + transcribe) are integration tests under `tests/` -//! and are gated by the `RMLX_TEST_MODEL_WHISPER` env var. The tests here -//! verify config parsing, NPZ parsing logic, and weight-map helpers. +//! Full model tests (load + transcribe) are integration tests under +//! `crates/rmlx-audio/tests/transcribe.rs`. They resolve the Whisper snapshot +//! from `RMLX_O_MODELS_ROOT` auto-discovery (the same convention as +//! `make model-check-full`) and skip gracefully when the model or fixtures are +//! absent — no bespoke env var. The tests here verify config parsing, NPZ +//! parsing logic, and weight-map helpers. -use super::WhisperConfig; +use super::{DecodeFilters, WhisperConfig}; use crate::npz::{extract_npy_dtype, extract_npy_shape}; +use crate::tokenizer::{TOK_EOT, TOK_NO_TIMESTAMPS, TOK_TIMESTAMP_BEGIN}; use rmlx_mlx::Dtype; +const VOCAB: usize = 51_866; + +fn fresh_logits() -> Vec { + // All zeros so we can see exactly which entries get masked to -inf. + vec![0.0_f32; VOCAB] +} + +/// SuppressBlank masks EOT + blank ids at the first step only. +#[test] +fn suppress_blank_first_step_only() { + let blank = vec![TOK_EOT, 220]; + let f = DecodeFilters::new(vec![], blank, false); + let mut l = fresh_logits(); + f.apply(&mut l, &[], true); + assert!(l[TOK_EOT as usize].is_infinite()); + assert!(l[220].is_infinite()); + + // Not the first step: blank ids are no longer suppressed (EOT allowed). + let mut l2 = fresh_logits(); + f.apply(&mut l2, &[5_u32], false); + assert!(l2[TOK_EOT as usize].is_finite()); +} + +/// no_timestamps mode masks every timestamp token and notimestamps. +#[test] +fn no_timestamps_mode_masks_all_timestamps() { + let f = DecodeFilters::new(vec![], vec![TOK_EOT], false); + let mut l = fresh_logits(); + f.apply(&mut l, &[5_u32], false); + assert!(l[TOK_TIMESTAMP_BEGIN as usize].is_infinite()); + assert!(l[VOCAB - 1].is_infinite()); + // A plain text token stays finite. + assert!(l[100].is_finite()); +} + +/// Timestamp mode: the first sampled token must be a timestamp (BOS rule). +#[test] +fn timestamp_mode_bos_forces_timestamp() { + let f = DecodeFilters::new(vec![], vec![TOK_EOT], true); + let mut l = fresh_logits(); + f.apply(&mut l, &[], true); + // All text < timestamp_begin masked. + assert!(l[100].is_infinite()); + assert!(l[TOK_NO_TIMESTAMPS as usize].is_infinite()); + // Timestamps remain available. + assert!(l[TOK_TIMESTAMP_BEGIN as usize].is_finite()); +} + +/// Timestamp mode: right after a single opening timestamp the model must be +/// able to emit TEXT (the penultimate<2 ⇒ treated-as-timestamp branch forces a +/// non-timestamp next). This is the bug that produced empty transcripts. +#[test] +fn timestamp_mode_after_open_ts_allows_text() { + let f = DecodeFilters::new(vec![], vec![TOK_EOT], true); + let mut l = fresh_logits(); + // One opening timestamp sampled. + f.apply(&mut l, &[TOK_TIMESTAMP_BEGIN], false); + // Text tokens must remain selectable (NOT all masked → not forced to EOT). + assert!( + l[100].is_finite(), + "after a single opening timestamp the decoder must be able to emit text" + ); + // And further timestamps are masked (must pair with non-timestamp first). + assert!(l[TOK_TIMESTAMP_BEGIN as usize].is_infinite()); +} + /// Config JSON parses correctly for the large-v3 snapshot. #[test] fn parse_config_large_v3() { @@ -80,5 +150,6 @@ fn npy_shape_scalar() { assert_eq!(shape, Vec::::new()); } -// NOTE: Full smoke-probe test (WhisperModel::load → transcribe 1 s WAV) lives -// in crates/rmlx-audio/tests/smoke.rs and is gated by RMLX_TEST_MODEL_WHISPER. +// NOTE: Full smoke-probe + long-form regression tests (WhisperModel::load → +// transcribe) live in crates/rmlx-audio/tests/transcribe.rs and resolve the +// snapshot from RMLX_O_MODELS_ROOT auto-discovery (skip-if-absent). diff --git a/crates/rmlx-audio/tests/transcribe.rs b/crates/rmlx-audio/tests/transcribe.rs new file mode 100644 index 0000000..fb76d5d --- /dev/null +++ b/crates/rmlx-audio/tests/transcribe.rs @@ -0,0 +1,340 @@ +//! Real-model Whisper transcription integration tests. +//! +//! These tests resolve the Whisper snapshot + tokenizer from `RMLX_O_MODELS_ROOT` +//! auto-discovery (the `make model-check-full` convention) and **skip gracefully** +//! when the model, tokenizer, or fixtures are absent — there is no bespoke env +//! var. Two layers: +//! +//! 1. `say_clip_deterministic` — portable, asset-free: synthesise a known English +//! sentence with macOS `say` + `ffmpeg`, transcribe it, assert it matches the +//! sentence (low WER, case/punct-insensitive) and is byte-identical across two +//! runs at temp=0. Skips when `say`/`ffmpeg` are unavailable. +//! +//! 2. `long_form_regression` — scans the gitignored fixtures dir +//! (`crates/rmlx-audio/tests/fixtures/`) for any `*.{m4a,wav,mp3,…}` paired +//! with a sibling `*.transcript.vtt`; transcribes the FULL file and asserts a +//! normalized WER ≤ threshold. Generic: any user drops their own audio + VTT. +//! +//! Single-MLX discipline: these load the model in-process. Do not run them +//! concurrently with a live `rmlx serve`. + +#![allow( + clippy::unwrap_used, + clippy::expect_used, + clippy::print_stdout, + clippy::print_stderr, + clippy::indexing_slicing, + clippy::bool_to_int_with_if, + clippy::map_unwrap_or, + clippy::cast_precision_loss, + clippy::cast_possible_truncation, + clippy::float_cmp, + clippy::missing_panics_doc, + clippy::items_after_statements, + clippy::needless_range_loop +)] + +use std::path::{Path, PathBuf}; +use std::sync::Arc; + +use rmlx_audio::tokenizer::{WhisperTask, WhisperTokenizer}; +use rmlx_audio::transcribe::{TranscribeOptions, Transcriber}; +use rmlx_audio::wav::WavDecoder; +use rmlx_audio::whisper::WhisperModel; +use rmlx_mlx::Device; + +// ── Snapshot resolution ───────────────────────────────────────────────────── + +const WHISPER_SLUG: &str = "mlx-community__whisper-large-v3-mlx"; +const TOKENIZER_SLUG: &str = "openai__whisper-large-v3-tokenizer"; + +fn models_root() -> Option { + let root = std::env::var("RMLX_O_MODELS_ROOT").ok()?; + let pb = PathBuf::from(root); + pb.exists().then_some(pb) +} + +fn whisper_paths() -> Option<(PathBuf, PathBuf)> { + let root = models_root()?; + let model = root.join(WHISPER_SLUG); + let tok = root.join(TOKENIZER_SLUG); + if model.join("config.json").exists() && tok.join("tokenizer.json").exists() { + Some((model, tok)) + } else { + None + } +} + +fn load_transcriber() -> Option { + let (model_path, tok_path) = whisper_paths()?; + rmlx_mlx::ensure_gpu_default_stream(); + let model = WhisperModel::load(&model_path).expect("load whisper model"); + let tokenizer = WhisperTokenizer::from_path(&tok_path).expect("load tokenizer"); + Some(Transcriber::new(Arc::new(model), Arc::new(tokenizer)).expect("transcriber")) +} + +// ── WER + normalization ───────────────────────────────────────────────────── + +/// Normalize a transcript for WER: lowercase, strip punctuation, collapse +/// whitespace. Used on BOTH hypothesis and reference. +fn normalize(text: &str) -> Vec { + text.chars() + .map(|c| { + if c.is_alphanumeric() || c.is_whitespace() { + c.to_ascii_lowercase() + } else { + ' ' + } + }) + .collect::() + .split_whitespace() + .map(str::to_owned) + .collect() +} + +/// Word-level edit distance (Levenshtein) over token vectors. +fn edit_distance(a: &[String], b: &[String]) -> usize { + let n = a.len(); + let m = b.len(); + if n == 0 { + return m; + } + if m == 0 { + return n; + } + let mut prev: Vec = (0..=m).collect(); + let mut cur = vec![0usize; m + 1]; + for i in 1..=n { + cur[0] = i; + for j in 1..=m { + let cost = if a[i - 1] == b[j - 1] { 0 } else { 1 }; + cur[j] = (prev[j] + 1).min(cur[j - 1] + 1).min(prev[j - 1] + cost); + } + std::mem::swap(&mut prev, &mut cur); + } + prev[m] +} + +fn wer(reference: &[String], hypothesis: &[String]) -> f64 { + if reference.is_empty() { + return if hypothesis.is_empty() { 0.0 } else { 1.0 }; + } + edit_distance(reference, hypothesis) as f64 / reference.len() as f64 +} + +/// Parse a WEBVTT file, stripping cue numbers, timestamps, and `Speaker Name: ` +/// prefixes; return the concatenated reference text. +fn parse_vtt_reference(path: &Path) -> String { + let raw = std::fs::read_to_string(path).expect("read vtt"); + let mut out: Vec = Vec::new(); + for line in raw.lines() { + let line = line.trim(); + if line.is_empty() + || line == "WEBVTT" + || line.contains("-->") + || line.chars().all(|c| c.is_ascii_digit()) + { + continue; + } + // Strip leading "Speaker Name: " — the speaker label ends at the first + // ": " that precedes alpha content. Only strip when a colon appears + // reasonably early (a name), not mid-sentence. + let text = match line.find(": ") { + Some(idx) if idx < 40 => &line[idx + 2..], + _ => line, + }; + out.push(text.to_owned()); + } + out.join(" ") +} + +// ── say-clip helpers ──────────────────────────────────────────────────────── + +fn have_cmd(cmd: &str) -> bool { + std::process::Command::new("which") + .arg(cmd) + .output() + .map(|o| o.status.success()) + .unwrap_or(false) +} + +/// Synthesise a known sentence to a 16 kHz mono WAV via `say` + `ffmpeg`. +fn synth_say_clip(sentence: &str) -> Option { + if !have_cmd("say") || !have_cmd("ffmpeg") { + return None; + } + let dir = std::env::temp_dir(); + let aiff = dir.join("rmlx_say_clip.aiff"); + let wav = dir.join("rmlx_say_clip.wav"); + let say_ok = std::process::Command::new("say") + .args(["-o", aiff.to_str()?, sentence]) + .status() + .map(|s| s.success()) + .unwrap_or(false); + if !say_ok { + return None; + } + let ff_ok = std::process::Command::new("ffmpeg") + .args([ + "-y", + "-i", + aiff.to_str()?, + "-ac", + "1", + "-ar", + "16000", + "-c:a", + "pcm_s16le", + wav.to_str()?, + ]) + .output() + .map(|o| o.status.success()) + .unwrap_or(false); + ff_ok.then_some(wav) +} + +fn transcribe_file( + t: &Transcriber, + path: &Path, + language: &str, +) -> rmlx_audio::transcribe::Transcription { + let bytes = std::fs::read(path).expect("read audio"); + let (raw, rate) = WavDecoder::decode(&bytes).expect("decode audio"); + let samples = rmlx_audio::transcribe::resample_to_16k(&raw, rate); + let opts = TranscribeOptions { + language: language.to_owned(), + task: WhisperTask::Transcribe, + temperature: 0.0, + condition_on_previous_text: true, + }; + t.transcribe(&samples, &opts, Device::Gpu) + .expect("transcribe") +} + +// ── Tests ─────────────────────────────────────────────────────────────────── + +/// Portable smoke: a known `say` sentence transcribes correctly and is +/// deterministic across two runs at temp=0. +#[test] +fn say_clip_deterministic() { + let Some(t) = load_transcriber() else { + eprintln!("skip say_clip_deterministic: whisper snapshot not present"); + return; + }; + let sentence = "The quick brown fox jumps over the lazy dog."; + let Some(wav) = synth_say_clip(sentence) else { + eprintln!("skip say_clip_deterministic: say/ffmpeg unavailable"); + return; + }; + + let r1 = transcribe_file(&t, &wav, "en"); + let r2 = transcribe_file(&t, &wav, "en"); + + // Determinism: identical text across runs at temp=0. + assert_eq!( + r1.text, r2.text, + "transcription not deterministic at temp=0:\n run1: {}\n run2: {}", + r1.text, r2.text + ); + + let reference = normalize(sentence); + let hyp = normalize(&r1.text); + let w = wer(&reference, &hyp); + println!("say-clip text: {:?} WER={w:.3}", r1.text); + // Threshold 0.25: the sentence content must be correct. Whisper is known to + // append a single filler token ("you", "thank you") on the trailing-silence + // boundary of a very short clip; one such word over a 9-word sentence is + // ~0.11 WER and does not indicate a decode defect. Determinism (above) is the + // stricter property being asserted. + assert!( + w <= 0.25, + "say-clip WER {w:.3} too high (>0.25); got {:?}", + r1.text + ); +} + +/// Long-form regression: transcribe every fixture audio with a sibling VTT and +/// assert normalized WER ≤ threshold on the FULL file. +#[test] +fn long_form_regression() { + let fixtures = Path::new(env!("CARGO_MANIFEST_DIR")).join("tests/fixtures"); + if !fixtures.exists() { + eprintln!("skip long_form_regression: no fixtures dir"); + return; + } + let Some(t) = load_transcriber() else { + eprintln!("skip long_form_regression: whisper snapshot not present"); + return; + }; + + const AUDIO_EXTS: &[&str] = &["m4a", "wav", "mp3", "flac", "ogg", "aac"]; + let mut ran = 0usize; + for entry in std::fs::read_dir(&fixtures).expect("read fixtures") { + let path = entry.expect("dir entry").path(); + let ext = path + .extension() + .and_then(|e| e.to_str()) + .map(str::to_lowercase) + .unwrap_or_default(); + if !AUDIO_EXTS.contains(&ext.as_str()) { + continue; + } + // Sibling reference: .transcript.vtt (stem strips the audio ext). + let vtt = path.with_extension("transcript.vtt"); + if !vtt.exists() { + continue; + } + + let t0 = std::time::Instant::now(); + let result = transcribe_file(&t, &path, "en"); + let elapsed = t0.elapsed().as_secs_f64(); + let rtf = if result.duration > 0.0 { + elapsed / f64::from(result.duration) + } else { + 0.0 + }; + + let reference = normalize(&parse_vtt_reference(&vtt)); + let hyp = normalize(&result.text); + let w = wer(&reference, &hyp); + + println!( + "== {} ==\n duration={:.1}s segments={} decode={:.1}s RTF={rtf:.3}\n ref_words={} hyp_words={} WER={w:.4}", + path.file_name().unwrap().to_string_lossy(), + result.duration, + result.segments.len(), + elapsed, + reference.len(), + hyp.len(), + ); + // Print a few aligned excerpts. + for i in [0usize, hyp.len() / 3, 2 * hyp.len() / 3] { + let r: String = reference + .iter() + .skip(i) + .take(12) + .cloned() + .collect::>() + .join(" "); + let h: String = hyp + .iter() + .skip(i) + .take(12) + .cloned() + .collect::>() + .join(" "); + println!(" ref@{i}: {r}\n hyp@{i}: {h}"); + } + + assert!( + w <= 0.30, + "{}: normalized WER {w:.4} exceeds 0.30", + path.file_name().unwrap().to_string_lossy() + ); + ran += 1; + } + + if ran == 0 { + eprintln!("skip long_form_regression: no fixture audio+VTT pairs found"); + } +} diff --git a/crates/rmlx-cli/Cargo.toml b/crates/rmlx-cli/Cargo.toml index 26a879c..12d87f4 100644 --- a/crates/rmlx-cli/Cargo.toml +++ b/crates/rmlx-cli/Cargo.toml @@ -20,6 +20,7 @@ name = "fused_qk_realmodel_probe" path = "examples/fused_qk_realmodel_probe.rs" [dependencies] +rmlx-audio = { workspace = true } rmlx-core = { workspace = true } rmlx-kv-quant = { workspace = true } rmlx-kv-ssd = { workspace = true } diff --git a/crates/rmlx-cli/src/commands/mod.rs b/crates/rmlx-cli/src/commands/mod.rs index 0668b8e..e61ecfc 100644 --- a/crates/rmlx-cli/src/commands/mod.rs +++ b/crates/rmlx-cli/src/commands/mod.rs @@ -11,6 +11,7 @@ pub(crate) mod parse; pub(crate) mod preset_table; pub(crate) mod profile; pub(crate) mod serve; +pub(crate) mod transcribe; pub(crate) use baseline::run_baseline; pub(crate) use eval::run_ppl; diff --git a/crates/rmlx-cli/src/commands/transcribe.rs b/crates/rmlx-cli/src/commands/transcribe.rs new file mode 100644 index 0000000..98d913c --- /dev/null +++ b/crates/rmlx-cli/src/commands/transcribe.rs @@ -0,0 +1,140 @@ +//! `rmlx transcribe