diff --git a/src/dify_plugin/core/entities/plugin/request.py b/src/dify_plugin/core/entities/plugin/request.py index aa5d7149..09583883 100644 --- a/src/dify_plugin/core/entities/plugin/request.py +++ b/src/dify_plugin/core/entities/plugin/request.py @@ -206,8 +206,10 @@ class ModelGetLLMNumTokens(PluginAccessModelRequest, PromptMessageMixin): class ModelInvokeTextEmbeddingRequest(PluginAccessModelRequest): action: ModelActions = ModelActions.InvokeTextEmbedding + model_type: ModelType = ModelType.TEXT_EMBEDDING texts: list[str] + input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT class ModelInvokeMultimodalEmbeddingRequest(PluginAccessModelRequest): diff --git a/src/dify_plugin/core/plugin_executor.py b/src/dify_plugin/core/plugin_executor.py index 1aa8668a..0b135a14 100644 --- a/src/dify_plugin/core/plugin_executor.py +++ b/src/dify_plugin/core/plugin_executor.py @@ -295,7 +295,8 @@ def invoke_text_embedding( data.model, data.credentials, data.texts, - data.user_id, + user=data.user_id, + input_type=data.input_type, ) msg = f"Model `{data.model_type}` not found for provider `{data.provider}`" raise ValueError( diff --git a/tests/entities/plugin/test_model_request.py b/tests/entities/plugin/test_model_request.py new file mode 100644 index 00000000..de174026 --- /dev/null +++ b/tests/entities/plugin/test_model_request.py @@ -0,0 +1,33 @@ +from dify_plugin.core.entities.plugin.request import ModelActions, ModelInvokeTextEmbeddingRequest, PluginInvokeType +from dify_plugin.entities.model import EmbeddingInputType, ModelType + + +def test_text_embedding_request_accepts_input_type() -> None: + request = ModelInvokeTextEmbeddingRequest( + type=PluginInvokeType.Model, + action=ModelActions.InvokeTextEmbedding, + user_id="user-id", + provider="provider", + model_type=ModelType.TEXT_EMBEDDING, + model="embedding-model", + credentials={}, + texts=["query text"], + input_type=EmbeddingInputType.QUERY, + ) + + assert request.input_type == EmbeddingInputType.QUERY + + +def test_text_embedding_request_defaults_input_type_to_document() -> None: + request = ModelInvokeTextEmbeddingRequest( + type=PluginInvokeType.Model, + action=ModelActions.InvokeTextEmbedding, + user_id="user-id", + provider="provider", + model="embedding-model", + credentials={}, + texts=["document text"], + ) + + assert request.model_type == ModelType.TEXT_EMBEDDING + assert request.input_type == EmbeddingInputType.DOCUMENT