Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 29 additions & 1 deletion src/mcp/server/auth/handlers/authorize.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from dataclasses import dataclass
from typing import Any, Literal

import httpx
from pydantic import AnyUrl, BaseModel, Field, RootModel, ValidationError
from starlette.datastructures import FormData, QueryParams
from starlette.requests import Request
Expand All @@ -16,7 +17,11 @@
OAuthAuthorizationServerProvider,
construct_redirect_uri,
)
from mcp.shared.auth import InvalidRedirectUriError, InvalidScopeError
from mcp.shared.auth import (
InvalidRedirectUriError,
InvalidScopeError,
OAuthClientInformationFull,
)

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -166,6 +171,29 @@ async def error_response(
client = await self.provider.get_client(
auth_request.client_id,
)
if not client:
# Check if `client_id` is a valid URL for Metadata Document
if auth_request.client_id.startswith("https://"):
try:
async with httpx.AsyncClient() as http_client:
response = await http_client.get(auth_request.client_id)
response.raise_for_status()
metadata = response.json()

if metadata.get("client_id") != auth_request.client_id:
return await error_response(
error="invalid_request",
error_description=f"Client ID '{auth_request.client_id}' \
doesn't match with metadata document",
)

client = OAuthClientInformationFull(**metadata)

except Exception as e:
return await error_response(
error="invalid_request",
error_description=f"Failed to fetch client metadata from {auth_request.client_id}: {e}",
)
if not client:
# For client_id validation errors, return direct error (no redirect)
return await error_response(
Expand Down
1 change: 1 addition & 0 deletions src/mcp/server/auth/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ def build_metadata(
op_tos_uri=None,
introspection_endpoint=None,
code_challenge_methods_supported=["S256"],
client_id_metadata_document_supported=True,
)

# Add registration endpoint if supported
Expand Down
1 change: 1 addition & 0 deletions tests/client/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -1359,6 +1359,7 @@ def test_build_metadata(
"revocation_endpoint": Is(revocation_endpoint),
"revocation_endpoint_auth_methods_supported": ["client_secret_post", "client_secret_basic"],
"code_challenge_methods_supported": ["S256"],
"client_id_metadata_document_supported": True,
}
)

Expand Down
158 changes: 158 additions & 0 deletions tests/server/fastmcp/auth/test_auth_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,7 @@ async def test_metadata_endpoint(self, test_client: httpx.AsyncClient):
"authorization_code",
"refresh_token",
]
assert metadata["client_id_metadata_document_supported"]
assert metadata["service_documentation"] == "https://docs.example.com/"

@pytest.mark.anyio
Expand Down Expand Up @@ -1355,6 +1356,163 @@ async def test_none_auth_method_public_client(
token_response = response.json()
assert "access_token" in token_response

@pytest.mark.anyio
async def test_cimd_authorization_flow(
self,
test_client: httpx.AsyncClient,
mock_oauth_provider: MockOAuthProvider,
pkce_challenge: dict[str, str],
):
"""Test Authorization using Client Identity Metadata (CIMD) flow."""
client_id_url = "https://example.com/client-metadata"

client_metadata = {
"client_id": client_id_url,
"client_name": "CIMD Test Client",
"redirect_uris": ["https://client.example.com/callback"],
"grant_types": ["authorization_code", "refresh_token"],
"response_types": ["code"],
"token_endpoint_auth_method": "none",
}

# Mocking httpx.AsyncClient to intercept the metadata fetch
mock_response = unittest.mock.Mock(spec=httpx.Response)
mock_response.status_code = 200
mock_response.json.return_value = client_metadata
mock_response.raise_for_status = unittest.mock.Mock()

mock_client_instance = unittest.mock.AsyncMock(spec=httpx.AsyncClient)
mock_client_instance.get.return_value = mock_response

# Setup context manager return
mock_client_instance.__aenter__.return_value = mock_client_instance
mock_client_instance.__aexit__.return_value = None

with unittest.mock.patch("httpx.AsyncClient", return_value=mock_client_instance):
# 3. Request authorization using the CIMD client_id
response = await test_client.get(
"/authorize",
params={
"response_type": "code",
"client_id": client_id_url, # Using URL as client_id
"redirect_uri": "https://client.example.com/callback",
"code_challenge": pkce_challenge["code_challenge"],
"code_challenge_method": "S256",
"state": "cimd_test_state",
},
)

assert response.status_code == 302

redirect_url = response.headers["location"]
parsed_url = urlparse(redirect_url)
query_params = parse_qs(parsed_url.query)

assert "code" in query_params
assert query_params["state"][0] == "cimd_test_state"

mock_client_instance.get.assert_called_with(client_id_url)

@pytest.mark.anyio
async def test_cimd_authorization_invalid_cimd_url(
self, test_client: httpx.AsyncClient, pkce_challenge: dict[str, str]
):
"""Test authorization endpoint with invalid CIMD url."""
response = await test_client.get(
"/authorize",
params={
"response_type": "code",
"client_id": "http://example.com/client-metadata", # Invalid CIMD url
"redirect_uri": "https://client.example.com/callback",
"code_challenge": pkce_challenge["code_challenge"],
"code_challenge_method": "S256",
"state": "cimd_test_state",
},
)

assert response.status_code == 400
assert "client id" in response.text.lower()

@pytest.mark.anyio
async def test_cimd_authorization_invalid_client_id(
self, test_client: httpx.AsyncClient, pkce_challenge: dict[str, str]
):
"""Test authorization endpoint with invalid client_id."""
client_id_url = "https://example.com/client-metadata"

client_metadata = {
"client_id": "https://invalid.com/client-metadata", # Invalid client id,
"client_name": "CIMD Test Client",
"redirect_uris": ["https://client.example.com/callback"],
"grant_types": ["authorization_code", "refresh_token"],
"response_types": ["code"],
"token_endpoint_auth_method": "none",
}

# Mocking httpx.AsyncClient to intercept the metadata fetch
mock_response = unittest.mock.Mock(spec=httpx.Response)
mock_response.status_code = 200
mock_response.json.return_value = client_metadata
mock_response.raise_for_status = unittest.mock.Mock()

mock_client_instance = unittest.mock.AsyncMock(spec=httpx.AsyncClient)
mock_client_instance.get.return_value = mock_response

# Setup context manager return
mock_client_instance.__aenter__.return_value = mock_client_instance
mock_client_instance.__aexit__.return_value = None

with unittest.mock.patch("httpx.AsyncClient", return_value=mock_client_instance):
# 3. Request authorization using the CIMD client_id
response = await test_client.get(
"/authorize",
params={
"response_type": "code",
"client_id": client_id_url, # Using URL as client_id
"redirect_uri": "https://client.example.com/callback",
"code_challenge": pkce_challenge["code_challenge"],
"code_challenge_method": "S256",
"state": "cimd_test_state",
},
)

assert response.status_code == 400
assert "client id" in response.text.lower()

@pytest.mark.anyio
async def test_cimd_authorization_metadata_fetch_error(
self, test_client: httpx.AsyncClient, pkce_challenge: dict[str, str]
):
"""Test authorization endpoint when fetching client metadata fails."""
client_id_url = "https://example.com/client-metadata"

# Mocking httpx.AsyncClient to raise an exception
mock_client_instance = unittest.mock.AsyncMock(spec=httpx.AsyncClient)
mock_client_instance.get.side_effect = httpx.RequestError("Network error")

# Setup context manager return
mock_client_instance.__aenter__.return_value = mock_client_instance
mock_client_instance.__aexit__.return_value = None

with unittest.mock.patch("httpx.AsyncClient", return_value=mock_client_instance):
response = await test_client.get(
"/authorize",
params={
"response_type": "code",
"client_id": client_id_url,
"redirect_uri": "https://client.example.com/callback",
"code_challenge": pkce_challenge["code_challenge"],
"code_challenge_method": "S256",
"state": "cimd_test_state",
},
)

# verify that we get a 400 error (because we can't fetch metadata to verify,
# and we can't redirect because we don't trust the client yet or don't know its redirect URIs)
assert response.status_code == 400
assert "invalid_request" in response.text
assert "Failed to fetch client metadata" in response.text


class TestAuthorizeEndpointErrors:
"""Test error handling in the OAuth authorization endpoint."""
Expand Down