Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 32 additions & 2 deletions src/dify_plugin/core/plugin_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,34 @@
from dify_plugin.interfaces.tool import Tool


def _detect_audio_suffix(header: bytes) -> str:
"""Guess a temp-file suffix from the leading magic bytes of an audio blob.

The speech-to-text model-invoke payload carries only the raw audio bytes,
with no original filename or extension. OpenAI/Azure Whisper endpoints
determine the audio format from the multipart filename extension, so a
wrong extension makes them reject otherwise-supported formats. We sniff the
container from the header to label the temp file correctly.

Returns:
The matching suffix (including the leading dot). MP3 content and any
unrecognized header fall through to ``.mp3``, preserving the previous
hardcoded behavior for those cases.
"""
suffix = ".mp3"
if header[:4] == b"RIFF" and header[8:12] == b"WAVE":
suffix = ".wav"
elif header[:4] == b"fLaC":
suffix = ".flac"
elif header[:4] == b"OggS": # covers oga / ogg-opus
suffix = ".ogg"
elif header[4:8] == b"ftyp": # covers m4a / mp4 (AAC)
suffix = ".m4a"
elif header[:4] == b"\x1a\x45\xdf\xa3": # EBML (webm / matroska)
suffix = ".webm"
return suffix


class PluginExecutor: # noqa: PLR0904
def __init__(self, config: DifyPluginEnv, registration: PluginRegistration) -> None:
self.config = config
Expand Down Expand Up @@ -536,8 +564,10 @@ def invoke_speech_to_text(
data.model_type,
)

with tempfile.NamedTemporaryFile(suffix=".mp3", mode="wb", delete=True) as temp:
temp.write(binascii.unhexlify(data.file))
audio_bytes = binascii.unhexlify(data.file)
suffix = _detect_audio_suffix(audio_bytes[:16])
with tempfile.NamedTemporaryFile(suffix=suffix, mode="wb", delete=True) as temp:
temp.write(audio_bytes)
temp.flush()
Comment on lines +567 to 571

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

On Windows platforms, attempting to open a file created by tempfile.NamedTemporaryFile a second time (via pathlib.Path(temp.name).open("rb")) while the temporary file object is still open will raise a PermissionError.

To ensure cross-platform compatibility (especially for Windows users/developers), we can use a custom context manager wrapper that leverages tempfile.TemporaryDirectory to manage the lifecycle of the temporary file, writing and closing the file immediately so it can be safely opened again.

        class WinSafeTempFile:
            def __init__(self, suffix: str) -> None:
                self._dir = tempfile.TemporaryDirectory()
                self.name = str(pathlib.Path(self._dir.name) / f"temp{suffix}")
            def __enter__(self) -> "WinSafeTempFile":
                return self
            def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
                self._dir.cleanup()
            def write(self, b: bytes) -> None:
                pathlib.Path(self.name).write_bytes(b)
            def flush(self) -> None:
                pass

        audio_bytes = binascii.unhexlify(data.file)
        suffix = _detect_audio_suffix(audio_bytes[:16])
        with WinSafeTempFile(suffix=suffix) as temp:
            temp.write(audio_bytes)
            temp.flush()


with pathlib.Path(temp.name).open("rb") as f:
Expand Down
Empty file added tests/core/__init__.py
Empty file.
47 changes: 47 additions & 0 deletions tests/core/test_detect_audio_suffix.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
"""Tests for audio format detection used by speech-to-text dispatch.

These exercise ``_detect_audio_suffix`` with only the leading magic bytes of
each container, so no real audio files are required.
"""

import pytest

from dify_plugin.core.plugin_executor import _detect_audio_suffix # noqa: PLC2701


@pytest.mark.parametrize(
("header", "expected"),
[
# WAV: "RIFF" .... "WAVE"
(b"RIFF\x24\x08\x00\x00WAVEfmt ", ".wav"),
# FLAC
(b"fLaC\x00\x00\x00\x22", ".flac"),
# Ogg (covers oga / ogg-opus)
(b"OggS\x00\x02\x00\x00\x00\x00\x00\x00", ".ogg"),
# MP4/M4A (AAC): ftyp box at offset 4
(b"\x00\x00\x00\x20ftypM4A ", ".m4a"),
(b"\x00\x00\x00\x18ftypmp42", ".m4a"),
# WebM / Matroska EBML magic
(b"\x1a\x45\xdf\xa3\x01\x00\x00\x00", ".webm"),
],
)
def test_detect_audio_suffix_known_formats(header: bytes, expected: str) -> None:
assert _detect_audio_suffix(header) == expected


@pytest.mark.parametrize(
"header",
[
b"",
b"\x00\x00\x00\x00",
b"not an audio header",
# "RIFF" but not "WAVE" (e.g. AVI) must not be misdetected as wav
b"RIFF\x24\x08\x00\x00AVI LIST",
# MP3 (ID3 tag and raw frame sync) intentionally falls through to .mp3,
# matching the previous hardcoded behavior.
b"ID3\x04\x00\x00\x00\x00\x00\x00",
b"\xff\xf3\x90\x64\x00\x00\x00\x00",
],
)
def test_detect_audio_suffix_falls_back_to_mp3(header: bytes) -> None:
assert _detect_audio_suffix(header) == ".mp3"
123 changes: 123 additions & 0 deletions tests/core/test_invoke_speech_to_text.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
"""End-to-end tests for ``PluginExecutor.invoke_speech_to_text``.

These drive the real dispatch with a recording fake model to assert that the
temp file handed to the speech2text model is labeled with a suffix that matches
the audio container, and that the full payload (not just the sniffed header) is
written.
"""

import binascii
import pathlib
from collections.abc import Mapping
from typing import IO

import pytest

from dify_plugin.config.config import DifyPluginEnv
from dify_plugin.core.entities.plugin.request import ModelInvokeSpeech2TextRequest
from dify_plugin.core.plugin_executor import PluginExecutor
from dify_plugin.core.runtime import Session
from dify_plugin.entities import I18nObject
from dify_plugin.entities.model import AIModelEntity, FetchFrom, ModelType
from dify_plugin.errors.model import InvokeError
from dify_plugin.interfaces.model.speech2text_model import Speech2TextModel


def _model_entity() -> AIModelEntity:
return AIModelEntity(
model="whisper",
label=I18nObject(en_us="whisper"),
model_type=ModelType.SPEECH2TEXT,
fetch_from=FetchFrom.PREDEFINED_MODEL,
model_properties={},
parameter_rules=[],
)


class RecordingSpeech2TextModel(Speech2TextModel):
"""Captures the temp-file suffix and bytes the executor hands to the model."""

model_type = ModelType.SPEECH2TEXT

def __init__(self) -> None:
super().__init__(model_schemas=[_model_entity()])
self.captured_suffix: str | None = None
self.captured_bytes: bytes | None = None

def validate_credentials(self, model: str, credentials: Mapping) -> None:
del model, credentials

@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
return {}

def _invoke(
self,
model: str,
credentials: dict,
file: IO[bytes],
user: str | None = None,
) -> str:
del model, credentials, user
self.captured_suffix = pathlib.Path(file.name).suffix
self.captured_bytes = file.read()
return "transcribed"


class _Registration:
def __init__(self, model_instance: object) -> None:
self.model_instance = model_instance

def get_model_instance(self, provider: str, model_type: ModelType) -> object:
del provider, model_type
return self.model_instance


def _request(audio: bytes) -> ModelInvokeSpeech2TextRequest:
return ModelInvokeSpeech2TextRequest(
user_id="user-1",
provider="provider",
model_type=ModelType.SPEECH2TEXT,
model="whisper",
credentials={},
file=binascii.hexlify(audio).decode("ascii"),
)


# Real container headers padded with trailing bytes so the test also proves the
# entire payload is written, not just the sniffed 16-byte header.
WAV = b"RIFF\x24\x08\x00\x00WAVEfmt " + b"\x00" * 64
M4A = b"\x00\x00\x00\x20ftypM4A " + b"\x11" * 64
OGG = b"OggS\x00\x02\x00\x00\x00\x00\x00\x00" + b"\x22" * 64
UNKNOWN = b"\x00\x01\x02\x03 not a known audio container " + b"\x33" * 32


@pytest.mark.parametrize(
("audio", "expected_suffix"),
[
(WAV, ".wav"),
(M4A, ".m4a"),
(OGG, ".ogg"),
(UNKNOWN, ".mp3"),
],
)
def test_invoke_speech_to_text_labels_temp_file_by_format(
audio: bytes,
expected_suffix: str,
) -> None:
model = RecordingSpeech2TextModel()
executor = PluginExecutor(DifyPluginEnv(), _Registration(model))

result = executor.invoke_speech_to_text(Session.empty_session(), _request(audio))

assert result == {"result": "transcribed"}
assert model.captured_suffix == expected_suffix
# The full payload reaches the model, not just the sniffed header.
assert model.captured_bytes == audio


def test_invoke_speech_to_text_rejects_non_speech2text_model() -> None:
executor = PluginExecutor(DifyPluginEnv(), _Registration(object()))

with pytest.raises(ValueError, match="not found for provider"):
executor.invoke_speech_to_text(Session.empty_session(), _request(WAV))