diff --git a/.gitignore b/.gitignore index 1928866..d28e8f1 100644 --- a/.gitignore +++ b/.gitignore @@ -7,7 +7,15 @@ _generate/ *.py[cod] *.egg-info *env* +# install -e artifacts _version.py +libggml* +libwhisper* + +# ignore downloaded source code... really this is just for quickly checking previous versions +pywhispercpp-*.* + + # custom .idea diff --git a/README.md b/README.md index ccda1c6..018f18e 100644 --- a/README.md +++ b/README.md @@ -292,7 +292,10 @@ options: ```python import _pywhispercpp as pwcpp -ctx = pwcpp.whisper_init_from_file('path/to/ggml/model') +ctx = pwcpp.whisper_init_from_file_with_params( + 'path/to/ggml/model', + pwcpp.whisper_context_default_params(), +) ``` # Discussions and contributions @@ -302,4 +305,4 @@ If you have any feedback, or you want to share how you are using this project, f # License -This project is licensed under the same license as [whisper.cpp](https://github.com/ggerganov/whisper.cpp/blob/master/LICENSE) (MIT [License](./LICENSE)). +This project is licensed under the same license as [whisper.cpp](https://github.com/ggerganov/whisper.cpp/blob/master/LICENSE) (MIT [License](./LICENSE)). \ No newline at end of file diff --git a/pywhispercpp/constants.py b/pywhispercpp/constants.py index f56a3e9..bfe582f 100644 --- a/pywhispercpp/constants.py +++ b/pywhispercpp/constants.py @@ -95,6 +95,12 @@ 'type': bool, 'description': "do not use past transcription (if any) as initial prompt for the decoder", 'options': None, + 'default': True + }, + 'no_timestamps': { + 'type': bool, + 'description': "do not generate timestamps", + 'options': None, 'default': False }, 'single_segment': { @@ -164,12 +170,24 @@ 'options': None, 'default': 0 }, + 'debug_mode': { + 'type': bool, + 'description': "enable debug mode in whisper.cpp", + 'options': None, + 'default': False + }, 'audio_ctx': { 'type': int, 'description': "overwrite the audio context size (0 = use default)", 'options': None, 'default': 0 }, + 'tdrz_enable': { + 'type': bool, + 'description': "enable tinydiarize speaker turn detection", + 'options': None, + 'default': False + }, 'initial_prompt': { 'type': str, 'description': "Initial prompt, these are prepended to any existing text context from a previous call", @@ -188,12 +206,24 @@ 'options': None, 'default': 0 }, + 'carry_initial_prompt': { + 'type': bool, + 'description': "always prepend the initial prompt to each decode window", + 'options': None, + 'default': False + }, 'language': { 'type': str, 'description': 'for auto-detection, set to None, "" or "auto"', 'options': None, 'default': "" }, + 'detect_language': { + 'type': bool, + 'description': 'enable automatic language detection during transcription', + 'options': None, + 'default': False + }, 'suppress_blank': { 'type': bool, 'description': 'common decoding parameters', @@ -206,6 +236,18 @@ 'options': None, 'default': False }, + 'suppress_nst': { + 'type': bool, + 'description': 'canonical whisper.cpp name for non-speech token suppression', + 'options': None, + 'default': False + }, + 'suppress_regex': { + 'type': str, + 'description': 'regex pattern used to suppress matching text during decoding', + 'options': None, + 'default': '' + }, 'temperature': { 'type': float, 'description': 'initial decoding temperature', @@ -252,7 +294,7 @@ 'type': dict, 'description': 'greedy', 'options': None, - 'default': {"best_of": -1} + 'default': {"best_of": 5} }, 'beam_search': { 'type': dict, @@ -264,7 +306,7 @@ 'type': bool, 'description': 'calculate the geometric mean of token probabilities for each segment.', 'options': None, - 'default': True + 'default': False }, 'vad': { 'type': bool, diff --git a/pywhispercpp/model.py b/pywhispercpp/model.py index 7f0f2a3..453e152 100644 --- a/pywhispercpp/model.py +++ b/pywhispercpp/model.py @@ -6,20 +6,21 @@ [whisper.cpp](https://github.com/ggerganov/whisper.cpp) API. """ import importlib.metadata +import subprocess +import os import logging import shutil import sys +import tempfile +import wave from pathlib import Path from time import time -from typing import Union, Callable, List, TextIO, Tuple, Optional +from typing import Any, Union, Callable, List, TextIO, Tuple, Optional, Dict, TypedDict + import _pywhispercpp as pw import numpy as np -import pywhispercpp.utils as utils import pywhispercpp.constants as constants -import subprocess -import os -import tempfile -import wave +import pywhispercpp.utils as utils __author__ = "absadiki" __copyright__ = "Copyright 2023, " @@ -29,6 +30,19 @@ logger = logging.getLogger(__name__) +class ContextParams(TypedDict, total=False): + use_gpu: bool + flash_attn: bool + gpu_device: int + dtw_token_timestamps: bool + dtw_aheads_preset: int + dtw_n_top: int + dtw_mem_size: int + + +_CONTEXT_PARAM_KEYS = frozenset(ContextParams.__annotations__) + + class Segment: """ A small class representing a transcription segment @@ -68,34 +82,84 @@ class Model: ``` """ - _new_segment_callback = None + def __init__(self, model: str = 'tiny', - models_dir: str = None, + models_dir: Optional[str] = None, params_sampling_strategy: int = 0, redirect_whispercpp_logs_to: Union[bool, TextIO, str, None] = False, use_openvino: bool = False, - openvino_model_path: str = None, + openvino_model_path: Optional[str] = None, openvino_device: str = 'CPU', - openvino_cache_dir: str = None, + openvino_cache_dir: Optional[str] = None, + context_params: Optional[ContextParams] = None, **params): """ - :param model: The name of the model, one of the [AVAILABLE_MODELS](/pywhispercpp/#pywhispercpp.constants.AVAILABLE_MODELS), - (default to `tiny`), or a direct path to a `ggml` model. - :param models_dir: The directory where the models are stored, or where they will be downloaded if they don't - exist, default to [MODELS_DIR](/pywhispercpp/#pywhispercpp.constants.MODELS_DIR) - :param params_sampling_strategy: 0 -> GREEDY, else BEAM_SEARCH - :param redirect_whispercpp_logs_to: where to redirect the whisper.cpp logs, default to False (no redirection), accepts str file path, sys.stdout, sys.stderr, or use None to redirect to devnull - :param use_openvino: whether to use OpenVINO or not - :param openvino_model_path: path to the OpenVINO model - :param openvino_device: OpenVINO device, default to CPU - :param openvino_cache_dir: OpenVINO cache directory - :param params: keyword arguments for different whisper.cpp parameters, - see [PARAMS_SCHEMA](/pywhispercpp/#pywhispercpp.constants.PARAMS_SCHEMA) + :param model: model name, default `tiny`, or a direct path to a ggml model file. + :param models_dir: directory containing model files; if omitted, uses `MODELS_DIR` unless `model` + is already a direct file path. + :param params_sampling_strategy: sampling strategy selector; `0` uses greedy decoding and any + other value uses beam search. + :param redirect_whispercpp_logs_to: log redirection target. Use `False` for no redirection, `None` + for `/dev/null`, a file path string, or `sys.stdout`/`sys.stderr`. + :param use_openvino: whether to initialize the OpenVINO encoder backend. + :param openvino_model_path: path to the OpenVINO model directory or files. + :param openvino_device: OpenVINO device name, default `CPU`. + :param openvino_cache_dir: OpenVINO cache directory. + :param context_params: optional whisper context loader params. Accepted keys are `use_gpu`, + `flash_attn`, `gpu_device`, `dtw_token_timestamps`, + `dtw_aheads_preset`, `dtw_n_top`, and `dtw_mem_size`. Omitted keys inherit + from `whisper_context_default_params()`. + :param params: keyword-only decode parameters matching the public API documented in `model.pyi`. + These values are forwarded to `whisper_full_params` and remain active for future calls. + Supported keys: + - `n_threads`: number of inference threads. Default is `min(4, hardware_concurrency())`. + - `n_max_text_ctx`: max prompt-text tokens carried into the decoder. Default `16384`. + - `offset_ms`: audio start offset in milliseconds. Default `0`. + - `duration_ms`: audio duration to process in milliseconds. Default `0`. + - `translate`: translate output to English. Default `False`. + - `no_context`: disable reuse of past transcription context. Default `True`. + - `no_timestamps`: disable timestamp generation. Default `False`. + - `single_segment`: force a single output segment. Default `False`. + - `print_special`: print special tokens. Default `False`. + - `print_progress`: print progress information. Default `True`. + - `print_realtime`: print realtime output from whisper.cpp. Default `False`. + - `print_timestamps`: print timestamps during realtime output. Default `True`. + - `token_timestamps`: enable token-level timestamps. Default `False`. + - `thold_pt`: token timestamp probability threshold. Default `0.01`. + - `thold_ptsum`: token timestamp sum threshold. Default `0.01`. + - `max_len`: max segment length in characters. Default `0`. + - `split_on_word`: split on words when `max_len` is used. Default `False`. + - `max_tokens`: max tokens per segment. Default `0`. + - `debug_mode`: enable whisper.cpp debug mode. Default `False`. + - `audio_ctx`: override audio context size. Default `0`. + - `tdrz_enable`: enable tinydiarize speaker-turn detection. Default `False`. + - `initial_prompt`: initial text prompt prepended before decoding. Default `None`. + - `prompt_tokens`: explicit prompt token sequence. Default `None`. + - `prompt_n_tokens`: number of prompt tokens. Default `0`. + - `carry_initial_prompt`: prepend the initial prompt to each decode window. Default `False`. + - `language`: language code. Default ``. + - `detect_language`: enable automatic language detection during transcription. Default `False`. + - `suppress_blank`: suppress blank outputs. Default `True`. + - `suppress_non_speech_tokens`: Python alias for `suppress_nst`. Default `False`. + - `suppress_nst`: suppress non-speech tokens. Default `False`. + - `suppress_regex`: regex pattern used to suppress matching text during decoding. Default `''`. + - `temperature`: initial decoding temperature. Default `0.0`. + - `max_initial_ts`: maximum initial timestamp. Default `1.0`. + - `length_penalty`: length penalty. Default `-1.0`. + - `temperature_inc`: fallback temperature increment. Default `0.2`. + - `entropy_thold`: entropy threshold. Default `2.4`. + - `logprob_thold`: logprob threshold. Default `-1.0`. + - `no_speech_thold`: no-speech threshold. Default `0.6`. + - `greedy`: greedy-decoder settings, typically `{"best_of": 5}`. + - `beam_search`: beam-search settings. Default `{"beam_size": -1, "patience": -1.0}`. + - `vad`: enable VAD. Default `False`. + - `vad_model_path`: path to the VAD model. Default `None`. """ self.model_path = utils.resolve_model_path(model, models_dir) self._ctx = None + self._context_params = self._resolve_context_params(context_params) self._sampling_strategy = pw.whisper_sampling_strategy.WHISPER_SAMPLING_GREEDY if params_sampling_strategy == 0 else \ pw.whisper_sampling_strategy.WHISPER_SAMPLING_BEAM_SEARCH self._params = pw.whisper_full_default_params(self._sampling_strategy) @@ -107,29 +171,34 @@ def __init__(self, self.openvino_model_path = openvino_model_path self.openvino_device = openvino_device self.openvino_cache_dir = openvino_cache_dir + # todo... maybe setup default callbacks for segments and abort globaly and/or per model instance? + self._new_segment_callback = None # init the model self._init_model() def transcribe(self, media: Union[str, np.ndarray], - n_processors: int = None, - new_segment_callback: Callable[[Segment], None] = None, + n_processors: Optional[int] = None, + new_segment_callback: Optional[Callable[[Segment], None]] = None, + abort_callback: Optional[Callable[[], bool]] = None, **params) -> List[Segment]: """ Transcribes the media provided as input and returns list of `Segment` objects. Accepts a media_file path (audio/video) or a raw numpy array. :param media: Media file path or a numpy array - :param n_processors: if not None, it will run the transcription on multiple processes - binding to whisper.cpp/whisper_full_parallel - > Split the input audio in chunks and process each chunk separately using whisper_full() - :param new_segment_callback: callback function that will be called when a new segment is generated - :param params: keyword arguments for different whisper.cpp parameters, see ::: constants.PARAMS_SCHEMA + :param n_processors: number of worker processes for `whisper_full_parallel`. If omitted, runs a + single-process `whisper_full()` decode. + :param new_segment_callback: callback invoked for each newly produced `Segment` during decoding. + :param abort_callback: callback function returning True to abort an in-flight transcription early. :param extract_probability: If True, calculates the geometric mean of token probabilities for each segment, providing a confidence score interpretable as a probability in [0, 1]. + :param params: additional keyword-only decode parameters matching the public API documented in + `model.pyi`, with the same supported keys and defaults as `Model.__init__`. + Any overrides applied here remain active for future calls. :return: List of transcription segments """ - if type(media) is np.ndarray: + if isinstance(media, np.ndarray): audio = media else: if not Path(media).exists(): @@ -142,10 +211,15 @@ def transcribe(self, # update params if any self._set_params(params) - # setting up callback - if new_segment_callback: - Model._new_segment_callback = new_segment_callback - pw.assign_new_segment_callback(self._params, Model.__call_new_segment_callback) + # setting up callback. make sure self._new_segment_callback = None when new_segment_callback = None. + # since this is no lonmger bound to the Model but on self + self._new_segment_callback = new_segment_callback + pw.assign_new_segment_callback( + self._params, + self.__call_new_segment_callback if new_segment_callback is not None else None, + ) + + pw.assign_abort_callback(self._params, abort_callback) # run inference start_time = time() @@ -191,7 +265,7 @@ def _get_segments(ctx, start: int, end: int, extract_probability: bool = False) else: avg_prob = np.nan - res.append(Segment(t0, t1, text.strip(), probability=np.float32(avg_prob))) + res.append(Segment(t0, t1, text.strip(), probability=float(avg_prob))) return res def get_params(self) -> dict: @@ -246,7 +320,7 @@ def system_info() -> None: return pw.whisper_print_system_info() @staticmethod - def available_languages() -> list[str]: + def available_languages() -> List[str]: """ Returns a list of supported language codes @@ -258,6 +332,49 @@ def available_languages() -> list[str]: res.append(pw.whisper_lang_str(i)) return res + @staticmethod + def _resolve_context_params(context_params: Optional[ContextParams]): + resolved = pw.whisper_context_default_params() + + if context_params is None: + return resolved + + if not isinstance(context_params, dict): + raise TypeError("context_params must be a ContextParams dict or None") + + unknown_keys = sorted(set(context_params) - _CONTEXT_PARAM_KEYS) + if unknown_keys: + raise TypeError( + f"Unknown context_params keys: {', '.join(unknown_keys)}" + ) + + for key, value in context_params.items(): + setattr(resolved, key, value) + return resolved + + @staticmethod + def _normalize_params(kwargs: dict) -> dict: + normalized = dict(kwargs) + + if 'suppress_non_speech_tokens' in normalized and 'suppress_nst' not in normalized: + normalized['suppress_nst'] = normalized.pop('suppress_non_speech_tokens') + + return normalized + + def _apply_prompt_token_params(self, normalized: dict) -> dict: + if 'prompt_tokens' not in normalized: + return normalized + + prompt_tokens = normalized.pop('prompt_tokens') + normalized.pop('prompt_n_tokens', None) + + if prompt_tokens is None: + self._params.clear_prompt_tokens() + else: + self._params.set_prompt_tokens(prompt_tokens) + + return normalized + def _init_model(self) -> None: """ Private method to initialize the method from the bindings, it will be called automatically from the __init__ @@ -265,7 +382,7 @@ def _init_model(self) -> None: """ logger.info("Initializing the model ...") with utils.redirect_stderr(to=self.redirect_whispercpp_logs_to): - self._ctx = pw.whisper_init_from_file(self.model_path) + self._ctx = pw.whisper_init_from_file_with_params(self.model_path, self._context_params) if self.use_openvino: pw.whisper_ctx_init_openvino_encoder(self._ctx, self.openvino_model_path, self.openvino_device, self.openvino_cache_dir) @@ -277,10 +394,15 @@ def _set_params(self, kwargs: dict) -> None: :param kwargs: dict like object for the different params :return: None """ - for param in kwargs: - setattr(self._params, param, kwargs[param]) + normalized = self._normalize_params(kwargs) + + if 'prompt_tokens' in normalized: + normalized = self._apply_prompt_token_params(normalized) - def _transcribe(self, audio: np.ndarray, n_processors: int = None): + for param, value in normalized.items(): + setattr(self._params, param, value) + + def _transcribe(self, audio: np.ndarray, n_processors: Optional[int] = None): """ Private method to call the whisper.cpp/whisper_full function @@ -297,8 +419,8 @@ def _transcribe(self, audio: np.ndarray, n_processors: int = None): res = Model._get_segments(self._ctx, 0, n, self.extract_probability) return res - @staticmethod - def __call_new_segment_callback(ctx, n_new, user_data) -> None: + + def __call_new_segment_callback(self, ctx, n_new, user_data=None) -> None: """ Internal new_segment_callback, it just calls the user's callback with the `Segment` object :param ctx: whisper.cpp ctx param @@ -310,10 +432,11 @@ def __call_new_segment_callback(ctx, n_new, user_data) -> None: start = n - n_new res = Model._get_segments(ctx, start, n, False) for segment in res: - Model._new_segment_callback(segment) + if self._new_segment_callback is not None: + self._new_segment_callback(segment) @staticmethod - def _load_audio(media_file_path: str) -> np.array: + def _load_audio(media_file_path: str) -> np.ndarray: """ Helper method to return a `np.array` object from a media file If the media file is not a WAV file, it will try to convert it using ffmpeg @@ -369,33 +492,40 @@ def wav_to_np(file_path): finally: os.remove(temp_file_path) - def auto_detect_language(self, media: Union[str, np.ndarray], offset_ms: int = 0, n_threads: int = 4) -> Tuple[Tuple[str, np.float32], dict[str, np.float32]]: + def auto_detect_language(self, media: Union[str, np.ndarray], offset_ms: Optional[int] = None, n_threads: Optional[int] = None) -> Tuple[Tuple[str, np.float32], Dict[str, np.float32]]: """ Automatic language detection using whisper.cpp/whisper_pcm_to_mel and whisper.cpp/whisper_lang_auto_detect :param media: Media file path or a numpy array - :param offset_ms: offset in milliseconds - :param n_threads: number of threads to use + :param offset_ms: offset in milliseconds; when omitted, uses the model's current `offset_ms` + :param n_threads: number of threads to use; when omitted, uses the model's current `n_threads` :return: ((detected_language, probability), probabilities for all languages) """ - if type(media) is np.ndarray: + if isinstance(media, np.ndarray): audio = media else: if not Path(media).exists(): raise FileNotFoundError(media) audio = self._load_audio(media) + if offset_ms is None: + offset_ms = self._params.offset_ms + + if n_threads is None: + n_threads = self._params.n_threads + pw.whisper_pcm_to_mel(self._ctx, audio, len(audio), n_threads) lang_count = self.lang_max_id() + 1 probs = np.zeros(lang_count, dtype=np.float32) auto_detect = pw.whisper_lang_auto_detect(self._ctx, offset_ms, n_threads, probs) langs = self.available_languages() lang_probs = {langs[i]: probs[i] for i in range(lang_count)} - return (langs[auto_detect], probs[auto_detect]), lang_probs + return (langs[auto_detect], np.float32(probs[auto_detect])), lang_probs def __del__(self): """ Free up resources :return: None """ - pw.whisper_free(self._ctx) \ No newline at end of file + if self._ctx is not None: + pw.whisper_free(self._ctx) \ No newline at end of file diff --git a/pywhispercpp/model.pyi b/pywhispercpp/model.pyi index 7e5b43d..35cb735 100644 --- a/pywhispercpp/model.pyi +++ b/pywhispercpp/model.pyi @@ -1,12 +1,21 @@ from __future__ import annotations -from typing import Any, Callable, Dict, List, Optional, TextIO, Tuple, TypedDict, Union +from typing import Any, Callable, Dict, List, Optional, TextIO, Tuple, TypedDict, TypeAlias, Union import numpy as np import numpy.typing as npt -AudioArray = npt.NDArray[np.float32] -AudioInput = Union[str, AudioArray] +AudioArray: TypeAlias = npt.NDArray[np.float32] +AudioInput: TypeAlias = Union[str, AudioArray] + +class ContextParams(TypedDict, total=False): + use_gpu: bool + flash_attn: bool + gpu_device: int + dtw_token_timestamps: bool + dtw_aheads_preset: int + dtw_n_top: int + dtw_mem_size: int class GreedyParams(TypedDict): @@ -30,6 +39,7 @@ class Segment: class Model: + model_path: str _new_segment_callback: Optional[Callable[[Segment], None]] def __init__( @@ -42,13 +52,15 @@ class Model: openvino_model_path: Optional[str] = None, openvino_device: str = 'CPU', openvino_cache_dir: Optional[str] = None, + context_params: Optional[ContextParams] = None, *, n_threads: Optional[int] = None, n_max_text_ctx: int = 16384, offset_ms: int = 0, duration_ms: int = 0, translate: bool = False, - no_context: bool = False, + no_context: bool = True, + no_timestamps: bool = False, single_segment: bool = False, print_special: bool = False, print_progress: bool = True, @@ -60,13 +72,19 @@ class Model: max_len: int = 0, split_on_word: bool = False, max_tokens: int = 0, + debug_mode: bool = False, audio_ctx: int = 0, + tdrz_enable: bool = False, initial_prompt: Optional[str] = None, prompt_tokens: Optional[Tuple[Any, ...]] = None, prompt_n_tokens: int = 0, + carry_initial_prompt: bool = False, language: str = '', + detect_language: bool = False, suppress_blank: bool = True, suppress_non_speech_tokens: bool = False, + suppress_nst: bool = False, + suppress_regex: str = '', temperature: float = 0.0, max_initial_ts: float = 1.0, length_penalty: float = -1.0, @@ -74,7 +92,7 @@ class Model: entropy_thold: float = 2.4, logprob_thold: float = -1.0, no_speech_thold: float = 0.6, - greedy: GreedyParams = {'best_of': -1}, + greedy: GreedyParams = {'best_of': 5}, beam_search: BeamSearchParams = {'beam_size': -1, 'patience': -1.0}, vad: bool = False, vad_model_path: Optional[str] = None, @@ -86,13 +104,15 @@ class Model: media: AudioInput, n_processors: Optional[int] = None, new_segment_callback: Optional[Callable[[Segment], None]] = None, + abort_callback: Optional[Callable[[], bool]] = None, *, n_threads: Optional[int] = None, n_max_text_ctx: int = 16384, offset_ms: int = 0, duration_ms: int = 0, translate: bool = False, - no_context: bool = False, + no_context: bool = True, + no_timestamps: bool = False, single_segment: bool = False, print_special: bool = False, print_progress: bool = True, @@ -104,13 +124,19 @@ class Model: max_len: int = 0, split_on_word: bool = False, max_tokens: int = 0, + debug_mode: bool = False, audio_ctx: int = 0, + tdrz_enable: bool = False, initial_prompt: Optional[str] = None, prompt_tokens: Optional[Tuple[Any, ...]] = None, prompt_n_tokens: int = 0, + carry_initial_prompt: bool = False, language: str = '', + detect_language: bool = False, suppress_blank: bool = True, suppress_non_speech_tokens: bool = False, + suppress_nst: bool = False, + suppress_regex: str = '', temperature: float = 0.0, max_initial_ts: float = 1.0, length_penalty: float = -1.0, @@ -118,7 +144,7 @@ class Model: entropy_thold: float = 2.4, logprob_thold: float = -1.0, no_speech_thold: float = 0.6, - greedy: GreedyParams = {'best_of': -1}, + greedy: GreedyParams = {'best_of': 5}, beam_search: BeamSearchParams = {'beam_size': -1, 'patience': -1.0}, extract_probability: bool = False, vad: bool = False, @@ -141,8 +167,8 @@ class Model: def auto_detect_language( self, media: AudioInput, - offset_ms: int = 0, - n_threads: int = 4, + offset_ms: Optional[int] = None, + n_threads: Optional[int] = None, ) -> Tuple[Tuple[str, np.float32], Dict[str, np.float32]]: ... def __del__(self) -> None: ... diff --git a/src/main.cpp b/src/main.cpp index 48341bb..6bc3c00 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -28,10 +28,12 @@ namespace py = pybind11; using namespace pybind11::literals; // to bring in the `_a` literal +inline bool has_python_user_data(const py::object & obj) { + return obj.ptr() != nullptr && obj.ptr() != Py_None; +} + -py::function py_new_segment_callback; -py::function py_encoder_begin_callback; -py::function py_logits_filter_callback; +py::object py_log_callback; // whisper context wrapper, to solve the incomplete type issue @@ -56,24 +58,28 @@ struct whisper_model_loader_wrapper { }; -struct whisper_context_wrapper whisper_init_from_file_wrapper(const char * path_model){ - struct whisper_context_params cparams = whisper_context_default_params(); +struct whisper_context_wrapper whisper_init_from_file_with_params_wrapper( + const char * path_model, + struct whisper_context_params cparams){ struct whisper_context * ctx = whisper_init_from_file_with_params(path_model, cparams); struct whisper_context_wrapper ctw_w; ctw_w.ptr = ctx; return ctw_w; } -struct whisper_context_wrapper whisper_init_from_buffer_wrapper(void * buffer, size_t buffer_size){ - struct whisper_context_params cparams = whisper_context_default_params(); +struct whisper_context_wrapper whisper_init_from_buffer_with_params_wrapper( + void * buffer, + size_t buffer_size, + struct whisper_context_params cparams){ struct whisper_context * ctx = whisper_init_from_buffer_with_params(buffer, buffer_size, cparams); struct whisper_context_wrapper ctw_w; ctw_w.ptr = ctx; return ctw_w; } -struct whisper_context_wrapper whisper_init_wrapper(struct whisper_model_loader_wrapper * loader){ - struct whisper_context_params cparams = whisper_context_default_params(); +struct whisper_context_wrapper whisper_init_with_params_wrapper( + struct whisper_model_loader_wrapper * loader, + struct whisper_context_params cparams){ struct whisper_context * ctx = whisper_init_with_params(loader->ptr, cparams); struct whisper_context_wrapper ctw_w; ctw_w.ptr = ctx; @@ -291,6 +297,69 @@ float whisper_full_get_token_p_wrapper(struct whisper_context_wrapper * ctx, int return whisper_full_get_token_p(ctx->ptr, i_segment, i_token); } +bool whisper_full_get_segment_speaker_turn_next_wrapper(struct whisper_context_wrapper * ctx, int i_segment){ + return whisper_full_get_segment_speaker_turn_next(ctx->ptr, i_segment); +} + +const char * whisper_model_type_readable_wrapper(struct whisper_context_wrapper * ctx_w){ + return whisper_model_type_readable(ctx_w->ptr); +} + +int whisper_model_n_vocab_wrapper(struct whisper_context_wrapper * ctx_w){ + return whisper_model_n_vocab(ctx_w->ptr); +} + +int whisper_model_n_audio_ctx_wrapper(struct whisper_context_wrapper * ctx_w){ + return whisper_model_n_audio_ctx(ctx_w->ptr); +} + +int whisper_model_n_audio_state_wrapper(struct whisper_context_wrapper * ctx_w){ + return whisper_model_n_audio_state(ctx_w->ptr); +} + +int whisper_model_n_audio_head_wrapper(struct whisper_context_wrapper * ctx_w){ + return whisper_model_n_audio_head(ctx_w->ptr); +} + +int whisper_model_n_audio_layer_wrapper(struct whisper_context_wrapper * ctx_w){ + return whisper_model_n_audio_layer(ctx_w->ptr); +} + +int whisper_model_n_text_ctx_wrapper(struct whisper_context_wrapper * ctx_w){ + return whisper_model_n_text_ctx(ctx_w->ptr); +} + +int whisper_model_n_text_state_wrapper(struct whisper_context_wrapper * ctx_w){ + return whisper_model_n_text_state(ctx_w->ptr); +} + +int whisper_model_n_text_head_wrapper(struct whisper_context_wrapper * ctx_w){ + return whisper_model_n_text_head(ctx_w->ptr); +} + +int whisper_model_n_text_layer_wrapper(struct whisper_context_wrapper * ctx_w){ + return whisper_model_n_text_layer(ctx_w->ptr); +} + +int whisper_model_n_mels_wrapper(struct whisper_context_wrapper * ctx_w){ + return whisper_model_n_mels(ctx_w->ptr); +} + +int whisper_model_ftype_wrapper(struct whisper_context_wrapper * ctx_w){ + return whisper_model_ftype(ctx_w->ptr); +} + +bool _abort_callback(void * user_data); +void _new_segment_callback(struct whisper_context * ctx, struct whisper_state * state, int n_new, void * user_data); +bool _encoder_begin_callback(struct whisper_context * ctx, struct whisper_state * state, void * user_data); +void _logits_filter_callback( + struct whisper_context * ctx, + struct whisper_state * state, + const whisper_token_data * tokens, + int n_tokens, + float * logits, + void * user_data); + int whisper_ctx_init_openvino_encoder_wrapper(struct whisper_context_wrapper * ctx, const char * model_path, const char * device, const char * cache_dir){ @@ -301,57 +370,96 @@ struct WhisperFullParamsWrapper : public whisper_full_params { std::string initial_prompt_str; std::string suppress_regex_str; std::string vad_model_path_str; + std::vector prompt_token_storage; + + void reset_progress_callback() { + progress_callback_user_data = this; + progress_callback = [](struct whisper_context* ctx, struct whisper_state* state, int progress, void* user_data) { + (void) ctx; + (void) state; + auto* self = static_cast(user_data); + if (self && self->print_progress) { + if (self->py_progress_callback) { + py::gil_scoped_acquire gil; + if (!has_python_user_data(self->py_progress_callback_user_data)) { + self->py_progress_callback(progress); + } else { + self->py_progress_callback(progress, self->py_progress_callback_user_data); + } + } else { + fprintf(stderr, "Progress: %3d%%\n", progress); + } + } + }; + } + + void sync_prompt_tokens() { + prompt_tokens = prompt_token_storage.empty() ? nullptr : prompt_token_storage.data(); + prompt_n_tokens = prompt_token_storage.size(); + } public: + py::function py_new_segment_callback; + py::object py_new_segment_callback_user_data; + py::function py_encoder_begin_callback; + py::object py_encoder_begin_callback_user_data; py::function py_progress_callback; + py::object py_progress_callback_user_data; + py::function py_logits_filter_callback; + py::object py_logits_filter_callback_user_data; + py::function py_abort_callback; + py::object py_abort_callback_user_data; WhisperFullParamsWrapper(const whisper_full_params& params = whisper_full_params()) : whisper_full_params(params), initial_prompt_str(params.initial_prompt ? params.initial_prompt : ""), suppress_regex_str(params.suppress_regex ? params.suppress_regex : ""), - vad_model_path_str(params.vad_model_path ? params.vad_model_path : "") + vad_model_path_str(params.vad_model_path ? params.vad_model_path : ""), + prompt_token_storage(), + py_new_segment_callback_user_data(py::none()), + py_encoder_begin_callback_user_data(py::none()), + py_progress_callback_user_data(py::none()), + py_logits_filter_callback_user_data(py::none()), + py_abort_callback(), + py_abort_callback_user_data(py::none()) { initial_prompt = initial_prompt_str.empty() ? nullptr : initial_prompt_str.c_str(); suppress_regex = suppress_regex_str.empty() ? nullptr : suppress_regex_str.c_str(); vad_model_path = vad_model_path_str.empty() ? nullptr : vad_model_path_str.c_str(); - // progress callback - progress_callback_user_data = this; - progress_callback = [](struct whisper_context* ctx, struct whisper_state* state, int progress, void* user_data) { - auto* self = static_cast(user_data); - if(self && self->print_progress){ - if (self->py_progress_callback) { - // call the python callback - py::gil_scoped_acquire gil; - self->py_progress_callback(progress); // Call Python callback - } - else { - fprintf(stderr, "Progress: %3d%%\n", progress); - } // Default message - } - } ; + new_segment_callback_user_data = this; + encoder_begin_callback_user_data = this; + abort_callback_user_data = this; + logits_filter_callback_user_data = this; + if (params.prompt_tokens && params.prompt_n_tokens > 0) { + prompt_token_storage.assign(params.prompt_tokens, params.prompt_tokens + params.prompt_n_tokens); + } + sync_prompt_tokens(); + reset_progress_callback(); } WhisperFullParamsWrapper(const WhisperFullParamsWrapper& other) : whisper_full_params(static_cast(other)), // Copy base struct initial_prompt_str(other.initial_prompt_str), suppress_regex_str(other.suppress_regex_str), vad_model_path_str(other.vad_model_path_str), - py_progress_callback(other.py_progress_callback) { + prompt_token_storage(other.prompt_token_storage), + py_new_segment_callback(other.py_new_segment_callback), + py_new_segment_callback_user_data(other.py_new_segment_callback_user_data), + py_encoder_begin_callback(other.py_encoder_begin_callback), + py_encoder_begin_callback_user_data(other.py_encoder_begin_callback_user_data), + py_progress_callback(other.py_progress_callback), + py_progress_callback_user_data(other.py_progress_callback_user_data), + py_logits_filter_callback(other.py_logits_filter_callback), + py_logits_filter_callback_user_data(other.py_logits_filter_callback_user_data), + py_abort_callback(other.py_abort_callback), + py_abort_callback_user_data(other.py_abort_callback_user_data) { // Reset pointers to new string copies initial_prompt = initial_prompt_str.empty() ? nullptr : initial_prompt_str.c_str(); suppress_regex = suppress_regex_str.empty() ? nullptr : suppress_regex_str.c_str(); vad_model_path = vad_model_path_str.empty() ? nullptr : vad_model_path_str.c_str(); - progress_callback_user_data = this; - progress_callback = [](struct whisper_context* ctx, struct whisper_state* state, int progress, void* user_data) { - auto* self = static_cast(user_data); - if(self && self->print_progress){ - if (self->py_progress_callback) { - // call the python callback - py::gil_scoped_acquire gil; - self->py_progress_callback(progress); // Call Python callback - } - else { - fprintf(stderr, "Progress: %3d%%\n", progress); - } // Default message - } - }; + new_segment_callback_user_data = this; + encoder_begin_callback_user_data = this; + abort_callback_user_data = this; + logits_filter_callback_user_data = this; + sync_prompt_tokens(); + reset_progress_callback(); } void set_initial_prompt(const std::string& prompt) { initial_prompt_str = prompt; @@ -365,6 +473,104 @@ struct WhisperFullParamsWrapper : public whisper_full_params { vad_model_path_str = model_path; vad_model_path = vad_model_path_str.c_str(); } + py::tuple get_prompt_tokens() const { + py::tuple tokens(prompt_token_storage.size()); + for (size_t index = 0; index < prompt_token_storage.size(); ++index) { + tokens[index] = prompt_token_storage[index]; + } + return tokens; + } + void set_prompt_tokens(const std::vector& tokens) { + prompt_token_storage = tokens; + sync_prompt_tokens(); + } + void clear_prompt_tokens() { + prompt_token_storage.clear(); + sync_prompt_tokens(); + } + py::object get_new_segment_callback_user_data() const { + return py_new_segment_callback_user_data; + } + void set_new_segment_callback_user_data(py::object user_data) { + py_new_segment_callback_user_data = std::move(user_data); + new_segment_callback_user_data = this; + } + void set_new_segment_callback(py::function callback) { + py_new_segment_callback = std::move(callback); + new_segment_callback_user_data = this; + new_segment_callback = _new_segment_callback; + } + void clear_new_segment_callback() { + py_new_segment_callback = py::function(); + new_segment_callback = nullptr; + new_segment_callback_user_data = this; + } + py::object get_encoder_begin_callback_user_data() const { + return py_encoder_begin_callback_user_data; + } + void set_encoder_begin_callback_user_data(py::object user_data) { + py_encoder_begin_callback_user_data = std::move(user_data); + encoder_begin_callback_user_data = this; + } + void set_encoder_begin_callback(py::function callback) { + py_encoder_begin_callback = std::move(callback); + encoder_begin_callback_user_data = this; + encoder_begin_callback = _encoder_begin_callback; + } + void clear_encoder_begin_callback() { + py_encoder_begin_callback = py::function(); + encoder_begin_callback = nullptr; + encoder_begin_callback_user_data = this; + } + py::object get_progress_callback_user_data() const { + return py_progress_callback_user_data; + } + void set_progress_callback_user_data(py::object user_data) { + py_progress_callback_user_data = std::move(user_data); + progress_callback_user_data = this; + } + void set_progress_callback(py::function callback) { + py_progress_callback = std::move(callback); + reset_progress_callback(); + } + void clear_progress_callback() { + py_progress_callback = py::function(); + reset_progress_callback(); + } + py::object get_logits_filter_callback_user_data() const { + return py_logits_filter_callback_user_data; + } + void set_logits_filter_callback_user_data(py::object user_data) { + py_logits_filter_callback_user_data = std::move(user_data); + logits_filter_callback_user_data = this; + } + void set_logits_filter_callback(py::function callback) { + py_logits_filter_callback = std::move(callback); + logits_filter_callback_user_data = this; + logits_filter_callback = _logits_filter_callback; + } + void clear_logits_filter_callback() { + py_logits_filter_callback = py::function(); + logits_filter_callback = nullptr; + logits_filter_callback_user_data = this; + } + py::object get_abort_callback_user_data() const { + return py_abort_callback_user_data; + } + void set_abort_callback_user_data(py::object user_data) { + py_abort_callback_user_data = std::move(user_data); + abort_callback_user_data = this; + } + void set_abort_callback(py::function callback) { + py_abort_callback = std::move(callback); + abort_callback_user_data = this; + abort_callback = _abort_callback; + } + void clear_abort_callback() { + py_abort_callback = py::function(); + abort_callback = nullptr; + abort_callback_user_data = this; + } }; WhisperFullParamsWrapper whisper_full_default_params_wrapper(enum whisper_sampling_strategy strategy) { return WhisperFullParamsWrapper(whisper_full_default_params(strategy)); @@ -373,30 +579,72 @@ WhisperFullParamsWrapper whisper_full_default_params_wrapper(enum whisper_sampl // callbacks mechanism void _new_segment_callback(struct whisper_context * ctx, struct whisper_state * state, int n_new, void * user_data){ + (void) state; struct whisper_context_wrapper ctx_w; ctx_w.ptr = ctx; - // call the python callback - py::gil_scoped_acquire gil; // Acquire the GIL while in this scope. - py_new_segment_callback(ctx_w, n_new, user_data); + auto * params = static_cast(user_data); + if (!params || !params->py_new_segment_callback) { + return; + } + + py::gil_scoped_acquire gil; + py::function callback = params->py_new_segment_callback; + if (!has_python_user_data(params->py_new_segment_callback_user_data)) { + callback(ctx_w, n_new); + } else { + callback(ctx_w, n_new, params->py_new_segment_callback_user_data); + } }; -void assign_new_segment_callback(struct whisper_full_params *params, py::function f){ - params->new_segment_callback = _new_segment_callback; - py_new_segment_callback = f; +void assign_new_segment_callback(struct whisper_full_params *params_base, py::object callback){ + auto * params = static_cast(params_base); + if (callback.is_none()) { + params->clear_new_segment_callback(); + return; + } + + params->set_new_segment_callback(callback.cast()); +} + +void clear_new_segment_callback(struct whisper_full_params *params_base) { + auto * params = static_cast(params_base); + params->clear_new_segment_callback(); }; bool _encoder_begin_callback(struct whisper_context * ctx, struct whisper_state * state, void * user_data){ + (void) state; struct whisper_context_wrapper ctx_w; ctx_w.ptr = ctx; - // call the python callback - py::object result_py = py_encoder_begin_callback(ctx_w, user_data); + auto * params = static_cast(user_data); + if (!params || !params->py_encoder_begin_callback) { + return false; + } + + py::gil_scoped_acquire gil; + py::function callback = params->py_encoder_begin_callback; + py::object result_py; + if (!has_python_user_data(params->py_encoder_begin_callback_user_data)) { + result_py = callback(ctx_w); + } else { + result_py = callback(ctx_w, params->py_encoder_begin_callback_user_data); + } bool res = result_py.cast(); return res; } -void assign_encoder_begin_callback(struct whisper_full_params *params, py::function f){ - params->encoder_begin_callback = _encoder_begin_callback; - py_encoder_begin_callback = f; +void assign_encoder_begin_callback(struct whisper_full_params *params_base, py::object callback){ + auto * params = static_cast(params_base); + if (callback.is_none()) { + params->clear_encoder_begin_callback(); + return; + } + + params->set_encoder_begin_callback(callback.cast()); +} + +void clear_encoder_begin_callback(struct whisper_full_params *params_base) { + auto * params = static_cast(params_base); + params->clear_encoder_begin_callback(); } void _logits_filter_callback( @@ -406,15 +654,102 @@ void _logits_filter_callback( int n_tokens, float * logits, void * user_data){ + (void) state; + (void) tokens; struct whisper_context_wrapper ctx_w; ctx_w.ptr = ctx; - // call the python callback - py_logits_filter_callback(ctx_w, n_tokens, logits, user_data); + auto * params = static_cast(user_data); + if (!params || !params->py_logits_filter_callback) { + return; + } + + py::gil_scoped_acquire gil; + py::function callback = params->py_logits_filter_callback; + if (!has_python_user_data(params->py_logits_filter_callback_user_data)) { + callback(ctx_w, n_tokens, logits); + } else { + callback(ctx_w, n_tokens, logits, params->py_logits_filter_callback_user_data); + } +} + +void assign_logits_filter_callback(struct whisper_full_params *params_base, py::object callback){ + auto * params = static_cast(params_base); + if (callback.is_none()) { + params->clear_logits_filter_callback(); + return; + } + + params->set_logits_filter_callback(callback.cast()); +} + +void clear_logits_filter_callback(struct whisper_full_params *params_base) { + auto * params = static_cast(params_base); + params->clear_logits_filter_callback(); +} + +void assign_progress_callback(whisper_full_params *params_base, py::object callback) { + auto * params = static_cast(params_base); + if (callback.is_none()) { + params->clear_progress_callback(); + return; + } + + params->set_progress_callback(callback.cast()); +} + +void clear_progress_callback(whisper_full_params *params_base) { + auto * params = static_cast(params_base); + params->clear_progress_callback(); +} + +bool _abort_callback(void * user_data) { + auto * params = static_cast(user_data); + if (!params || !params->py_abort_callback) { + return false; + } + + py::gil_scoped_acquire gil; + py::function callback = params->py_abort_callback; + py::object result_py; + if (!has_python_user_data(params->py_abort_callback_user_data)) { + result_py = callback(); + } else { + result_py = callback(params->py_abort_callback_user_data); + } + return result_py.cast(); } -void assign_logits_filter_callback(struct whisper_full_params *params, py::function f){ - params->logits_filter_callback = _logits_filter_callback; - py_logits_filter_callback = f; +void assign_abort_callback(whisper_full_params *params_base, py::object callback){ + auto * params = static_cast(params_base); + if (callback.is_none()) { + params->clear_abort_callback(); + return; + } + + params->set_abort_callback(callback.cast()); +} + +void clear_abort_callback(whisper_full_params *params_base) { + auto * params = static_cast(params_base); + params->clear_abort_callback(); +} + +void whisper_log_set_wrapper(py::object callback) { + if (callback.is_none()) { + py_log_callback = py::none(); + whisper_log_set(nullptr, nullptr); + return; + } + + py_log_callback = callback.cast(); + whisper_log_set( + [](enum ggml_log_level level, const char * text, void * user_data) { + (void) user_data; + py::gil_scoped_acquire gil; + py::function log_callback = py_log_callback.cast(); + log_callback(py::int_(static_cast(level)), py::str(text ? text : "")); + }, + nullptr); } py::dict get_greedy(whisper_full_params * params){ @@ -532,7 +867,34 @@ PYBIND11_MODULE(_pywhispercpp, m) { m.attr("WHISPER_HOP_LENGTH") = WHISPER_HOP_LENGTH; m.attr("WHISPER_CHUNK_SIZE") = WHISPER_CHUNK_SIZE; + py::enum_(m, "whisper_alignment_heads_preset") + .value("WHISPER_AHEADS_NONE", whisper_alignment_heads_preset::WHISPER_AHEADS_NONE) + .value("WHISPER_AHEADS_N_TOP_MOST", whisper_alignment_heads_preset::WHISPER_AHEADS_N_TOP_MOST) + .value("WHISPER_AHEADS_CUSTOM", whisper_alignment_heads_preset::WHISPER_AHEADS_CUSTOM) + .value("WHISPER_AHEADS_TINY_EN", whisper_alignment_heads_preset::WHISPER_AHEADS_TINY_EN) + .value("WHISPER_AHEADS_TINY", whisper_alignment_heads_preset::WHISPER_AHEADS_TINY) + .value("WHISPER_AHEADS_BASE_EN", whisper_alignment_heads_preset::WHISPER_AHEADS_BASE_EN) + .value("WHISPER_AHEADS_BASE", whisper_alignment_heads_preset::WHISPER_AHEADS_BASE) + .value("WHISPER_AHEADS_SMALL_EN", whisper_alignment_heads_preset::WHISPER_AHEADS_SMALL_EN) + .value("WHISPER_AHEADS_SMALL", whisper_alignment_heads_preset::WHISPER_AHEADS_SMALL) + .value("WHISPER_AHEADS_MEDIUM_EN", whisper_alignment_heads_preset::WHISPER_AHEADS_MEDIUM_EN) + .value("WHISPER_AHEADS_MEDIUM", whisper_alignment_heads_preset::WHISPER_AHEADS_MEDIUM) + .value("WHISPER_AHEADS_LARGE_V1", whisper_alignment_heads_preset::WHISPER_AHEADS_LARGE_V1) + .value("WHISPER_AHEADS_LARGE_V2", whisper_alignment_heads_preset::WHISPER_AHEADS_LARGE_V2) + .value("WHISPER_AHEADS_LARGE_V3", whisper_alignment_heads_preset::WHISPER_AHEADS_LARGE_V3) + .value("WHISPER_AHEADS_LARGE_V3_TURBO", whisper_alignment_heads_preset::WHISPER_AHEADS_LARGE_V3_TURBO) + .export_values(); + py::class_(m, "whisper_context"); + py::class_(m, "whisper_context_params") + .def(py::init<>()) + .def_readwrite("use_gpu", &whisper_context_params::use_gpu) + .def_readwrite("flash_attn", &whisper_context_params::flash_attn) + .def_readwrite("gpu_device", &whisper_context_params::gpu_device) + .def_readwrite("dtw_token_timestamps", &whisper_context_params::dtw_token_timestamps) + .def_readwrite("dtw_aheads_preset", &whisper_context_params::dtw_aheads_preset) + .def_readwrite("dtw_n_top", &whisper_context_params::dtw_n_top) + .def_readwrite("dtw_mem_size", &whisper_context_params::dtw_mem_size); py::class_(m, "whisper_token") .def(py::init<>()); py::class_(m,"whisper_token_data") @@ -545,20 +907,23 @@ PYBIND11_MODULE(_pywhispercpp, m) { .def_readwrite("ptsum", &whisper_token_data::ptsum) .def_readwrite("t0", &whisper_token_data::t0) .def_readwrite("t1", &whisper_token_data::t1) + .def_readwrite("t_dtw", &whisper_token_data::t_dtw) .def_readwrite("vlen", &whisper_token_data::vlen); py::class_(m,"whisper_model_loader") .def(py::init<>()); - DEF_RELEASE_GIL("whisper_init_from_file", &whisper_init_from_file_wrapper, "Various functions for loading a ggml whisper model.\n" - "Allocate (almost) all memory needed for the model.\n" - "Return NULL on failure"); - DEF_RELEASE_GIL("whisper_init_from_buffer", &whisper_init_from_buffer_wrapper, "Various functions for loading a ggml whisper model.\n" - "Allocate (almost) all memory needed for the model.\n" - "Return NULL on failure"); - DEF_RELEASE_GIL("whisper_init", &whisper_init_wrapper, "Various functions for loading a ggml whisper model.\n" - "Allocate (almost) all memory needed for the model.\n" - "Return NULL on failure"); + m.def("whisper_context_default_params", &whisper_context_default_params, + "Return the default context parameters used during model initialization."); + DEF_RELEASE_GIL("whisper_init_from_file_with_params", &whisper_init_from_file_with_params_wrapper, "Various functions for loading a ggml whisper model.\n" + "Allocate (almost) all memory needed for the model.\n" + "Return NULL on failure"); + DEF_RELEASE_GIL("whisper_init_from_buffer_with_params", &whisper_init_from_buffer_with_params_wrapper, "Various functions for loading a ggml whisper model.\n" + "Allocate (almost) all memory needed for the model.\n" + "Return NULL on failure"); + DEF_RELEASE_GIL("whisper_init_with_params", &whisper_init_with_params_wrapper, "Various functions for loading a ggml whisper model.\n" + "Allocate (almost) all memory needed for the model.\n" + "Return NULL on failure"); m.def("whisper_free", &whisper_free_wrapper, "Frees all memory allocated by the model."); @@ -694,11 +1059,7 @@ PYBIND11_MODULE(_pywhispercpp, m) { << "progress_callback=" << (self.progress_callback ? "(function pointer)" : "None") << ", " << "encoder_begin_callback=" << (self.encoder_begin_callback ? "(function pointer)" : "None") << ", " << "abort_callback=" << (self.abort_callback ? "(function pointer)" : "None") << ", " - << "logits_filter_callback=" << (self.logits_filter_callback ? "(function pointer)" : "None") << ", " - << "grammar_rules=" << (self.grammar_rules ? "(whisper_grammar_element **)" : "None") << ", " - << "n_grammar_rules=" << self.n_grammar_rules << ", " - << "i_start_rule=" << self.i_start_rule << ", " - << "grammar_penalty=" << self.grammar_penalty + << "logits_filter_callback=" << (self.logits_filter_callback ? "(function pointer)" : "None") << ")"; return oss.str(); }); @@ -712,10 +1073,23 @@ PYBIND11_MODULE(_pywhispercpp, m) { .def_readwrite("duration_ms", &WhisperFullParamsWrapper::duration_ms) .def_readwrite("translate", &WhisperFullParamsWrapper::translate) .def_readwrite("no_context", &WhisperFullParamsWrapper::no_context) + .def_readwrite("no_timestamps", &WhisperFullParamsWrapper::no_timestamps) .def_readwrite("single_segment", &WhisperFullParamsWrapper::single_segment) .def_readwrite("print_special", &WhisperFullParamsWrapper::print_special) .def_readwrite("print_progress", &WhisperFullParamsWrapper::print_progress) .def_readwrite("progress_callback", &WhisperFullParamsWrapper::py_progress_callback) + .def("set_progress_callback", + [](WhisperFullParamsWrapper &self, py::object callback) { + if (callback.is_none()) { + self.clear_progress_callback(); + } else { + self.set_progress_callback(callback.cast()); + } + }, + py::arg("callback") = py::none(), + "Assign a progress callback that receives progress updates.") + .def("clear_progress_callback", &WhisperFullParamsWrapper::clear_progress_callback, + "Clear any previously assigned progress callback while preserving default progress behavior.") .def_readwrite("print_realtime", &WhisperFullParamsWrapper::print_realtime) .def_readwrite("print_timestamps", &WhisperFullParamsWrapper::print_timestamps) .def_readwrite("token_timestamps", &WhisperFullParamsWrapper::token_timestamps) @@ -724,7 +1098,9 @@ PYBIND11_MODULE(_pywhispercpp, m) { .def_readwrite("max_len", &WhisperFullParamsWrapper::max_len) .def_readwrite("split_on_word", &WhisperFullParamsWrapper::split_on_word) .def_readwrite("max_tokens", &WhisperFullParamsWrapper::max_tokens) + .def_readwrite("debug_mode", &WhisperFullParamsWrapper::debug_mode) .def_readwrite("audio_ctx", &WhisperFullParamsWrapper::audio_ctx) + .def_readwrite("tdrz_enable", &WhisperFullParamsWrapper::tdrz_enable) .def_property("suppress_regex", [](WhisperFullParamsWrapper &self) { return py::str(self.suppress_regex ? self.suppress_regex : ""); @@ -740,8 +1116,60 @@ PYBIND11_MODULE(_pywhispercpp, m) { self.set_initial_prompt(initial_prompt); } ) - .def_readwrite("prompt_tokens", &WhisperFullParamsWrapper::prompt_tokens) + .def_property("prompt_tokens", + [](WhisperFullParamsWrapper &self) { + return self.get_prompt_tokens(); + }, + [](WhisperFullParamsWrapper &self, py::object tokens) { + if (tokens.is_none()) { + self.clear_prompt_tokens(); + } else { + self.set_prompt_tokens(tokens.cast>()); + } + }) + .def("set_prompt_tokens", &WhisperFullParamsWrapper::set_prompt_tokens, + py::arg("tokens"), + "Assign prompt tokens from a Python sequence.") + .def("clear_prompt_tokens", &WhisperFullParamsWrapper::clear_prompt_tokens, + "Clear any previously assigned prompt tokens.") + .def("set_new_segment_callback", + [](WhisperFullParamsWrapper &self, py::object callback) { + if (callback.is_none()) { + self.clear_new_segment_callback(); + } else { + self.set_new_segment_callback(callback.cast()); + } + }, + py::arg("callback") = py::none(), + "Assign a new-segment callback.") + .def("clear_new_segment_callback", &WhisperFullParamsWrapper::clear_new_segment_callback, + "Clear any previously assigned new-segment callback.") + .def("set_encoder_begin_callback", + [](WhisperFullParamsWrapper &self, py::object callback) { + if (callback.is_none()) { + self.clear_encoder_begin_callback(); + } else { + self.set_encoder_begin_callback(callback.cast()); + } + }, + py::arg("callback") = py::none(), + "Assign an encoder-begin callback.") + .def("clear_encoder_begin_callback", &WhisperFullParamsWrapper::clear_encoder_begin_callback, + "Clear any previously assigned encoder-begin callback.") + .def("set_abort_callback", + [](WhisperFullParamsWrapper &self, py::object callback) { + if (callback.is_none()) { + self.clear_abort_callback(); + } else { + self.set_abort_callback(callback.cast()); + } + }, + py::arg("callback") = py::none(), + "Assign an abort callback that returns True to stop processing.") + .def("clear_abort_callback", &WhisperFullParamsWrapper::clear_abort_callback, + "Clear any previously assigned abort callback.") .def_readwrite("prompt_n_tokens", &WhisperFullParamsWrapper::prompt_n_tokens) + .def_readwrite("carry_initial_prompt", &WhisperFullParamsWrapper::carry_initial_prompt) .def_property("language", [](WhisperFullParamsWrapper &self) { return py::str(self.language); @@ -754,7 +1182,9 @@ PYBIND11_MODULE(_pywhispercpp, m) { self.language = ""; //defaults to auto-detect } }) + .def_readwrite("detect_language", &WhisperFullParamsWrapper::detect_language) .def_readwrite("suppress_blank", &WhisperFullParamsWrapper::suppress_blank) + .def_readwrite("suppress_nst", &WhisperFullParamsWrapper::suppress_nst) .def_readwrite("temperature", &WhisperFullParamsWrapper::temperature) .def_readwrite("max_initial_ts", &WhisperFullParamsWrapper::max_initial_ts) .def_readwrite("length_penalty", &WhisperFullParamsWrapper::length_penalty) @@ -767,9 +1197,33 @@ PYBIND11_MODULE(_pywhispercpp, m) { [](WhisperFullParamsWrapper &self, py::dict dict) {self.greedy.best_of = dict["best_of"].cast();}) .def_property("beam_search", [](WhisperFullParamsWrapper &self) {return py::dict("beam_size"_a=self.beam_search.beam_size, "patience"_a=self.beam_search.patience);}, [](WhisperFullParamsWrapper &self, py::dict dict) {self.beam_search.beam_size = dict["beam_size"].cast(); self.beam_search.patience = dict["patience"].cast();}) - .def_readwrite("new_segment_callback_user_data", &WhisperFullParamsWrapper::new_segment_callback_user_data) - .def_readwrite("encoder_begin_callback_user_data", &WhisperFullParamsWrapper::encoder_begin_callback_user_data) - .def_readwrite("logits_filter_callback_user_data", &WhisperFullParamsWrapper::logits_filter_callback_user_data) + .def_property("new_segment_callback_user_data", + &WhisperFullParamsWrapper::get_new_segment_callback_user_data, + &WhisperFullParamsWrapper::set_new_segment_callback_user_data) + .def_property("progress_callback_user_data", + &WhisperFullParamsWrapper::get_progress_callback_user_data, + &WhisperFullParamsWrapper::set_progress_callback_user_data) + .def_property("encoder_begin_callback_user_data", + &WhisperFullParamsWrapper::get_encoder_begin_callback_user_data, + &WhisperFullParamsWrapper::set_encoder_begin_callback_user_data) + .def_property("abort_callback_user_data", + &WhisperFullParamsWrapper::get_abort_callback_user_data, + &WhisperFullParamsWrapper::set_abort_callback_user_data) + .def_property("logits_filter_callback_user_data", + &WhisperFullParamsWrapper::get_logits_filter_callback_user_data, + &WhisperFullParamsWrapper::set_logits_filter_callback_user_data) + .def("set_logits_filter_callback", + [](WhisperFullParamsWrapper &self, py::object callback) { + if (callback.is_none()) { + self.clear_logits_filter_callback(); + } else { + self.set_logits_filter_callback(callback.cast()); + } + }, + py::arg("callback") = py::none(), + "Assign a logits-filter callback.") + .def("clear_logits_filter_callback", &WhisperFullParamsWrapper::clear_logits_filter_callback, + "Clear any previously assigned logits-filter callback.") .def_readwrite("vad", &WhisperFullParamsWrapper::vad) .def_property("vad_model_path", [](WhisperFullParamsWrapper &self) { @@ -799,6 +1253,8 @@ PYBIND11_MODULE(_pywhispercpp, m) { m.def("whisper_full_lang_id", &whisper_full_lang_id_wrapper, "Language id associated with the current context"); m.def("whisper_full_get_segment_t0", &whisper_full_get_segment_t0_wrapper, "Get the start time of the specified segment"); m.def("whisper_full_get_segment_t1", &whisper_full_get_segment_t1_wrapper, "Get the end time of the specified segment"); + m.def("whisper_full_get_segment_speaker_turn_next", &whisper_full_get_segment_speaker_turn_next_wrapper, + "Get whether the next segment is predicted as a speaker turn."); m.def("whisper_full_get_segment_text", &whisper_full_get_segment_text_wrapper, "Get the text of the specified segment"); m.def("whisper_full_n_tokens", &whisper_full_n_tokens_wrapper, "Get number of tokens in the specified segment."); @@ -812,6 +1268,18 @@ PYBIND11_MODULE(_pywhispercpp, m) { m.def("whisper_full_get_token_p", &whisper_full_get_token_p_wrapper, "Get the probability of the specified token in the specified segment."); m.def("whisper_ctx_init_openvino_encoder", &whisper_ctx_init_openvino_encoder_wrapper, "Given a context, enable use of OpenVINO for encode inference."); + m.def("whisper_model_type_readable", &whisper_model_type_readable_wrapper, "Return the readable model type string."); + m.def("whisper_model_n_vocab", &whisper_model_n_vocab_wrapper, "Return the model vocabulary size."); + m.def("whisper_model_n_audio_ctx", &whisper_model_n_audio_ctx_wrapper, "Return the audio context size baked into the model."); + m.def("whisper_model_n_audio_state", &whisper_model_n_audio_state_wrapper, "Return the number of audio state units in the model."); + m.def("whisper_model_n_audio_head", &whisper_model_n_audio_head_wrapper, "Return the number of audio attention heads in the model."); + m.def("whisper_model_n_audio_layer", &whisper_model_n_audio_layer_wrapper, "Return the number of audio layers in the model."); + m.def("whisper_model_n_text_ctx", &whisper_model_n_text_ctx_wrapper, "Return the text context size baked into the model."); + m.def("whisper_model_n_text_state", &whisper_model_n_text_state_wrapper, "Return the number of text state units in the model."); + m.def("whisper_model_n_text_head", &whisper_model_n_text_head_wrapper, "Return the number of text attention heads in the model."); + m.def("whisper_model_n_text_layer", &whisper_model_n_text_layer_wrapper, "Return the number of text layers in the model."); + m.def("whisper_model_n_mels", &whisper_model_n_mels_wrapper, "Return the number of mel bins used by the model."); + m.def("whisper_model_ftype", &whisper_model_ftype_wrapper, "Return the model file type identifier."); //////////////////////////////////////////////////////////////////////////// @@ -823,14 +1291,66 @@ PYBIND11_MODULE(_pywhispercpp, m) { // Helper mechanism to set callbacks from python // The only difference from the C-Style API - m.def("assign_new_segment_callback", &assign_new_segment_callback, "Assigns a new_segment_callback, takes instance and a callable function with the same parameters which are defined in the interface", - py::arg("params"), py::arg("callback")); + m.def("assign_new_segment_callback", + [](whisper_full_params * params, py::object callback) { + assign_new_segment_callback(params, callback); + }, + "Assign a new-segment callback.", + py::arg("params"), py::arg("callback") = py::none()); - m.def("assign_encoder_begin_callback", &assign_encoder_begin_callback, "Assigns an encoder_begin_callback, takes instance and a callable function with the same parameters which are defined in the interface", - py::arg("params"), py::arg("callback")); + m.def("clear_new_segment_callback", &clear_new_segment_callback, + "Clear any previously assigned new-segment callback.", + py::arg("params")); - m.def("assign_logits_filter_callback", &assign_logits_filter_callback, "Assigns a logits_filter_callback, takes instance and a callable function with the same parameters which are defined in the interface", - py::arg("params"), py::arg("callback")); + m.def("assign_encoder_begin_callback", + [](whisper_full_params * params, py::object callback) { + assign_encoder_begin_callback(params, callback); + }, + "Assign an encoder-begin callback.", + py::arg("params"), py::arg("callback") = py::none()); + + m.def("clear_encoder_begin_callback", &clear_encoder_begin_callback, + "Clear any previously assigned encoder-begin callback.", + py::arg("params")); + + m.def("assign_logits_filter_callback", + [](whisper_full_params * params, py::object callback) { + assign_logits_filter_callback(params, callback); + }, + "Assign a logits-filter callback.", + py::arg("params"), py::arg("callback") = py::none()); + + m.def("clear_logits_filter_callback", &clear_logits_filter_callback, + "Clear any previously assigned logits-filter callback.", + py::arg("params")); + + m.def("assign_progress_callback", + [](whisper_full_params * params, py::object callback) { + assign_progress_callback(params, callback); + }, + "Assign a progress callback that receives progress updates.", + py::arg("params"), py::arg("callback") = py::none()); + + m.def("clear_progress_callback", &clear_progress_callback, + "Clear any previously assigned progress callback while preserving default progress behavior.", + py::arg("params")); + + m.def("assign_abort_callback", + [](whisper_full_params * params, py::object callback) { + assign_abort_callback(params, callback); + }, + "Assign an abort callback that returns True to stop processing.", + py::arg("params"), py::arg("callback") = py::none()); + + m.def("clear_abort_callback", &clear_abort_callback, "Clear any previously assigned abort callback.", + py::arg("params")); + + m.def("whisper_log_set", + [](py::object callback) { + whisper_log_set_wrapper(callback); + }, + "Assign a Python log callback or None to restore the default logger.", + py::arg("callback") = py::none()); // VAD py::class_(m,"whisper_vad_params") diff --git a/tests/test_c_api.py b/tests/test_c_api.py index b415395..6138c1c 100644 --- a/tests/test_c_api.py +++ b/tests/test_c_api.py @@ -11,7 +11,10 @@ class TestCAPI(TestCase): model_file = './whisper.cpp/models/for-tests-ggml-tiny.en.bin' def test_whisper_init_from_file(self): - ctx = pw.whisper_init_from_file(self.model_file) + ctx = pw.whisper_init_from_file_with_params( + self.model_file, + pw.whisper_context_default_params(), + ) self.assertIsInstance(ctx, pw.whisper_context) def test_whisper_lang_str(self): diff --git a/tests/test_model.py b/tests/test_model.py index f38200f..b68f8a6 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -8,6 +8,7 @@ from pathlib import Path from unittest import TestCase +import _pywhispercpp as pw from pywhispercpp.model import Model, Segment if __name__ == '__main__': @@ -44,6 +45,44 @@ def test_auto_detect_language(self): detected_language, probs = self.model.auto_detect_language(str(self.audio_file)) return self.assertIsInstance(detected_language, tuple) and self.assertEqual(detected_language[0], 'en') + def test_context_params_dict_init(self): + model = Model( + "tiny", + models_dir=str(WHISPER_CPP_DIR/'models'), + context_params={"use_gpu": False, "flash_attn": False}, + ) + self.assertIsInstance(model, Model) + + def test_compat_alias_for_non_speech_tokens(self): + model = Model( + "tiny", + models_dir=str(WHISPER_CPP_DIR/'models'), + suppress_non_speech_tokens=True, + ) + self.assertTrue(model.get_params()["suppress_nst"]) + + def test_prompt_token_helper_exists(self): + params = pw.whisper_full_default_params( + pw.whisper_sampling_strategy.WHISPER_SAMPLING_GREEDY + ) + params.set_prompt_tokens((1, 2, 3)) + self.assertEqual(params.prompt_n_tokens, 3) + + def test_model_metadata_bindings(self): + self.assertIsInstance(pw.whisper_model_type_readable(self.model._ctx), str) + self.assertGreater(pw.whisper_model_n_vocab(self.model._ctx), 0) + self.assertGreater(pw.whisper_model_n_audio_ctx(self.model._ctx), 0) + self.assertGreater(pw.whisper_model_n_text_ctx(self.model._ctx), 0) + + def test_speaker_turn_accessor_smoke(self): + self.model.transcribe(str(self.audio_file)) + segment_count = pw.whisper_full_n_segments(self.model._ctx) + self.assertGreater(segment_count, 0) + self.assertIsInstance( + pw.whisper_full_get_segment_speaker_turn_next(self.model._ctx, 0), + bool, + ) + if __name__ == '__main__': unittest.main() diff --git a/whisper.cpp b/whisper.cpp index 4979e04..9386f23 160000 --- a/whisper.cpp +++ b/whisper.cpp @@ -1 +1 @@ -Subproject commit 4979e04f5dcaccb36057e059bbaed8a2f5288315 +Subproject commit 9386f239401074690479731c1e41683fbbeac557