Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -84,3 +84,9 @@ scripts/release_e2e/stage6_perf/last_run.json

# Repo-local model-snapshot dir (RMLX_O_MODELS_ROOT fallback). Drop snapshots here.
/models/

# Local audio test fixtures (real recordings + reference transcripts). The
# long-form regression test (crates/rmlx-audio/tests/transcribe.rs) scans this
# dir for `*.{m4a,wav,mp3,…}` paired with a sibling `*.transcript.vtt`. Drop
# your own audio + VTT here; nothing in it is tracked.
/crates/rmlx-audio/tests/fixtures/
38 changes: 38 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

19 changes: 11 additions & 8 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -123,14 +123,17 @@ parking_lot = "0.12"
# Default features fine (SIMD acceleration on aarch64 is always-on in 6.x).
rustfft = "6.4"

# Pure-Rust audio decode for WAV + MP3.
# wav — RIFF/WAVE container format reader (symphonia-format-riff).
# pcm — PCM codec required to decode the audio inside WAV files
# (symphonia-codec-pcm; NOT included by wav alone).
# mp3 — MPEG Layer 3 codec (symphonia-bundle-mp3/mp3).
# default-features=false drops AIFF/OGG/FLAC/AAC/MKV/ADPCM from the default
# registry so only the three features we need are compiled in.
symphonia = { version = "0.6", default-features = false, features = ["wav", "pcm", "mp3"] }
# Pure-Rust audio decode for WAV + MP3 + AAC/MP4 (`.m4a`).
# wav — RIFF/WAVE container format reader (symphonia-format-riff).
# pcm — PCM codec required to decode the audio inside WAV files
# (symphonia-codec-pcm; NOT included by wav alone).
# mp3 — MPEG Layer 3 codec (symphonia-bundle-mp3/mp3).
# isomp4 — ISO-BMFF / MP4 container reader (needed for `.m4a`, `.mp4`).
# aac — AAC codec (the audio inside most `.m4a` recordings).
# These are all features of the existing `symphonia` dep — no new crate. They
# let `rmlx transcribe` / the audio endpoint ingest real meeting recordings
# (`.m4a`) directly. default-features=false still drops AIFF/OGG/FLAC/MKV/ADPCM.
symphonia = { version = "0.6", default-features = false, features = ["wav", "pcm", "mp3", "isomp4", "aac"] }

# Testing
tempfile = "3"
Expand Down
4 changes: 4 additions & 0 deletions crates/rmlx-audio/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,14 @@
pub mod mel;
pub mod npz;
pub mod tokenizer;
pub mod transcribe;
pub mod tts;
pub mod vad;
pub mod wav;
pub mod whisper;

pub use transcribe::{
render, resample_to_16k, OutputFormat, Segment, TranscribeOptions, Transcriber, Transcription,
};
pub use vad::{voiced_segments, SileroVad, VadState};
pub use wav::{WavDecoder, WavEncoder};
16 changes: 14 additions & 2 deletions crates/rmlx-audio/src/npz.rs
Original file line number Diff line number Diff line change
Expand Up @@ -604,9 +604,11 @@ pub fn parse_npy_array(name: &str, data: &[u8]) -> Result<Array, NpzError> {
let raw = &data[header_end..];
let n_elems: usize = shape.iter().product();
let elem_bytes = match dtype {
Dtype::U8 => 1,
Dtype::F16 => 2,
Dtype::F32 => 4,
_ => {
Dtype::F32 | Dtype::U32 | Dtype::I32 => 4,
Dtype::Bf16 => {
// extract_npy_dtype never yields Bf16, but the match must be total.
return Err(npy_err(&format!("unsupported dtype {dtype:?}")));
}
};
Expand All @@ -625,6 +627,12 @@ pub fn parse_npy_array(name: &str, data: &[u8]) -> Result<Array, NpzError> {
// ── NPY header field extractors (pub for tests) ───────────────────────────────

/// Extract the numpy dtype descriptor from a `.npy` header string.
///
/// Whisper `weights.npz` carries float weights (`f2`/`f4`) plus a small
/// `alignment_heads` mask. That mask is stored as a NumPy boolean (`b1`) or
/// small-int array; older parsers hard-errored on it with
/// "cannot parse dtype". We map booleans / 1-byte ints to `U8` and 4-byte ints
/// to `U32`/`I32` so the mask loads instead of aborting the whole archive.
pub fn extract_npy_dtype(header: &str) -> Option<Dtype> {
let start = header.find("'descr'")?;
let rest = &header[start + 7..];
Expand All @@ -635,6 +643,10 @@ pub fn extract_npy_dtype(header: &str) -> Option<Dtype> {
match s {
"f2" => Some(Dtype::F16),
"f4" => Some(Dtype::F32),
// boolean + 1-byte ints → U8 (alignment_heads mask).
"b1" | "u1" | "i1" => Some(Dtype::U8),
"u4" => Some(Dtype::U32),
"i4" => Some(Dtype::I32),
_ => None,
}
}
Expand Down
13 changes: 13 additions & 0 deletions crates/rmlx-audio/src/npz_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,19 @@ fn npy_dtype_f4() {
assert_eq!(extract_npy_dtype(hdr), Some(Dtype::F32));
}

#[test]
fn npy_dtype_bool_and_ints() {
// alignment_heads ships as a boolean (`b1`) mask — must parse, not error.
let hdr_bool = "{'descr': '|b1', 'fortran_order': False, 'shape': (20, 32), }";
assert_eq!(extract_npy_dtype(hdr_bool), Some(Dtype::U8));
let hdr_u1 = "{'descr': '|u1', 'fortran_order': False, 'shape': (10,), }";
assert_eq!(extract_npy_dtype(hdr_u1), Some(Dtype::U8));
let hdr_i4 = "{'descr': '<i4', 'fortran_order': False, 'shape': (4,), }";
assert_eq!(extract_npy_dtype(hdr_i4), Some(Dtype::I32));
let hdr_u4 = "{'descr': '<u4', 'fortran_order': False, 'shape': (4,), }";
assert_eq!(extract_npy_dtype(hdr_u4), Some(Dtype::U32));
}

#[test]
fn npy_shape_1d() {
let hdr = "{'descr': '<f2', 'fortran_order': False, 'shape': (1280,), }";
Expand Down
105 changes: 92 additions & 13 deletions crates/rmlx-audio/src/tokenizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,29 @@
//!
//! ## Special token IDs (locked to `openai/whisper-large-v3`)
//!
//! All values verified against `WhisperTokenizerFast.from_pretrained` output:
//! All values verified against the shipped `tokenizer.json` `added_tokens`
//! table. **large-v3 has 100 language slots** (`<|en|>`=50259 … `<|yue|>`=50358),
//! one more than v1/v2 — this shifts every special after the language block up
//! by one vs the v2 layout. Getting these wrong made the decoder emit the wrong
//! task token and treat `<|notimestamps|>` as the timestamp-begin sentinel,
//! producing empty / garbage transcripts.
//!
//! | Token | ID |
//! |---|---|
//! | `<|endoftext|>` (eot) | 50 257 |
//! | `<|startoftranscript|>` (sot) | 50 258 |
//! | `<|en|>` | 50 259 |
//! | `<|translate|>` | 50 358 |
//! | `<|transcribe|>` | 50 359 |
//! | `<|nospeech|>` | 50 362 |
//! | `<|notimestamps|>` | 50 363 |
//! | `<|0.00|>` (timestamp_begin) | 50 364 |
//! | `<|yue|>` (last language) | 50 358 |
//! | `<|translate|>` | 50 359 |
//! | `<|transcribe|>` | 50 360 |
//! | `<|startoflm|>` | 50 361 |
//! | `<|startofprev|>` | 50 362 |
//! | `<|nospeech|>` | 50 363 |
//! | `<|notimestamps|>` | 50 364 |
//! | `<|0.00|>` (timestamp_begin) | 50 365 |
//!
//! Total vocabulary: 51 866 tokens (50 257 base GPT-2 + 1 609 added).
//! Total vocabulary: 51 866 tokens (50 257 base GPT-2 + 1 609 added; the last
//! timestamp `<|30.00|>` is 51 865).
//!
//! ## Tokenizer loading
//!
Expand All @@ -49,16 +58,22 @@ pub const TOK_EOT: u32 = 50_257;
pub const TOK_SOT: u32 = 50_258;
/// English language token (`<|en|>`).
pub const TOK_EN: u32 = 50_259;
/// Last language token (`<|yue|>`). large-v3 has 100 languages (50259..=50358).
pub const TOK_LANG_LAST: u32 = 50_358;
/// Translate task token (`<|translate|>`).
pub const TOK_TRANSLATE: u32 = 50_358;
pub const TOK_TRANSLATE: u32 = 50_359;
/// Transcribe task token (`<|transcribe|>`).
pub const TOK_TRANSCRIBE: u32 = 50_359;
pub const TOK_TRANSCRIBE: u32 = 50_360;
/// Start-of-LM token (`<|startoflm|>`).
pub const TOK_SOT_LM: u32 = 50_361;
/// Start-of-previous-context token (`<|startofprev|>`), for prompt conditioning.
pub const TOK_SOT_PREV: u32 = 50_362;
/// No-speech token (`<|nospeech|>`).
pub const TOK_NOSPEECH: u32 = 50_362;
pub const TOK_NOSPEECH: u32 = 50_363;
/// No-timestamps token (`<|notimestamps|>`).
pub const TOK_NO_TIMESTAMPS: u32 = 50_363;
/// First timestamp token `<|0.00|>`.
pub const TOK_TIMESTAMP_BEGIN: u32 = 50_364;
pub const TOK_NO_TIMESTAMPS: u32 = 50_364;
/// First timestamp token `<|0.00|>`. Timestamps run 50365..=51865 in 0.02 s steps.
pub const TOK_TIMESTAMP_BEGIN: u32 = 50_365;

/// Language token IDs for the 99 supported languages.
///
Expand Down Expand Up @@ -228,6 +243,70 @@ impl WhisperTokenizer {
Ok(enc.get_ids().to_vec())
}

/// Derive the Whisper non-speech / special suppression token set from the
/// loaded tokenizer.
///
/// This is a faithful port of openai-whisper's `Tokenizer.non_speech_tokens`
/// plus `_get_suppress_tokens`: the non-speech set is computed by BPE-encoding
/// a fixed list of punctuation / symbol strings and keeping the ids that
/// encode to a single token (general — no hardcoded magic id list). On top of
/// that we always suppress the structural specials (`SOT`, `TRANSCRIBE`,
/// `TRANSLATE`, `NOSPEECH`) so they can never be emitted as content.
///
/// The returned set is sorted+deduplicated and intended to be applied as a
/// logit mask at every decode step.
pub fn suppress_tokens(&self) -> Vec<u32> {
// openai-whisper symbol list (`whisper/tokenizer.py::non_speech_tokens`).
// Each char of the first group is a standalone symbol; the second group
// is space-separated multi-char symbols.
const SYMBOL_CHARS: &str = "\"#()*+/:;<=>@[\\]^_`{|}~「」『』";
const SYMBOL_WORDS: &str =
"<< >> <<< >>> -- --- -( -[ (' (\" (( )) ((( ))) [[ ]] {{ }} ♪♪ ♪♪♪";
// "Miscellaneous" musical symbols: each is forced in even when it
// BPE-splits into multiple tokens (matches the reference).
const MISC: &str = "♩♪♫♬♭♮♯";

let mut set: std::collections::BTreeSet<u32> = std::collections::BTreeSet::new();

// Reference seeds the set with the first token of " -" and " '".
for seed in [" -", " '"] {
if let Ok(ids) = self.encode(seed) {
if let Some(&first) = ids.first() {
set.insert(first);
}
}
}

let misc: std::collections::BTreeSet<char> = MISC.chars().collect();
let symbols: Vec<String> = SYMBOL_CHARS
.chars()
.map(|c| c.to_string())
.chain(SYMBOL_WORDS.split(' ').map(str::to_owned))
.chain(MISC.chars().map(|c| c.to_string()))
.collect();

for symbol in &symbols {
let is_misc = symbol.chars().count() == 1
&& symbol.chars().next().is_some_and(|c| misc.contains(&c));
for variant in [symbol.clone(), format!(" {symbol}")] {
if let Ok(ids) = self.encode(&variant) {
if (ids.len() == 1 || is_misc) && !ids.is_empty() {
if let Some(&first) = ids.first() {
set.insert(first);
}
}
}
}
}

// Structural specials that must never appear in content.
for t in [TOK_SOT, TOK_TRANSCRIBE, TOK_TRANSLATE, TOK_NOSPEECH] {
set.insert(t);
}

set.into_iter().collect()
}

/// Decode token IDs back to text, skipping tokens ≥ `TOK_EOT` (specials).
pub fn decode(&self, tokens: &[u32]) -> Result<String, TokenizerError> {
// Filter timestamp and special tokens before decode.
Expand Down
14 changes: 9 additions & 5 deletions crates/rmlx-audio/src/tokenizer_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ use super::{
/// `WhisperTokenizerFast.from_pretrained` and calling `tokenizer.encode(f"<|{code}|>")`.
/// Full table coverage is constrained to languages that are both (a) in the Whisper
/// vocabulary and (b) listed in the `language_token` compile-time match.
/// When `RMLX_TEST_MODEL_WHISPER` is set, integration tests in `crates/rmlx-audio/tests/`
/// can verify against a live tokenizer.json.
/// Integration tests in `crates/rmlx-audio/tests/transcribe.rs` verify against a
/// live `tokenizer.json` when the snapshot is present under `RMLX_O_MODELS_ROOT`.
#[test]
fn language_token_ids() {
// Verified against WhisperTokenizerFast.from_pretrained output.
Expand All @@ -37,14 +37,18 @@ fn language_token_ids() {
}

/// Special token constant sanity check.
///
/// large-v3 has 100 language slots (`<|en|>`=50259 … `<|yue|>`=50358), so every
/// special after the language block is shifted up by one vs the v1/v2 layout.
/// Verified against the shipped `tokenizer.json` `added_tokens` table.
#[test]
fn special_token_constants() {
assert_eq!(TOK_EOT, 50_257);
assert_eq!(TOK_SOT, 50_258);
assert_eq!(TOK_EN, 50_259);
assert_eq!(TOK_TRANSLATE, 50_358);
assert_eq!(TOK_TRANSCRIBE, 50_359);
assert_eq!(TOK_NO_TIMESTAMPS, 50_363);
assert_eq!(TOK_TRANSLATE, 50_359);
assert_eq!(TOK_TRANSCRIBE, 50_360);
assert_eq!(TOK_NO_TIMESTAMPS, 50_364);
}

/// SOT sequence structure.
Expand Down
Loading
Loading