diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml new file mode 100644 index 0000000..02a16fd --- /dev/null +++ b/.github/workflows/test.yml @@ -0,0 +1,50 @@ +name: Test + +on: + push: + branches: [ main, dev ] + pull_request: + branches: [ main, dev ] + +jobs: + test: + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [ubuntu-latest, windows-latest, macos-latest] + python-version: ["3.10", "3.11", "3.12", "3.13"] + + steps: + - uses: actions/checkout@v4 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + cache: 'pip' + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -e ".[dev]" + + - name: Run tests with pytest + run: | + pytest tests/ --doctest-modules --junitxml=junit/test-results.xml --cov=src/utcp --cov-report=xml --cov-report=html + + - name: Upload coverage reports to Codecov + uses: codecov/codecov-action@v3 + if: matrix.os == 'ubuntu-latest' && matrix.python-version == '3.11' + with: + file: ./coverage.xml + fail_ci_if_error: false + + - name: Upload test results + uses: actions/upload-artifact@v4 + if: always() + with: + name: test-results-${{ matrix.os }}-${{ matrix.python-version }} + path: | + junit/ + htmlcov/ + coverage.xml diff --git a/README.md b/README.md index 4d7d19d..ccd67af 100644 --- a/README.md +++ b/README.md @@ -196,13 +196,51 @@ UTCP supports several authentication methods to secure tool access. The `auth` o #### API Key (`ApiKeyAuth`) -Authentication using a static API key, typically sent in a request header. +Authentication using a static API key that can be sent in different locations. ```json { "auth_type": "api_key", "api_key": "YOUR_SECRET_API_KEY", - "var_name": "X-API-Key" + "var_name": "X-API-Key", + "location": "header" +} +``` + +**Key Fields:** +* `api_key`: Your secret API key +* `var_name`: The name of the parameter (header name, query parameter name, or cookie name) +* `location`: Where to send the API key - `"header"` (default), `"query"`, or `"cookie"` + +**Examples:** + +*Header-based API key (most common):* +```json +{ + "auth_type": "api_key", + "api_key": "sk-1234567890abcdef", + "var_name": "Authorization", + "location": "header" +} +``` + +*Query parameter-based API key:* +```json +{ + "auth_type": "api_key", + "api_key": "abc123def456", + "var_name": "api_key", + "location": "query" +} +``` + +*Cookie-based API key:* +```json +{ + "auth_type": "api_key", + "api_key": "session_token_xyz", + "var_name": "auth_token", + "location": "cookie" } ``` @@ -243,8 +281,8 @@ Providers are at the heart of UTCP's flexibility. They define the communication * `websocket`: WebSocket bidirectional connection (work in progress) * `grpc`: gRPC (Google Remote Procedure Call) (work in progress) * `graphql`: GraphQL query language (work in progress) -* `tcp`: Raw TCP socket (work in progress) -* `udp`: User Datagram Protocol (work in progress) +* `tcp`: Raw TCP socket +* `udp`: User Datagram Protocol * `webrtc`: Web Real-Time Communication (work in progress) * `mcp`: Model Context Protocol (for interoperability) * `text`: Local text file @@ -266,6 +304,11 @@ For connecting to standard RESTful APIs. "url": "https://api.example.com/utcp", "http_method": "POST", "content_type": "application/json", + "headers": { + "User-Agent": "MyApp/1.0" + }, + "body_field": "LLM_generated_param_to_be_sent_as_body", + "header_fields": ["LLM_generated_param_to_be_sent_as_header"], "auth": { "auth_type": "oauth2", "token_url": "https://api.example.com/oauth/token", @@ -275,6 +318,15 @@ For connecting to standard RESTful APIs. } ``` +**Key HttpProvider Fields:** +* `http_method`: HTTP method - `"GET"`, `"POST"`, `"PUT"`, `"DELETE"`, `"PATCH"` (default: `"GET"`) +* `url`: The endpoint URL (supports path parameters with `{param}` syntax) +* `content_type`: Content-Type header for request body (default: `"application/json"`) +* `headers`: Static headers to include in all requests +* `body_field`: Name of the input field to use as request body (default: `"body"`) +* `header_fields`: List of input fields to send as request headers +* `auth`: Authentication configuration + #### Automatic OpenAPI Conversion UTCP simplifies integration with existing web services by automatically converting OpenAPI v3 specifications into UTCP tools. Instead of pointing to a `UtcpManual`, the `url` for an `http` provider can point directly to an OpenAPI JSON specification. The `OpenApiConverter` handles this conversion automatically, making it seamless to integrate thousands of existing APIs. @@ -289,6 +341,53 @@ UTCP simplifies integration with existing web services by automatically converti When the client registers this provider, it will fetch the OpenAPI spec from the URL, convert all defined endpoints into UTCP `Tool` objects, and make them available for searching and calling. +#### URL Path Parameters + +HTTP-based providers (HTTP, SSE, HTTP Stream) support dynamic URL path parameters that can be substituted from tool arguments. This enables integration with RESTful APIs that use path-based resource identification. + +**URL Template Format:** +Path parameters are specified in the URL using curly braces: `{parameter_name}` + +**Example:** +```json +{ + "name": "openlibrary_api", + "provider_type": "http", + "url": "https://openlibrary.org/api/volumes/brief/{key_type}/{value}.json", + "http_method": "GET" +} +``` + +**How it works:** +1. When calling a tool, parameters matching the path parameter names are extracted from the tool arguments +2. These parameters are substituted into the URL template +3. The used parameters are removed from the arguments (so they don't become query parameters) +4. Any remaining arguments become query parameters + +**Example usage:** +```python +# Tool call arguments +arguments = { + "key_type": "isbn", + "value": "9780140328721", + "format": "json" +} + +# Results in URL: https://openlibrary.org/api/volumes/brief/isbn/9780140328721.json?format=json +``` + +**Multiple Path Parameters:** +URLs can contain multiple path parameters: +```json +{ + "url": "https://api.example.com/users/{user_id}/posts/{post_id}/comments/{comment_id}" +} +``` + +**Error Handling:** +- If a required path parameter is missing from the tool arguments, an error is raised +- All path parameters must be provided for the tool call to succeed + ### Server-Sent Events (SSE) Provider For tools that stream data using SSE. The `url` should point to the discovery endpoint. @@ -297,11 +396,33 @@ For tools that stream data using SSE. The `url` should point to the discovery en { "name": "live_updates_service", "provider_type": "sse", - "url": "https://api.example.com/utcp", - "event_type": "message" + "url": "https://api.example.com/stream", + "event_type": "message", + "reconnect": true, + "retry_timeout": 30000, + "headers": { + "Accept": "text/event-stream" + }, + "body_field": null, + "header_fields": ["LLM_generated_param_to_be_sent_as_header"], + "auth": { + "auth_type": "api_key", + "api_key": "your_api_key", + "var_name": "Authorization", + "location": "header" + } } ``` +**Key SSEProvider Fields:** +* `url`: The SSE endpoint URL (supports path parameters) +* `event_type`: Filter for specific SSE event types (optional) +* `reconnect`: Whether to automatically reconnect on disconnect (default: `true`) +* `retry_timeout`: Retry timeout in milliseconds (default: `30000`) +* `headers`: Static headers for the SSE connection +* `body_field`: Input field for connection request body (optional) +* `header_fields`: Input fields to send as headers for initial connection + ### HTTP Stream Provider For tools that use HTTP chunked transfer encoding to stream data. The `url` should point to the discovery endpoint. @@ -310,11 +431,34 @@ For tools that use HTTP chunked transfer encoding to stream data. The `url` shou { "name": "streaming_data_source", "provider_type": "http_stream", - "url": "https://api.example.com/utcp", - "http_method": "GET" + "url": "https://api.example.com/stream", + "http_method": "POST", + "content_type": "application/octet-stream", + "chunk_size": 4096, + "timeout": 60000, + "headers": { + "Accept": "application/octet-stream" + }, + "body_field": "data", + "header_fields": ["LLM_generated_param_to_be_sent_as_header"], + "auth": { + "auth_type": "basic", + "username": "your_username", + "password": "your_password" + } } ``` +**Key StreamableHttpProvider Fields:** +* `http_method`: HTTP method - `"GET"` or `"POST"` (default: `"GET"`) +* `url`: The streaming endpoint URL (supports path parameters) +* `content_type`: Content-Type for streaming data (default: `"application/octet-stream"`, also supports `"application/x-ndjson"`, `"application/json"`) +* `chunk_size`: Size of chunks in bytes (default: `4096`) +* `timeout`: Timeout in milliseconds (default: `60000`) +* `headers`: Static headers for the stream connection +* `body_field`: Input field for request body (optional) +* `header_fields`: Input fields to send as headers + ### CLI Provider For wrapping local command-line tools. @@ -323,13 +467,24 @@ For wrapping local command-line tools. { "name": "my_cli_tool", "provider_type": "cli", - "command_name": "my-command -utcp" + "command_name": "my-command --utcp", + "env_vars": { + "MY_API_KEY": "${API_KEY}", + "DEBUG": "1" + }, + "working_dir": "/path/to/working/directory" } ``` +**Key CliProvider Fields:** +* `command_name`: The command to execute (should support UTCP discovery) +* `env_vars`: Environment variables to set when executing (optional) +* `working_dir`: Working directory for command execution (optional) +* `auth`: Always `null` (CLI tools don't use UTCP auth) + ### WebSocket Provider (work in progress) -For tools that communicate over a WebSocket connection. Tool discovery may need to be handled via a separate HTTP endpoint. +For tools that communicate over a WebSocket connection. ```json { @@ -357,40 +512,150 @@ For connecting to gRPC services. ### GraphQL Provider (work in progress) -For interacting with GraphQL APIs. The `url` should point to the discovery endpoint. +For interacting with GraphQL APIs. ```json { "name": "my_graphql_api", "provider_type": "graphql", - "url": "https://api.example.com/utcp", - "operation_type": "query" + "url": "https://api.example.com/graphql", + "operation_type": "query", + "operation_name": "GetUserData", + "headers": { + "Content-Type": "application/json" + }, + "header_fields": ["LLM_generated_param_to_be_sent_as_header"], + "auth": { + "auth_type": "oauth2", + "token_url": "https://api.example.com/oauth/token", + "client_id": "graphql_client", + "client_secret": "secret_123" + } } ``` -### TCP Provider (work in progress) +**Key GraphQLProvider Fields:** +* `url`: The GraphQL endpoint URL +* `operation_type`: Type of GraphQL operation - `"query"`, `"mutation"`, `"subscription"` (default: `"query"`) +* `operation_name`: Name of the GraphQL operation (optional) +* `headers`: Static headers for GraphQL requests +* `header_fields`: Input fields to send as headers + +### TCP Provider -For raw TCP socket communication. +For TCP socket communication. Supports multiple framing strategies, JSON and text-based request formats, and configurable response handling. +**Basic Example:** ```json { - "name": "raw_tcp_service", + "name": "tcp_service", "provider_type": "tcp", "host": "localhost", - "port": 12345 + "port": 12345, + "timeout": 30000, + "request_data_format": "json", + "framing_strategy": "stream", + "response_byte_format": "utf-8" } ``` -### UDP Provider (work in progress) +**Key TCP Provider Fields:** + +* `host`: The hostname or IP address of the TCP server +* `port`: The TCP port number +* `timeout`: Timeout in milliseconds (default: 30000) +* `request_data_format`: Either `"json"` for structured data or `"text"` for template-based formatting (default: `"json"`) +* `request_data_template`: Template string for text format with `UTCP_ARG_argname_UTCP_ARG` placeholders +* `response_byte_format`: Encoding for response bytes - `"utf-8"`, `"ascii"`, etc., or `null` for raw bytes (default: `"utf-8"`) +* `framing_strategy`: Message framing strategy: `"stream"`, `"length_prefix"`, `"delimiter"`, or `"fixed_length"` (default: `"stream"`) +* `length_prefix_bytes`: For length-prefix framing: 1, 2, 4, or 8 bytes (default: 4) +* `length_prefix_endian`: For length-prefix framing: `"big"` or `"little"` (default: `"big"`) +* `message_delimiter`: For delimiter framing: delimiter string like `"\n"`, `"\r\n"`, `"\x00"` (default: `"\x00"`) +* `fixed_message_length`: For fixed-length framing: exact message length in bytes +* `max_response_size`: For stream framing: maximum bytes to read (default: 65536) + +**Length-Prefix Framing Example:** +```json +{ + "name": "binary_tcp_service", + "provider_type": "tcp", + "host": "192.168.1.50", + "port": 8080, + "framing_strategy": "length_prefix", + "length_prefix_bytes": 4, + "length_prefix_endian": "big", + "request_data_format": "json", + "response_byte_format": "utf-8" +} +``` -For UDP socket communication. +**Delimiter Framing Example:** +```json +{ + "name": "line_based_tcp_service", + "provider_type": "tcp", + "host": "tcp.example.com", + "port": 9999, + "framing_strategy": "delimiter", + "message_delimiter": "\n", + "request_data_format": "text", + "request_data_template": "GET UTCP_ARG_resource_UTCP_ARG", + "response_byte_format": "ascii" +} +``` + +**Fixed-Length Framing Example:** +```json +{ + "name": "fixed_protocol_service", + "provider_type": "tcp", + "host": "legacy.example.com", + "port": 7777, + "framing_strategy": "fixed_length", + "fixed_message_length": 1024, + "request_data_format": "text", + "response_byte_format": null +} +``` + +### UDP Provider + +For UDP socket communication. Supports both JSON and text-based request formats with configurable response handling. ```json { "name": "udp_telemetry_service", "provider_type": "udp", "host": "localhost", - "port": 54321 + "port": 54321, + "timeout": 30000, + "request_data_format": "json", + "number_of_response_datagrams": 1, + "response_byte_format": "utf-8" +} +``` + +**Key UDP Provider Fields:** + +* `host`: The hostname or IP address of the UDP server +* `port`: The UDP port number +* `timeout`: Timeout in milliseconds (default: 30000) +* `request_data_format`: Either `"json"` for structured data or `"text"` for template-based formatting (default: `"json"`) +* `request_data_template`: Template string for text format with `UTCP_ARG_argname_UTCP_ARG` placeholders +* `number_of_response_datagrams`: Number of UDP response packets to expect (default: 0 for no response) +* `response_byte_format`: Encoding for response bytes - `"utf-8"`, `"ascii"`, etc., or `null` for raw bytes (default: `"utf-8"`) + +**Text Format Example:** +```json +{ + "name": "legacy_udp_service", + "provider_type": "udp", + "host": "192.168.1.100", + "port": 9999, + "request_data_format": "text", + "request_data_template": "CMD:UTCP_ARG_command_UTCP_ARG;VALUE:UTCP_ARG_value_UTCP_ARG", + "number_of_response_datagrams": 2, + "response_byte_format": "ascii" } ``` @@ -411,9 +676,10 @@ For peer-to-peer communication using WebRTC. For interoperability with the Model Context Protocol (MCP). This provider can connect to MCP servers via `stdio` or `http`. +**HTTP MCP Server Example:** ```json { - "name": "my_mcp_service", + "name": "my_mcp_http_service", "provider_type": "mcp", "config": { "mcpServers": { @@ -432,40 +698,39 @@ For interoperability with the Model Context Protocol (MCP). This provider can co } ``` -### Text Provider - -For loading tool definitions from a local text file. This is useful for defining a collection of tools that may use various other providers. - -```json -{ - "name": "my_local_tools", - "signaling_server": "wss://signaling.example.com", - "peer_id": "unique-peer-id" -} -``` - -### MCP Provider - -For interoperability with Model Context Protocol (MCP) servers. - +**Stdio MCP Server Example:** ```json { - "name": "my_mcp_server", + "name": "my_mcp_stdio_service", "provider_type": "mcp", "config": { "mcpServers": { - "server_one": { + "local-server": { + "transport": "stdio", "command": "python", - "args": ["-m", "my_mcp_server.main"] + "args": ["-m", "my_mcp_server.main"], + "env": { + "API_KEY": "${MCP_API_KEY}", + "DEBUG": "1" + } } } } } ``` +**Key MCPProvider Fields:** +* `config`: MCP configuration object containing server definitions +* `config.mcpServers`: Dictionary of server name to server configuration +* `auth`: OAuth2 authentication (optional, only for HTTP servers) + +**MCP Server Types:** +* **HTTP**: `{"transport": "http", "url": "server_url"}` +* **Stdio**: `{"transport": "stdio", "command": "cmd", "args": [...], "env": {...}}` + ### Text Provider -For loading tool definitions from a local file. This is useful for defining a collection of tools from different providers in a single place. +For loading tool definitions from a local text file. This is useful for defining a collection of tools that may use various other providers. ```json { @@ -475,13 +740,39 @@ For loading tool definitions from a local file. This is useful for defining a co } ``` +**Key TextProvider Fields:** +* `file_path`: Path to the file containing tool definitions (required) +* `auth`: Always `null` (text files don't require authentication) + +**Use Cases:** +- Define tools that produce static output files +- Create tool collections that reference other providers +- Download manuals from a remote server to allow inspection of tools before calling them and guarantee security for high-risk environments + + + ### Authentication UTCP supports several authentication methods, which can be configured on a per-provider basis: -* **API Key**: `ApiKeyAuth` - Authentication using an API key sent in a header. -* **Basic Auth**: `BasicAuth` - Authentication using a username and password. -* **OAuth2**: `OAuth2Auth` - Authentication using the OAuth2 protocol. +* **API Key**: `ApiKeyAuth` - Authentication using an API key that can be sent in headers, query parameters, or cookies +* **Basic Auth**: `BasicAuth` - Authentication using a username and password +* **OAuth2**: `OAuth2Auth` - Authentication using the OAuth2 client credentials flow with automatic token management + +#### Enhanced Authentication Features + +**Flexible API Key Placement:** +- Headers (most common): `"location": "header"` +- Query parameters: `"location": "query"` +- Cookies: `"location": "cookie"` + +**OAuth2 Automatic Token Management:** +- Supports both body-based and header-based OAuth2 token requests +- Automatic token caching and reuse +- Fallback mechanisms for different OAuth2 server implementations + +**Comprehensive HTTP Transport Support:** +All HTTP-based transports (HTTP, SSE, HTTP Stream) support the full range of authentication methods with proper configuration handling during both tool discovery and tool execution. ## UTCP Client Architecture @@ -537,6 +828,27 @@ for tool in tools: print(tool.name, tool.description) ``` +## Testing + +The UTCP client includes comprehensive test suites for all transport implementations. Tests cover functionality, error handling, different configuration options, and edge cases. + +### Running Tests + +To run all tests: +```bash +python -m pytest +``` + +To run tests for a specific transport (e.g., TCP): +```bash +python -m pytest tests/client/transport_interfaces/test_tcp_transport.py -v +``` + +To run tests with coverage: +```bash +python -m pytest --cov=utcp tests/ +``` + ## Build 1. Create a virtual environment (e.g. `conda create --name utcp python=3.10`) and enable it (`conda activate utcp`) 2. Install required libraries (`pip install -r requirements.txt`) diff --git a/pyproject.toml b/pyproject.toml index f03f362..37532c4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "utcp" -version = "0.1.8" +version = "0.2.0" authors = [ { name = "Razvan-Ion Radulescu" }, { name = "Andrei-Stefan Ghiurtu" }, diff --git a/src/utcp/client/openapi_converter.py b/src/utcp/client/openapi_converter.py index b884864..e0e9df3 100644 --- a/src/utcp/client/openapi_converter.py +++ b/src/utcp/client/openapi_converter.py @@ -1,6 +1,7 @@ import json from typing import Any, Dict, List, Optional, Tuple import sys +import uuid from utcp.shared.tool import Tool, ToolInputOutputSchema from utcp.shared.utcp_manual import UtcpManual from urllib.parse import urlparse @@ -15,17 +16,37 @@ class OpenApiConverter: def __init__(self, openapi_spec: Dict[str, Any], spec_url: Optional[str] = None, provider_name: Optional[str] = None): self.spec = openapi_spec self.spec_url = spec_url + # Single counter for all placeholder variables + self.placeholder_counter = 0 # If provider_name is None then get the first word in spec.info.title if provider_name is None: - title = openapi_spec.get("info", {}).get("title", "openapi_provider") + title = openapi_spec.get("info", {}).get("title", "openapi_provider_" + uuid.uuid4().hex) # Replace characters that are invalid for identifiers invalid_chars = " -.,!?'\"\\/()[]{}#@$%^&*+=~`|;:<>" self.provider_name = ''.join('_' if c in invalid_chars else c for c in title) else: self.provider_name = provider_name + + def _increment_placeholder_counter(self) -> int: + """Increments the global counter and returns the new value. + + Returns: + The new counter value after incrementing + """ + self.placeholder_counter += 1 + return self.placeholder_counter + + def _get_placeholder(self, placeholder_name: str) -> str: + """Returns a placeholder string using the current counter value. + + Args: + placeholder_name: The name of the placeholder variable + """ + return f"${{{placeholder_name}_{self.placeholder_counter}}}" def convert(self) -> UtcpManual: """Parses the OpenAPI specification and returns a UtcpManual.""" + self.placeholder_counter = 0 tools = [] servers = self.spec.get("servers") if servers: @@ -123,17 +144,26 @@ def _create_auth_from_scheme(self, scheme: Dict[str, Any], scheme_name: str) -> # For API key auth, use the parameter name from the OpenAPI spec location = scheme.get("in", "header") # Default to header if not specified param_name = scheme.get("name", "Authorization") # Default name + # Use the current counter value for the placeholder + api_key_placeholder = self._get_placeholder("API_KEY") + # Increment the counter after using it + self._increment_placeholder_counter() return ApiKeyAuth( - api_key=f"${{{self.provider_name.upper()}_API_KEY}}", # Placeholder for environment variable + api_key=api_key_placeholder, var_name=param_name, location=location ) elif scheme_type == "basic": # OpenAPI 2.0 format: type: basic + # Use the current counter value for both placeholders + username_placeholder = self._get_placeholder("USERNAME") + password_placeholder = self._get_placeholder("PASSWORD") + # Increment the counter after using it + self._increment_placeholder_counter() return BasicAuth( - username=f"${{{self.provider_name.upper()}_USERNAME}}", - password=f"${{{self.provider_name.upper()}_PASSWORD}}" + username=username_placeholder, + password=password_placeholder ) elif scheme_type == "http": @@ -141,14 +171,23 @@ def _create_auth_from_scheme(self, scheme: Dict[str, Any], scheme_name: str) -> http_scheme = scheme.get("scheme", "").lower() if http_scheme == "basic": # For basic auth, use conventional environment variable names + # Use the current counter value for both placeholders + username_placeholder = self._get_placeholder("USERNAME") + password_placeholder = self._get_placeholder("PASSWORD") + # Increment the counter after using it + self._increment_placeholder_counter() return BasicAuth( - username=f"${{{self.provider_name.upper()}_USERNAME}}", - password=f"${{{self.provider_name.upper()}_PASSWORD}}" + username=username_placeholder, + password=password_placeholder ) elif http_scheme == "bearer": # Treat bearer tokens as API keys + # Use the current counter value for the placeholder + api_key_placeholder = self._get_placeholder("API_KEY") + # Increment the counter after using it + self._increment_placeholder_counter() return ApiKeyAuth( - api_key=f"Bearer ${{{self.provider_name.upper()}_API_KEY}}", + api_key=f"Bearer {api_key_placeholder}", var_name="Authorization", location="header" ) @@ -164,10 +203,15 @@ def _create_auth_from_scheme(self, scheme: Dict[str, Any], scheme_name: str) -> if flow_type in ["authorizationCode", "accessCode", "clientCredentials", "application"]: token_url = flow_config.get("tokenUrl") if token_url: + # Use the current counter value for both placeholders + client_id_placeholder = self._get_placeholder("CLIENT_ID") + client_secret_placeholder = self._get_placeholder("CLIENT_SECRET") + # Increment the counter after using it + self._increment_placeholder_counter() return OAuth2Auth( token_url=token_url, - client_id=f"${{{self.provider_name.upper()}_CLIENT_ID}}", - client_secret=f"${{{self.provider_name.upper()}_CLIENT_SECRET}}", + client_id=client_id_placeholder, + client_secret=client_secret_placeholder, scope=" ".join(flow_config.get("scopes", {}).keys()) or None ) @@ -176,10 +220,15 @@ def _create_auth_from_scheme(self, scheme: Dict[str, Any], scheme_name: str) -> flow_type = scheme.get("flow", "") token_url = scheme.get("tokenUrl") if token_url and flow_type in ["accessCode", "application", "clientCredentials"]: + # Use the current counter value for both placeholders + client_id_placeholder = self._get_placeholder("CLIENT_ID") + client_secret_placeholder = self._get_placeholder("CLIENT_SECRET") + # Increment the counter after using it + self._increment_placeholder_counter() return OAuth2Auth( token_url=token_url, - client_id=f"${{{self.provider_name.upper()}_CLIENT_ID}}", - client_secret=f"${{{self.provider_name.upper()}_CLIENT_SECRET}}", + client_id=client_id_placeholder, + client_secret=client_secret_placeholder, scope=" ".join(scheme.get("scopes", {}).keys()) or None ) @@ -198,7 +247,7 @@ def _create_tool(self, path: str, method: str, operation: Dict[str, Any], base_u outputs = self._extract_outputs(operation) auth = self._extract_auth(operation) - provider_name = self.spec.get("info", {}).get("title", "openapi_provider") + provider_name = self.spec.get("info", {}).get("title", "openapi_provider_" + uuid.uuid4().hex) # Combine base URL and path, ensuring no double slashes full_url = base_url.rstrip('/') + '/' + path.lstrip('/') diff --git a/src/utcp/client/transport_interfaces/tcp_transport.py b/src/utcp/client/transport_interfaces/tcp_transport.py new file mode 100644 index 0000000..216c3f4 --- /dev/null +++ b/src/utcp/client/transport_interfaces/tcp_transport.py @@ -0,0 +1,406 @@ +""" +Transmission Control Protocol (TCP) transport for UTCP client. + +This transport communicates with tools over TCP sockets. +""" +import asyncio +import json +import logging +import socket +import struct +from typing import Dict, Any, List, Optional, Callable, Union + +from utcp.client.client_transport_interface import ClientTransportInterface +from utcp.shared.provider import Provider, TCPProvider +from utcp.shared.tool import Tool + + +class TCPTransport(ClientTransportInterface): + """Transport implementation for TCP-based tool providers. + + This transport communicates with tools over TCP sockets. It supports: + - Tool discovery via TCP messages + - Tool execution by sending TCP packets with arguments + - Multiple framing strategies: length-prefix, delimiter, fixed-length, and stream + - JSON and text-based request formatting + - Template-based argument substitution + - Configurable response byte format (text encoding or raw bytes) + - Connection management for each request + """ + + def __init__(self, logger: Optional[Callable[[str], None]] = None): + """Initialize the TCP transport. + + Args: + logger: Optional logger function for debugging + """ + self._log = logger or (lambda *args, **kwargs: None) + + def _log_info(self, message: str): + """Log informational messages.""" + self._log(f"[TCPTransport] {message}") + + def _log_error(self, message: str): + """Log error messages.""" + logging.error(f"[TCPTransport Error] {message}") + + def _format_tool_call_message( + self, + arguments: Dict[str, Any], + provider: TCPProvider + ) -> str: + """Format a tool call message based on provider configuration. + + Args: + arguments: Arguments for the tool call + provider: The TCPProvider with formatting configuration + + Returns: + Formatted message string + """ + if provider.request_data_format == "json": + return json.dumps(arguments) + elif provider.request_data_format == "text": + # Use template-based formatting + if provider.request_data_template is not None and provider.request_data_template != "": + message = provider.request_data_template + # Replace placeholders with argument values + for arg_name, arg_value in arguments.items(): + placeholder = f"UTCP_ARG_{arg_name}_UTCP_ARG" + if isinstance(arg_value, str): + message = message.replace(placeholder, arg_value) + else: + message = message.replace(placeholder, json.dumps(arg_value)) + return message + else: + # Fallback to simple key=value format + return " ".join([str(v) for k, v in arguments.items()]) + else: + # Default to JSON format + return json.dumps(arguments) + + def _encode_message_with_framing(self, message: str, provider: TCPProvider) -> bytes: + """Encode message with appropriate TCP framing. + + Args: + message: Message to encode + provider: TCPProvider with framing configuration + + Returns: + Framed message bytes + """ + message_bytes = message.encode('utf-8') + + if provider.framing_strategy == "length_prefix": + # Add length prefix before the message + length = len(message_bytes) + if provider.length_prefix_bytes == 1: + length_bytes = struct.pack(f"{'>' if provider.length_prefix_endian == 'big' else '<'}B", length) + elif provider.length_prefix_bytes == 2: + length_bytes = struct.pack(f"{'>' if provider.length_prefix_endian == 'big' else '<'}H", length) + elif provider.length_prefix_bytes == 4: + length_bytes = struct.pack(f"{'>' if provider.length_prefix_endian == 'big' else '<'}I", length) + elif provider.length_prefix_bytes == 8: + length_bytes = struct.pack(f"{'>' if provider.length_prefix_endian == 'big' else '<'}Q", length) + else: + raise ValueError(f"Invalid length_prefix_bytes: {provider.length_prefix_bytes}") + return length_bytes + message_bytes + + elif provider.framing_strategy == "delimiter": + # Add delimiter after the message + delimiter = provider.message_delimiter or "\\x00" + # Handle escape sequences + delimiter = delimiter.encode('utf-8').decode('unicode_escape') + return message_bytes + delimiter.encode('utf-8') + + elif provider.framing_strategy in ("fixed_length", "stream"): + # No additional framing needed + return message_bytes + + else: + raise ValueError(f"Unknown framing strategy: {provider.framing_strategy}") + + def _decode_response_with_framing(self, sock: socket.socket, provider: TCPProvider, timeout: float) -> bytes: + """Decode response based on TCP framing strategy. + + Args: + sock: Connected TCP socket + provider: TCPProvider with framing configuration + timeout: Read timeout in seconds + + Returns: + Response message bytes + """ + sock.settimeout(timeout) + + if provider.framing_strategy == "length_prefix": + # Read length prefix first + length_bytes = sock.recv(provider.length_prefix_bytes) + if len(length_bytes) < provider.length_prefix_bytes: + raise Exception(f"Incomplete length prefix: got {len(length_bytes)} bytes, expected {provider.length_prefix_bytes}") + + # Unpack length + if provider.length_prefix_bytes == 1: + length = struct.unpack(f"{'>' if provider.length_prefix_endian == 'big' else '<'}B", length_bytes)[0] + elif provider.length_prefix_bytes == 2: + length = struct.unpack(f"{'>' if provider.length_prefix_endian == 'big' else '<'}H", length_bytes)[0] + elif provider.length_prefix_bytes == 4: + length = struct.unpack(f"{'>' if provider.length_prefix_endian == 'big' else '<'}I", length_bytes)[0] + elif provider.length_prefix_bytes == 8: + length = struct.unpack(f"{'>' if provider.length_prefix_endian == 'big' else '<'}Q", length_bytes)[0] + else: + raise ValueError(f"Invalid length_prefix_bytes: {provider.length_prefix_bytes}") + + # Read the message data + response_data = b"" + while len(response_data) < length: + chunk = sock.recv(length - len(response_data)) + if not chunk: + raise Exception("Connection closed while reading message") + response_data += chunk + + return response_data + + elif provider.framing_strategy == "delimiter": + # Read until delimiter is found + delimiter = provider.message_delimiter or "\\x00" + delimiter = delimiter.encode('utf-8').decode('unicode_escape').encode('utf-8') + + response_data = b"" + while True: + chunk = sock.recv(1) + if not chunk: + raise Exception("Connection closed while reading message") + response_data += chunk + + # Check if we've received the delimiter + if response_data.endswith(delimiter): + # Remove delimiter from response + return response_data[:-len(delimiter)] + + elif provider.framing_strategy == "fixed_length": + # Read exactly fixed_message_length bytes + if provider.fixed_message_length is None: + raise ValueError("fixed_message_length must be set for fixed_length framing") + + response_data = b"" + while len(response_data) < provider.fixed_message_length: + chunk = sock.recv(provider.fixed_message_length - len(response_data)) + if not chunk: + raise Exception("Connection closed while reading message") + response_data += chunk + + return response_data + + elif provider.framing_strategy == "stream": + # Read until connection closes or max_response_size is reached + response_data = b"" + while len(response_data) < provider.max_response_size: + try: + chunk = sock.recv(min(4096, provider.max_response_size - len(response_data))) + if not chunk: + # Connection closed + break + response_data += chunk + except socket.timeout: + # Timeout reached + break + + return response_data + + else: + raise ValueError(f"Unknown framing strategy: {provider.framing_strategy}") + + async def _send_tcp_message( + self, + host: str, + port: int, + message: str, + provider: TCPProvider, + timeout: float = 30.0, + response_encoding: Optional[str] = "utf-8" + ) -> Union[str, bytes]: + """Send a TCP message and wait for response. + + Args: + host: Host to connect to + port: Port to connect to + message: Message to send + provider: TCPProvider with framing configuration + timeout: Timeout in seconds + response_encoding: Encoding to decode response bytes. If None, returns raw bytes. + + Returns: + Response message or raw bytes if encoding is None + """ + loop = asyncio.get_event_loop() + + def _send_and_receive(): + """Blocking function to send TCP message and receive response.""" + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + try: + # Set connection timeout + sock.settimeout(timeout) + + # Connect to server + sock.connect((host, port)) + + # Encode message with framing + framed_message = self._encode_message_with_framing(message, provider) + + # Send message + sock.sendall(framed_message) + + # Receive response based on framing strategy + response_bytes = self._decode_response_with_framing(sock, provider, timeout) + + return response_bytes + + except socket.timeout: + raise Exception(f"TCP connection timeout after {timeout} seconds") + except Exception as e: + raise Exception(f"TCP communication error: {e}") + finally: + sock.close() + + try: + # Run blocking socket operations in executor + response_bytes = await loop.run_in_executor(None, _send_and_receive) + + # Return based on encoding preference + if response_encoding is None: + return response_bytes + else: + try: + return response_bytes.decode(response_encoding) + except UnicodeDecodeError as e: + self._log_error(f"Failed to decode response with encoding '{response_encoding}': {e}") + # Return raw bytes as fallback + return response_bytes + + except Exception as e: + self._log_error(f"Error in TCP communication: {e}") + raise + + async def register_tool_provider(self, manual_provider: Provider) -> List[Tool]: + """Register a TCP provider and discover its tools. + + Sends a discovery message to the TCP provider and parses the response. + + Args: + manual_provider: The TCPProvider to register + + Returns: + List of tools discovered from the TCP provider + + Raises: + ValueError: If provider is not a TCPProvider + """ + if not isinstance(manual_provider, TCPProvider): + raise ValueError("TCPTransport can only be used with TCPProvider") + + self._log_info(f"Registering TCP provider '{manual_provider.name}'") + + try: + # Send discovery message + discovery_message = json.dumps({ + "type": "utcp" + }) + + response = await self._send_tcp_message( + manual_provider.host, + manual_provider.port, + discovery_message, + manual_provider, + manual_provider.timeout / 1000.0, # Convert ms to seconds + manual_provider.response_byte_format + ) + + # Parse response + try: + # Handle bytes response by trying to decode as UTF-8 for JSON parsing + if isinstance(response, bytes): + response_str = response.decode('utf-8') + else: + response_str = response + + response_data = json.loads(response_str) + + # Check if response contains tools + if isinstance(response_data, dict) and 'tools' in response_data: + tools_data = response_data['tools'] + + # Parse tools + tools = [] + for tool_data in tools_data: + try: + tool = Tool(**tool_data) + tools.append(tool) + except Exception as e: + self._log_error(f"Invalid tool definition in TCP provider '{manual_provider.name}': {e}") + continue + + self._log_info(f"Discovered {len(tools)} tools from TCP provider '{manual_provider.name}'") + return tools + else: + self._log_info(f"No tools found in TCP provider '{manual_provider.name}' response") + return [] + + except json.JSONDecodeError as e: + self._log_error(f"Invalid JSON response from TCP provider '{manual_provider.name}': {e}") + return [] + + except Exception as e: + self._log_error(f"Error registering TCP provider '{manual_provider.name}': {e}") + return [] + + async def deregister_tool_provider(self, manual_provider: Provider) -> None: + """Deregister a TCP provider. + + This is a no-op for TCP providers since connections are created per request. + + Args: + manual_provider: The provider to deregister + """ + if not isinstance(manual_provider, TCPProvider): + raise ValueError("TCPTransport can only be used with TCPProvider") + + self._log_info(f"Deregistering TCP provider '{manual_provider.name}' (no-op)") + + async def call_tool(self, tool_name: str, arguments: Dict[str, Any], tool_provider: Provider) -> Any: + """Call a TCP tool. + + Sends a tool call message to the TCP provider and returns the response. + + Args: + tool_name: Name of the tool to call + arguments: Arguments for the tool call + tool_provider: The TCPProvider containing the tool + + Returns: + The response from the TCP tool + + Raises: + ValueError: If provider is not a TCPProvider + """ + if not isinstance(tool_provider, TCPProvider): + raise ValueError("TCPTransport can only be used with TCPProvider") + + self._log_info(f"Calling TCP tool '{tool_name}' on provider '{tool_provider.name}'") + + try: + tool_call_message = self._format_tool_call_message(arguments, tool_provider) + + response = await self._send_tcp_message( + tool_provider.host, + tool_provider.port, + tool_call_message, + tool_provider, + tool_provider.timeout / 1000.0, # Convert ms to seconds + tool_provider.response_byte_format + ) + return response + + except Exception as e: + self._log_error(f"Error calling TCP tool '{tool_name}': {e}") + raise diff --git a/src/utcp/client/transport_interfaces/udp_transport.py b/src/utcp/client/transport_interfaces/udp_transport.py new file mode 100644 index 0000000..16228e3 --- /dev/null +++ b/src/utcp/client/transport_interfaces/udp_transport.py @@ -0,0 +1,324 @@ +""" +User Datagram Protocol (UDP) transport for UTCP client. + +This transport communicates with tools over UDP sockets. +""" +import asyncio +import json +import logging +import socket +from typing import Dict, Any, List, Optional, Callable, Union + +from utcp.client.client_transport_interface import ClientTransportInterface +from utcp.shared.provider import Provider, UDPProvider +from utcp.shared.tool import Tool + + +class UDPTransport(ClientTransportInterface): + """Transport implementation for UDP-based tool providers. + + This transport communicates with tools over UDP sockets. It supports: + - Tool discovery via UDP messages + - Tool execution by sending UDP packets with arguments + - Multiple response datagrams handling + - JSON and text-based request formatting + - Template-based argument substitution + - Configurable response byte format (text encoding or raw bytes) + - Stateless operation (no persistent connections) + """ + + def __init__(self, logger: Optional[Callable[[str], None]] = None): + """Initialize the UDP transport. + + Args: + logger: Optional logger function for debugging + """ + self._log = logger or (lambda *args, **kwargs: None) + # UDP is stateless, so no connections to manage + + def _log_info(self, message: str): + """Log informational messages.""" + self._log(f"[UDPTransport] {message}") + + def _log_error(self, message: str): + """Log error messages.""" + logging.error(f"[UDPTransport Error] {message}") + + def _format_tool_call_message( + self, + arguments: Dict[str, Any], + provider: UDPProvider + ) -> str: + """Format a tool call message based on provider configuration. + + Args: + arguments: Arguments for the tool call + provider: The UDPProvider with formatting configuration + + Returns: + Formatted message string + """ + if provider.request_data_format == "json": + return json.dumps(arguments) + elif provider.request_data_format == "text": + # Use template-based formatting + if provider.request_data_template is not None and provider.request_data_template != "": + message = provider.request_data_template + # Replace placeholders with argument values + for arg_name, arg_value in arguments.items(): + placeholder = f"UTCP_ARG_{arg_name}_UTCP_ARG" + if isinstance(arg_value, str): + message = message.replace(placeholder, arg_value) + else: + message = message.replace(placeholder, json.dumps(arg_value)) + return message + else: + # Fallback to simple key=value format + return " ".join([str(v) for k, v in arguments.items()]) + else: + # Default to JSON format + return json.dumps(arguments) + + async def _send_udp_message( + self, + host: str, + port: int, + message: str, + timeout: float = 30.0, + num_response_datagrams: int = 1, + response_encoding: Optional[str] = "utf-8" + ) -> Union[str, bytes]: + """Send a UDP message and wait for response(s). + + Args: + host: Host to send message to + port: Port to send message to + message: Message to send + timeout: Timeout in seconds + num_response_datagrams: Number of response datagrams to receive + response_encoding: Encoding to decode response bytes. If None, returns raw bytes. + + Returns: + Response message (concatenated if multiple datagrams) or raw bytes if encoding is None + """ + if num_response_datagrams == 0: + # No response expected - just send and return + await self._send_udp_no_response(host, port, message) + return b"" if response_encoding is None else "" + + # Use simple socket approach with executor for Windows compatibility + loop = asyncio.get_event_loop() + + def _send_and_receive(): + """Blocking function to send UDP message and receive responses.""" + sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + try: + # Resolve host to IP for comparison + try: + resolved_host_ip = socket.gethostbyname(host) + except socket.gaierror: + resolved_host_ip = host # Fallback to original if resolution fails + + # Send message + message_bytes = message.encode('utf-8') + sock.sendto(message_bytes, (host, port)) + + # Collect responses + response_bytes_list = [] + + for i in range(max(1, num_response_datagrams)): + try: + # Use shorter timeout for subsequent datagrams + current_timeout = timeout if i == 0 else 1.0 + + # Set socket timeout + sock.settimeout(current_timeout) + + # Receive response + data, addr = sock.recvfrom(65535) + + # Verify it's from the expected host (compare with resolved IP) + if addr[0] == host or addr[0] == resolved_host_ip: + response_bytes_list.append(data) + else: + # Got response from wrong host, don't count it + continue + + except socket.timeout: + if i == 0: + # First datagram timed out + raise TimeoutError(f"UDP request timed out after {timeout} seconds") + else: + # Subsequent datagrams timed out, but we have some data + break + + return response_bytes_list + + finally: + sock.close() + + try: + # Run blocking socket operations in executor + response_bytes_list = await loop.run_in_executor(None, _send_and_receive) + + # Concatenate response bytes + combined_bytes = b''.join(response_bytes_list) + + # Return based on encoding preference + if response_encoding is None: + return combined_bytes + else: + try: + return combined_bytes.decode(response_encoding) + except UnicodeDecodeError as e: + self._log_error(f"Failed to decode response with encoding '{response_encoding}': {e}") + # Return raw bytes as fallback + return combined_bytes + + except TimeoutError as e: + self._log_error(str(e)) + raise asyncio.TimeoutError(str(e)) + except Exception as e: + self._log_error(f"Error sending UDP message: {e}") + raise + + async def _send_udp_no_response(self, host: str, port: int, message: str) -> None: + """Send a UDP message without expecting a response.""" + def _send_only(): + """Blocking function to send UDP message only.""" + sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + try: + message_bytes = message.encode('utf-8') + sock.sendto(message_bytes, (host, port)) + finally: + sock.close() + + try: + loop = asyncio.get_event_loop() + await loop.run_in_executor(None, _send_only) + except Exception as e: + self._log_error(f"Error sending UDP message (no response): {e}") + raise + + async def register_tool_provider(self, manual_provider: Provider) -> List[Tool]: + """Register a UDP provider and discover its tools. + + Sends a discovery message to the UDP provider and parses the response. + + Args: + manual_provider: The UDPProvider to register + + Returns: + List of tools discovered from the UDP provider + + Raises: + ValueError: If provider is not a UDPProvider + """ + if not isinstance(manual_provider, UDPProvider): + raise ValueError("UDPTransport can only be used with UDPProvider") + + self._log_info(f"Registering UDP provider '{manual_provider.name}' at {manual_provider.host}:{manual_provider.port}") + + try: + # Send discovery message + discovery_message = json.dumps({ + "type": "utcp" + }) + + response = await self._send_udp_message( + manual_provider.host, + manual_provider.port, + discovery_message, + manual_provider.timeout / 1000.0, # Convert ms to seconds + manual_provider.number_of_response_datagrams, + manual_provider.response_byte_format + ) + + # Parse response + try: + # Handle bytes response by trying to decode as UTF-8 for JSON parsing + if isinstance(response, bytes): + response_str = response.decode('utf-8') + else: + response_str = response + + response_data = json.loads(response_str) + + # Check if response contains tools + if isinstance(response_data, dict) and 'tools' in response_data: + tools_data = response_data['tools'] + + # Parse tools + tools = [] + for tool_data in tools_data: + try: + tool = Tool(**tool_data) + tools.append(tool) + except Exception as e: + self._log_error(f"Invalid tool definition in UDP provider '{manual_provider.name}': {e}") + continue + + self._log_info(f"Discovered {len(tools)} tools from UDP provider '{manual_provider.name}'") + return tools + else: + self._log_info(f"No tools found in UDP provider '{manual_provider.name}' response") + return [] + + except json.JSONDecodeError as e: + self._log_error(f"Invalid JSON response from UDP provider '{manual_provider.name}': {e}") + return [] + + except Exception as e: + self._log_error(f"Error registering UDP provider '{manual_provider.name}': {e}") + return [] + + async def deregister_tool_provider(self, manual_provider: Provider) -> None: + """Deregister a UDP provider. + + This is a no-op for UDP providers since they are stateless. + + Args: + manual_provider: The provider to deregister + """ + if not isinstance(manual_provider, UDPProvider): + raise ValueError("UDPTransport can only be used with UDPProvider") + + self._log_info(f"Deregistering UDP provider '{manual_provider.name}' (no-op)") + + async def call_tool(self, tool_name: str, arguments: Dict[str, Any], tool_provider: Provider) -> Any: + """Call a UDP tool. + + Sends a tool call message to the UDP provider and returns the response. + + Args: + tool_name: Name of the tool to call + arguments: Arguments for the tool call + tool_provider: The UDPProvider containing the tool + + Returns: + The response from the UDP tool + + Raises: + ValueError: If provider is not a UDPProvider + """ + if not isinstance(tool_provider, UDPProvider): + raise ValueError("UDPTransport can only be used with UDPProvider") + + self._log_info(f"Calling UDP tool '{tool_name}' on provider '{tool_provider.name}'") + + try: + tool_call_message = self._format_tool_call_message(arguments, tool_provider) + + response = await self._send_udp_message( + tool_provider.host, + tool_provider.port, + tool_call_message, + tool_provider.timeout / 1000.0, # Convert ms to seconds + tool_provider.number_of_response_datagrams, + tool_provider.response_byte_format + ) + return response + + except Exception as e: + self._log_error(f"Error calling UDP tool '{tool_name}': {e}") + raise diff --git a/src/utcp/client/utcp_client.py b/src/utcp/client/utcp_client.py index 062fce3..c3e238e 100644 --- a/src/utcp/client/utcp_client.py +++ b/src/utcp/client/utcp_client.py @@ -14,6 +14,8 @@ from utcp.client.transport_interfaces.mcp_transport import MCPTransport from utcp.client.transport_interfaces.text_transport import TextTransport from utcp.client.transport_interfaces.graphql_transport import GraphQLClientTransport +from utcp.client.transport_interfaces.tcp_transport import TCPTransport +from utcp.client.transport_interfaces.udp_transport import UDPTransport from utcp.client.utcp_client_config import UtcpClientConfig, UtcpVariableNotFound from utcp.client.tool_repository import ToolRepository from utcp.client.tool_repositories.in_mem_tool_repository import InMemToolRepository @@ -22,6 +24,7 @@ from utcp.shared.provider import Provider, HttpProvider, CliProvider, SSEProvider, \ StreamableHttpProvider, WebSocketProvider, GRPCProvider, GraphQLProvider, \ TCPProvider, UDPProvider, WebRTCProvider, MCPProvider, TextProvider +from utcp.client.variable_substitutor import DefaultVariableSubstitutor, VariableSubstitutor class UtcpClientInterface(ABC): """ @@ -78,6 +81,32 @@ def search_tools(self, query: str, limit: int = 10) -> List[Tool]: """ pass + @abstractmethod + def get_required_variables_for_manual_and_tools(self, manual_provider: Provider) -> List[str]: + """ + Get the required variables for a manual provider and its tools. + + Args: + manual_provider: The manual provider. + + Returns: + A list of required variables for the manual provider and its tools. + """ + pass + + @abstractmethod + def get_required_variables_for_tool(self, tool_name: str) -> List[str]: + """ + Get the required variables for a registered tool. + + Args: + tool_name: The name of a registered tool. + + Returns: + A list of required variables for the tool. + """ + pass + class UtcpClient(UtcpClientInterface): transports: Dict[str, ClientTransportInterface] = { "http": HttpClientTransport(), @@ -87,15 +116,18 @@ class UtcpClient(UtcpClientInterface): "mcp": MCPTransport(), "text": TextTransport(), "graphql": GraphQLClientTransport(), + "tcp": TCPTransport(), + "udp": UDPTransport(), } - def __init__(self, config: UtcpClientConfig, tool_repository: ToolRepository, search_strategy: ToolSearchStrategy): + def __init__(self, config: UtcpClientConfig, tool_repository: ToolRepository, search_strategy: ToolSearchStrategy, variable_substitutor: VariableSubstitutor): """ Use 'create' class method to create a new instance instead, as it supports loading UtcpClientConfig. """ self.tool_repository = tool_repository self.search_strategy = search_strategy self.config = config + self.variable_substitutor = variable_substitutor @classmethod async def create(cls, config: Optional[Union[Dict[str, Any], UtcpClientConfig]] = None, tool_repository: Optional[ToolRepository] = None, search_strategy: Optional[ToolSearchStrategy] = None) -> 'UtcpClient': @@ -119,7 +151,7 @@ async def create(cls, config: Optional[Union[Dict[str, Any], UtcpClientConfig]] elif isinstance(config, dict): config = UtcpClientConfig.model_validate(config) - client = cls(config, tool_repository, search_strategy) + client = cls(config, tool_repository, search_strategy, DefaultVariableSubstitutor()) # If a providers file is used, configure TextTransport to resolve relative paths from its directory if config.providers_file_path: @@ -129,15 +161,9 @@ async def create(cls, config: Optional[Union[Dict[str, Any], UtcpClientConfig]] if client.config.variables: config_without_vars = client.config.model_copy() config_without_vars.variables = None - client.config.variables = client._replace_vars_in_obj(client.config.variables, config_without_vars) + client.config.variables = client.variable_substitutor.substitute(client.config.variables, config_without_vars) await client.load_providers(config.providers_file_path) - # for provider in providers: - # print(f"Registering provider '{provider.name}' with {len(provider.tools)} tools") - # try: - # await client.register_tool_provider(provider) - # except Exception as e: - # print(f"Error registering provider '{provider.name}': {str(e)}") return client @@ -203,7 +229,6 @@ async def register_single_provider(provider_data=provider_data): provider = provider_class.model_validate(provider_data) # Apply variable substitution and register provider - provider = self._substitute_provider_variables(provider) tools = await self.register_tool_provider(provider) print(f"Successfully registered provider '{provider.name}' with {len(tools)} tools") return provider @@ -221,45 +246,61 @@ async def register_single_provider(provider_data=provider_data): return registered_providers - def _get_variable(self, key: str, config: UtcpClientConfig) -> str: - if config.variables and key in config.variables: - return config.variables[key] - if config.load_variables_from: - for var_loader in config.load_variables_from: - var = var_loader.get(key) - if var: - return var - try: - env_var = os.environ.get(key) - if env_var: - return env_var - except Exception: - pass - - raise UtcpVariableNotFound(key) - - def _replace_vars_in_obj(self, obj: Any, config: UtcpClientConfig) -> Any: - if isinstance(obj, dict): - return {k: self._replace_vars_in_obj(v, config) for k, v in obj.items()} - elif isinstance(obj, list): - return [self._replace_vars_in_obj(elem, config) for elem in obj] - elif isinstance(obj, str): - # Use a regular expression to find all variables in the string, supporting ${VAR} and $VAR formats - def replacer(match): - # The first group that is not None is the one that matched - var_name = next(g for g in match.groups() if g is not None) - return self._get_variable(var_name, config) - - return re.sub(r'\${(\w+)}|\$(\w+)', replacer, obj) - else: - return obj - - def _substitute_provider_variables(self, provider: Provider) -> Provider: + def _substitute_provider_variables(self, provider: Provider, provider_name: Optional[str] = None) -> Provider: provider_dict = provider.model_dump() - processed_dict = self._replace_vars_in_obj(provider_dict, self.config) + processed_dict = self.variable_substitutor.substitute(provider_dict, self.config, provider_name) return provider.__class__(**processed_dict) + async def get_required_variables_for_manual_and_tools(self, manual_provider: Provider) -> List[str]: + """ + Get the required variables for a manual provider and its tools. + + Args: + manual_provider: The provider to validate. + + Returns: + A list of required variables for the provider. + + Raises: + ValueError: If the provider type is not supported. + UtcpVariableNotFound: If a variable is not found in the environment or in the configuration. + """ + manual_provider.name = re.sub(r'[^\w]', '_', manual_provider.name) + variables_for_provider = self.variable_substitutor.find_required_variables(manual_provider.model_dump(), manual_provider.name) + if len(variables_for_provider) > 0: + try: + manual_provider = self._substitute_provider_variables(manual_provider, manual_provider.name) + except UtcpVariableNotFound as e: + return variables_for_provider + return variables_for_provider + if manual_provider.provider_type not in self.transports: + raise ValueError(f"Provider type not supported: {manual_provider.provider_type}") + tools: List[Tool] = await self.transports[manual_provider.provider_type].register_tool_provider(manual_provider) + for tool in tools: + variables_for_provider.extend(self.variable_substitutor.find_required_variables(tool.tool_provider.model_dump(), manual_provider.name)) + return variables_for_provider + + async def get_required_variables_for_tool(self, tool_name: str) -> List[str]: + """ + Get the required variables for a tool. + + Args: + tool_name: The name of the tool to validate. + + Returns: + A list of required variables for the tool. + + Raises: + ValueError: If the provider type is not supported. + UtcpVariableNotFound: If a variable is not found in the environment or in the configuration. + """ + provider_name = tool_name.split(".")[0] + tool = await self.tool_repository.get_tool(tool_name) + if tool is None: + raise ValueError(f"Tool not found: {tool_name}") + return self.variable_substitutor.find_required_variables(tool.tool_provider.model_dump(), provider_name) + async def register_tool_provider(self, manual_provider: Provider) -> List[Tool]: """ Register a tool provider. @@ -274,8 +315,11 @@ async def register_tool_provider(self, manual_provider: Provider) -> List[Tool]: ValueError: If the provider type is not supported. UtcpVariableNotFound: If a variable is not found in the environment or in the configuration. """ - manual_provider = self._substitute_provider_variables(manual_provider) - manual_provider.name = manual_provider.name.replace(".", "_") + # Replace all non-word characters with underscore + manual_provider.name = re.sub(r'[^\w]', '_', manual_provider.name) + if await self.tool_repository.get_provider(manual_provider.name) is not None: + raise ValueError(f"Provider {manual_provider.name} already registered, please use a different name or deregister the existing provider") + manual_provider = self._substitute_provider_variables(manual_provider, manual_provider.name) if manual_provider.provider_type not in self.transports: raise ValueError(f"Provider type not supported: {manual_provider.provider_type}") tools: List[Tool] = await self.transports[manual_provider.provider_type].register_tool_provider(manual_provider) @@ -316,20 +360,19 @@ async def call_tool(self, tool_name: str, arguments: Dict[str, Any]) -> Any: ValueError: If the tool is not found. UtcpVariableNotFound: If a variable is not found in the environment or in the configuration. """ - provider_name = tool_name.split(".")[0] - provider = await self.tool_repository.get_provider(provider_name) - if provider is None: - raise ValueError(f"Provider not found: {provider_name}") - tools = await self.tool_repository.get_tools_by_provider(provider_name) - tool = next((t for t in tools if t.name == tool_name), None) + manual_provider_name = tool_name.split(".")[0] + manual_provider = await self.tool_repository.get_provider(manual_provider_name) + if manual_provider is None: + raise ValueError(f"Provider not found: {manual_provider_name}") + tool = await self.tool_repository.get_tool(tool_name) if tool is None: raise ValueError(f"Tool not found: {tool_name}") tool_provider = tool.tool_provider - tool_provider = self._substitute_provider_variables(tool_provider) + tool_provider = self._substitute_provider_variables(tool_provider, manual_provider_name) return await self.transports[tool_provider.provider_type].call_tool(tool_name, arguments, tool_provider) async def search_tools(self, query: str, limit: int = 10) -> List[Tool]: - return await self.search_strategy.search_tools(query, limit) + return await self.search_strategy.search_tools(query, limit) diff --git a/src/utcp/client/variable_substitutor.py b/src/utcp/client/variable_substitutor.py new file mode 100644 index 0000000..2fc04b6 --- /dev/null +++ b/src/utcp/client/variable_substitutor.py @@ -0,0 +1,80 @@ +from abc import ABC, abstractmethod +from utcp.client.utcp_client_config import UtcpClientConfig +from typing import Any +import os +import re +from utcp.client.utcp_client_config import UtcpVariableNotFound +from typing import List, Optional + +class VariableSubstitutor(ABC): + @abstractmethod + def substitute(self, obj: Any, config: UtcpClientConfig, provider_name: Optional[str] = None) -> Any: + pass + + @abstractmethod + def find_required_variables(self, obj: dict | list | str, provider_name: str) -> List[str]: + pass + +class DefaultVariableSubstitutor(VariableSubstitutor): + def _get_variable(self, key: str, config: UtcpClientConfig, provider_name: Optional[str] = None) -> str: + if provider_name: + key = provider_name.replace("_", "!").replace("!", "__") + "_" + key + if config.variables and key in config.variables: + return config.variables[key] + if config.load_variables_from: + for var_loader in config.load_variables_from: + var = var_loader.get(key) + if var: + return var + try: + env_var = os.environ.get(key) + if env_var: + return env_var + except Exception: + pass + + raise UtcpVariableNotFound(key) + + def substitute(self, obj: dict | list | str, config: UtcpClientConfig, provider_name: Optional[str] = None) -> Any: + if isinstance(obj, dict): + return {k: self.substitute(v, config, provider_name) for k, v in obj.items()} + elif isinstance(obj, list): + return [self.substitute(elem, config, provider_name) for elem in obj] + elif isinstance(obj, str): + # Use a regular expression to find all variables in the string, supporting ${VAR} and $VAR formats + def replacer(match): + # The first group that is not None is the one that matched + var_name = next((g for g in match.groups() if g is not None), "") + return self._get_variable(var_name, config, provider_name) + + return re.sub(r'\${(\w+)}|\$(\w+)', replacer, obj) + else: + return obj + + def find_required_variables(self, obj: dict | list | str, provider_name: str) -> List[str]: + if isinstance(obj, dict): + result = [] + for v in obj.values(): + vars = self.find_required_variables(v, provider_name) + result.extend(vars) + return result + elif isinstance(obj, list): + result = [] + for elem in obj: + vars = self.find_required_variables(elem, provider_name) + result.extend(vars) + return result + elif isinstance(obj, str): + # Find all variables in the string, supporting ${VAR} and $VAR formats + variables = [] + pattern = r'\${(\w+)}|\$(\w+)' + + for match in re.finditer(pattern, obj): + # The first group that is not None is the one that matched + var_name = next(g for g in match.groups() if g is not None) + full_var_name = provider_name.replace("_", "!").replace("!", "__") + "_" + var_name + variables.append(full_var_name) + + return variables + else: + return [] diff --git a/src/utcp/shared/provider.py b/src/utcp/shared/provider.py index f7cf494..2a9b339 100644 --- a/src/utcp/shared/provider.py +++ b/src/utcp/shared/provider.py @@ -110,20 +110,82 @@ class GraphQLProvider(Provider): header_fields: Optional[List[str]] = Field(default=None, description="List of input fields to be sent as request headers for the initial connection.") class TCPProvider(Provider): - """Options specific to raw TCP socket tools""" + """Options specific to raw TCP socket tools + + For request data handling: + - If request_data_format is 'json', arguments will be formatted as a JSON object and sent + - If request_data_format is 'text', the request_data_template can contain placeholders + in the format UTCP_ARG_argname_UTCP_ARG which will be replaced with the value of + the argument named 'argname' + For response data handling: + - If response_byte_format is None, raw bytes will be returned + - If response_byte_format is an encoding (e.g., 'utf-8'), bytes will be decoded to text + For TCP stream framing (choose one): + 1. Length-prefix framing: Set framing_strategy='length_prefix' and length_prefix_bytes + 2. Delimiter-based framing: Set framing_strategy='delimiter' and message_delimiter + 3. Fixed-length framing: Set framing_strategy='fixed_length' and fixed_message_length + 4. Stream-based: Set framing_strategy='stream' to read until connection closes + """ provider_type: Literal["tcp"] = "tcp" host: str port: int + request_data_format: Literal["json", "text"] = "json" + request_data_template: Optional[str] = None + response_byte_format: Optional[str] = Field(default="utf-8", description="Encoding to decode response bytes. If None, returns raw bytes.") + # TCP Framing Strategy + framing_strategy: Literal["length_prefix", "delimiter", "fixed_length", "stream"] = Field( + default="stream", + description="Strategy for framing TCP messages" + ) + # Length-prefix framing options + length_prefix_bytes: Literal[1, 2, 4, 8] = Field( + default=4, + description="Number of bytes for length prefix (1, 2, 4, or 8). Used with 'length_prefix' framing." + ) + length_prefix_endian: Literal["big", "little"] = Field( + default="big", + description="Byte order for length prefix. Used with 'length_prefix' framing." + ) + # Delimiter-based framing options + message_delimiter: str = Field( + default='\\x00', + description="Delimiter to detect end of TCP response (e.g., '\\n', '\\r\\n', '\\x00'). Used with 'delimiter' framing." + ) + # Fixed-length framing options + fixed_message_length: Optional[int] = Field( + default=None, + description="Fixed length of each message in bytes. Used with 'fixed_length' framing." + ) + # Stream-based options + max_response_size: int = Field( + default=65536, + description="Maximum bytes to read from TCP stream. Used with 'stream' framing." + ) timeout: int = 30000 auth: None = None class UDPProvider(Provider): - """Options specific to UDP socket tools""" + """Options specific to UDP socket tools + + For request data handling: + - If request_data_format is 'json', arguments will be formatted as a JSON object and sent + - If request_data_format is 'text', the request_data_template can contain placeholders + in the format UTCP_ARG_argname_UTCP_ARG which will be replaced with the value of + the argument named 'argname' + + For response data handling: + - If response_byte_format is None, raw bytes will be returned + - If response_byte_format is an encoding (e.g., 'utf-8'), bytes will be decoded to text + """ provider_type: Literal["udp"] = "udp" host: str port: int + number_of_response_datagrams: int = 1 + request_data_format: Literal["json", "text"] = "json" + request_data_template: Optional[str] = None + response_byte_format: Optional[str] = Field(default="utf-8", description="Encoding to decode response bytes. If None, returns raw bytes.") timeout: int = 30000 auth: None = None @@ -191,4 +253,4 @@ class TextProvider(Provider): TextProvider ], Field(discriminator="provider_type") -] \ No newline at end of file +] diff --git a/src/utcp/shared/tool.py b/src/utcp/shared/tool.py index f1aed22..3994b7c 100644 --- a/src/utcp/shared/tool.py +++ b/src/utcp/shared/tool.py @@ -1,6 +1,5 @@ import inspect from typing import Dict, Any, Optional, List, Set, Tuple, get_type_hints, get_origin, get_args, Union -from typing import get_origin, get_args, List, Dict, Optional, Union, Any from pydantic import BaseModel, Field from utcp.shared.provider import ProviderUnion diff --git a/src/utcp/version.py b/src/utcp/version.py index 1ac5329..675a6b3 100644 --- a/src/utcp/version.py +++ b/src/utcp/version.py @@ -2,7 +2,7 @@ import tomli from pathlib import Path -__version__ = "0.1.8" +__version__ = "0.2.0" try: __version__ = version("utcp") except PackageNotFoundError: diff --git a/tests/client/test_openapi_converter_auth.py b/tests/client/test_openapi_converter_auth.py index 98aa39d..a30a498 100644 --- a/tests/client/test_openapi_converter_auth.py +++ b/tests/client/test_openapi_converter_auth.py @@ -47,7 +47,7 @@ async def test_webscraping_ai_auth_extraction(): assert tool.tool_provider.auth is not None assert isinstance(tool.tool_provider.auth, ApiKeyAuth) assert tool.tool_provider.auth.var_name == "api_key" - assert tool.tool_provider.auth.api_key == "${WEBSCRAPING_AI_API_KEY}" + assert tool.tool_provider.auth.api_key.startswith("${API_KEY_") assert tool.tool_provider.auth.location == "query" diff --git a/tests/client/test_utcp_client.py b/tests/client/test_utcp_client.py new file mode 100644 index 0000000..c1cd9f5 --- /dev/null +++ b/tests/client/test_utcp_client.py @@ -0,0 +1,788 @@ +import pytest +import pytest_asyncio +import asyncio +import json +import os +import tempfile +from typing import Dict, Any, List, Optional +from unittest.mock import MagicMock, AsyncMock, patch + +from utcp.client.utcp_client import UtcpClient, UtcpClientInterface +from utcp.client.utcp_client_config import UtcpClientConfig, UtcpVariableNotFound +from utcp.client.tool_repository import ToolRepository +from utcp.client.tool_repositories.in_mem_tool_repository import InMemToolRepository +from utcp.client.tool_search_strategy import ToolSearchStrategy +from utcp.client.tool_search_strategies.tag_search import TagSearchStrategy +from utcp.client.variable_substitutor import VariableSubstitutor, DefaultVariableSubstitutor +from utcp.shared.tool import Tool, ToolInputOutputSchema +from utcp.shared.provider import ( + Provider, HttpProvider, CliProvider, MCPProvider, TextProvider, + McpConfig, McpStdioServer, McpHttpServer +) +from utcp.shared.auth import ApiKeyAuth, BasicAuth, OAuth2Auth + + +class MockToolRepository(ToolRepository): + """Mock tool repository for testing.""" + + def __init__(self): + self.providers: Dict[str, Provider] = {} + self.tools: Dict[str, Tool] = {} + self.provider_tools: Dict[str, List[Tool]] = {} + + async def save_provider_with_tools(self, provider: Provider, tools: List[Tool]) -> None: + self.providers[provider.name] = provider + self.provider_tools[provider.name] = tools + for tool in tools: + self.tools[tool.name] = tool + + async def remove_provider(self, provider_name: str) -> None: + if provider_name not in self.providers: + raise ValueError(f"Provider not found: {provider_name}") + # Remove tools associated with provider + if provider_name in self.provider_tools: + for tool in self.provider_tools[provider_name]: + if tool.name in self.tools: + del self.tools[tool.name] + del self.provider_tools[provider_name] + del self.providers[provider_name] + + async def remove_tool(self, tool_name: str) -> None: + if tool_name not in self.tools: + raise ValueError(f"Tool not found: {tool_name}") + del self.tools[tool_name] + # Remove from provider_tools + for provider_name, tools in self.provider_tools.items(): + self.provider_tools[provider_name] = [t for t in tools if t.name != tool_name] + + async def get_tool(self, tool_name: str) -> Optional[Tool]: + return self.tools.get(tool_name) + + async def get_tools(self) -> List[Tool]: + return list(self.tools.values()) + + async def get_tools_by_provider(self, provider_name: str) -> Optional[List[Tool]]: + return self.provider_tools.get(provider_name) + + async def get_provider(self, provider_name: str) -> Optional[Provider]: + return self.providers.get(provider_name) + + async def get_providers(self) -> List[Provider]: + return list(self.providers.values()) + + +class MockToolSearchStrategy(ToolSearchStrategy): + """Mock search strategy for testing.""" + + def __init__(self, tool_repository: ToolRepository): + self.tool_repository = tool_repository + + async def search_tools(self, query: str, limit: int = 10) -> List[Tool]: + tools = await self.tool_repository.get_tools() + # Simple mock search: return tools that contain the query in name or description + matched_tools = [ + tool for tool in tools + if query.lower() in tool.name.lower() or query.lower() in tool.description.lower() + ] + return matched_tools[:limit] if limit > 0 else matched_tools + + +class MockTransport: + """Mock transport for testing.""" + + def __init__(self, tools: List[Tool] = None, call_result: Any = "mock_result"): + self.tools = tools or [] + self.call_result = call_result + self.registered_providers = [] + self.deregistered_providers = [] + self.tool_calls = [] + + async def register_tool_provider(self, provider: Provider) -> List[Tool]: + self.registered_providers.append(provider) + return self.tools + + async def deregister_tool_provider(self, provider: Provider) -> None: + self.deregistered_providers.append(provider) + + async def call_tool(self, tool_name: str, arguments: Dict[str, Any], tool_provider: Provider) -> Any: + self.tool_calls.append((tool_name, arguments, tool_provider)) + return self.call_result + + +@pytest_asyncio.fixture +async def mock_tool_repository(): + """Create a mock tool repository.""" + return MockToolRepository() + + +@pytest_asyncio.fixture +async def mock_search_strategy(mock_tool_repository): + """Create a mock search strategy.""" + return MockToolSearchStrategy(mock_tool_repository) + + +@pytest_asyncio.fixture +async def sample_tools(): + """Create sample tools for testing.""" + http_provider = HttpProvider( + name="test_http_provider", + url="https://api.example.com/tool", + http_method="POST" + ) + + cli_provider = CliProvider( + name="test_cli_provider", + command_name="echo" + ) + + return [ + Tool( + name="http_tool", + description="HTTP test tool", + inputs=ToolInputOutputSchema( + type="object", + properties={"param1": {"type": "string", "description": "Test parameter"}}, + required=["param1"] + ), + outputs=ToolInputOutputSchema( + type="object", + properties={"result": {"type": "string", "description": "Test result"}} + ), + tags=["http", "test"], + tool_provider=http_provider + ), + Tool( + name="cli_tool", + description="CLI test tool", + inputs=ToolInputOutputSchema( + type="object", + properties={"command": {"type": "string", "description": "Command to execute"}}, + required=["command"] + ), + outputs=ToolInputOutputSchema( + type="object", + properties={"output": {"type": "string", "description": "Command output"}} + ), + tags=["cli", "test"], + tool_provider=cli_provider + ) + ] + + +@pytest_asyncio.fixture +async def utcp_client(mock_tool_repository, mock_search_strategy): + """Create a UtcpClient instance with mocked dependencies.""" + config = UtcpClientConfig() + variable_substitutor = DefaultVariableSubstitutor() + + client = UtcpClient(config, mock_tool_repository, mock_search_strategy, variable_substitutor) + + # Clear the repository before each test to ensure clean state + client.tool_repository.providers.clear() + client.tool_repository.tools.clear() + client.tool_repository.provider_tools.clear() + + return client + + +class TestUtcpClientInterface: + """Test the UtcpClientInterface abstract methods.""" + + def test_interface_is_abstract(self): + """Test that UtcpClientInterface cannot be instantiated directly.""" + with pytest.raises(TypeError): + UtcpClientInterface() + + def test_utcp_client_implements_interface(self): + """Test that UtcpClient properly implements the interface.""" + assert issubclass(UtcpClient, UtcpClientInterface) + + +class TestUtcpClient: + """Test the UtcpClient implementation.""" + + @pytest.mark.asyncio + async def test_init(self, mock_tool_repository, mock_search_strategy): + """Test UtcpClient initialization.""" + config = UtcpClientConfig() + variable_substitutor = DefaultVariableSubstitutor() + + client = UtcpClient(config, mock_tool_repository, mock_search_strategy, variable_substitutor) + + assert client.config is config + assert client.tool_repository is mock_tool_repository + assert client.search_strategy is mock_search_strategy + assert client.variable_substitutor is variable_substitutor + + @pytest.mark.asyncio + async def test_create_with_defaults(self): + """Test creating UtcpClient with default parameters.""" + with patch.object(UtcpClient, 'load_providers', new_callable=AsyncMock): + client = await UtcpClient.create() + + assert isinstance(client.config, UtcpClientConfig) + assert isinstance(client.tool_repository, InMemToolRepository) + assert isinstance(client.search_strategy, TagSearchStrategy) + assert isinstance(client.variable_substitutor, DefaultVariableSubstitutor) + + @pytest.mark.asyncio + async def test_create_with_dict_config(self): + """Test creating UtcpClient with dictionary config.""" + config_dict = { + "variables": {"TEST_VAR": "test_value"}, + "providers_file_path": "test_providers.json" + } + + with patch.object(UtcpClient, 'load_providers', new_callable=AsyncMock): + client = await UtcpClient.create(config=config_dict) + + assert client.config.variables == {"TEST_VAR": "test_value"} + assert client.config.providers_file_path == "test_providers.json" + + @pytest.mark.asyncio + async def test_create_with_utcp_config(self): + """Test creating UtcpClient with UtcpClientConfig object.""" + config = UtcpClientConfig( + variables={"TEST_VAR": "test_value"}, + providers_file_path="test_providers.json" + ) + + with patch.object(UtcpClient, 'load_providers', new_callable=AsyncMock): + client = await UtcpClient.create(config=config) + + assert client.config is config + + @pytest.mark.asyncio + async def test_register_tool_provider(self, utcp_client, sample_tools): + """Test registering a tool provider.""" + http_provider = HttpProvider( + name="test_provider", + url="https://api.example.com/tool", + http_method="POST" + ) + + # Mock the transport + mock_transport = MockTransport(sample_tools[:1]) # Return first tool + utcp_client.transports["http"] = mock_transport + + tools = await utcp_client.register_tool_provider(http_provider) + + assert len(tools) == 1 + assert tools[0].name == "test_provider.http_tool" # Should be prefixed + # Check that the registered provider has the expected properties + registered_provider = mock_transport.registered_providers[0] + assert registered_provider.name == "test_provider" + assert registered_provider.url == "https://api.example.com/tool" + assert registered_provider.http_method == "POST" + + # Verify tool was saved in repository + saved_tool = await utcp_client.tool_repository.get_tool("test_provider.http_tool") + assert saved_tool is not None + + @pytest.mark.asyncio + async def test_register_tool_provider_unsupported_type(self, utcp_client): + """Test registering a tool provider with unsupported type.""" + # Create a provider with a supported type but then modify it + provider = HttpProvider( + name="test_provider", + url="https://example.com", + http_method="GET" + ) + + # Simulate an unsupported type by removing it from transports + original_transports = utcp_client.transports.copy() + del utcp_client.transports["http"] + + try: + with pytest.raises(ValueError, match="Provider type not supported: http"): + await utcp_client.register_tool_provider(provider) + finally: + # Restore original transports + utcp_client.transports = original_transports + + @pytest.mark.asyncio + async def test_register_tool_provider_name_sanitization(self, utcp_client, sample_tools): + """Test that provider names are sanitized.""" + provider = HttpProvider( + name="test-provider.with/special@chars", + url="https://api.example.com/tool", + http_method="POST" + ) + + mock_transport = MockTransport(sample_tools[:1]) + utcp_client.transports["http"] = mock_transport + + tools = await utcp_client.register_tool_provider(provider) + + # Name should be sanitized + assert provider.name == "test_provider_with_special_chars" + assert tools[0].name == "test_provider_with_special_chars.http_tool" + + @pytest.mark.asyncio + async def test_deregister_tool_provider(self, utcp_client, sample_tools): + """Test deregistering a tool provider.""" + provider = HttpProvider( + name="test_provider", + url="https://api.example.com/tool", + http_method="POST" + ) + + mock_transport = MockTransport(sample_tools[:1]) + utcp_client.transports["http"] = mock_transport + + # First register the provider + await utcp_client.register_tool_provider(provider) + + # Then deregister it + await utcp_client.deregister_tool_provider("test_provider") + + # Verify provider was removed from repository + saved_provider = await utcp_client.tool_repository.get_provider("test_provider") + assert saved_provider is None + + # Verify transport deregister was called + assert len(mock_transport.deregistered_providers) == 1 + + @pytest.mark.asyncio + async def test_deregister_nonexistent_provider(self, utcp_client): + """Test deregistering a non-existent provider.""" + with pytest.raises(ValueError, match="Provider not found: nonexistent"): + await utcp_client.deregister_tool_provider("nonexistent") + + @pytest.mark.asyncio + async def test_call_tool(self, utcp_client, sample_tools): + """Test calling a tool.""" + provider = HttpProvider( + name="test_provider", + url="https://api.example.com/tool", + http_method="POST" + ) + + mock_transport = MockTransport(sample_tools[:1], "test_result") + utcp_client.transports["http"] = mock_transport + + # Register the provider first + await utcp_client.register_tool_provider(provider) + + # Call the tool + result = await utcp_client.call_tool("test_provider.http_tool", {"param1": "value1"}) + + assert result == "test_result" + assert len(mock_transport.tool_calls) == 1 + assert mock_transport.tool_calls[0][0] == "test_provider.http_tool" + assert mock_transport.tool_calls[0][1] == {"param1": "value1"} + + @pytest.mark.asyncio + async def test_call_tool_nonexistent_provider(self, utcp_client): + """Test calling a tool with nonexistent provider.""" + with pytest.raises(ValueError, match="Provider not found: nonexistent"): + await utcp_client.call_tool("nonexistent.tool", {"param": "value"}) + + @pytest.mark.asyncio + async def test_call_tool_nonexistent_tool(self, utcp_client, sample_tools): + """Test calling a nonexistent tool.""" + provider = HttpProvider( + name="test_provider", + url="https://api.example.com/tool", + http_method="POST" + ) + + mock_transport = MockTransport(sample_tools[:1]) + utcp_client.transports["http"] = mock_transport + + # Register the provider first + await utcp_client.register_tool_provider(provider) + + with pytest.raises(ValueError, match="Tool not found: test_provider.nonexistent"): + await utcp_client.call_tool("test_provider.nonexistent", {"param": "value"}) + + @pytest.mark.asyncio + async def test_search_tools(self, utcp_client, sample_tools): + """Test searching for tools.""" + # Add tools to the search strategy's repository + for i, tool in enumerate(sample_tools): + tool.name = f"provider_{i}.{tool.name}" + await utcp_client.tool_repository.save_provider_with_tools( + tool.tool_provider, [tool] + ) + + # Search for tools + results = await utcp_client.search_tools("http", limit=10) + + # Should find the HTTP tool + assert len(results) == 1 + assert "http" in results[0].name.lower() or "http" in results[0].description.lower() + + @pytest.mark.asyncio + async def test_get_required_variables_for_manual_and_tools(self, utcp_client): + """Test getting required variables for a provider.""" + provider = HttpProvider( + name="test_provider", + url="https://api.example.com/$API_URL", + http_method="POST", + auth=ApiKeyAuth(api_key="$API_KEY", var_name="Authorization") + ) + + # Mock the variable substitutor + mock_substitutor = MagicMock() + mock_substitutor.find_required_variables.return_value = ["API_URL", "API_KEY"] + mock_substitutor.substitute.return_value = provider.model_dump() # Return the original dict + utcp_client.variable_substitutor = mock_substitutor + + variables = await utcp_client.get_required_variables_for_manual_and_tools(provider) + + assert variables == ["API_URL", "API_KEY"] + mock_substitutor.find_required_variables.assert_called_once() + + @pytest.mark.asyncio + async def test_get_required_variables_for_tool(self, utcp_client, sample_tools): + """Test getting required variables for a tool.""" + provider = HttpProvider( + name="test_provider", + url="https://api.example.com/$API_URL", + http_method="POST" + ) + + tool = sample_tools[0] + tool.name = "test_provider.http_tool" + tool.tool_provider = provider + + # Add tool to repository + await utcp_client.tool_repository.save_provider_with_tools(provider, [tool]) + + # Mock the variable substitutor + mock_substitutor = MagicMock() + mock_substitutor.find_required_variables.return_value = ["API_URL"] + utcp_client.variable_substitutor = mock_substitutor + + variables = await utcp_client.get_required_variables_for_tool("test_provider.http_tool") + + assert variables == ["API_URL"] + mock_substitutor.find_required_variables.assert_called_once() + + @pytest.mark.asyncio + async def test_get_required_variables_for_nonexistent_tool(self, utcp_client): + """Test getting required variables for a nonexistent tool.""" + with pytest.raises(ValueError, match="Tool not found: nonexistent.tool"): + await utcp_client.get_required_variables_for_tool("nonexistent.tool") + + +class TestUtcpClientProviderLoading: + """Test provider loading functionality.""" + + @pytest.mark.asyncio + async def test_load_providers_from_file(self, utcp_client): + """Test loading providers from a JSON file.""" + # Create a temporary providers file with array format (as expected by load_providers) + providers_data = [ + { + "name": "http_provider", + "provider_type": "http", + "url": "https://api.example.com/tools", + "http_method": "GET" + }, + { + "name": "cli_provider", + "provider_type": "cli", + "command_name": "echo" + } + ] + + with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f: + json.dump(providers_data, f) + temp_file = f.name + + try: + # Mock the transports + mock_http_transport = MockTransport([]) + mock_cli_transport = MockTransport([]) + utcp_client.transports["http"] = mock_http_transport + utcp_client.transports["cli"] = mock_cli_transport + + # Load providers + providers = await utcp_client.load_providers(temp_file) + + assert len(providers) == 2 + assert len(mock_http_transport.registered_providers) == 1 + assert len(mock_cli_transport.registered_providers) == 1 + + finally: + os.unlink(temp_file) + + @pytest.mark.asyncio + async def test_load_providers_file_not_found(self, utcp_client): + """Test loading providers from a non-existent file.""" + with pytest.raises(FileNotFoundError): + await utcp_client.load_providers("nonexistent.json") + + @pytest.mark.asyncio + async def test_load_providers_invalid_json(self, utcp_client): + """Test loading providers from invalid JSON file.""" + with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f: + f.write("invalid json content") + temp_file = f.name + + try: + with pytest.raises(ValueError, match="Invalid JSON in providers file"): + await utcp_client.load_providers(temp_file) + finally: + os.unlink(temp_file) + + @pytest.mark.asyncio + async def test_load_providers_with_variables(self, utcp_client): + """Test loading providers with variable substitution.""" + providers_data = [ + { + "name": "http_provider", + "provider_type": "http", + "url": "$BASE_URL/tools", + "http_method": "GET", + "auth": { + "auth_type": "api_key", + "api_key": "$API_KEY", + "var_name": "Authorization" + } + } + ] + + with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f: + json.dump(providers_data, f) + temp_file = f.name + + try: + # Setup client with variables (need provider prefixed variables) + utcp_client.config.variables = { + "http__provider_BASE_URL": "https://api.example.com", + "http__provider_API_KEY": "secret_key" + } + + # Mock the transport + mock_transport = MockTransport([]) + utcp_client.transports["http"] = mock_transport + + # Load providers + providers = await utcp_client.load_providers(temp_file) + + assert len(providers) == 1 + # Check that the registered provider has substituted values + registered_provider = mock_transport.registered_providers[0] + assert registered_provider.url == "https://api.example.com/tools" + assert registered_provider.auth.api_key == "secret_key" + + finally: + os.unlink(temp_file) + + @pytest.mark.asyncio + async def test_load_providers_missing_variable(self, utcp_client): + """Test loading providers with missing variable.""" + providers_data = [ + { + "name": "http_provider", + "provider_type": "http", + "url": "$MISSING_VAR/tools", + "http_method": "GET" + } + ] + + with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f: + json.dump(providers_data, f) + temp_file = f.name + + try: + # Mock transport to avoid registration issues + utcp_client.transports["http"] = MockTransport([]) + + # The load_providers method catches exceptions and returns empty list + # So we need to check the registration directly which will raise the exception + provider_data = { + "name": "http_provider", + "provider_type": "http", + "url": "$MISSING_VAR/tools", + "http_method": "GET" + } + provider = HttpProvider.model_validate(provider_data) + + with pytest.raises(UtcpVariableNotFound, match="Variable http__provider_MISSING_VAR"): + await utcp_client.register_tool_provider(provider) + finally: + os.unlink(temp_file) + + +class TestUtcpClientTransports: + """Test transport-related functionality.""" + + def test_default_transports_initialized(self, utcp_client): + """Test that default transports are properly initialized.""" + expected_transport_types = [ + "http", "cli", "sse", "http_stream", "mcp", "text", "graphql", "tcp", "udp" + ] + + for transport_type in expected_transport_types: + assert transport_type in utcp_client.transports + assert utcp_client.transports[transport_type] is not None + + @pytest.mark.asyncio + async def test_variable_substitution(self, utcp_client): + """Test variable substitution in providers.""" + provider = HttpProvider( + name="test_provider", + url="$BASE_URL/api", + http_method="POST", + auth=ApiKeyAuth(api_key="$API_KEY", var_name="Authorization") + ) + + # Set up variables with provider prefix + utcp_client.config.variables = { + "test__provider_BASE_URL": "https://api.example.com", + "test__provider_API_KEY": "secret_key" + } + + substituted_provider = utcp_client._substitute_provider_variables(provider, "test_provider") + + assert substituted_provider.url == "https://api.example.com/api" + assert substituted_provider.auth.api_key == "secret_key" + + @pytest.mark.asyncio + async def test_variable_substitution_missing_variable(self, utcp_client): + """Test variable substitution with missing variable.""" + provider = HttpProvider( + name="test_provider", + url="$MISSING_VAR/api", + http_method="POST" + ) + + with pytest.raises(UtcpVariableNotFound, match="Variable test__provider_MISSING_VAR"): + utcp_client._substitute_provider_variables(provider, "test_provider") + + +class TestUtcpClientEdgeCases: + """Test edge cases and error conditions.""" + + @pytest.mark.asyncio + async def test_empty_provider_file(self, utcp_client): + """Test loading an empty provider file.""" + with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f: + json.dump([], f) # Empty array instead of empty object + temp_file = f.name + + try: + providers = await utcp_client.load_providers(temp_file) + assert providers == [] + finally: + os.unlink(temp_file) + + @pytest.mark.asyncio + async def test_register_provider_with_existing_name(self, utcp_client, sample_tools): + """Test registering a provider with an existing name should raise an error.""" + provider1 = HttpProvider( + name="duplicate_name", + url="https://api.example1.com/tool", + http_method="POST" + ) + provider2 = HttpProvider( + name="duplicate_name", + url="https://api.example2.com/tool", + http_method="GET" + ) + + mock_transport = MockTransport(sample_tools[:1]) + utcp_client.transports["http"] = mock_transport + + # Register first provider + await utcp_client.register_tool_provider(provider1) + + # Attempting to register second provider with same name should raise an error + with pytest.raises(ValueError, match="Provider duplicate_name already registered"): + await utcp_client.register_tool_provider(provider2) + + # Should still have the first provider + saved_provider = await utcp_client.tool_repository.get_provider("duplicate_name") + assert saved_provider.url == "https://api.example1.com/tool" + assert saved_provider.http_method == "POST" + + @pytest.mark.asyncio + async def test_complex_mcp_provider(self, utcp_client): + """Test loading a complex MCP provider configuration.""" + providers_data = [ + { + "name": "mcp_provider", + "provider_type": "mcp", + "config": { + "mcpServers": { + "stdio_server": { + "transport": "stdio", + "command": "python", + "args": ["-m", "test_server"], + "env": {"TEST_VAR": "test_value"} + }, + "http_server": { + "transport": "http", + "url": "http://localhost:8000/mcp" + } + } + } + } + ] + + with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f: + json.dump(providers_data, f) + temp_file = f.name + + try: + # Mock the MCP transport + mock_transport = MockTransport([]) + utcp_client.transports["mcp"] = mock_transport + + providers = await utcp_client.load_providers(temp_file) + + assert len(providers) == 1 + provider = providers[0] + assert isinstance(provider, MCPProvider) + assert len(provider.config.mcpServers) == 2 + assert "stdio_server" in provider.config.mcpServers + assert "http_server" in provider.config.mcpServers + + finally: + os.unlink(temp_file) + + @pytest.mark.asyncio + async def test_text_transport_configuration(self, utcp_client): + """Test TextTransport base path configuration.""" + # Create a temporary directory structure + with tempfile.TemporaryDirectory() as temp_dir: + providers_file = os.path.join(temp_dir, "providers.json") + + with open(providers_file, 'w') as f: + json.dump([], f) # Empty array + + # Create client with providers file path + config = UtcpClientConfig(providers_file_path=providers_file) + + with patch.object(UtcpClient, 'load_providers', new_callable=AsyncMock): + client = await UtcpClient.create(config=config) + + # Check that TextTransport was configured with the correct base path + text_transport = client.transports["text"] + assert hasattr(text_transport, 'base_path') + assert text_transport.base_path == temp_dir + + @pytest.mark.asyncio + async def test_load_providers_wrong_format(self, utcp_client): + """Test loading providers with wrong JSON format (object instead of array).""" + providers_data = { + "http_provider": { + "provider_type": "http", + "url": "https://api.example.com/tools", + "http_method": "GET" + } + } + + with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f: + json.dump(providers_data, f) + temp_file = f.name + + try: + with pytest.raises(ValueError, match="Providers file must contain a JSON array at the root level"): + await utcp_client.load_providers(temp_file) + finally: + os.unlink(temp_file) diff --git a/tests/client/transport_interfaces/test_tcp_transport.py b/tests/client/transport_interfaces/test_tcp_transport.py new file mode 100644 index 0000000..c32dcdf --- /dev/null +++ b/tests/client/transport_interfaces/test_tcp_transport.py @@ -0,0 +1,875 @@ +import pytest +import pytest_asyncio +import json +import asyncio +import socket +import struct +import threading +from unittest.mock import MagicMock, patch, AsyncMock + +from utcp.client.transport_interfaces.tcp_transport import TCPTransport +from utcp.shared.provider import TCPProvider +from utcp.shared.tool import Tool, ToolInputOutputSchema + + +class MockTCPServer: + """Mock TCP server for testing.""" + + def __init__(self, host='localhost', port=0, response_delay=0.0): + self.host = host + self.port = port + self.sock = None + self.running = False + self.responses = {} # Map message -> response + self.call_count = 0 + self.server_task = None + self.connections = [] + self.response_delay = response_delay # Delay before sending response (seconds) + + async def start(self): + """Start the mock TCP server.""" + # Create socket and bind + self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self.sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + self.sock.bind((self.host, self.port)) + if self.port == 0: # Auto-assign port + self.port = self.sock.getsockname()[1] + + self.sock.listen(5) + self.running = True + + # Start listening task + self.server_task = asyncio.create_task(self._accept_connections()) + + # Give the server a moment to start + await asyncio.sleep(0.1) + + async def stop(self): + """Stop the mock TCP server.""" + self.running = False + if self.server_task: + self.server_task.cancel() + try: + await self.server_task + except asyncio.CancelledError: + pass + + # Close all active connections + for conn in self.connections: + try: + conn.close() + except Exception: + pass + self.connections.clear() + + if self.sock: + self.sock.close() + + async def _accept_connections(self): + """Accept incoming TCP connections.""" + self.sock.setblocking(False) + + while self.running: + try: + conn, addr = await asyncio.get_event_loop().sock_accept(self.sock) + self.connections.append(conn) + # Handle each connection in a separate task + asyncio.create_task(self._handle_connection(conn, addr)) + except asyncio.CancelledError: + break + except Exception as e: + if self.running: + print(f"Mock TCP server accept error: {e}") + await asyncio.sleep(0.01) + + async def _handle_connection(self, conn, addr): + """Handle a single TCP connection.""" + try: + # Read data from client + data = await asyncio.get_event_loop().sock_recv(conn, 4096) + if not data: + return + + self.call_count += 1 + + try: + message = data.decode('utf-8') + except UnicodeDecodeError: + message = data.hex() # Fallback for binary data + + # Get response for this message + response = self.responses.get(message, '{"error": "unknown_message"}') + + # Convert response to bytes + if isinstance(response, str): + response_bytes = response.encode('utf-8') + elif isinstance(response, bytes): + response_bytes = response + elif isinstance(response, dict) or isinstance(response, list): + response_bytes = json.dumps(response).encode('utf-8') + else: + response_bytes = str(response).encode('utf-8') + + # Add delay if configured + if self.response_delay > 0: + await asyncio.sleep(self.response_delay) + + # Send response back + await asyncio.get_event_loop().sock_sendall(conn, response_bytes) + + except Exception as e: + if self.running: + print(f"Mock TCP server connection error: {e}") + finally: + try: + conn.close() + except Exception: + pass + if conn in self.connections: + self.connections.remove(conn) + + def set_response(self, message, response): + """Set a response for a specific message.""" + self.responses[message] = response + + +class MockTCPServerWithFraming(MockTCPServer): + """Mock TCP server that handles different framing strategies.""" + + def __init__(self, host='localhost', port=0, framing_strategy='stream', response_delay=0.0): + super().__init__(host, port, response_delay) + self.framing_strategy = framing_strategy + self.length_prefix_bytes = 4 + self.length_prefix_endian = 'big' + self.message_delimiter = '\n' + self.fixed_message_length = None + + async def _handle_connection(self, conn, addr): + """Handle a single TCP connection with framing.""" + try: + if self.framing_strategy == 'length_prefix': + # Read length prefix first + length_data = await asyncio.get_event_loop().sock_recv(conn, self.length_prefix_bytes) + if not length_data: + return + + if self.length_prefix_bytes == 1: + message_length = struct.unpack(f"{'>' if self.length_prefix_endian == 'big' else '<'}B", length_data)[0] + elif self.length_prefix_bytes == 2: + message_length = struct.unpack(f"{'>' if self.length_prefix_endian == 'big' else '<'}H", length_data)[0] + elif self.length_prefix_bytes == 4: + message_length = struct.unpack(f"{'>' if self.length_prefix_endian == 'big' else '<'}I", length_data)[0] + + # Read the actual message + data = await asyncio.get_event_loop().sock_recv(conn, message_length) + + elif self.framing_strategy == 'delimiter': + # Read until delimiter + data = b'' + delimiter_bytes = self.message_delimiter.encode('utf-8') + while not data.endswith(delimiter_bytes): + chunk = await asyncio.get_event_loop().sock_recv(conn, 1) + if not chunk: + break + data += chunk + # Remove delimiter + data = data[:-len(delimiter_bytes)] + + elif self.framing_strategy == 'fixed_length': + # Read fixed number of bytes + data = await asyncio.get_event_loop().sock_recv(conn, self.fixed_message_length) + + else: # stream + # Read all available data + data = await asyncio.get_event_loop().sock_recv(conn, 4096) + + if not data: + return + + self.call_count += 1 + + try: + message = data.decode('utf-8') + except UnicodeDecodeError: + message = data.hex() + + # Get response for this message + response = self.responses.get(message, '{"error": "unknown_message"}') + + # Convert response to bytes + if isinstance(response, str): + response_bytes = response.encode('utf-8') + elif isinstance(response, bytes): + response_bytes = response + elif isinstance(response, dict) or isinstance(response, list): + response_bytes = json.dumps(response).encode('utf-8') + else: + response_bytes = str(response).encode('utf-8') + + # Add delay if configured + if self.response_delay > 0: + await asyncio.sleep(self.response_delay) + + # Send response with appropriate framing + if self.framing_strategy == 'length_prefix': + # Add length prefix + length = len(response_bytes) + if self.length_prefix_bytes == 1: + length_bytes = struct.pack(f"{'>' if self.length_prefix_endian == 'big' else '<'}B", length) + elif self.length_prefix_bytes == 2: + length_bytes = struct.pack(f"{'>' if self.length_prefix_endian == 'big' else '<'}H", length) + elif self.length_prefix_bytes == 4: + length_bytes = struct.pack(f"{'>' if self.length_prefix_endian == 'big' else '<'}I", length) + + await asyncio.get_event_loop().sock_sendall(conn, length_bytes + response_bytes) + + elif self.framing_strategy == 'delimiter': + # Add delimiter + delimiter_bytes = self.message_delimiter.encode('utf-8') + await asyncio.get_event_loop().sock_sendall(conn, response_bytes + delimiter_bytes) + + else: # stream or fixed_length + await asyncio.get_event_loop().sock_sendall(conn, response_bytes) + + except Exception as e: + if self.running: + print(f"Mock TCP server connection error: {e}") + finally: + try: + conn.close() + except Exception: + pass + if conn in self.connections: + self.connections.remove(conn) + + +@pytest_asyncio.fixture +async def mock_tcp_server(): + """Create a mock TCP server for testing.""" + server = MockTCPServer() + await server.start() + yield server + await server.stop() + + +@pytest_asyncio.fixture +async def mock_tcp_server_length_prefix(): + """Create a mock TCP server with length-prefix framing.""" + server = MockTCPServerWithFraming(framing_strategy='length_prefix') + await server.start() + yield server + await server.stop() + + +@pytest_asyncio.fixture +async def mock_tcp_server_delimiter(): + """Create a mock TCP server with delimiter framing.""" + server = MockTCPServerWithFraming(framing_strategy='delimiter') + await server.start() + yield server + await server.stop() + + +@pytest_asyncio.fixture +async def mock_tcp_server_slow(): + """Create a mock TCP server with a 2-second response delay.""" + server = MockTCPServer(response_delay=2.0) # 2-second delay + await server.start() + yield server + await server.stop() + + +@pytest.fixture +def logger(): + """Create a mock logger.""" + return MagicMock() + + +@pytest.fixture +def tcp_transport(logger): + """Create a TCP transport instance.""" + return TCPTransport(logger=logger) + + +@pytest.fixture +def tcp_provider(mock_tcp_server): + """Create a basic TCP provider for testing.""" + return TCPProvider( + name="test_tcp_provider", + host=mock_tcp_server.host, + port=mock_tcp_server.port, + request_data_format="json", + response_byte_format="utf-8", + framing_strategy="stream", + timeout=5000 + ) + + +@pytest.fixture +def text_template_provider(mock_tcp_server): + """Create a TCP provider with text template format.""" + return TCPProvider( + name="text_template_provider", + host=mock_tcp_server.host, + port=mock_tcp_server.port, + request_data_format="text", + request_data_template="ACTION UTCP_ARG_cmd_UTCP_ARG PARAM UTCP_ARG_value_UTCP_ARG", + response_byte_format="utf-8", + framing_strategy="stream", + timeout=5000 + ) + + +@pytest.fixture +def raw_bytes_provider(mock_tcp_server): + """Create a TCP provider that returns raw bytes.""" + return TCPProvider( + name="raw_bytes_provider", + host=mock_tcp_server.host, + port=mock_tcp_server.port, + request_data_format="json", + response_byte_format=None, # Raw bytes + framing_strategy="stream", + timeout=5000 + ) + + +@pytest.fixture +def length_prefix_provider(mock_tcp_server_length_prefix): + """Create a TCP provider with length-prefix framing.""" + return TCPProvider( + name="length_prefix_provider", + host=mock_tcp_server_length_prefix.host, + port=mock_tcp_server_length_prefix.port, + request_data_format="json", + response_byte_format="utf-8", + framing_strategy="length_prefix", + length_prefix_bytes=4, + length_prefix_endian="big", + timeout=5000 + ) + + +@pytest.fixture +def delimiter_provider(mock_tcp_server_delimiter): + """Create a TCP provider with delimiter framing.""" + return TCPProvider( + name="delimiter_provider", + host=mock_tcp_server_delimiter.host, + port=mock_tcp_server_delimiter.port, + request_data_format="json", + response_byte_format="utf-8", + framing_strategy="delimiter", + message_delimiter="\n", + timeout=5000 + ) + + +# Test register_tool_provider +@pytest.mark.asyncio +async def test_register_tool_provider(tcp_transport, tcp_provider, mock_tcp_server, logger): + """Test registering a tool provider.""" + # Set up discovery response + discovery_response = { + "tools": [ + { + "name": "test_tool", + "description": "A test tool", + "inputs": { + "type": "object", + "properties": { + "param1": {"type": "string", "description": "First parameter"} + }, + "required": ["param1"] + }, + "outputs": { + "type": "object", + "properties": { + "result": {"type": "string", "description": "Result"} + } + }, + "tool_provider": tcp_provider.model_dump() + } + ] + } + + mock_tcp_server.set_response('{"type": "utcp"}', discovery_response) + + # Register the provider + tools = await tcp_transport.register_tool_provider(tcp_provider) + + # Check results + assert len(tools) == 1 + assert tools[0].name == "test_tool" + assert tools[0].description == "A test tool" + assert mock_tcp_server.call_count == 1 + + # Verify logger was called + logger.assert_called() + + +@pytest.mark.asyncio +async def test_register_tool_provider_empty_response(tcp_transport, tcp_provider, mock_tcp_server): + """Test registering a tool provider with empty response.""" + mock_tcp_server.set_response('{"type": "utcp"}', {"tools": []}) + + tools = await tcp_transport.register_tool_provider(tcp_provider) + + assert len(tools) == 0 + assert mock_tcp_server.call_count == 1 + + +@pytest.mark.asyncio +async def test_register_tool_provider_invalid_json(tcp_transport, tcp_provider, mock_tcp_server): + """Test registering a tool provider with invalid JSON response.""" + mock_tcp_server.set_response('{"type": "utcp"}', "invalid json response") + + tools = await tcp_transport.register_tool_provider(tcp_provider) + + assert len(tools) == 0 + + +@pytest.mark.asyncio +async def test_register_tool_provider_invalid_provider_type(tcp_transport): + """Test registering a non-TCP provider raises ValueError.""" + from utcp.shared.provider import HttpProvider + + invalid_provider = HttpProvider(url="http://example.com") + + with pytest.raises(ValueError, match="TCPTransport can only be used with TCPProvider"): + await tcp_transport.register_tool_provider(invalid_provider) + + +# Test deregister_tool_provider +@pytest.mark.asyncio +async def test_deregister_tool_provider(tcp_transport, tcp_provider): + """Test deregistering a tool provider (should be a no-op).""" + # Should not raise any exceptions + await tcp_transport.deregister_tool_provider(tcp_provider) + + +@pytest.mark.asyncio +async def test_deregister_tool_provider_invalid_type(tcp_transport): + """Test deregistering a non-TCP provider raises ValueError.""" + from utcp.shared.provider import HttpProvider + + invalid_provider = HttpProvider(url="http://example.com") + + with pytest.raises(ValueError, match="TCPTransport can only be used with TCPProvider"): + await tcp_transport.deregister_tool_provider(invalid_provider) + + +# Test call_tool with JSON format +@pytest.mark.asyncio +async def test_call_tool_json_format(tcp_transport, tcp_provider, mock_tcp_server): + """Test calling a tool with JSON format.""" + mock_tcp_server.set_response('{"param1": "value1"}', '{"result": "success"}') + + arguments = {"param1": "value1"} + result = await tcp_transport.call_tool("test_tool", arguments, tcp_provider) + + assert result == '{"result": "success"}' + assert mock_tcp_server.call_count == 1 + + +@pytest.mark.asyncio +async def test_call_tool_text_template_format(tcp_transport, text_template_provider, mock_tcp_server): + """Test calling a tool with text template format.""" + mock_tcp_server.set_response("ACTION get PARAM data123", '{"result": "template_success"}') + + arguments = {"cmd": "get", "value": "data123"} + result = await tcp_transport.call_tool("test_tool", arguments, text_template_provider) + + assert result == '{"result": "template_success"}' + assert mock_tcp_server.call_count == 1 + + +@pytest.mark.asyncio +async def test_call_tool_text_format_no_template(tcp_transport, mock_tcp_server): + """Test calling a tool with text format but no template.""" + provider = TCPProvider( + name="no_template_provider", + host=mock_tcp_server.host, + port=mock_tcp_server.port, + request_data_format="text", + request_data_template=None, + response_byte_format="utf-8", + framing_strategy="stream", + timeout=5000 + ) + + # Should use fallback format (space-separated values) + mock_tcp_server.set_response("value1 value2", '{"result": "fallback_success"}') + + arguments = {"param1": "value1", "param2": "value2"} + result = await tcp_transport.call_tool("test_tool", arguments, provider) + + assert result == '{"result": "fallback_success"}' + + +@pytest.mark.asyncio +async def test_call_tool_raw_bytes_response(tcp_transport, raw_bytes_provider, mock_tcp_server): + """Test calling a tool that returns raw bytes.""" + binary_response = b'\x01\x02\x03\x04' + mock_tcp_server.set_response('{"param1": "value1"}', binary_response) + + arguments = {"param1": "value1"} + result = await tcp_transport.call_tool("test_tool", arguments, raw_bytes_provider) + + assert result == binary_response + assert isinstance(result, bytes) + + +@pytest.mark.asyncio +async def test_call_tool_invalid_provider_type(tcp_transport): + """Test calling a tool with non-TCP provider raises ValueError.""" + from utcp.shared.provider import HttpProvider + + invalid_provider = HttpProvider(url="http://example.com") + + with pytest.raises(ValueError, match="TCPTransport can only be used with TCPProvider"): + await tcp_transport.call_tool("test_tool", {}, invalid_provider) + + +# Test framing strategies +@pytest.mark.asyncio +async def test_call_tool_length_prefix_framing(tcp_transport, length_prefix_provider, mock_tcp_server_length_prefix): + """Test calling a tool with length-prefix framing.""" + mock_tcp_server_length_prefix.set_response('{"param1": "value1"}', '{"result": "length_prefix_success"}') + + arguments = {"param1": "value1"} + result = await tcp_transport.call_tool("test_tool", arguments, length_prefix_provider) + + assert result == '{"result": "length_prefix_success"}' + + +@pytest.mark.asyncio +async def test_call_tool_delimiter_framing(tcp_transport, delimiter_provider, mock_tcp_server_delimiter): + """Test calling a tool with delimiter framing.""" + mock_tcp_server_delimiter.set_response('{"param1": "value1"}', '{"result": "delimiter_success"}') + + arguments = {"param1": "value1"} + result = await tcp_transport.call_tool("test_tool", arguments, delimiter_provider) + + assert result == '{"result": "delimiter_success"}' + + +@pytest.mark.asyncio +async def test_call_tool_fixed_length_framing(tcp_transport, mock_tcp_server): + """Test calling a tool with fixed-length framing.""" + provider = TCPProvider( + name="fixed_length_provider", + host=mock_tcp_server.host, + port=mock_tcp_server.port, + request_data_format="json", + response_byte_format="utf-8", + framing_strategy="fixed_length", + fixed_message_length=20, + timeout=5000 + ) + + # Set up server to handle fixed-length messages + mock_tcp_server.responses['{"param1": "value1"}'] = '{"result": "fixed"}'.ljust(20) # Pad to 20 bytes + + arguments = {"param1": "value1"} + result = await tcp_transport.call_tool("test_tool", arguments, provider) + + assert '{"result": "fixed"}' in result + + +# Test message formatting +def test_format_tool_call_message_json(tcp_transport): + """Test formatting tool call message with JSON format.""" + provider = TCPProvider( + name="test", + host="localhost", + port=1234, + request_data_format="json" + ) + + arguments = {"param1": "value1", "param2": 123} + result = tcp_transport._format_tool_call_message(arguments, provider) + + assert result == json.dumps(arguments) + + +def test_format_tool_call_message_text_with_template(tcp_transport): + """Test formatting tool call message with text template.""" + provider = TCPProvider( + name="test", + host="localhost", + port=1234, + request_data_format="text", + request_data_template="ACTION UTCP_ARG_cmd_UTCP_ARG PARAM UTCP_ARG_value_UTCP_ARG" + ) + + arguments = {"cmd": "get", "value": "data123"} + result = tcp_transport._format_tool_call_message(arguments, provider) + + # Should substitute placeholders + assert result == "ACTION get PARAM data123" + + +def test_format_tool_call_message_text_with_complex_values(tcp_transport): + """Test formatting tool call message with complex values in template.""" + provider = TCPProvider( + name="test", + host="localhost", + port=1234, + request_data_format="text", + request_data_template="DATA UTCP_ARG_obj_UTCP_ARG" + ) + + arguments = {"obj": {"nested": "value", "number": 123}} + result = tcp_transport._format_tool_call_message(arguments, provider) + + # Should JSON-serialize complex values + assert result == 'DATA {"nested": "value", "number": 123}' + + +def test_format_tool_call_message_text_no_template(tcp_transport): + """Test formatting tool call message with text format but no template.""" + provider = TCPProvider( + name="test", + host="localhost", + port=1234, + request_data_format="text", + request_data_template=None + ) + + arguments = {"param1": "value1", "param2": "value2"} + result = tcp_transport._format_tool_call_message(arguments, provider) + + # Should use fallback format (space-separated values) + assert result == "value1 value2" + + +def test_format_tool_call_message_default_to_json(tcp_transport): + """Test formatting tool call message defaults to JSON for unknown format.""" + # Create a provider with valid format first + provider = TCPProvider( + name="test", + host="localhost", + port=1234, + request_data_format="json" + ) + + # Manually set an invalid format to test the fallback behavior + provider.request_data_format = "unknown" # Invalid format + + arguments = {"param1": "value1"} + result = tcp_transport._format_tool_call_message(arguments, provider) + + # Should default to JSON + assert result == json.dumps(arguments) + + +# Test framing encoding and decoding +def test_encode_message_with_length_prefix_framing(tcp_transport): + """Test encoding message with length-prefix framing.""" + provider = TCPProvider( + name="test", + host="localhost", + port=1234, + framing_strategy="length_prefix", + length_prefix_bytes=4, + length_prefix_endian="big" + ) + + message = "test message" + result = tcp_transport._encode_message_with_framing(message, provider) + + # Should have 4-byte big-endian length prefix + expected_length = len(message.encode('utf-8')) + expected_prefix = struct.pack('>I', expected_length) + + assert result.startswith(expected_prefix) + assert result[4:] == message.encode('utf-8') + + +def test_encode_message_with_delimiter_framing(tcp_transport): + """Test encoding message with delimiter framing.""" + provider = TCPProvider( + name="test", + host="localhost", + port=1234, + framing_strategy="delimiter", + message_delimiter="\n" + ) + + message = "test message" + result = tcp_transport._encode_message_with_framing(message, provider) + + # Should have delimiter appended + assert result == (message + "\n").encode('utf-8') + + +def test_encode_message_with_stream_framing(tcp_transport): + """Test encoding message with stream framing.""" + provider = TCPProvider( + name="test", + host="localhost", + port=1234, + framing_strategy="stream" + ) + + message = "test message" + result = tcp_transport._encode_message_with_framing(message, provider) + + # Should just be the raw message + assert result == message.encode('utf-8') + + +# Test error handling and edge cases +@pytest.mark.asyncio +async def test_call_tool_server_error(tcp_transport, tcp_provider, mock_tcp_server): + """Test handling server errors during tool calls.""" + # Don't set any response, so the server will return an error + arguments = {"param1": "value1"} + + # Call the tool - should get the default error response + result = await tcp_transport.call_tool("test_tool", arguments, tcp_provider) + + # Should receive the default error message + assert '{"error": "unknown_message"}' in result + + +@pytest.mark.asyncio +async def test_register_tool_provider_malformed_tool(tcp_transport, tcp_provider, mock_tcp_server): + """Test registering provider with malformed tool definition.""" + # Set up discovery response with invalid tool + discovery_response = { + "tools": [ + { + "name": "test_tool", + # Missing required fields like inputs, outputs, tool_provider + } + ] + } + + mock_tcp_server.set_response('{"type": "utcp"}', discovery_response) + + # Register the provider - should handle invalid tool gracefully + tools = await tcp_transport.register_tool_provider(tcp_provider) + + # Should return empty list due to invalid tool definition + assert len(tools) == 0 + + +@pytest.mark.asyncio +async def test_register_tool_provider_bytes_response(tcp_transport, tcp_provider, mock_tcp_server): + """Test registering provider that returns bytes response.""" + # Set up discovery response as JSON but provider returns raw bytes + discovery_response = '{"tools": []}'.encode('utf-8') + + mock_tcp_server.set_response('{"type": "utcp"}', discovery_response) + + # Register the provider - should handle bytes response by decoding + tools = await tcp_transport.register_tool_provider(tcp_provider) + + # Should successfully decode and parse + assert len(tools) == 0 + + +# Test logging functionality +@pytest.mark.asyncio +async def test_logging_calls(tcp_transport, tcp_provider, mock_tcp_server, logger): + """Test that logging functions are called appropriately.""" + # Set up discovery response + discovery_response = {"tools": []} + mock_tcp_server.set_response('{"type": "utcp"}', discovery_response) + + # Register provider + await tcp_transport.register_tool_provider(tcp_provider) + + # Verify logger was called + logger.assert_called() + + # Call tool + mock_tcp_server.set_response('{}', {"result": "test"}) + await tcp_transport.call_tool("test_tool", {}, tcp_provider) + + # Logger should have been called multiple times + assert logger.call_count > 1 + + +# Test timeout handling +@pytest.mark.asyncio +async def test_call_tool_timeout(tcp_transport): + """Test calling a tool with timeout using delimiter framing.""" + # Create a slow server with delimiter framing + slow_server = MockTCPServerWithFraming( + framing_strategy='delimiter', + response_delay=2.0 # 2-second delay + ) + await slow_server.start() + + try: + # Create provider with 1-second timeout, but server has 2-second delay + provider = TCPProvider( + name="timeout_provider", + host=slow_server.host, + port=slow_server.port, + request_data_format="json", + response_byte_format="utf-8", + framing_strategy="delimiter", + message_delimiter="\n", + timeout=1000 # 1 second timeout, but server delays 2 seconds + ) + + # Set up a response (server will delay 2 seconds before responding) + slow_server.set_response('{"param1": "value1"}', '{"result": "delayed_response"}') + + arguments = {"param1": "value1"} + + # Should timeout because server takes 2 seconds but timeout is 1 second + # Delimiter framing will treat timeout as an error since it expects a complete message + with pytest.raises(Exception): # Expect timeout error + await tcp_transport.call_tool("test_tool", arguments, provider) + finally: + await slow_server.stop() + + +@pytest.mark.asyncio +async def test_call_tool_connection_refused(tcp_transport): + """Test calling a tool when connection is refused.""" + # Use a port that's definitely not listening + provider = TCPProvider( + name="refused_provider", + host="localhost", + port=1, # Port 1 should be refused + request_data_format="json", + response_byte_format="utf-8", + framing_strategy="stream", + timeout=5000 + ) + + arguments = {"param1": "value1"} + + # Should handle connection error gracefully + with pytest.raises(Exception): # Expect connection refused or similar + await tcp_transport.call_tool("test_tool", arguments, provider) + + +# Test different byte encodings +@pytest.mark.asyncio +async def test_call_tool_different_encodings(tcp_transport, mock_tcp_server): + """Test calling a tool with different response byte encodings.""" + # Test ASCII encoding + provider_ascii = TCPProvider( + name="ascii_provider", + host=mock_tcp_server.host, + port=mock_tcp_server.port, + request_data_format="json", + response_byte_format="ascii", + framing_strategy="stream", + timeout=5000 + ) + + mock_tcp_server.set_response('{"param1": "value1"}', '{"result": "ascii_success"}') + + arguments = {"param1": "value1"} + result = await tcp_transport.call_tool("test_tool", arguments, provider_ascii) + + assert result == '{"result": "ascii_success"}' + assert isinstance(result, str) diff --git a/tests/client/transport_interfaces/test_udp_transport.py b/tests/client/transport_interfaces/test_udp_transport.py new file mode 100644 index 0000000..1bb3b0b --- /dev/null +++ b/tests/client/transport_interfaces/test_udp_transport.py @@ -0,0 +1,625 @@ +import pytest +import pytest_asyncio +import json +import asyncio +import socket +from unittest.mock import MagicMock, patch, AsyncMock + +from utcp.client.transport_interfaces.udp_transport import UDPTransport +from utcp.shared.provider import UDPProvider +from utcp.shared.tool import Tool, ToolInputOutputSchema + + +class MockUDPServer: + """Mock UDP server for testing.""" + + def __init__(self, host='localhost', port=0): + self.host = host + self.port = port + self.sock = None + self.running = False + self.responses = {} # Map message -> response + self.call_count = 0 + self.listen_task = None + + async def start(self): + """Start the mock UDP server.""" + # Create socket and bind + self.sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + # Keep it blocking since we're using run_in_executor + self.sock.bind((self.host, self.port)) + if self.port == 0: # Auto-assign port + self.port = self.sock.getsockname()[1] + + self.running = True + + # Start listening task + self.listen_task = asyncio.create_task(self._listen()) + + # Give the server a moment to start + await asyncio.sleep(0.1) + + async def stop(self): + """Stop the mock UDP server.""" + self.running = False + if self.listen_task: + self.listen_task.cancel() + try: + await self.listen_task + except asyncio.CancelledError: + pass + if self.sock: + self.sock.close() + + async def _listen(self): + """Listen for UDP messages and send responses.""" + # Use a blocking approach with short timeout for responsiveness + self.sock.settimeout(0.01) # Very short timeout + + while self.running: + try: + data, addr = self.sock.recvfrom(4096) + self.call_count += 1 + + try: + message = data.decode('utf-8') + except UnicodeDecodeError: + message = data.hex() # Fallback for binary data + + # Get response for this message + response = self.responses.get(message, '{"error": "unknown_message"}') + + # Convert response to bytes + if isinstance(response, str): + response_bytes = response.encode('utf-8') + elif isinstance(response, bytes): + response_bytes = response + elif isinstance(response, dict) or isinstance(response, list): + response_bytes = json.dumps(response).encode('utf-8') + else: + response_bytes = str(response).encode('utf-8') + + # Send response back immediately + self.sock.sendto(response_bytes, addr) + + except socket.timeout: + # Expected timeout, continue loop + await asyncio.sleep(0.001) # Brief async yield + continue + except asyncio.CancelledError: + break + except Exception as e: + if self.running: # Only log if we're still supposed to be running + import traceback + print(f"Mock UDP server error: {e}") + print(f"Traceback: {traceback.format_exc()}") + await asyncio.sleep(0.01) # Brief pause before retrying + + def set_response(self, message, response): + """Set a response for a specific message.""" + self.responses[message] = response + + +@pytest_asyncio.fixture +async def mock_udp_server(): + """Create a mock UDP server for testing.""" + server = MockUDPServer() + await server.start() + yield server + await server.stop() + + +@pytest.fixture +def logger(): + """Create a mock logger.""" + return MagicMock() + + +@pytest.fixture +def udp_transport(logger): + """Create a UDP transport instance.""" + return UDPTransport(logger=logger) + + +@pytest.fixture +def udp_provider(mock_udp_server): + """Create a basic UDP provider for testing.""" + return UDPProvider( + name="test_udp_provider", + host=mock_udp_server.host, + port=mock_udp_server.port, + number_of_response_datagrams=1, + request_data_format="json", + response_byte_format="utf-8", + timeout=5000 + ) + + +@pytest.fixture +def text_template_provider(mock_udp_server): + """Create a UDP provider with text template format.""" + return UDPProvider( + name="test_text_template_provider", + host=mock_udp_server.host, + port=mock_udp_server.port, + number_of_response_datagrams=1, + request_data_format="text", + request_data_template="COMMAND UTCP_ARG_action_UTCP_ARG UTCP_ARG_value_UTCP_ARG", + response_byte_format="utf-8", + timeout=5000 + ) + + +@pytest.fixture +def raw_bytes_provider(mock_udp_server): + """Create a UDP provider that returns raw bytes.""" + return UDPProvider( + name="test_raw_bytes_provider", + host=mock_udp_server.host, + port=mock_udp_server.port, + number_of_response_datagrams=1, + request_data_format="json", + response_byte_format=None, # Return raw bytes + timeout=5000 + ) + + +@pytest.fixture +def multi_datagram_provider(mock_udp_server): + """Create a UDP provider that expects multiple response datagrams.""" + return UDPProvider( + name="test_multi_datagram_provider", + host=mock_udp_server.host, + port=mock_udp_server.port, + number_of_response_datagrams=3, + request_data_format="json", + response_byte_format="utf-8", + timeout=5000 + ) + + +# Test register_tool_provider +@pytest.mark.asyncio +async def test_register_tool_provider(udp_transport, udp_provider, mock_udp_server, logger): + """Test registering a tool provider.""" + # Set up discovery response + discovery_response = { + "tools": [ + { + "name": "test_tool", + "description": "Test tool", + "inputs": { + "type": "object", + "properties": { + "param1": {"type": "string"} + } + }, + "outputs": { + "type": "object", + "properties": { + "result": {"type": "string"} + } + }, + "tags": [], + "tool_provider": { + "provider_type": "udp", + "name": "test_udp_provider", + "host": "localhost", + "port": udp_provider.port + } + } + ] + } + + mock_udp_server.set_response('{"type": "utcp"}', discovery_response) + print(f"Mock UDP server port: {mock_udp_server.port}") + print(f"UDP provider port: {udp_provider.port}") + + # Register the provider + tools = await udp_transport.register_tool_provider(udp_provider) + + # Verify tools were returned + assert len(tools) == 1 + assert tools[0].name == "test_tool" + assert tools[0].description == "Test tool" + + # Verify logger was called + logger.assert_called() + + +@pytest.mark.asyncio +async def test_register_tool_provider_empty_response(udp_transport, udp_provider, mock_udp_server): + """Test registering a tool provider with empty response.""" + # Set up empty discovery response + mock_udp_server.set_response('{"type": "utcp"}', {"tools": []}) + + # Register the provider + tools = await udp_transport.register_tool_provider(udp_provider) + + # Verify no tools were returned + assert len(tools) == 0 + + +@pytest.mark.asyncio +async def test_register_tool_provider_invalid_json(udp_transport, udp_provider, mock_udp_server): + """Test registering a tool provider with invalid JSON response.""" + # Set up invalid JSON response + mock_udp_server.set_response('{"type": "utcp"}', "invalid json") + + # Register the provider + tools = await udp_transport.register_tool_provider(udp_provider) + + # Verify no tools were returned due to JSON error + assert len(tools) == 0 + + +@pytest.mark.asyncio +async def test_register_tool_provider_invalid_provider_type(udp_transport): + """Test registering a non-UDP provider raises ValueError.""" + from utcp.shared.provider import HttpProvider + + http_provider = HttpProvider( + name="test_http_provider", + url="http://example.com" + ) + + with pytest.raises(ValueError, match="UDPTransport can only be used with UDPProvider"): + await udp_transport.register_tool_provider(http_provider) + + +# Test deregister_tool_provider +@pytest.mark.asyncio +async def test_deregister_tool_provider(udp_transport, udp_provider): + """Test deregistering a tool provider (should be a no-op).""" + # This should not raise any exceptions + await udp_transport.deregister_tool_provider(udp_provider) + + +@pytest.mark.asyncio +async def test_deregister_tool_provider_invalid_type(udp_transport): + """Test deregistering a non-UDP provider raises ValueError.""" + from utcp.shared.provider import HttpProvider + + http_provider = HttpProvider( + name="test_http_provider", + url="http://example.com" + ) + + with pytest.raises(ValueError, match="UDPTransport can only be used with UDPProvider"): + await udp_transport.deregister_tool_provider(http_provider) + + +# Test call_tool with JSON format +@pytest.mark.asyncio +async def test_call_tool_json_format(udp_transport, udp_provider, mock_udp_server): + """Test calling a tool with JSON format.""" + # Set up tool call response + arguments = {"param1": "value1", "param2": 42} + expected_message = json.dumps(arguments) + response = {"result": "success", "data": "processed"} + + mock_udp_server.set_response(expected_message, response) + + # Call the tool + result = await udp_transport.call_tool("test_tool", arguments, udp_provider) + + # Verify response + assert result == json.dumps(response) + assert mock_udp_server.call_count >= 1 + + +@pytest.mark.asyncio +async def test_call_tool_text_template_format(udp_transport, text_template_provider, mock_udp_server): + """Test calling a tool with text template format.""" + # Set up tool call response + arguments = {"action": "get", "value": "data123"} + expected_message = "COMMAND get data123" # Template substitution + response = "SUCCESS: data123 retrieved" + + mock_udp_server.set_response(expected_message, response) + + # Call the tool + result = await udp_transport.call_tool("test_tool", arguments, text_template_provider) + + # Verify response + assert result == response + assert mock_udp_server.call_count >= 1 + + +@pytest.mark.asyncio +async def test_call_tool_text_format_no_template(udp_transport, mock_udp_server): + """Test calling a tool with text format but no template.""" + provider = UDPProvider( + name="test_provider", + host=mock_udp_server.host, + port=mock_udp_server.port, + request_data_format="text", + request_data_template=None, # No template + response_byte_format="utf-8", + number_of_response_datagrams=1 # Expect 1 response + ) + + # Set up tool call response + arguments = {"param1": "value1", "param2": "value2"} + expected_message = "value1 value2" # Fallback format + response = "OK" + + mock_udp_server.set_response(expected_message, response) + + # Call the tool + result = await udp_transport.call_tool("test_tool", arguments, provider) + + # Verify response + assert result == response + + +@pytest.mark.asyncio +async def test_call_tool_raw_bytes_response(udp_transport, raw_bytes_provider, mock_udp_server): + """Test calling a tool that returns raw bytes.""" + # Set up tool call response with raw bytes + arguments = {"param1": "value1"} + expected_message = json.dumps(arguments) + raw_response = b"\x01\x02\x03\x04binary_data" + + mock_udp_server.set_response(expected_message, raw_response) + + # Call the tool + result = await udp_transport.call_tool("test_tool", arguments, raw_bytes_provider) + + # Verify response is raw bytes + assert isinstance(result, bytes) + assert result == raw_response + + +@pytest.mark.asyncio +async def test_call_tool_invalid_provider_type(udp_transport): + """Test calling a tool with non-UDP provider raises ValueError.""" + from utcp.shared.provider import HttpProvider + + http_provider = HttpProvider( + name="test_http_provider", + url="http://example.com" + ) + + with pytest.raises(ValueError, match="UDPTransport can only be used with UDPProvider"): + await udp_transport.call_tool("test_tool", {"param": "value"}, http_provider) + + +# Test multi-datagram support +@pytest.mark.asyncio +async def test_call_tool_multiple_datagrams(udp_transport, multi_datagram_provider, mock_udp_server): + """Test calling a tool that expects multiple response datagrams.""" + # This test is complex because we need to simulate multiple UDP responses + # For now, let's test that the transport handles the configuration correctly + + # Mock the _send_udp_message method to simulate multiple datagram responses + with patch.object(udp_transport, '_send_udp_message') as mock_send: + mock_send.return_value = "part1part2part3" # Concatenated response + + arguments = {"param1": "value1"} + result = await udp_transport.call_tool("test_tool", arguments, multi_datagram_provider) + + # Verify the method was called with correct parameters + mock_send.assert_called_once_with( + multi_datagram_provider.host, + multi_datagram_provider.port, + json.dumps(arguments), + multi_datagram_provider.timeout / 1000.0, + 3, # number_of_response_datagrams + "utf-8" # response_byte_format + ) + + assert result == "part1part2part3" + + +# Test _send_udp_message method directly +@pytest.mark.asyncio +async def test_send_udp_message_single_datagram(udp_transport, mock_udp_server): + """Test sending a UDP message and receiving a single response.""" + # Set up response + message = "test message" + response = "test response" + mock_udp_server.set_response(message, response) + + # Send message + result = await udp_transport._send_udp_message( + mock_udp_server.host, + mock_udp_server.port, + message, + timeout=5.0, + num_response_datagrams=1, + response_encoding="utf-8" + ) + + # Verify response + assert result == response + + +@pytest.mark.asyncio +async def test_send_udp_message_raw_bytes(udp_transport, mock_udp_server): + """Test sending a UDP message and receiving raw bytes.""" + # Set up binary response + message = "test message" + response = b"\x01\x02\x03binary" + mock_udp_server.set_response(message, response) + + # Send message with no encoding (raw bytes) + result = await udp_transport._send_udp_message( + mock_udp_server.host, + mock_udp_server.port, + message, + timeout=5.0, + num_response_datagrams=1, + response_encoding=None + ) + + # Verify response is bytes + assert isinstance(result, bytes) + assert result == response + + +@pytest.mark.asyncio +async def test_send_udp_message_timeout(): + """Test UDP message timeout handling.""" + udp_transport = UDPTransport() + + # Try to send to a non-existent server (should timeout) + with pytest.raises(Exception): # Should raise socket timeout or connection error + await udp_transport._send_udp_message( + "127.0.0.1", + 99999, # Non-existent port + "test message", + timeout=0.1, # Very short timeout + num_response_datagrams=1, + response_encoding="utf-8" + ) + + +# Test _format_tool_call_message method +def test_format_tool_call_message_json(udp_transport): + """Test formatting tool call message with JSON format.""" + provider = UDPProvider( + name="test", + host="localhost", + port=1234, + request_data_format="json" + ) + + arguments = {"param1": "value1", "param2": 42} + result = udp_transport._format_tool_call_message(arguments, provider) + + # Should return JSON string + assert result == json.dumps(arguments) + + # Verify it's valid JSON + parsed = json.loads(result) + assert parsed == arguments + + +def test_format_tool_call_message_text_with_template(udp_transport): + """Test formatting tool call message with text template.""" + provider = UDPProvider( + name="test", + host="localhost", + port=1234, + request_data_format="text", + request_data_template="ACTION UTCP_ARG_cmd_UTCP_ARG PARAM UTCP_ARG_value_UTCP_ARG" + ) + + arguments = {"cmd": "get", "value": "data123"} + result = udp_transport._format_tool_call_message(arguments, provider) + + # Should substitute placeholders + assert result == "ACTION get PARAM data123" + + +def test_format_tool_call_message_text_with_complex_values(udp_transport): + """Test formatting tool call message with complex values in template.""" + provider = UDPProvider( + name="test", + host="localhost", + port=1234, + request_data_format="text", + request_data_template="DATA UTCP_ARG_obj_UTCP_ARG" + ) + + arguments = {"obj": {"nested": "value", "number": 123}} + result = udp_transport._format_tool_call_message(arguments, provider) + + # Should JSON-serialize complex values + assert result == 'DATA {"nested": "value", "number": 123}' + + +def test_format_tool_call_message_text_no_template(udp_transport): + """Test formatting tool call message with text format but no template.""" + provider = UDPProvider( + name="test", + host="localhost", + port=1234, + request_data_format="text", + request_data_template=None + ) + + arguments = {"param1": "value1", "param2": "value2"} + result = udp_transport._format_tool_call_message(arguments, provider) + + # Should use fallback format (space-separated values) + assert result == "value1 value2" + + +def test_format_tool_call_message_default_to_json(udp_transport): + """Test formatting tool call message defaults to JSON for unknown format.""" + # Create a provider with valid format first + provider = UDPProvider( + name="test", + host="localhost", + port=1234, + request_data_format="json" + ) + + # Manually set an invalid format to test the fallback behavior + provider.request_data_format = "unknown" # Invalid format + + arguments = {"param1": "value1"} + result = udp_transport._format_tool_call_message(arguments, provider) + + # Should default to JSON + assert result == json.dumps(arguments) + + +# Test error handling and edge cases +@pytest.mark.asyncio +async def test_call_tool_server_error(udp_transport, udp_provider, mock_udp_server): + """Test handling server errors during tool calls.""" + # Don't set any response, so the server will return an error + arguments = {"param1": "value1"} + + # Call the tool - should get the default error response + result = await udp_transport.call_tool("test_tool", arguments, udp_provider) + + # Should receive the default error message + assert '{"error": "unknown_message"}' in result + + +@pytest.mark.asyncio +async def test_register_tool_provider_malformed_tool(udp_transport, udp_provider, mock_udp_server): + """Test registering provider with malformed tool definition.""" + # Set up discovery response with invalid tool + discovery_response = { + "tools": [ + { + "name": "test_tool", + # Missing required fields like inputs, outputs, tool_provider + } + ] + } + + mock_udp_server.set_response('{"type": "utcp"}', discovery_response) + + # Register the provider - should handle invalid tool gracefully + tools = await udp_transport.register_tool_provider(udp_provider) + + # Should return empty list due to invalid tool definition + assert len(tools) == 0 + + +# Test logging functionality +@pytest.mark.asyncio +async def test_logging_calls(udp_transport, udp_provider, mock_udp_server, logger): + """Test that logging functions are called appropriately.""" + # Set up discovery response + discovery_response = {"tools": []} + mock_udp_server.set_response('{"type": "utcp"}', discovery_response) + + # Register provider + await udp_transport.register_tool_provider(udp_provider) + + # Verify logger was called + logger.assert_called() + + # Call tool + mock_udp_server.set_response('{}', {"result": "test"}) + await udp_transport.call_tool("test_tool", {}, udp_provider) + + # Logger should have been called multiple times + assert logger.call_count > 1