Skip to content

Commit fa13db7

Browse files
committed
feat(cimd): Add support for CIMD in server
1 parent a9cc822 commit fa13db7

File tree

4 files changed

+189
-1
lines changed

4 files changed

+189
-1
lines changed

src/mcp/server/auth/handlers/authorize.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from dataclasses import dataclass
33
from typing import Any, Literal
44

5+
import httpx
56
from pydantic import AnyUrl, BaseModel, Field, RootModel, ValidationError
67
from starlette.datastructures import FormData, QueryParams
78
from starlette.requests import Request
@@ -16,7 +17,11 @@
1617
OAuthAuthorizationServerProvider,
1718
construct_redirect_uri,
1819
)
19-
from mcp.shared.auth import InvalidRedirectUriError, InvalidScopeError
20+
from mcp.shared.auth import (
21+
InvalidRedirectUriError,
22+
InvalidScopeError,
23+
OAuthClientInformationFull,
24+
)
2025

2126
logger = logging.getLogger(__name__)
2227

@@ -166,6 +171,29 @@ async def error_response(
166171
client = await self.provider.get_client(
167172
auth_request.client_id,
168173
)
174+
if not client:
175+
# Check if `client_id` is a valid URL for Metadata Document
176+
if auth_request.client_id.startswith("https://"):
177+
try:
178+
async with httpx.AsyncClient() as http_client:
179+
response = await http_client.get(auth_request.client_id)
180+
response.raise_for_status()
181+
metadata = response.json()
182+
183+
if metadata.get("client_id") != auth_request.client_id:
184+
return await error_response(
185+
error="invalid_request",
186+
error_description=f"Client ID '{auth_request.client_id}' \
187+
doesn't match with metadata document",
188+
)
189+
190+
client = OAuthClientInformationFull(**metadata)
191+
192+
except Exception as e:
193+
return await error_response(
194+
error="invalid_request",
195+
error_description=f"Failed to fetch client metadata from {auth_request.client_id}: {e}",
196+
)
169197
if not client:
170198
# For client_id validation errors, return direct error (no redirect)
171199
return await error_response(

src/mcp/server/auth/routes.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,7 @@ def build_metadata(
173173
op_tos_uri=None,
174174
introspection_endpoint=None,
175175
code_challenge_methods_supported=["S256"],
176+
client_id_metadata_document_supported=True,
176177
)
177178

178179
# Add registration endpoint if supported

tests/client/test_auth.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1359,6 +1359,7 @@ def test_build_metadata(
13591359
"revocation_endpoint": Is(revocation_endpoint),
13601360
"revocation_endpoint_auth_methods_supported": ["client_secret_post", "client_secret_basic"],
13611361
"code_challenge_methods_supported": ["S256"],
1362+
"client_id_metadata_document_supported": True,
13621363
}
13631364
)
13641365

tests/server/fastmcp/auth/test_auth_integration.py

Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -325,6 +325,7 @@ async def test_metadata_endpoint(self, test_client: httpx.AsyncClient):
325325
"authorization_code",
326326
"refresh_token",
327327
]
328+
assert metadata["client_id_metadata_document_supported"]
328329
assert metadata["service_documentation"] == "https://docs.example.com/"
329330

330331
@pytest.mark.anyio
@@ -1355,6 +1356,163 @@ async def test_none_auth_method_public_client(
13551356
token_response = response.json()
13561357
assert "access_token" in token_response
13571358

1359+
@pytest.mark.anyio
1360+
async def test_cimd_authorization_flow(
1361+
self,
1362+
test_client: httpx.AsyncClient,
1363+
mock_oauth_provider: MockOAuthProvider,
1364+
pkce_challenge: dict[str, str],
1365+
):
1366+
"""Test Authorization using Client Identity Metadata (CIMD) flow."""
1367+
client_id_url = "https://example.com/client-metadata"
1368+
1369+
client_metadata = {
1370+
"client_id": client_id_url,
1371+
"client_name": "CIMD Test Client",
1372+
"redirect_uris": ["https://client.example.com/callback"],
1373+
"grant_types": ["authorization_code", "refresh_token"],
1374+
"response_types": ["code"],
1375+
"token_endpoint_auth_method": "none",
1376+
}
1377+
1378+
# Mocking httpx.AsyncClient to intercept the metadata fetch
1379+
mock_response = unittest.mock.Mock(spec=httpx.Response)
1380+
mock_response.status_code = 200
1381+
mock_response.json.return_value = client_metadata
1382+
mock_response.raise_for_status = unittest.mock.Mock()
1383+
1384+
mock_client_instance = unittest.mock.AsyncMock(spec=httpx.AsyncClient)
1385+
mock_client_instance.get.return_value = mock_response
1386+
1387+
# Setup context manager return
1388+
mock_client_instance.__aenter__.return_value = mock_client_instance
1389+
mock_client_instance.__aexit__.return_value = None
1390+
1391+
with unittest.mock.patch("httpx.AsyncClient", return_value=mock_client_instance):
1392+
# 3. Request authorization using the CIMD client_id
1393+
response = await test_client.get(
1394+
"/authorize",
1395+
params={
1396+
"response_type": "code",
1397+
"client_id": client_id_url, # Using URL as client_id
1398+
"redirect_uri": "https://client.example.com/callback",
1399+
"code_challenge": pkce_challenge["code_challenge"],
1400+
"code_challenge_method": "S256",
1401+
"state": "cimd_test_state",
1402+
},
1403+
)
1404+
1405+
assert response.status_code == 302
1406+
1407+
redirect_url = response.headers["location"]
1408+
parsed_url = urlparse(redirect_url)
1409+
query_params = parse_qs(parsed_url.query)
1410+
1411+
assert "code" in query_params
1412+
assert query_params["state"][0] == "cimd_test_state"
1413+
1414+
mock_client_instance.get.assert_called_with(client_id_url)
1415+
1416+
@pytest.mark.anyio
1417+
async def test_cimd_authorization_invalid_cimd_url(
1418+
self, test_client: httpx.AsyncClient, pkce_challenge: dict[str, str]
1419+
):
1420+
"""Test authorization endpoint with invalid CIMD url."""
1421+
response = await test_client.get(
1422+
"/authorize",
1423+
params={
1424+
"response_type": "code",
1425+
"client_id": "http://example.com/client-metadata", # Invalid CIMD url
1426+
"redirect_uri": "https://client.example.com/callback",
1427+
"code_challenge": pkce_challenge["code_challenge"],
1428+
"code_challenge_method": "S256",
1429+
"state": "cimd_test_state",
1430+
},
1431+
)
1432+
1433+
assert response.status_code == 400
1434+
assert "client id" in response.text.lower()
1435+
1436+
@pytest.mark.anyio
1437+
async def test_cimd_authorization_invalid_client_id(
1438+
self, test_client: httpx.AsyncClient, pkce_challenge: dict[str, str]
1439+
):
1440+
"""Test authorization endpoint with invalid client_id."""
1441+
client_id_url = "https://example.com/client-metadata"
1442+
1443+
client_metadata = {
1444+
"client_id": "https://invalid.com/client-metadata", # Invalid client id,
1445+
"client_name": "CIMD Test Client",
1446+
"redirect_uris": ["https://client.example.com/callback"],
1447+
"grant_types": ["authorization_code", "refresh_token"],
1448+
"response_types": ["code"],
1449+
"token_endpoint_auth_method": "none",
1450+
}
1451+
1452+
# Mocking httpx.AsyncClient to intercept the metadata fetch
1453+
mock_response = unittest.mock.Mock(spec=httpx.Response)
1454+
mock_response.status_code = 200
1455+
mock_response.json.return_value = client_metadata
1456+
mock_response.raise_for_status = unittest.mock.Mock()
1457+
1458+
mock_client_instance = unittest.mock.AsyncMock(spec=httpx.AsyncClient)
1459+
mock_client_instance.get.return_value = mock_response
1460+
1461+
# Setup context manager return
1462+
mock_client_instance.__aenter__.return_value = mock_client_instance
1463+
mock_client_instance.__aexit__.return_value = None
1464+
1465+
with unittest.mock.patch("httpx.AsyncClient", return_value=mock_client_instance):
1466+
# 3. Request authorization using the CIMD client_id
1467+
response = await test_client.get(
1468+
"/authorize",
1469+
params={
1470+
"response_type": "code",
1471+
"client_id": client_id_url, # Using URL as client_id
1472+
"redirect_uri": "https://client.example.com/callback",
1473+
"code_challenge": pkce_challenge["code_challenge"],
1474+
"code_challenge_method": "S256",
1475+
"state": "cimd_test_state",
1476+
},
1477+
)
1478+
1479+
assert response.status_code == 400
1480+
assert "client id" in response.text.lower()
1481+
1482+
@pytest.mark.anyio
1483+
async def test_cimd_authorization_metadata_fetch_error(
1484+
self, test_client: httpx.AsyncClient, pkce_challenge: dict[str, str]
1485+
):
1486+
"""Test authorization endpoint when fetching client metadata fails."""
1487+
client_id_url = "https://example.com/client-metadata"
1488+
1489+
# Mocking httpx.AsyncClient to raise an exception
1490+
mock_client_instance = unittest.mock.AsyncMock(spec=httpx.AsyncClient)
1491+
mock_client_instance.get.side_effect = httpx.RequestError("Network error")
1492+
1493+
# Setup context manager return
1494+
mock_client_instance.__aenter__.return_value = mock_client_instance
1495+
mock_client_instance.__aexit__.return_value = None
1496+
1497+
with unittest.mock.patch("httpx.AsyncClient", return_value=mock_client_instance):
1498+
response = await test_client.get(
1499+
"/authorize",
1500+
params={
1501+
"response_type": "code",
1502+
"client_id": client_id_url,
1503+
"redirect_uri": "https://client.example.com/callback",
1504+
"code_challenge": pkce_challenge["code_challenge"],
1505+
"code_challenge_method": "S256",
1506+
"state": "cimd_test_state",
1507+
},
1508+
)
1509+
1510+
# verify that we get a 400 error (because we can't fetch metadata to verify,
1511+
# and we can't redirect because we don't trust the client yet or don't know its redirect URIs)
1512+
assert response.status_code == 400
1513+
assert "invalid_request" in response.text
1514+
assert "Failed to fetch client metadata" in response.text
1515+
13581516

13591517
class TestAuthorizeEndpointErrors:
13601518
"""Test error handling in the OAuth authorization endpoint."""

0 commit comments

Comments
 (0)