diff --git a/src/mcp/server/auth/handlers/authorize.py b/src/mcp/server/auth/handlers/authorize.py index 3570d28c2..5b7c3e8b9 100644 --- a/src/mcp/server/auth/handlers/authorize.py +++ b/src/mcp/server/auth/handlers/authorize.py @@ -2,6 +2,7 @@ from dataclasses import dataclass from typing import Any, Literal +import httpx from pydantic import AnyUrl, BaseModel, Field, RootModel, ValidationError from starlette.datastructures import FormData, QueryParams from starlette.requests import Request @@ -16,7 +17,11 @@ OAuthAuthorizationServerProvider, construct_redirect_uri, ) -from mcp.shared.auth import InvalidRedirectUriError, InvalidScopeError +from mcp.shared.auth import ( + InvalidRedirectUriError, + InvalidScopeError, + OAuthClientInformationFull, +) logger = logging.getLogger(__name__) @@ -166,6 +171,29 @@ async def error_response( client = await self.provider.get_client( auth_request.client_id, ) + if not client: + # Check if `client_id` is a valid URL for Metadata Document + if auth_request.client_id.startswith("https://"): + try: + async with httpx.AsyncClient() as http_client: + response = await http_client.get(auth_request.client_id) + response.raise_for_status() + metadata = response.json() + + if metadata.get("client_id") != auth_request.client_id: + return await error_response( + error="invalid_request", + error_description=f"Client ID '{auth_request.client_id}' \ + doesn't match with metadata document", + ) + + client = OAuthClientInformationFull(**metadata) + + except Exception as e: + return await error_response( + error="invalid_request", + error_description=f"Failed to fetch client metadata from {auth_request.client_id}: {e}", + ) if not client: # For client_id validation errors, return direct error (no redirect) return await error_response( diff --git a/src/mcp/server/auth/routes.py b/src/mcp/server/auth/routes.py index 71a9c8b16..0a2bbc313 100644 --- a/src/mcp/server/auth/routes.py +++ b/src/mcp/server/auth/routes.py @@ -173,6 +173,7 @@ def build_metadata( op_tos_uri=None, introspection_endpoint=None, code_challenge_methods_supported=["S256"], + client_id_metadata_document_supported=True, ) # Add registration endpoint if supported diff --git a/tests/client/test_auth.py b/tests/client/test_auth.py index 6025ff811..4cd7821d6 100644 --- a/tests/client/test_auth.py +++ b/tests/client/test_auth.py @@ -1359,6 +1359,7 @@ def test_build_metadata( "revocation_endpoint": Is(revocation_endpoint), "revocation_endpoint_auth_methods_supported": ["client_secret_post", "client_secret_basic"], "code_challenge_methods_supported": ["S256"], + "client_id_metadata_document_supported": True, } ) diff --git a/tests/server/fastmcp/auth/test_auth_integration.py b/tests/server/fastmcp/auth/test_auth_integration.py index 7342013a8..e92f767d6 100644 --- a/tests/server/fastmcp/auth/test_auth_integration.py +++ b/tests/server/fastmcp/auth/test_auth_integration.py @@ -325,6 +325,7 @@ async def test_metadata_endpoint(self, test_client: httpx.AsyncClient): "authorization_code", "refresh_token", ] + assert metadata["client_id_metadata_document_supported"] assert metadata["service_documentation"] == "https://docs.example.com/" @pytest.mark.anyio @@ -1355,6 +1356,163 @@ async def test_none_auth_method_public_client( token_response = response.json() assert "access_token" in token_response + @pytest.mark.anyio + async def test_cimd_authorization_flow( + self, + test_client: httpx.AsyncClient, + mock_oauth_provider: MockOAuthProvider, + pkce_challenge: dict[str, str], + ): + """Test Authorization using Client Identity Metadata (CIMD) flow.""" + client_id_url = "https://example.com/client-metadata" + + client_metadata = { + "client_id": client_id_url, + "client_name": "CIMD Test Client", + "redirect_uris": ["https://client.example.com/callback"], + "grant_types": ["authorization_code", "refresh_token"], + "response_types": ["code"], + "token_endpoint_auth_method": "none", + } + + # Mocking httpx.AsyncClient to intercept the metadata fetch + mock_response = unittest.mock.Mock(spec=httpx.Response) + mock_response.status_code = 200 + mock_response.json.return_value = client_metadata + mock_response.raise_for_status = unittest.mock.Mock() + + mock_client_instance = unittest.mock.AsyncMock(spec=httpx.AsyncClient) + mock_client_instance.get.return_value = mock_response + + # Setup context manager return + mock_client_instance.__aenter__.return_value = mock_client_instance + mock_client_instance.__aexit__.return_value = None + + with unittest.mock.patch("httpx.AsyncClient", return_value=mock_client_instance): + # 3. Request authorization using the CIMD client_id + response = await test_client.get( + "/authorize", + params={ + "response_type": "code", + "client_id": client_id_url, # Using URL as client_id + "redirect_uri": "https://client.example.com/callback", + "code_challenge": pkce_challenge["code_challenge"], + "code_challenge_method": "S256", + "state": "cimd_test_state", + }, + ) + + assert response.status_code == 302 + + redirect_url = response.headers["location"] + parsed_url = urlparse(redirect_url) + query_params = parse_qs(parsed_url.query) + + assert "code" in query_params + assert query_params["state"][0] == "cimd_test_state" + + mock_client_instance.get.assert_called_with(client_id_url) + + @pytest.mark.anyio + async def test_cimd_authorization_invalid_cimd_url( + self, test_client: httpx.AsyncClient, pkce_challenge: dict[str, str] + ): + """Test authorization endpoint with invalid CIMD url.""" + response = await test_client.get( + "/authorize", + params={ + "response_type": "code", + "client_id": "http://example.com/client-metadata", # Invalid CIMD url + "redirect_uri": "https://client.example.com/callback", + "code_challenge": pkce_challenge["code_challenge"], + "code_challenge_method": "S256", + "state": "cimd_test_state", + }, + ) + + assert response.status_code == 400 + assert "client id" in response.text.lower() + + @pytest.mark.anyio + async def test_cimd_authorization_invalid_client_id( + self, test_client: httpx.AsyncClient, pkce_challenge: dict[str, str] + ): + """Test authorization endpoint with invalid client_id.""" + client_id_url = "https://example.com/client-metadata" + + client_metadata = { + "client_id": "https://invalid.com/client-metadata", # Invalid client id, + "client_name": "CIMD Test Client", + "redirect_uris": ["https://client.example.com/callback"], + "grant_types": ["authorization_code", "refresh_token"], + "response_types": ["code"], + "token_endpoint_auth_method": "none", + } + + # Mocking httpx.AsyncClient to intercept the metadata fetch + mock_response = unittest.mock.Mock(spec=httpx.Response) + mock_response.status_code = 200 + mock_response.json.return_value = client_metadata + mock_response.raise_for_status = unittest.mock.Mock() + + mock_client_instance = unittest.mock.AsyncMock(spec=httpx.AsyncClient) + mock_client_instance.get.return_value = mock_response + + # Setup context manager return + mock_client_instance.__aenter__.return_value = mock_client_instance + mock_client_instance.__aexit__.return_value = None + + with unittest.mock.patch("httpx.AsyncClient", return_value=mock_client_instance): + # 3. Request authorization using the CIMD client_id + response = await test_client.get( + "/authorize", + params={ + "response_type": "code", + "client_id": client_id_url, # Using URL as client_id + "redirect_uri": "https://client.example.com/callback", + "code_challenge": pkce_challenge["code_challenge"], + "code_challenge_method": "S256", + "state": "cimd_test_state", + }, + ) + + assert response.status_code == 400 + assert "client id" in response.text.lower() + + @pytest.mark.anyio + async def test_cimd_authorization_metadata_fetch_error( + self, test_client: httpx.AsyncClient, pkce_challenge: dict[str, str] + ): + """Test authorization endpoint when fetching client metadata fails.""" + client_id_url = "https://example.com/client-metadata" + + # Mocking httpx.AsyncClient to raise an exception + mock_client_instance = unittest.mock.AsyncMock(spec=httpx.AsyncClient) + mock_client_instance.get.side_effect = httpx.RequestError("Network error") + + # Setup context manager return + mock_client_instance.__aenter__.return_value = mock_client_instance + mock_client_instance.__aexit__.return_value = None + + with unittest.mock.patch("httpx.AsyncClient", return_value=mock_client_instance): + response = await test_client.get( + "/authorize", + params={ + "response_type": "code", + "client_id": client_id_url, + "redirect_uri": "https://client.example.com/callback", + "code_challenge": pkce_challenge["code_challenge"], + "code_challenge_method": "S256", + "state": "cimd_test_state", + }, + ) + + # verify that we get a 400 error (because we can't fetch metadata to verify, + # and we can't redirect because we don't trust the client yet or don't know its redirect URIs) + assert response.status_code == 400 + assert "invalid_request" in response.text + assert "Failed to fetch client metadata" in response.text + class TestAuthorizeEndpointErrors: """Test error handling in the OAuth authorization endpoint."""