From 425581e4cc185471f073e8b8826d8061c1866043 Mon Sep 17 00:00:00 2001 From: Y-Rookie Date: Thu, 23 Oct 2025 17:33:35 +0800 Subject: [PATCH 1/3] redis mcp server supports more APIs --- server/mcp_server_redis/pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/mcp_server_redis/pyproject.toml b/server/mcp_server_redis/pyproject.toml index e5c7eccb..9f8ef569 100644 --- a/server/mcp_server_redis/pyproject.toml +++ b/server/mcp_server_redis/pyproject.toml @@ -14,4 +14,4 @@ mcp-server-redis = "mcp_server_redis.server:main" [build-system] requires = ["hatchling"] -build-backend = "hatchling.build" \ No newline at end of file +build-backend = "hatchling.build" From fa161e9faabf721ee5b421634b024b64580fa946 Mon Sep 17 00:00:00 2001 From: Y-Rookie Date: Mon, 1 Jun 2026 16:55:23 +0800 Subject: [PATCH 2/3] redis mcp server supports sts token credentials --- server/mcp_server_redis/README.md | 97 +++++- server/mcp_server_redis/README_zh.md | 103 ++++++- .../resource/redis_resource.py | 7 +- .../mcp_server_redis/resource/vpc_resource.py | 7 +- .../src/mcp_server_redis/server.py | 194 +++++++++++- .../mcp_server_redis/tests/test_sts_auth.py | 166 ++++++++++ .../mcp_server_redis/tests/verify_sts_flow.py | 283 ++++++++++++++++++ server/mcp_server_redis/uv.lock | 11 +- 8 files changed, 844 insertions(+), 24 deletions(-) create mode 100644 server/mcp_server_redis/tests/test_sts_auth.py create mode 100644 server/mcp_server_redis/tests/verify_sts_flow.py diff --git a/server/mcp_server_redis/README.md b/server/mcp_server_redis/README.md index 038bebb8..a873ae40 100644 --- a/server/mcp_server_redis/README.md +++ b/server/mcp_server_redis/README.md @@ -182,13 +182,59 @@ --- ## Authentication Method -Obtain the access key ID, secret access key, and region from the Volcengine Management Console, and use API Key authentication. -You need to set `VOLCENGINE_ACCESS_KEY` and `VOLCENGINE_SECRET_KEY` in the configuration file. +Redis MCP now supports the following Volcengine credential modes: + +### 1. Static AK/SK + +Obtain the access key ID, secret access key, and region from the Volcengine Management Console, then set: + +- `VOLCENGINE_REGION` +- `VOLCENGINE_ACCESS_KEY` +- `VOLCENGINE_SECRET_KEY` + +### 2. AK/SK + SessionToken + +If you are using temporary credentials, additionally set this environment variable: + +- `VOLCENGINE_SESSION_TOKEN` + +This mode is suitable for local `stdio` runs or any client that injects credentials through environment variables. + +### 3. STS temporary credentials via `Authorization` header + +For `streamable-http` calls, Redis MCP also supports passing temporary credentials in the request header: + +```http +Authorization: Bearer +``` + +The decoded JSON payload should contain: + +```json +{ + "AccessKeyId": "", + "SecretAccessKey": "", + "SessionToken": "", + "CurrentTime": "2026-05-28T10:00:00+08:00", + "ExpiredTime": "2026-05-28T11:00:00+08:00", + "Region": "cn-beijing" +} +``` + +Notes: + +- `SessionToken` is required when using STS credentials. +- `Region` can be provided either in the payload or through `VOLCENGINE_REGION` / request parameters. +- If both header credentials and environment credentials are provided, request header credentials take precedence. +- If `CurrentTime` and `ExpiredTime` are present, the server will validate whether the STS token is expired. --- ## Deployment Volcengine Redis service access address: https://www.volcengine.com/docs/6293/65743 + +### Example 1: static AK/SK (stdio) + ```json { "mcpServers": { @@ -208,7 +254,52 @@ Volcengine Redis service access address: https://www.volcengine.com/docs/6293/65 } } ``` + +### Example 2: temporary credentials through environment variables (stdio) + +```json +{ + "mcpServers": { + "redis": { + "command": "uvx", + "args": [ + "--from", + "git+https://github.com/volcengine/mcp-server.git#subdirectory=server/mcp_server_redis", + "mcp-server-redis" + ], + "env": { + "VOLCENGINE_REGION": "cn-beijing", + "VOLCENGINE_ACCESS_KEY": "", + "VOLCENGINE_SECRET_KEY": "", + "VOLCENGINE_SESSION_TOKEN": "" + } + } + } +} +``` + +### Example 3: STS credentials through `Authorization` header (streamable-http) + +If your MCP client calls Redis MCP through HTTP, you can pass a Bearer token whose content is the Base64-encoded JSON shown above. The server will extract: + +- `AccessKeyId` +- `SecretAccessKey` +- `SessionToken` +- optional `Region` + +and use them to initialize the underlying Redis and VPC SDK clients dynamically for the current request. + +## Verification + +An end-to-end verification script is provided at `server/mcp_server_redis/tests/verify_sts_flow.py`. + +Run it with: + +```bash +uv run --project server/mcp_server_redis python server/mcp_server_redis/tests/verify_sts_flow.py --region cn-beijing +``` + Currently, the supported regions: ["cn-beijing", "cn-guangzhou", "cn-shanghai", "cn-hongkong", "ap-southeast-1", "ap-southeast-3"] ## License -volcengine/mcp-server is licensed under the [MIT License](https://github.com/volcengine/mcp-server/blob/main/LICENSE). \ No newline at end of file +volcengine/mcp-server is licensed under the [MIT License](https://github.com/volcengine/mcp-server/blob/main/LICENSE). diff --git a/server/mcp_server_redis/README_zh.md b/server/mcp_server_redis/README_zh.md index c10d63e3..74f9960c 100644 --- a/server/mcp_server_redis/README_zh.md +++ b/server/mcp_server_redis/README_zh.md @@ -183,13 +183,59 @@ --- ## 鉴权方式 -在火山引擎管理控制台获取访问密钥 ID、秘密访问密钥和区域,采用 API Key 鉴权。 -需要在配置文件中设置 `VOLCENGINE_ACCESS_KEY` 和 `VOLCENGINE_SECRET_KEY`。 +Redis MCP 现已支持以下几种火山引擎凭证方式: + +### 1. 静态 AK/SK + +在火山引擎管理控制台获取访问密钥 ID、秘密访问密钥和区域后,配置: + +- `VOLCENGINE_REGION` +- `VOLCENGINE_ACCESS_KEY` +- `VOLCENGINE_SECRET_KEY` + +### 2. AK/SK + SessionToken + +如果你使用的是临时凭证,还需要额外设置以下环境变量: + +- `VOLCENGINE_SESSION_TOKEN` + +该方式适用于本地 `stdio` 模式,或通过环境变量注入临时凭证的场景。 + +### 3. 通过 `Authorization` Header 传递 STS 临时凭证 + +对于 `streamable-http` 调用方式,Redis MCP 支持通过请求头传递临时凭证: + +```http +Authorization: Bearer +``` + +解码后的 JSON 内容应包含: + +```json +{ + "AccessKeyId": "", + "SecretAccessKey": "", + "SessionToken": "", + "CurrentTime": "2026-05-28T10:00:00+08:00", + "ExpiredTime": "2026-05-28T11:00:00+08:00", + "Region": "cn-beijing" +} +``` + +说明: + +- 使用 STS 时需要提供 `SessionToken`。 +- `Region` 可以放在 Header 对应的 JSON 中,也可以继续通过 `VOLCENGINE_REGION` 或请求参数传入。 +- 如果同时提供 Header 凭证和环境变量凭证,请求头中的凭证优先级更高。 +- 如果 JSON 中带有 `CurrentTime` 和 `ExpiredTime`,服务端会校验 STS 是否已过期。 --- ## 部署 火山引擎Redis 服务接入地址:https://www.volcengine.com/docs/6293/65743 + +### 示例 1:静态 AK/SK(stdio) + ```json { "mcpServers": { @@ -209,10 +255,59 @@ } } ``` + +### 示例 2:通过环境变量传递临时凭证(stdio) + +```json +{ + "mcpServers": { + "redis": { + "command": "uvx", + "args": [ + "--from", + "git+https://github.com/volcengine/mcp-server.git#subdirectory=server/mcp_server_redis", + "mcp-server-redis" + ], + "env": { + "VOLCENGINE_REGION": "cn-beijing", + "VOLCENGINE_ACCESS_KEY": "", + "VOLCENGINE_SECRET_KEY": "", + "VOLCENGINE_SESSION_TOKEN": "" + } + } + } +} +``` + +### 示例 3:通过 `Authorization` Header 传递 STS(streamable-http) + +如果你的 MCP Client 是通过 HTTP 调用 Redis MCP,可以把上面的 JSON 先做 Base64 编码,再按以下格式放入请求头: + +```http +Authorization: Bearer +``` + +服务端会在当前请求内动态提取并使用: + +- `AccessKeyId` +- `SecretAccessKey` +- `SessionToken` +- 可选的 `Region` + +然后基于这些临时凭证初始化底层 Redis 与 VPC SDK Client。 + +## 验证方式 + +仓库中提供了端到端验证脚本:`server/mcp_server_redis/tests/verify_sts_flow.py`。 + +运行方式: + +```bash +uv run --project server/mcp_server_redis python server/mcp_server_redis/tests/verify_sts_flow.py --region cn-beijing +``` + 当前支持的Region: ["cn-beijing", "cn-guangzhou", "cn-shanghai", "cn-hongkong", "ap-southeast-1", "ap-southeast-3"] ## License volcengine/mcp-server is licensed under the [MIT License](https://github.com/volcengine/mcp-server/blob/main/LICENSE). - - diff --git a/server/mcp_server_redis/src/mcp_server_redis/resource/redis_resource.py b/server/mcp_server_redis/src/mcp_server_redis/resource/redis_resource.py index 31035755..a24db014 100644 --- a/server/mcp_server_redis/src/mcp_server_redis/resource/redis_resource.py +++ b/server/mcp_server_redis/src/mcp_server_redis/resource/redis_resource.py @@ -42,10 +42,13 @@ class RedisSDK: """初始化 Volcano Redis SDK Client""" - def __init__(self, region: str = None, ak: str = None, sk: str = None, host: str = None): + def __init__(self, region: str = None, ak: str = None, sk: str = None, host: str = None, + session_token: str = None): configuration = volcenginesdkcore.Configuration() configuration.ak = ak configuration.sk = sk + if session_token: + configuration.session_token = session_token configuration.region = region if region not in redis_supported_regions: raise Exception(f"Redis is not supported in region {region}.") @@ -164,4 +167,4 @@ def describe_planned_events(self, args: dict) -> DescribePlannedEventsResponse: return self.client.describe_planned_events(DescribePlannedEventsRequest(**args)) def describe_key_scan_jobs(self, args: dict) -> DescribeKeyScanJobsResponse: - return self.client.describe_key_scan_jobs(DescribeKeyScanJobsRequest(**args)) \ No newline at end of file + return self.client.describe_key_scan_jobs(DescribeKeyScanJobsRequest(**args)) diff --git a/server/mcp_server_redis/src/mcp_server_redis/resource/vpc_resource.py b/server/mcp_server_redis/src/mcp_server_redis/resource/vpc_resource.py index 7c07405c..16d43fe5 100644 --- a/server/mcp_server_redis/src/mcp_server_redis/resource/vpc_resource.py +++ b/server/mcp_server_redis/src/mcp_server_redis/resource/vpc_resource.py @@ -8,10 +8,13 @@ class VpcSDK: """初始化 Volcano VPC SDK Client""" - def __init__(self, region: str = None, ak: str = None, sk: str = None, host: str = None): + def __init__(self, region: str = None, ak: str = None, sk: str = None, host: str = None, + session_token: str = None): configuration = volcenginesdkcore.Configuration() configuration.ak = ak configuration.sk = sk + if session_token: + configuration.session_token = session_token configuration.region = region if region not in vpc_supported_regions: raise Exception(f"Vpc is not supported in region {region}.") @@ -28,4 +31,4 @@ def describe_subnets(self, args:dict) -> DescribeSubnetsResponse: return self.client.describe_subnets(DescribeSubnetsRequest(**args)) def describe_eip_addresses(self, args: dict) -> DescribeEipAddressesResponse: - return self.client.describe_eip_addresses(DescribeEipAddressesRequest(**args)) \ No newline at end of file + return self.client.describe_eip_addresses(DescribeEipAddressesRequest(**args)) diff --git a/server/mcp_server_redis/src/mcp_server_redis/server.py b/server/mcp_server_redis/src/mcp_server_redis/server.py index 7bab42e4..8a6910d5 100644 --- a/server/mcp_server_redis/src/mcp_server_redis/server.py +++ b/server/mcp_server_redis/src/mcp_server_redis/server.py @@ -1,6 +1,9 @@ import os +import json +import base64 import logging import argparse +from datetime import datetime from typing import Any from mcp.server.fastmcp import FastMCP @@ -10,16 +13,189 @@ # Initialize the MCP service mcp_server = FastMCP("redis_mcp_server", port=int(os.getenv("MCP_SERVER_PORT", "8000"))) -redis_resource = RedisSDK( - region=os.getenv('VOLCENGINE_REGION'), host=os.getenv('VOLCENGINE_ENDPOINT'), - ak=os.getenv('VOLCENGINE_ACCESS_KEY'), sk=os.getenv('VOLCENGINE_SECRET_KEY') -) -vpc_resource = VpcSDK( - region=os.getenv('VOLCENGINE_REGION'), host=None, - ak=os.getenv('VOLCENGINE_ACCESS_KEY'), sk=os.getenv('VOLCENGINE_SECRET_KEY') -) logger = logging.getLogger("mcp_server_redis") +VOLCENGINE_ACCESS_KEY_ENV_NAMES = ("VOLCENGINE_ACCESS_KEY",) +VOLCENGINE_SECRET_KEY_ENV_NAMES = ("VOLCENGINE_SECRET_KEY",) +VOLCENGINE_SESSION_TOKEN_ENV_NAMES = ("VOLCENGINE_SESSION_TOKEN",) +AUTHORIZATION_ENV_NAMES = ("authorization", "AUTHORIZATION") + +_REDIS_CLIENT_CACHE: dict[tuple[str, str, str, str, str], RedisSDK] = {} +_VPC_CLIENT_CACHE: dict[tuple[str, str, str, str], VpcSDK] = {} + + +def _get_env_value(*names: str) -> str: + for name in names: + value = os.getenv(name) + if value: + return value + return "" + + +def _normalize_iso8601(value: str) -> str: + return value.replace("Z", "+00:00") if value.endswith("Z") else value + + +def _validate_sts_time_window(payload: dict[str, Any]) -> None: + current_time = payload.get("CurrentTime") + expired_time = payload.get("ExpiredTime") + if not current_time or not expired_time: + return + current_dt = datetime.fromisoformat(_normalize_iso8601(str(current_time))) + expired_dt = datetime.fromisoformat(_normalize_iso8601(str(expired_time))) + if current_dt > expired_dt: + raise ValueError("STS token is expired") + + +def _parse_authorization_payload(raw_value: str) -> dict[str, str]: + token = raw_value.split(" ", 1)[1] if " " in raw_value else raw_value + payload = json.loads(base64.b64decode(token).decode("utf-8")) + _validate_sts_time_window(payload) + access_key = str(payload.get("AccessKeyId") or "").strip() + secret_key = str(payload.get("SecretAccessKey") or "").strip() + session_token = str(payload.get("SessionToken") or "").strip() + region = str(payload.get("Region") or "").strip() + if not access_key or not secret_key: + raise ValueError("AccessKeyId or SecretAccessKey missing in authorization payload") + return { + "access_key": access_key, + "secret_key": secret_key, + "session_token": session_token, + "region": region, + } + + +def _get_request_authorization() -> str: + try: + ctx = mcp_server.get_context() + except Exception: + return "" + request_context = getattr(ctx, "request_context", None) + if request_context is None: + request_context = getattr(ctx, "_request_context", None) + request = getattr(request_context, "request", None) + if request is None: + return "" + return str(request.headers.get("authorization") or "").strip() + + +def _resolve_volcengine_credentials(region: str | None = None) -> dict[str, str]: + request_authorization = _get_request_authorization() + if request_authorization: + credentials = _parse_authorization_payload(request_authorization) + credentials["region"] = region or credentials["region"] or os.getenv("VOLCENGINE_REGION", "") + return credentials + + env_authorization = _get_env_value(*AUTHORIZATION_ENV_NAMES) + if env_authorization: + credentials = _parse_authorization_payload(env_authorization) + credentials["region"] = region or credentials["region"] or os.getenv("VOLCENGINE_REGION", "") + return credentials + + access_key = _get_env_value(*VOLCENGINE_ACCESS_KEY_ENV_NAMES) + secret_key = _get_env_value(*VOLCENGINE_SECRET_KEY_ENV_NAMES) + session_token = _get_env_value(*VOLCENGINE_SESSION_TOKEN_ENV_NAMES) + resolved_region = region or os.getenv("VOLCENGINE_REGION", "") + if access_key and secret_key: + return { + "access_key": access_key, + "secret_key": secret_key, + "session_token": session_token, + "region": resolved_region, + } + + missing = [] + if not access_key: + missing.append("VOLCENGINE_ACCESS_KEY") + if not secret_key: + missing.append("VOLCENGINE_SECRET_KEY") + if not resolved_region: + missing.append("VOLCENGINE_REGION or request region_id") + raise ValueError( + "Redis MCP credentials are not configured. Missing: " + ", ".join(missing) + ) + + +def _extract_region_from_args(args: tuple[Any, ...], kwargs: dict[str, Any]) -> str | None: + if args and isinstance(args[0], dict): + request = args[0] + return request.get("region_id") or request.get("RegionId") or request.get("region") + request = kwargs.get("args") + if isinstance(request, dict): + return request.get("region_id") or request.get("RegionId") or request.get("region") + return kwargs.get("region") + + +def _get_redis_client(region: str | None = None) -> RedisSDK: + credentials = _resolve_volcengine_credentials(region) + resolved_region = credentials["region"] + if not resolved_region: + raise ValueError("VOLCENGINE_REGION or request region_id is required for Redis client") + host = os.getenv("VOLCENGINE_ENDPOINT", "") + cache_key = ( + resolved_region, + host, + credentials["access_key"], + credentials["secret_key"], + credentials["session_token"], + ) + client = _REDIS_CLIENT_CACHE.get(cache_key) + if client is None: + client = RedisSDK( + region=resolved_region, + host=host or None, + ak=credentials["access_key"], + sk=credentials["secret_key"], + session_token=credentials["session_token"] or None, + ) + _REDIS_CLIENT_CACHE[cache_key] = client + return client + + +def _get_vpc_client(region: str | None = None) -> VpcSDK: + credentials = _resolve_volcengine_credentials(region) + resolved_region = credentials["region"] + if not resolved_region: + raise ValueError("VOLCENGINE_REGION or request region_id is required for VPC client") + cache_key = ( + resolved_region, + credentials["access_key"], + credentials["secret_key"], + credentials["session_token"], + ) + client = _VPC_CLIENT_CACHE.get(cache_key) + if client is None: + client = VpcSDK( + region=resolved_region, + host=None, + ak=credentials["access_key"], + sk=credentials["secret_key"], + session_token=credentials["session_token"] or None, + ) + _VPC_CLIENT_CACHE[cache_key] = client + return client + + +class _SDKProxy: + def __init__(self, service_type: str): + self.service_type = service_type + + def __getattr__(self, method_name: str): + def _call(*args, **kwargs): + region = _extract_region_from_args(args, kwargs) + if self.service_type == "redis": + client = _get_redis_client(region) + else: + client = _get_vpc_client(region) + method = getattr(client, method_name) + return method(*args, **kwargs) + + return _call + + +redis_resource = _SDKProxy("redis") +vpc_resource = _SDKProxy("vpc") + @mcp_server.tool( name="get_available_params", @@ -1670,4 +1846,4 @@ def main(): if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/server/mcp_server_redis/tests/test_sts_auth.py b/server/mcp_server_redis/tests/test_sts_auth.py new file mode 100644 index 00000000..b91707dd --- /dev/null +++ b/server/mcp_server_redis/tests/test_sts_auth.py @@ -0,0 +1,166 @@ +import base64 +import json +import os +import unittest +from types import SimpleNamespace +from unittest.mock import Mock, patch + +from mcp_server_redis import server as redis_server + + +def _make_auth_header(payload: dict) -> str: + return "Bearer " + base64.b64encode(json.dumps(payload).encode("utf-8")).decode("utf-8") + + +class RedisServerSTSAuthTests(unittest.TestCase): + def setUp(self) -> None: + redis_server._REDIS_CLIENT_CACHE.clear() + redis_server._VPC_CLIENT_CACHE.clear() + + def test_parse_authorization_payload_returns_sts_credentials(self): + header = _make_auth_header( + { + "AccessKeyId": "sts-ak", + "SecretAccessKey": "sts-sk", + "SessionToken": "sts-token", + "CurrentTime": "2026-05-28T10:00:00+08:00", + "ExpiredTime": "2026-05-28T11:00:00+08:00", + "Region": "cn-beijing", + } + ) + + credentials = redis_server._parse_authorization_payload(header) + + self.assertEqual( + credentials, + { + "access_key": "sts-ak", + "secret_key": "sts-sk", + "session_token": "sts-token", + "region": "cn-beijing", + }, + ) + + def test_parse_authorization_payload_rejects_expired_sts(self): + header = _make_auth_header( + { + "AccessKeyId": "sts-ak", + "SecretAccessKey": "sts-sk", + "SessionToken": "sts-token", + "CurrentTime": "2026-05-28T12:00:00+08:00", + "ExpiredTime": "2026-05-28T11:00:00+08:00", + "Region": "cn-beijing", + } + ) + + with self.assertRaisesRegex(ValueError, "STS token is expired"): + redis_server._parse_authorization_payload(header) + + def test_resolve_credentials_prefers_request_authorization(self): + header = _make_auth_header( + { + "AccessKeyId": "header-ak", + "SecretAccessKey": "header-sk", + "SessionToken": "header-token", + "Region": "cn-guangzhou", + } + ) + ctx = SimpleNamespace( + request_context=SimpleNamespace( + request=SimpleNamespace(headers={"authorization": header}) + ) + ) + + with patch.object(redis_server.mcp_server, "get_context", return_value=ctx), patch.dict( + os.environ, + { + "VOLCENGINE_ACCESS_KEY": "env-ak", + "VOLCENGINE_SECRET_KEY": "env-sk", + "VOLCENGINE_SESSION_TOKEN": "env-token", + "VOLCENGINE_REGION": "cn-beijing", + }, + clear=True, + ): + credentials = redis_server._resolve_volcengine_credentials() + + self.assertEqual(credentials["access_key"], "header-ak") + self.assertEqual(credentials["secret_key"], "header-sk") + self.assertEqual(credentials["session_token"], "header-token") + self.assertEqual(credentials["region"], "cn-guangzhou") + + def test_resolve_credentials_supports_env_session_token(self): + with patch.object(redis_server.mcp_server, "get_context", side_effect=RuntimeError("no context")), patch.dict( + os.environ, + { + "VOLCENGINE_ACCESS_KEY": "env-ak", + "VOLCENGINE_SECRET_KEY": "env-sk", + "VOLCENGINE_SESSION_TOKEN": "env-token", + "VOLCENGINE_REGION": "cn-beijing", + }, + clear=True, + ): + credentials = redis_server._resolve_volcengine_credentials() + + self.assertEqual(credentials["access_key"], "env-ak") + self.assertEqual(credentials["secret_key"], "env-sk") + self.assertEqual(credentials["session_token"], "env-token") + self.assertEqual(credentials["region"], "cn-beijing") + + def test_resolve_credentials_ignores_legacy_env_session_token_name(self): + with patch.object(redis_server.mcp_server, "get_context", side_effect=RuntimeError("no context")), patch.dict( + os.environ, + { + "VOLCENGINE_ACCESS_KEY": "env-ak", + "VOLCENGINE_SECRET_KEY": "env-sk", + "VOLCENGINE_ACCESS_SESSION_TOKEN": "legacy-token", + "VOLCENGINE_REGION": "cn-beijing", + }, + clear=True, + ): + credentials = redis_server._resolve_volcengine_credentials() + + self.assertEqual(credentials["access_key"], "env-ak") + self.assertEqual(credentials["secret_key"], "env-sk") + self.assertEqual(credentials["session_token"], "") + self.assertEqual(credentials["region"], "cn-beijing") + + def test_get_redis_client_injects_session_token_into_sdk(self): + with patch.object(redis_server.mcp_server, "get_context", side_effect=RuntimeError("no context")), patch.dict( + os.environ, + { + "VOLCENGINE_ACCESS_KEY": "env-ak", + "VOLCENGINE_SECRET_KEY": "env-sk", + "VOLCENGINE_SESSION_TOKEN": "env-token", + "VOLCENGINE_REGION": "cn-beijing", + "VOLCENGINE_ENDPOINT": "redis.custom.endpoint", + }, + clear=True, + ), patch.object(redis_server, "RedisSDK") as redis_sdk_cls: + instance = object() + redis_sdk_cls.return_value = instance + + client = redis_server._get_redis_client() + + self.assertIs(client, instance) + redis_sdk_cls.assert_called_once_with( + region="cn-beijing", + host="redis.custom.endpoint", + ak="env-ak", + sk="env-sk", + session_token="env-token", + ) + + def test_sdk_proxy_uses_request_region_when_dispatching(self): + fake_client = Mock() + fake_client.describe_regions.return_value = {"ok": True} + + with patch.object(redis_server, "_get_redis_client", return_value=fake_client) as get_client: + result = redis_server.redis_resource.describe_regions({"region_id": "cn-shanghai"}) + + self.assertEqual(result, {"ok": True}) + get_client.assert_called_once_with("cn-shanghai") + fake_client.describe_regions.assert_called_once_with({"region_id": "cn-shanghai"}) + + +if __name__ == "__main__": + unittest.main() diff --git a/server/mcp_server_redis/tests/verify_sts_flow.py b/server/mcp_server_redis/tests/verify_sts_flow.py new file mode 100644 index 00000000..21ea934c --- /dev/null +++ b/server/mcp_server_redis/tests/verify_sts_flow.py @@ -0,0 +1,283 @@ +import argparse +import asyncio +import base64 +import json +import os +import socket +import sys +from dataclasses import dataclass +from pathlib import Path + +from mcp import ClientSession +from mcp.client.stdio import StdioServerParameters, stdio_client +from mcp.client.streamable_http import streamablehttp_client + + +PROJECT_ROOT = Path(__file__).resolve().parents[1] +DEFAULT_HTTP_PORT = 8765 +REQUIRED_ENV_VARS = ( + "VOLCENGINE_ACCESS_KEY", + "VOLCENGINE_SECRET_KEY", + "VOLCENGINE_SESSION_TOKEN", +) + + +@dataclass +class CheckResult: + name: str + passed: bool + detail: str + required: bool = True + + +def _ensure_required_env_vars() -> None: + missing = [name for name in REQUIRED_ENV_VARS if not os.getenv(name)] + if missing: + raise RuntimeError("Missing required environment variables: " + ", ".join(missing)) + + +def _resolve_region(cli_region: str | None) -> str: + return cli_region or os.getenv("VOLCENGINE_REGION") or "cn-beijing" + + +def _build_sts_header(region: str, expired: bool = False) -> str: + payload = { + "AccessKeyId": os.environ["VOLCENGINE_ACCESS_KEY"], + "SecretAccessKey": os.environ["VOLCENGINE_SECRET_KEY"], + "SessionToken": os.environ["VOLCENGINE_SESSION_TOKEN"], + "CurrentTime": "2026-06-01T16:00:00+08:00", + "ExpiredTime": "2026-06-01T17:00:00+08:00", + "Region": region, + } + if expired: + payload["CurrentTime"] = "2026-06-01T18:00:00+08:00" + payload["ExpiredTime"] = "2026-06-01T17:00:00+08:00" + encoded = base64.b64encode(json.dumps(payload).encode("utf-8")).decode("utf-8") + return f"Bearer {encoded}" + + +def _extract_text(result) -> str: + content = getattr(result, "content", []) or [] + if not content: + return "" + return "\n".join(getattr(item, "text", str(item)) for item in content) + + +def _extract_structured_content(result): + return getattr(result, "structuredContent", None) + + +async def _run_stdio_positive(region: str) -> CheckResult: + env = { + "VOLCENGINE_ACCESS_KEY": os.environ["VOLCENGINE_ACCESS_KEY"], + "VOLCENGINE_SECRET_KEY": os.environ["VOLCENGINE_SECRET_KEY"], + "VOLCENGINE_SESSION_TOKEN": os.environ["VOLCENGINE_SESSION_TOKEN"], + } + if os.getenv("VOLCENGINE_REGION"): + env["VOLCENGINE_REGION"] = os.environ["VOLCENGINE_REGION"] + + server = StdioServerParameters( + command=sys.executable, + args=["-m", "mcp_server_redis.server", "--transport", "stdio"], + env=env, + cwd=PROJECT_ROOT, + ) + async with stdio_client(server) as (read_stream, write_stream): + async with ClientSession(read_stream, write_stream) as session: + await session.initialize() + tools = await session.list_tools() + result = await session.call_tool("describe_db_instances", {"region_id": region, "page_size": 1}) + + tool_names = [tool.name for tool in tools.tools] + data = _extract_structured_content(result) or {} + if result.isError: + return CheckResult("stdio + STS env", False, _extract_text(result)) + if "describe_db_instances" not in tool_names: + return CheckResult("stdio + STS env", False, "Tool list does not include describe_db_instances") + instance_count = data.get("total_instances_num", "unknown") + return CheckResult("stdio + STS env", True, f"Tool call succeeded, total_instances_num={instance_count}") + + +async def _run_stdio_missing_token_control(region: str) -> CheckResult: + env = { + "VOLCENGINE_ACCESS_KEY": os.environ["VOLCENGINE_ACCESS_KEY"], + "VOLCENGINE_SECRET_KEY": os.environ["VOLCENGINE_SECRET_KEY"], + } + if os.getenv("VOLCENGINE_REGION"): + env["VOLCENGINE_REGION"] = os.environ["VOLCENGINE_REGION"] + + server = StdioServerParameters( + command=sys.executable, + args=["-m", "mcp_server_redis.server", "--transport", "stdio"], + env=env, + cwd=PROJECT_ROOT, + ) + async with stdio_client(server) as (read_stream, write_stream): + async with ClientSession(read_stream, write_stream) as session: + await session.initialize() + result = await session.call_tool("describe_db_instances", {"region_id": region, "page_size": 1}) + + if result.isError: + return CheckResult( + "stdio missing session token control", + True, + "Call failed as expected: " + _extract_text(result).strip(), + required=False, + ) + return CheckResult( + "stdio missing session token control", + False, + "Call still succeeded without VOLCENGINE_SESSION_TOKEN; current AK/SK may not be STS-only credentials.", + required=False, + ) + + +async def _wait_for_port(host: str, port: int, timeout_seconds: float = 15) -> None: + deadline = asyncio.get_running_loop().time() + timeout_seconds + while True: + try: + with socket.create_connection((host, port), timeout=1): + return + except OSError: + if asyncio.get_running_loop().time() >= deadline: + raise TimeoutError(f"Timed out waiting for server at {host}:{port}") + await asyncio.sleep(0.2) + + +async def _start_http_server(port: int): + env = os.environ.copy() + for name in ( + "VOLCENGINE_ACCESS_KEY", + "VOLCENGINE_SECRET_KEY", + "VOLCENGINE_SESSION_TOKEN", + "VOLCENGINE_REGION", + "authorization", + "AUTHORIZATION", + ): + env.pop(name, None) + env["MCP_SERVER_PORT"] = str(port) + env["PYTHONUNBUFFERED"] = "1" + + process = await asyncio.create_subprocess_exec( + sys.executable, + "-m", + "mcp_server_redis.server", + "--transport", + "streamable-http", + cwd=str(PROJECT_ROOT), + env=env, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.STDOUT, + ) + await _wait_for_port("127.0.0.1", port) + return process + + +async def _stop_http_server(process) -> None: + if process.returncode is not None: + return + process.terminate() + try: + await asyncio.wait_for(process.wait(), timeout=5) + except asyncio.TimeoutError: + process.kill() + await process.wait() + + +async def _run_http_positive(region: str, port: int) -> CheckResult: + process = await _start_http_server(port) + try: + async with streamablehttp_client( + f"http://127.0.0.1:{port}/mcp", + headers={"Authorization": _build_sts_header(region)}, + ) as streams: + read_stream, write_stream, _ = streams + async with ClientSession(read_stream, write_stream) as session: + await session.initialize() + result = await session.call_tool("describe_db_instances", {"region_id": region, "page_size": 1}) + data = _extract_structured_content(result) or {} + if result.isError: + return CheckResult("streamable-http + Authorization STS", False, _extract_text(result)) + instance_count = data.get("total_instances_num", "unknown") + return CheckResult( + "streamable-http + Authorization STS", + True, + f"Tool call succeeded, total_instances_num={instance_count}", + ) + finally: + await _stop_http_server(process) + + +async def _run_http_expired_token_control(region: str, port: int) -> CheckResult: + process = await _start_http_server(port) + try: + async with streamablehttp_client( + f"http://127.0.0.1:{port}/mcp", + headers={"Authorization": _build_sts_header(region, expired=True)}, + ) as streams: + read_stream, write_stream, _ = streams + async with ClientSession(read_stream, write_stream) as session: + await session.initialize() + result = await session.call_tool("describe_db_instances", {"region_id": region, "page_size": 1}) + if result.isError and "STS token is expired" in _extract_text(result): + return CheckResult( + "streamable-http expired token control", + True, + "Expired header was rejected as expected.", + ) + if result.isError: + return CheckResult( + "streamable-http expired token control", + False, + "Call failed, but not with the expected expiration error: " + _extract_text(result).strip(), + ) + return CheckResult( + "streamable-http expired token control", + False, + "Expired STS header unexpectedly succeeded.", + ) + finally: + await _stop_http_server(process) + + +def _print_result(result: CheckResult) -> None: + status = "PASS" if result.passed else ("WARN" if not result.required else "FAIL") + requirement = "required" if result.required else "optional" + print(f"[{status}] {result.name} ({requirement})") + print(f" {result.detail}") + + +async def _main_async(region: str, port: int) -> int: + results = [ + await _run_stdio_positive(region), + await _run_stdio_missing_token_control(region), + await _run_http_positive(region, port), + await _run_http_expired_token_control(region, port + 1), + ] + + print(f"Redis MCP STS verification started, region={region}, http_port={port}") + for result in results: + _print_result(result) + + required_failures = [result for result in results if result.required and not result.passed] + if required_failures: + print("\nOverall result: FAIL") + return 1 + + print("\nOverall result: PASS") + return 0 + + +def main() -> int: + parser = argparse.ArgumentParser(description="Verify Redis MCP STS support end-to-end") + parser.add_argument("--region", help="Region used for Redis API calls. Defaults to VOLCENGINE_REGION or cn-beijing") + parser.add_argument("--http-port", type=int, default=DEFAULT_HTTP_PORT, help="Base port used for streamable-http verification") + args = parser.parse_args() + + _ensure_required_env_vars() + region = _resolve_region(args.region) + return asyncio.run(_main_async(region, args.http_port)) + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/server/mcp_server_redis/uv.lock b/server/mcp_server_redis/uv.lock index 5dbcf491..cd82a1e1 100644 --- a/server/mcp_server_redis/uv.lock +++ b/server/mcp_server_redis/uv.lock @@ -210,9 +210,9 @@ dependencies = [ [package.metadata] requires-dist = [ - { name = "mcp", specifier = ">=1.9.4" }, + { name = "mcp", specifier = ">=1.12.0" }, { name = "mcp", extras = ["cli"] }, - { name = "volcengine-python-sdk", specifier = ">=1.0.130" }, + { name = "volcengine-python-sdk", specifier = ">=4.0.24" }, ] [[package]] @@ -668,7 +668,7 @@ wheels = [ [[package]] name = "volcengine-python-sdk" -version = "4.0.6" +version = "5.0.31" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "certifi" }, @@ -676,4 +676,7 @@ dependencies = [ { name = "six" }, { name = "urllib3" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/9b/8c/8b424f33dcd50faacce4a13c57e65b33aab875c3cff48bd97c4679ff2254/volcengine-python-sdk-4.0.6.tar.gz", hash = "sha256:6367a892f10759c96133a31508aa32fd761b09f877d2efce00bf74b4eabb832f", size = 6161665, upload-time = "2025-07-17T12:48:13.888Z" } +sdist = { url = "https://files.pythonhosted.org/packages/81/19/5090a6a75627f36631e2f5809a841a5fa9b1a4db6eb235cdaab072bb91ef/volcengine_python_sdk-5.0.31.tar.gz", hash = "sha256:59a16c2c613fa6661032c6b868ac27ef484a2164b9189ccd25a1aba590c7712b", size = 9238438, upload-time = "2026-05-29T09:56:58.62Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/68/ce/568b19cf2a1b37a9efa096abf113cba2bc61b36eac7d08e7f841006068b5/volcengine_python_sdk-5.0.31-py2.py3-none-any.whl", hash = "sha256:10a6a19c82165f0e64f56551dbdaa76310e8f53764af24e73bd3997e2e14c8f4", size = 36811058, upload-time = "2026-05-29T09:56:53.659Z" }, +] From c39f28f884dcb707b0f9919f85f77254ddb01524 Mon Sep 17 00:00:00 2001 From: Y-Rookie Date: Mon, 1 Jun 2026 17:31:50 +0800 Subject: [PATCH 3/3] redis mcp server supports sts token credentials --- server/mcp_server_redis/README.md | 6 ++++-- server/mcp_server_redis/README_zh.md | 6 ++++-- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/server/mcp_server_redis/README.md b/server/mcp_server_redis/README.md index a873ae40..25ffe094 100644 --- a/server/mcp_server_redis/README.md +++ b/server/mcp_server_redis/README.md @@ -202,7 +202,7 @@ This mode is suitable for local `stdio` runs or any client that injects credenti ### 3. STS temporary credentials via `Authorization` header -For `streamable-http` calls, Redis MCP also supports passing temporary credentials in the request header: +For HTTP-based MCP calls such as `streamable-http`, Redis MCP supports passing temporary credentials in the request header: ```http Authorization: Bearer @@ -227,6 +227,8 @@ Notes: - `Region` can be provided either in the payload or through `VOLCENGINE_REGION` / request parameters. - If both header credentials and environment credentials are provided, request header credentials take precedence. - If `CurrentTime` and `ExpiredTime` are present, the server will validate whether the STS token is expired. +- The `Authorization` header is mainly intended for HTTP transports such as `streamable-http`. +- For non-HTTP transports such as `stdio`, prefer setting `VOLCENGINE_ACCESS_KEY`, `VOLCENGINE_SECRET_KEY`, and `VOLCENGINE_SESSION_TOKEN` through environment variables. --- @@ -278,7 +280,7 @@ Volcengine Redis service access address: https://www.volcengine.com/docs/6293/65 } ``` -### Example 3: STS credentials through `Authorization` header (streamable-http) +### Example 3: STS credentials through `Authorization` header (HTTP transports such as `streamable-http`) If your MCP client calls Redis MCP through HTTP, you can pass a Bearer token whose content is the Base64-encoded JSON shown above. The server will extract: diff --git a/server/mcp_server_redis/README_zh.md b/server/mcp_server_redis/README_zh.md index 74f9960c..5988a1ff 100644 --- a/server/mcp_server_redis/README_zh.md +++ b/server/mcp_server_redis/README_zh.md @@ -203,7 +203,7 @@ Redis MCP 现已支持以下几种火山引擎凭证方式: ### 3. 通过 `Authorization` Header 传递 STS 临时凭证 -对于 `streamable-http` 调用方式,Redis MCP 支持通过请求头传递临时凭证: +对于基于 HTTP 的 MCP 调用方式(例如 `streamable-http`),Redis MCP 支持通过请求头传递临时凭证: ```http Authorization: Bearer @@ -228,6 +228,8 @@ Authorization: Bearer - `Region` 可以放在 Header 对应的 JSON 中,也可以继续通过 `VOLCENGINE_REGION` 或请求参数传入。 - 如果同时提供 Header 凭证和环境变量凭证,请求头中的凭证优先级更高。 - 如果 JSON 中带有 `CurrentTime` 和 `ExpiredTime`,服务端会校验 STS 是否已过期。 +- `Authorization` Header 主要面向 `streamable-http` 这类 HTTP 传输方式。 +- 对于 `stdio` 这类非 HTTP 传输方式,更推荐通过环境变量传递 `VOLCENGINE_ACCESS_KEY`、`VOLCENGINE_SECRET_KEY` 和 `VOLCENGINE_SESSION_TOKEN`。 --- @@ -279,7 +281,7 @@ Authorization: Bearer } ``` -### 示例 3:通过 `Authorization` Header 传递 STS(streamable-http) +### 示例 3:通过 `Authorization` Header 传递 STS(适用于 `streamable-http` 等 HTTP 传输方式) 如果你的 MCP Client 是通过 HTTP 调用 Redis MCP,可以把上面的 JSON 先做 Base64 编码,再按以下格式放入请求头: