diff --git a/openhands-sdk/openhands/sdk/mcp/tool.py b/openhands-sdk/openhands/sdk/mcp/tool.py index 232cf10755..c9e9c98058 100644 --- a/openhands-sdk/openhands/sdk/mcp/tool.py +++ b/openhands-sdk/openhands/sdk/mcp/tool.py @@ -270,7 +270,11 @@ def action_from_arguments(self, arguments: dict[str, Any]) -> MCPToolAction: exclude_fields = set(DiscriminatedUnionMixin.model_fields.keys()) | set( DiscriminatedUnionMixin.model_computed_fields.keys() ) - sanitized = validated.model_dump(exclude_none=True, exclude=exclude_fields) + sanitized = validated.model_dump( + by_alias=True, + exclude_none=True, + exclude=exclude_fields, + ) return MCPToolAction(data=sanitized) @classmethod diff --git a/openhands-sdk/openhands/sdk/tool/schema.py b/openhands-sdk/openhands/sdk/tool/schema.py index 81e9256664..c3df99eef7 100644 --- a/openhands-sdk/openhands/sdk/tool/schema.py +++ b/openhands-sdk/openhands/sdk/tool/schema.py @@ -175,6 +175,12 @@ class Schema(DiscriminatedUnionMixin): model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid", frozen=True) + @classmethod + def _discriminator_field_names(cls) -> set[str]: + return set(DiscriminatedUnionMixin.model_fields.keys()) | set( + DiscriminatedUnionMixin.model_computed_fields.keys() + ) + @classmethod def to_mcp_schema(cls) -> dict[str, Any]: """Convert to JSON schema format compatible with MCP.""" @@ -185,11 +191,18 @@ def to_mcp_schema(cls) -> dict[str, Any]: # Remove discriminator fields from properties (not for LLM) # Need to exclude both regular fields and computed fields (like 'kind') - exclude_fields = set(DiscriminatedUnionMixin.model_fields.keys()) | set( - DiscriminatedUnionMixin.model_computed_fields.keys() - ) + exclude_fields = cls._discriminator_field_names() + explicit_mcp_aliases = { + field_info.alias + for field_name, field_info in cls.model_fields.items() + if field_info.alias and field_info.alias != field_name + } for f in exclude_fields: - if "properties" in result and f in result["properties"]: + if ( + "properties" in result + and f in result["properties"] + and f not in explicit_mcp_aliases + ): result["properties"].pop(f) # Also remove from required if present if "required" in result and f in result["required"]: @@ -213,12 +226,27 @@ def from_mcp_schema( required = set(schema.get("required", []) or []) fields: dict[str, tuple] = {} + discriminator_fields = cls._discriminator_field_names() + used_field_names = set(props.keys()) for fname, spec in props.items(): spec = spec if isinstance(spec, dict) else {} tp = py_type(spec) # Add description if present desc: str | None = spec.get("description") + field_name = fname + field_alias = None + if fname in discriminator_fields: + # MCP tool argument names are user-defined JSON object keys. If + # one collides with OpenHands' internal discriminator (e.g. + # "kind"), keep the external name as an alias and use a safe + # internal field name for Pydantic. + field_alias = fname + field_name = f"mcp_arg_{fname}" + suffix = 2 + while field_name in used_field_names or field_name in fields: + field_name = f"mcp_arg_{fname}_{suffix}" + suffix += 1 # Required → bare type, ellipsis sentinel # Optional → make nullable via `| None`, default None @@ -229,11 +257,15 @@ def from_mcp_schema( anno = tp | None # allow explicit null in addition to omission default = None - fields[fname] = ( + field_kwargs: dict[str, Any] = {} + if desc: + field_kwargs["description"] = desc + if field_alias: + field_kwargs["alias"] = field_alias + + fields[field_name] = ( anno, - Field(default=default, description=desc) - if desc - else Field(default=default), + Field(default=default, **field_kwargs), ) return create_model(model_name, __base__=cls, **fields) # type: ignore[return-value] diff --git a/openhands-sdk/openhands/sdk/utils/models.py b/openhands-sdk/openhands/sdk/utils/models.py index 5c134e0f63..320b9a6644 100644 --- a/openhands-sdk/openhands/sdk/utils/models.py +++ b/openhands-sdk/openhands/sdk/utils/models.py @@ -206,14 +206,21 @@ def _validate_subtype( ) -> Self: if isinstance(data, cls): return data - kind = data.pop("kind", None) if not _is_abstract(cls): + has_kind_alias_field = any( + field_name != "kind" and field_info.alias == "kind" + for field_name, field_info in cls.model_fields.items() + ) + if has_kind_alias_field: + return handler(data) + kind = data.pop("kind", None) # Sanity check: if we're validating a concrete class directly, # the kind (if provided) should match the class name. This should # always be true at this point since resolve_kind() would have # already routed to the correct subclass. assert kind is None or kind == cls.__name__ return handler(data) + kind = data.pop("kind", None) if kind is None: subclasses = _get_checked_concrete_subclasses(cls) if not subclasses: diff --git a/tests/sdk/mcp/test_mcp_tool_validation.py b/tests/sdk/mcp/test_mcp_tool_validation.py index 90a0c11e7b..51f9d6c9a8 100644 --- a/tests/sdk/mcp/test_mcp_tool_validation.py +++ b/tests/sdk/mcp/test_mcp_tool_validation.py @@ -8,9 +8,9 @@ from openhands.sdk.mcp.tool import MCPToolDefinition -def _make_tool_with_schema(schema: dict): +def _make_tool_with_schema(schema: dict, name: str = "fetch"): mcp_tool = mcp.types.Tool( - name="fetch", + name=name, description="Fetch a URL", inputSchema=schema, ) @@ -39,6 +39,44 @@ def test_mcp_action_from_arguments_validates_and_sanitizes(): assert action.data == {"url": "https://example.com"} +def test_mcp_action_from_arguments_preserves_schema_kind_argument(): + tool = _make_tool_with_schema( + { + "type": "object", + "properties": { + "name": {"type": "string"}, + "file_path": {"type": "string"}, + "kind": { + "type": "string", + "description": "Symbol kind hint, such as Function or Method", + }, + "repo": {"type": "string"}, + }, + "required": ["name"], + }, + name="gitnexus_context", + ) + + openai_schema = tool.to_openai_tool()["function"]["parameters"] + assert "kind" in openai_schema["properties"] + + action = tool.action_from_arguments( + { + "name": "executeCommand", + "file_path": "src/vs/workbench/services/commands/common/commandService.ts", + "kind": "Method", + "repo": "vscode-benchmark-repo", + } + ) + + assert action.data == { + "name": "executeCommand", + "file_path": "src/vs/workbench/services/commands/common/commandService.ts", + "kind": "Method", + "repo": "vscode-benchmark-repo", + } + + def test_mcp_action_from_arguments_raises_on_invalid(): tool = _make_tool_with_schema( {