From 9b3b562b1c56146628643df0f2aa63ecdca698c8 Mon Sep 17 00:00:00 2001 From: Kota-Maeda Date: Mon, 8 Jun 2026 19:55:16 +0900 Subject: [PATCH] fix(model): detect audio format from magic bytes instead of hardcoding .mp3 invoke_speech_to_text wrote incoming audio to a NamedTemporaryFile with a hardcoded .mp3 suffix, so non-mp3 content (m4a/AAC, wav, ogg, flac, webm) was labeled .mp3 and rejected by OpenAI/Azure Whisper with "Invalid file format". The model-invoke payload carries only raw bytes (no filename), so detect the container from the leading magic bytes and pick the matching suffix, falling back to .mp3 for unknown content (backward compatible). Adds unit tests for the detection helper and an end-to-end test asserting the dispatch labels the temp file by format and writes the full payload. --- src/dify_plugin/core/plugin_executor.py | 34 ++++++- tests/core/__init__.py | 0 tests/core/test_detect_audio_suffix.py | 47 +++++++++ tests/core/test_invoke_speech_to_text.py | 123 +++++++++++++++++++++++ 4 files changed, 202 insertions(+), 2 deletions(-) create mode 100644 tests/core/__init__.py create mode 100644 tests/core/test_detect_audio_suffix.py create mode 100644 tests/core/test_invoke_speech_to_text.py diff --git a/src/dify_plugin/core/plugin_executor.py b/src/dify_plugin/core/plugin_executor.py index 55594a2b..4912e6a4 100644 --- a/src/dify_plugin/core/plugin_executor.py +++ b/src/dify_plugin/core/plugin_executor.py @@ -92,6 +92,34 @@ from dify_plugin.interfaces.tool import Tool +def _detect_audio_suffix(header: bytes) -> str: + """Guess a temp-file suffix from the leading magic bytes of an audio blob. + + The speech-to-text model-invoke payload carries only the raw audio bytes, + with no original filename or extension. OpenAI/Azure Whisper endpoints + determine the audio format from the multipart filename extension, so a + wrong extension makes them reject otherwise-supported formats. We sniff the + container from the header to label the temp file correctly. + + Returns: + The matching suffix (including the leading dot). MP3 content and any + unrecognized header fall through to ``.mp3``, preserving the previous + hardcoded behavior for those cases. + """ + suffix = ".mp3" + if header[:4] == b"RIFF" and header[8:12] == b"WAVE": + suffix = ".wav" + elif header[:4] == b"fLaC": + suffix = ".flac" + elif header[:4] == b"OggS": # covers oga / ogg-opus + suffix = ".ogg" + elif header[4:8] == b"ftyp": # covers m4a / mp4 (AAC) + suffix = ".m4a" + elif header[:4] == b"\x1a\x45\xdf\xa3": # EBML (webm / matroska) + suffix = ".webm" + return suffix + + class PluginExecutor: # noqa: PLR0904 def __init__(self, config: DifyPluginEnv, registration: PluginRegistration) -> None: self.config = config @@ -536,8 +564,10 @@ def invoke_speech_to_text( data.model_type, ) - with tempfile.NamedTemporaryFile(suffix=".mp3", mode="wb", delete=True) as temp: - temp.write(binascii.unhexlify(data.file)) + audio_bytes = binascii.unhexlify(data.file) + suffix = _detect_audio_suffix(audio_bytes[:16]) + with tempfile.NamedTemporaryFile(suffix=suffix, mode="wb", delete=True) as temp: + temp.write(audio_bytes) temp.flush() with pathlib.Path(temp.name).open("rb") as f: diff --git a/tests/core/__init__.py b/tests/core/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/core/test_detect_audio_suffix.py b/tests/core/test_detect_audio_suffix.py new file mode 100644 index 00000000..30da0ca4 --- /dev/null +++ b/tests/core/test_detect_audio_suffix.py @@ -0,0 +1,47 @@ +"""Tests for audio format detection used by speech-to-text dispatch. + +These exercise ``_detect_audio_suffix`` with only the leading magic bytes of +each container, so no real audio files are required. +""" + +import pytest + +from dify_plugin.core.plugin_executor import _detect_audio_suffix # noqa: PLC2701 + + +@pytest.mark.parametrize( + ("header", "expected"), + [ + # WAV: "RIFF" .... "WAVE" + (b"RIFF\x24\x08\x00\x00WAVEfmt ", ".wav"), + # FLAC + (b"fLaC\x00\x00\x00\x22", ".flac"), + # Ogg (covers oga / ogg-opus) + (b"OggS\x00\x02\x00\x00\x00\x00\x00\x00", ".ogg"), + # MP4/M4A (AAC): ftyp box at offset 4 + (b"\x00\x00\x00\x20ftypM4A ", ".m4a"), + (b"\x00\x00\x00\x18ftypmp42", ".m4a"), + # WebM / Matroska EBML magic + (b"\x1a\x45\xdf\xa3\x01\x00\x00\x00", ".webm"), + ], +) +def test_detect_audio_suffix_known_formats(header: bytes, expected: str) -> None: + assert _detect_audio_suffix(header) == expected + + +@pytest.mark.parametrize( + "header", + [ + b"", + b"\x00\x00\x00\x00", + b"not an audio header", + # "RIFF" but not "WAVE" (e.g. AVI) must not be misdetected as wav + b"RIFF\x24\x08\x00\x00AVI LIST", + # MP3 (ID3 tag and raw frame sync) intentionally falls through to .mp3, + # matching the previous hardcoded behavior. + b"ID3\x04\x00\x00\x00\x00\x00\x00", + b"\xff\xf3\x90\x64\x00\x00\x00\x00", + ], +) +def test_detect_audio_suffix_falls_back_to_mp3(header: bytes) -> None: + assert _detect_audio_suffix(header) == ".mp3" diff --git a/tests/core/test_invoke_speech_to_text.py b/tests/core/test_invoke_speech_to_text.py new file mode 100644 index 00000000..386374c0 --- /dev/null +++ b/tests/core/test_invoke_speech_to_text.py @@ -0,0 +1,123 @@ +"""End-to-end tests for ``PluginExecutor.invoke_speech_to_text``. + +These drive the real dispatch with a recording fake model to assert that the +temp file handed to the speech2text model is labeled with a suffix that matches +the audio container, and that the full payload (not just the sniffed header) is +written. +""" + +import binascii +import pathlib +from collections.abc import Mapping +from typing import IO + +import pytest + +from dify_plugin.config.config import DifyPluginEnv +from dify_plugin.core.entities.plugin.request import ModelInvokeSpeech2TextRequest +from dify_plugin.core.plugin_executor import PluginExecutor +from dify_plugin.core.runtime import Session +from dify_plugin.entities import I18nObject +from dify_plugin.entities.model import AIModelEntity, FetchFrom, ModelType +from dify_plugin.errors.model import InvokeError +from dify_plugin.interfaces.model.speech2text_model import Speech2TextModel + + +def _model_entity() -> AIModelEntity: + return AIModelEntity( + model="whisper", + label=I18nObject(en_us="whisper"), + model_type=ModelType.SPEECH2TEXT, + fetch_from=FetchFrom.PREDEFINED_MODEL, + model_properties={}, + parameter_rules=[], + ) + + +class RecordingSpeech2TextModel(Speech2TextModel): + """Captures the temp-file suffix and bytes the executor hands to the model.""" + + model_type = ModelType.SPEECH2TEXT + + def __init__(self) -> None: + super().__init__(model_schemas=[_model_entity()]) + self.captured_suffix: str | None = None + self.captured_bytes: bytes | None = None + + def validate_credentials(self, model: str, credentials: Mapping) -> None: + del model, credentials + + @property + def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: + return {} + + def _invoke( + self, + model: str, + credentials: dict, + file: IO[bytes], + user: str | None = None, + ) -> str: + del model, credentials, user + self.captured_suffix = pathlib.Path(file.name).suffix + self.captured_bytes = file.read() + return "transcribed" + + +class _Registration: + def __init__(self, model_instance: object) -> None: + self.model_instance = model_instance + + def get_model_instance(self, provider: str, model_type: ModelType) -> object: + del provider, model_type + return self.model_instance + + +def _request(audio: bytes) -> ModelInvokeSpeech2TextRequest: + return ModelInvokeSpeech2TextRequest( + user_id="user-1", + provider="provider", + model_type=ModelType.SPEECH2TEXT, + model="whisper", + credentials={}, + file=binascii.hexlify(audio).decode("ascii"), + ) + + +# Real container headers padded with trailing bytes so the test also proves the +# entire payload is written, not just the sniffed 16-byte header. +WAV = b"RIFF\x24\x08\x00\x00WAVEfmt " + b"\x00" * 64 +M4A = b"\x00\x00\x00\x20ftypM4A " + b"\x11" * 64 +OGG = b"OggS\x00\x02\x00\x00\x00\x00\x00\x00" + b"\x22" * 64 +UNKNOWN = b"\x00\x01\x02\x03 not a known audio container " + b"\x33" * 32 + + +@pytest.mark.parametrize( + ("audio", "expected_suffix"), + [ + (WAV, ".wav"), + (M4A, ".m4a"), + (OGG, ".ogg"), + (UNKNOWN, ".mp3"), + ], +) +def test_invoke_speech_to_text_labels_temp_file_by_format( + audio: bytes, + expected_suffix: str, +) -> None: + model = RecordingSpeech2TextModel() + executor = PluginExecutor(DifyPluginEnv(), _Registration(model)) + + result = executor.invoke_speech_to_text(Session.empty_session(), _request(audio)) + + assert result == {"result": "transcribed"} + assert model.captured_suffix == expected_suffix + # The full payload reaches the model, not just the sniffed header. + assert model.captured_bytes == audio + + +def test_invoke_speech_to_text_rejects_non_speech2text_model() -> None: + executor = PluginExecutor(DifyPluginEnv(), _Registration(object())) + + with pytest.raises(ValueError, match="not found for provider"): + executor.invoke_speech_to_text(Session.empty_session(), _request(WAV))