diff --git a/src/dify_plugin/core/plugin_executor.py b/src/dify_plugin/core/plugin_executor.py index 55594a2b..4912e6a4 100644 --- a/src/dify_plugin/core/plugin_executor.py +++ b/src/dify_plugin/core/plugin_executor.py @@ -92,6 +92,34 @@ from dify_plugin.interfaces.tool import Tool +def _detect_audio_suffix(header: bytes) -> str: + """Guess a temp-file suffix from the leading magic bytes of an audio blob. + + The speech-to-text model-invoke payload carries only the raw audio bytes, + with no original filename or extension. OpenAI/Azure Whisper endpoints + determine the audio format from the multipart filename extension, so a + wrong extension makes them reject otherwise-supported formats. We sniff the + container from the header to label the temp file correctly. + + Returns: + The matching suffix (including the leading dot). MP3 content and any + unrecognized header fall through to ``.mp3``, preserving the previous + hardcoded behavior for those cases. + """ + suffix = ".mp3" + if header[:4] == b"RIFF" and header[8:12] == b"WAVE": + suffix = ".wav" + elif header[:4] == b"fLaC": + suffix = ".flac" + elif header[:4] == b"OggS": # covers oga / ogg-opus + suffix = ".ogg" + elif header[4:8] == b"ftyp": # covers m4a / mp4 (AAC) + suffix = ".m4a" + elif header[:4] == b"\x1a\x45\xdf\xa3": # EBML (webm / matroska) + suffix = ".webm" + return suffix + + class PluginExecutor: # noqa: PLR0904 def __init__(self, config: DifyPluginEnv, registration: PluginRegistration) -> None: self.config = config @@ -536,8 +564,10 @@ def invoke_speech_to_text( data.model_type, ) - with tempfile.NamedTemporaryFile(suffix=".mp3", mode="wb", delete=True) as temp: - temp.write(binascii.unhexlify(data.file)) + audio_bytes = binascii.unhexlify(data.file) + suffix = _detect_audio_suffix(audio_bytes[:16]) + with tempfile.NamedTemporaryFile(suffix=suffix, mode="wb", delete=True) as temp: + temp.write(audio_bytes) temp.flush() with pathlib.Path(temp.name).open("rb") as f: diff --git a/tests/core/__init__.py b/tests/core/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/core/test_detect_audio_suffix.py b/tests/core/test_detect_audio_suffix.py new file mode 100644 index 00000000..30da0ca4 --- /dev/null +++ b/tests/core/test_detect_audio_suffix.py @@ -0,0 +1,47 @@ +"""Tests for audio format detection used by speech-to-text dispatch. + +These exercise ``_detect_audio_suffix`` with only the leading magic bytes of +each container, so no real audio files are required. +""" + +import pytest + +from dify_plugin.core.plugin_executor import _detect_audio_suffix # noqa: PLC2701 + + +@pytest.mark.parametrize( + ("header", "expected"), + [ + # WAV: "RIFF" .... "WAVE" + (b"RIFF\x24\x08\x00\x00WAVEfmt ", ".wav"), + # FLAC + (b"fLaC\x00\x00\x00\x22", ".flac"), + # Ogg (covers oga / ogg-opus) + (b"OggS\x00\x02\x00\x00\x00\x00\x00\x00", ".ogg"), + # MP4/M4A (AAC): ftyp box at offset 4 + (b"\x00\x00\x00\x20ftypM4A ", ".m4a"), + (b"\x00\x00\x00\x18ftypmp42", ".m4a"), + # WebM / Matroska EBML magic + (b"\x1a\x45\xdf\xa3\x01\x00\x00\x00", ".webm"), + ], +) +def test_detect_audio_suffix_known_formats(header: bytes, expected: str) -> None: + assert _detect_audio_suffix(header) == expected + + +@pytest.mark.parametrize( + "header", + [ + b"", + b"\x00\x00\x00\x00", + b"not an audio header", + # "RIFF" but not "WAVE" (e.g. AVI) must not be misdetected as wav + b"RIFF\x24\x08\x00\x00AVI LIST", + # MP3 (ID3 tag and raw frame sync) intentionally falls through to .mp3, + # matching the previous hardcoded behavior. + b"ID3\x04\x00\x00\x00\x00\x00\x00", + b"\xff\xf3\x90\x64\x00\x00\x00\x00", + ], +) +def test_detect_audio_suffix_falls_back_to_mp3(header: bytes) -> None: + assert _detect_audio_suffix(header) == ".mp3" diff --git a/tests/core/test_invoke_speech_to_text.py b/tests/core/test_invoke_speech_to_text.py new file mode 100644 index 00000000..386374c0 --- /dev/null +++ b/tests/core/test_invoke_speech_to_text.py @@ -0,0 +1,123 @@ +"""End-to-end tests for ``PluginExecutor.invoke_speech_to_text``. + +These drive the real dispatch with a recording fake model to assert that the +temp file handed to the speech2text model is labeled with a suffix that matches +the audio container, and that the full payload (not just the sniffed header) is +written. +""" + +import binascii +import pathlib +from collections.abc import Mapping +from typing import IO + +import pytest + +from dify_plugin.config.config import DifyPluginEnv +from dify_plugin.core.entities.plugin.request import ModelInvokeSpeech2TextRequest +from dify_plugin.core.plugin_executor import PluginExecutor +from dify_plugin.core.runtime import Session +from dify_plugin.entities import I18nObject +from dify_plugin.entities.model import AIModelEntity, FetchFrom, ModelType +from dify_plugin.errors.model import InvokeError +from dify_plugin.interfaces.model.speech2text_model import Speech2TextModel + + +def _model_entity() -> AIModelEntity: + return AIModelEntity( + model="whisper", + label=I18nObject(en_us="whisper"), + model_type=ModelType.SPEECH2TEXT, + fetch_from=FetchFrom.PREDEFINED_MODEL, + model_properties={}, + parameter_rules=[], + ) + + +class RecordingSpeech2TextModel(Speech2TextModel): + """Captures the temp-file suffix and bytes the executor hands to the model.""" + + model_type = ModelType.SPEECH2TEXT + + def __init__(self) -> None: + super().__init__(model_schemas=[_model_entity()]) + self.captured_suffix: str | None = None + self.captured_bytes: bytes | None = None + + def validate_credentials(self, model: str, credentials: Mapping) -> None: + del model, credentials + + @property + def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: + return {} + + def _invoke( + self, + model: str, + credentials: dict, + file: IO[bytes], + user: str | None = None, + ) -> str: + del model, credentials, user + self.captured_suffix = pathlib.Path(file.name).suffix + self.captured_bytes = file.read() + return "transcribed" + + +class _Registration: + def __init__(self, model_instance: object) -> None: + self.model_instance = model_instance + + def get_model_instance(self, provider: str, model_type: ModelType) -> object: + del provider, model_type + return self.model_instance + + +def _request(audio: bytes) -> ModelInvokeSpeech2TextRequest: + return ModelInvokeSpeech2TextRequest( + user_id="user-1", + provider="provider", + model_type=ModelType.SPEECH2TEXT, + model="whisper", + credentials={}, + file=binascii.hexlify(audio).decode("ascii"), + ) + + +# Real container headers padded with trailing bytes so the test also proves the +# entire payload is written, not just the sniffed 16-byte header. +WAV = b"RIFF\x24\x08\x00\x00WAVEfmt " + b"\x00" * 64 +M4A = b"\x00\x00\x00\x20ftypM4A " + b"\x11" * 64 +OGG = b"OggS\x00\x02\x00\x00\x00\x00\x00\x00" + b"\x22" * 64 +UNKNOWN = b"\x00\x01\x02\x03 not a known audio container " + b"\x33" * 32 + + +@pytest.mark.parametrize( + ("audio", "expected_suffix"), + [ + (WAV, ".wav"), + (M4A, ".m4a"), + (OGG, ".ogg"), + (UNKNOWN, ".mp3"), + ], +) +def test_invoke_speech_to_text_labels_temp_file_by_format( + audio: bytes, + expected_suffix: str, +) -> None: + model = RecordingSpeech2TextModel() + executor = PluginExecutor(DifyPluginEnv(), _Registration(model)) + + result = executor.invoke_speech_to_text(Session.empty_session(), _request(audio)) + + assert result == {"result": "transcribed"} + assert model.captured_suffix == expected_suffix + # The full payload reaches the model, not just the sniffed header. + assert model.captured_bytes == audio + + +def test_invoke_speech_to_text_rejects_non_speech2text_model() -> None: + executor = PluginExecutor(DifyPluginEnv(), _Registration(object())) + + with pytest.raises(ValueError, match="not found for provider"): + executor.invoke_speech_to_text(Session.empty_session(), _request(WAV))