-
Notifications
You must be signed in to change notification settings - Fork 140
fix(model): detect audio format from magic bytes instead of hardcoding .mp3 #348
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
Kota-Maeda
wants to merge
1
commit into
langgenius:main
Choose a base branch
from
Kota-Maeda:fix/speech2text-audio-format-detection
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+202
−2
Open
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,47 @@ | ||
| """Tests for audio format detection used by speech-to-text dispatch. | ||
|
|
||
| These exercise ``_detect_audio_suffix`` with only the leading magic bytes of | ||
| each container, so no real audio files are required. | ||
| """ | ||
|
|
||
| import pytest | ||
|
|
||
| from dify_plugin.core.plugin_executor import _detect_audio_suffix # noqa: PLC2701 | ||
|
|
||
|
|
||
| @pytest.mark.parametrize( | ||
| ("header", "expected"), | ||
| [ | ||
| # WAV: "RIFF" .... "WAVE" | ||
| (b"RIFF\x24\x08\x00\x00WAVEfmt ", ".wav"), | ||
| # FLAC | ||
| (b"fLaC\x00\x00\x00\x22", ".flac"), | ||
| # Ogg (covers oga / ogg-opus) | ||
| (b"OggS\x00\x02\x00\x00\x00\x00\x00\x00", ".ogg"), | ||
| # MP4/M4A (AAC): ftyp box at offset 4 | ||
| (b"\x00\x00\x00\x20ftypM4A ", ".m4a"), | ||
| (b"\x00\x00\x00\x18ftypmp42", ".m4a"), | ||
| # WebM / Matroska EBML magic | ||
| (b"\x1a\x45\xdf\xa3\x01\x00\x00\x00", ".webm"), | ||
| ], | ||
| ) | ||
| def test_detect_audio_suffix_known_formats(header: bytes, expected: str) -> None: | ||
| assert _detect_audio_suffix(header) == expected | ||
|
|
||
|
|
||
| @pytest.mark.parametrize( | ||
| "header", | ||
| [ | ||
| b"", | ||
| b"\x00\x00\x00\x00", | ||
| b"not an audio header", | ||
| # "RIFF" but not "WAVE" (e.g. AVI) must not be misdetected as wav | ||
| b"RIFF\x24\x08\x00\x00AVI LIST", | ||
| # MP3 (ID3 tag and raw frame sync) intentionally falls through to .mp3, | ||
| # matching the previous hardcoded behavior. | ||
| b"ID3\x04\x00\x00\x00\x00\x00\x00", | ||
| b"\xff\xf3\x90\x64\x00\x00\x00\x00", | ||
| ], | ||
| ) | ||
| def test_detect_audio_suffix_falls_back_to_mp3(header: bytes) -> None: | ||
| assert _detect_audio_suffix(header) == ".mp3" |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,123 @@ | ||
| """End-to-end tests for ``PluginExecutor.invoke_speech_to_text``. | ||
|
|
||
| These drive the real dispatch with a recording fake model to assert that the | ||
| temp file handed to the speech2text model is labeled with a suffix that matches | ||
| the audio container, and that the full payload (not just the sniffed header) is | ||
| written. | ||
| """ | ||
|
|
||
| import binascii | ||
| import pathlib | ||
| from collections.abc import Mapping | ||
| from typing import IO | ||
|
|
||
| import pytest | ||
|
|
||
| from dify_plugin.config.config import DifyPluginEnv | ||
| from dify_plugin.core.entities.plugin.request import ModelInvokeSpeech2TextRequest | ||
| from dify_plugin.core.plugin_executor import PluginExecutor | ||
| from dify_plugin.core.runtime import Session | ||
| from dify_plugin.entities import I18nObject | ||
| from dify_plugin.entities.model import AIModelEntity, FetchFrom, ModelType | ||
| from dify_plugin.errors.model import InvokeError | ||
| from dify_plugin.interfaces.model.speech2text_model import Speech2TextModel | ||
|
|
||
|
|
||
| def _model_entity() -> AIModelEntity: | ||
| return AIModelEntity( | ||
| model="whisper", | ||
| label=I18nObject(en_us="whisper"), | ||
| model_type=ModelType.SPEECH2TEXT, | ||
| fetch_from=FetchFrom.PREDEFINED_MODEL, | ||
| model_properties={}, | ||
| parameter_rules=[], | ||
| ) | ||
|
|
||
|
|
||
| class RecordingSpeech2TextModel(Speech2TextModel): | ||
| """Captures the temp-file suffix and bytes the executor hands to the model.""" | ||
|
|
||
| model_type = ModelType.SPEECH2TEXT | ||
|
|
||
| def __init__(self) -> None: | ||
| super().__init__(model_schemas=[_model_entity()]) | ||
| self.captured_suffix: str | None = None | ||
| self.captured_bytes: bytes | None = None | ||
|
|
||
| def validate_credentials(self, model: str, credentials: Mapping) -> None: | ||
| del model, credentials | ||
|
|
||
| @property | ||
| def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: | ||
| return {} | ||
|
|
||
| def _invoke( | ||
| self, | ||
| model: str, | ||
| credentials: dict, | ||
| file: IO[bytes], | ||
| user: str | None = None, | ||
| ) -> str: | ||
| del model, credentials, user | ||
| self.captured_suffix = pathlib.Path(file.name).suffix | ||
| self.captured_bytes = file.read() | ||
| return "transcribed" | ||
|
|
||
|
|
||
| class _Registration: | ||
| def __init__(self, model_instance: object) -> None: | ||
| self.model_instance = model_instance | ||
|
|
||
| def get_model_instance(self, provider: str, model_type: ModelType) -> object: | ||
| del provider, model_type | ||
| return self.model_instance | ||
|
|
||
|
|
||
| def _request(audio: bytes) -> ModelInvokeSpeech2TextRequest: | ||
| return ModelInvokeSpeech2TextRequest( | ||
| user_id="user-1", | ||
| provider="provider", | ||
| model_type=ModelType.SPEECH2TEXT, | ||
| model="whisper", | ||
| credentials={}, | ||
| file=binascii.hexlify(audio).decode("ascii"), | ||
| ) | ||
|
|
||
|
|
||
| # Real container headers padded with trailing bytes so the test also proves the | ||
| # entire payload is written, not just the sniffed 16-byte header. | ||
| WAV = b"RIFF\x24\x08\x00\x00WAVEfmt " + b"\x00" * 64 | ||
| M4A = b"\x00\x00\x00\x20ftypM4A " + b"\x11" * 64 | ||
| OGG = b"OggS\x00\x02\x00\x00\x00\x00\x00\x00" + b"\x22" * 64 | ||
| UNKNOWN = b"\x00\x01\x02\x03 not a known audio container " + b"\x33" * 32 | ||
|
|
||
|
|
||
| @pytest.mark.parametrize( | ||
| ("audio", "expected_suffix"), | ||
| [ | ||
| (WAV, ".wav"), | ||
| (M4A, ".m4a"), | ||
| (OGG, ".ogg"), | ||
| (UNKNOWN, ".mp3"), | ||
| ], | ||
| ) | ||
| def test_invoke_speech_to_text_labels_temp_file_by_format( | ||
| audio: bytes, | ||
| expected_suffix: str, | ||
| ) -> None: | ||
| model = RecordingSpeech2TextModel() | ||
| executor = PluginExecutor(DifyPluginEnv(), _Registration(model)) | ||
|
|
||
| result = executor.invoke_speech_to_text(Session.empty_session(), _request(audio)) | ||
|
|
||
| assert result == {"result": "transcribed"} | ||
| assert model.captured_suffix == expected_suffix | ||
| # The full payload reaches the model, not just the sniffed header. | ||
| assert model.captured_bytes == audio | ||
|
|
||
|
|
||
| def test_invoke_speech_to_text_rejects_non_speech2text_model() -> None: | ||
| executor = PluginExecutor(DifyPluginEnv(), _Registration(object())) | ||
|
|
||
| with pytest.raises(ValueError, match="not found for provider"): | ||
| executor.invoke_speech_to_text(Session.empty_session(), _request(WAV)) |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
On Windows platforms, attempting to open a file created by
tempfile.NamedTemporaryFilea second time (viapathlib.Path(temp.name).open("rb")) while the temporary file object is still open will raise aPermissionError.To ensure cross-platform compatibility (especially for Windows users/developers), we can use a custom context manager wrapper that leverages
tempfile.TemporaryDirectoryto manage the lifecycle of the temporary file, writing and closing the file immediately so it can be safely opened again.