Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 49 additions & 4 deletions src/ModelContextProtocol.Core/Server/McpServerImpl.cs
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,17 @@ await originalListResourceTemplatesHandler(request, cancellationToken).Configure
}
}

return await handler(request, cancellationToken).ConfigureAwait(false);
try
{
var result = await handler(request, cancellationToken).ConfigureAwait(false);
ReadResourceCompleted(request.Params?.Uri ?? string.Empty);
return result;
}
catch (Exception e) when (e is not OperationCanceledException and not McpProtocolException)
{
ReadResourceError(request.Params?.Uri ?? string.Empty, e);
throw;
}
});
subscribeHandler = BuildFilterPipeline(subscribeHandler, options.Filters.SubscribeToResourcesFilters);
unsubscribeHandler = BuildFilterPipeline(unsubscribeHandler, options.Filters.UnsubscribeFromResourcesFilters);
Expand Down Expand Up @@ -487,7 +497,7 @@ await originalListPromptsHandler(request, cancellationToken).ConfigureAwait(fals

listPromptsHandler = BuildFilterPipeline(listPromptsHandler, options.Filters.ListPromptsFilters);
getPromptHandler = BuildFilterPipeline(getPromptHandler, options.Filters.GetPromptFilters, handler =>
(request, cancellationToken) =>
async (request, cancellationToken) =>
{
// Initial handler that sets MatchedPrimitive
if (request.Params?.Name is { } promptName && prompts is not null &&
Expand All @@ -496,7 +506,17 @@ await originalListPromptsHandler(request, cancellationToken).ConfigureAwait(fals
request.MatchedPrimitive = prompt;
}

return handler(request, cancellationToken);
try
{
var result = await handler(request, cancellationToken).ConfigureAwait(false);
GetPromptCompleted(request.Params?.Name ?? string.Empty);
return result;
}
catch (Exception e) when (e is not OperationCanceledException and not McpProtocolException)
{
GetPromptError(request.Params?.Name ?? string.Empty, e);
throw;
}
});

ServerCapabilities.Prompts.ListChanged = listChanged;
Expand Down Expand Up @@ -610,7 +630,16 @@ await originalListToolsHandler(request, cancellationToken).ConfigureAwait(false)

try
{
return await handler(request, cancellationToken).ConfigureAwait(false);
var result = await handler(request, cancellationToken).ConfigureAwait(false);

// Don't log here for task-augmented calls; logging happens asynchronously
// in ExecuteToolAsTaskAsync when the tool actually completes.
if (result.Task is null)
{
ToolCallCompleted(request.Params?.Name ?? string.Empty, result.IsError is true);
}

return result;
}
catch (Exception e) when (e is not OperationCanceledException and not McpProtocolException)
{
Expand Down Expand Up @@ -944,6 +973,21 @@ internal static LoggingLevel ToLoggingLevel(LogLevel level) =>
[LoggerMessage(Level = LogLevel.Error, Message = "\"{ToolName}\" threw an unhandled exception.")]
private partial void ToolCallError(string toolName, Exception exception);

[LoggerMessage(Level = LogLevel.Information, Message = "\"{ToolName}\" completed. IsError = {IsError}.")]
private partial void ToolCallCompleted(string toolName, bool isError);

[LoggerMessage(Level = LogLevel.Error, Message = "GetPrompt \"{PromptName}\" threw an unhandled exception.")]
private partial void GetPromptError(string promptName, Exception exception);

[LoggerMessage(Level = LogLevel.Information, Message = "GetPrompt \"{PromptName}\" completed.")]
private partial void GetPromptCompleted(string promptName);

[LoggerMessage(Level = LogLevel.Error, Message = "ReadResource \"{ResourceUri}\" threw an unhandled exception.")]
private partial void ReadResourceError(string resourceUri, Exception exception);

[LoggerMessage(Level = LogLevel.Information, Message = "ReadResource \"{ResourceUri}\" completed.")]
private partial void ReadResourceCompleted(string resourceUri);

/// <summary>
/// Executes a tool call as a task and returns a CallToolTaskResult immediately.
/// </summary>
Expand Down Expand Up @@ -1004,6 +1048,7 @@ private async ValueTask<CallToolResult> ExecuteToolAsTaskAsync(

// Invoke the tool with task-specific cancellation token
var result = await tool.InvokeAsync(request, taskCancellationToken).ConfigureAwait(false);
ToolCallCompleted(request.Params?.Name ?? string.Empty, result.IsError is true);

// Determine final status based on whether there was an error
var finalStatus = result.IsError is true ? McpTaskStatus.Failed : McpTaskStatus.Completed;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using Microsoft.Extensions.AI;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options;
using ModelContextProtocol.Client;
using ModelContextProtocol.Protocol;
Expand Down Expand Up @@ -195,6 +196,37 @@ await Assert.ThrowsAsync<McpProtocolException>(async () => await client.GetPromp
cancellationToken: TestContext.Current.CancellationToken));
}

[Fact]
public async Task Logs_Prompt_Name_On_Successful_Call()
{
await using McpClient client = await CreateMcpClientForServer();

var result = await client.GetPromptAsync(
"returns_chat_messages",
new Dictionary<string, object?> { ["message"] = "hello" },
cancellationToken: TestContext.Current.CancellationToken);

Assert.NotNull(result);

var infoLog = Assert.Single(MockLoggerProvider.LogMessages, m => m.Message == "GetPrompt \"returns_chat_messages\" completed.");
Assert.Equal(LogLevel.Information, infoLog.LogLevel);
}

[Fact]
public async Task Logs_Prompt_Name_When_Prompt_Throws()
{
await using McpClient client = await CreateMcpClientForServer();

await Assert.ThrowsAsync<McpProtocolException>(async () => await client.GetPromptAsync(
"throws_exception",
new Dictionary<string, object?> { ["message"] = "test" },
cancellationToken: TestContext.Current.CancellationToken));

var errorLog = Assert.Single(MockLoggerProvider.LogMessages, m => m.LogLevel == LogLevel.Error);
Assert.Equal("GetPrompt \"throws_exception\" threw an unhandled exception.", errorLog.Message);
Assert.IsType<FormatException>(errorLog.Exception);
}

[Fact]
public async Task Throws_Exception_On_Unknown_Prompt()
{
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using Microsoft.Extensions.AI;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options;
using ModelContextProtocol.Client;
using ModelContextProtocol.Protocol;
Expand Down Expand Up @@ -239,6 +240,35 @@ await Assert.ThrowsAsync<McpProtocolException>(async () => await client.ReadReso
cancellationToken: TestContext.Current.CancellationToken));
}

[Fact]
public async Task Logs_Resource_Uri_On_Successful_Read()
{
await using McpClient client = await CreateMcpClientForServer();

var result = await client.ReadResourceAsync(
"resource://mcp/some_neat_direct_resource",
cancellationToken: TestContext.Current.CancellationToken);

Assert.NotNull(result);

var infoLog = Assert.Single(MockLoggerProvider.LogMessages, m => m.Message == "ReadResource \"resource://mcp/some_neat_direct_resource\" completed.");
Assert.Equal(LogLevel.Information, infoLog.LogLevel);
}

[Fact]
public async Task Logs_Resource_Uri_When_Resource_Throws()
{
await using McpClient client = await CreateMcpClientForServer();

await Assert.ThrowsAsync<McpProtocolException>(async () => await client.ReadResourceAsync(
"resource://mcp/throws_exception",
cancellationToken: TestContext.Current.CancellationToken));

var errorLog = Assert.Single(MockLoggerProvider.LogMessages, m => m.LogLevel == LogLevel.Error);
Assert.Equal("ReadResource \"resource://mcp/throws_exception\" threw an unhandled exception.", errorLog.Message);
Assert.IsType<InvalidOperationException>(errorLog.Exception);
}

[Fact]
public async Task Throws_Exception_On_Unknown_Resource()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ public async Task Can_List_Registered_Tools()
await using McpClient client = await CreateMcpClientForServer();

var tools = await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken);
Assert.Equal(16, tools.Count);
Assert.Equal(17, tools.Count);

McpClientTool echoTool = tools.First(t => t.Name == "echo");
Assert.Equal("Echoes the input back to the client.", echoTool.Description);
Expand Down Expand Up @@ -165,7 +165,7 @@ public async Task Can_Create_Multiple_Servers_From_Options_And_List_Registered_T
cancellationToken: TestContext.Current.CancellationToken))
{
var tools = await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken);
Assert.Equal(16, tools.Count);
Assert.Equal(17, tools.Count);

McpClientTool echoTool = tools.First(t => t.Name == "echo");
Assert.Equal("Echoes the input back to the client.", echoTool.Description);
Expand All @@ -191,7 +191,7 @@ public async Task Can_Be_Notified_Of_Tool_Changes()
await using McpClient client = await CreateMcpClientForServer();

var tools = await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken);
Assert.Equal(16, tools.Count);
Assert.Equal(17, tools.Count);

Channel<JsonRpcNotification> listChanged = Channel.CreateUnbounded<JsonRpcNotification>();
var notificationRead = listChanged.Reader.ReadAsync(TestContext.Current.CancellationToken);
Expand All @@ -212,7 +212,7 @@ public async Task Can_Be_Notified_Of_Tool_Changes()
await notificationRead;

tools = await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken);
Assert.Equal(17, tools.Count);
Assert.Equal(18, tools.Count);
Assert.Contains(tools, t => t.Name == "NewTool");

notificationRead = listChanged.Reader.ReadAsync(TestContext.Current.CancellationToken);
Expand All @@ -222,7 +222,7 @@ public async Task Can_Be_Notified_Of_Tool_Changes()
}

tools = await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken);
Assert.Equal(16, tools.Count);
Assert.Equal(17, tools.Count);
Assert.DoesNotContain(tools, t => t.Name == "NewTool");
}

Expand Down Expand Up @@ -380,6 +380,39 @@ public async Task Returns_IsError_Content_And_Logs_Error_When_Tool_Fails()
Assert.Equal("Test error", errorLog.Exception.Message);
}

[Fact]
public async Task Logs_Tool_Name_On_Successful_Call()
{
await using McpClient client = await CreateMcpClientForServer();

var result = await client.CallToolAsync(
"echo",
new Dictionary<string, object?> { ["message"] = "test" },
cancellationToken: TestContext.Current.CancellationToken);

Assert.True(result.IsError is not true);
Assert.Equal("hello test", (result.Content[0] as TextContentBlock)?.Text);

var infoLog = Assert.Single(MockLoggerProvider.LogMessages, m => m.Message == "\"echo\" completed. IsError = False.");
Assert.Equal(LogLevel.Information, infoLog.LogLevel);
}

[Fact]
public async Task Logs_Tool_Name_With_IsError_When_Tool_Returns_Error()
{
await using McpClient client = await CreateMcpClientForServer();

var result = await client.CallToolAsync(
"return_is_error",
cancellationToken: TestContext.Current.CancellationToken);

Assert.True(result.IsError);
Assert.Contains("Tool returned an error", (result.Content[0] as TextContentBlock)?.Text);

var infoLog = Assert.Single(MockLoggerProvider.LogMessages, m => m.Message == "\"return_is_error\" completed. IsError = True.");
Assert.Equal(LogLevel.Information, infoLog.LogLevel);
}

[Fact]
public async Task Throws_Exception_On_Unknown_Tool()
{
Expand Down Expand Up @@ -786,6 +819,16 @@ public static string ThrowException()
throw new InvalidOperationException("Test error");
}

[McpServerTool]
public static CallToolResult ReturnIsError()
{
return new CallToolResult
{
IsError = true,
Content = [new TextContentBlock { Text = "Tool returned an error" }],
};
}

[McpServerTool]
public static int ReturnCancellationToken(CancellationToken cancellationToken)
{
Expand Down Expand Up @@ -868,5 +911,6 @@ public class ComplexObject
[JsonSerializable(typeof(ComplexObject))]
[JsonSerializable(typeof(string[]))]
[JsonSerializable(typeof(JsonElement))]
[JsonSerializable(typeof(CallToolResult))]
partial class BuilderToolsJsonContext : JsonSerializerContext;
}
Loading