diff --git a/src/ModelContextProtocol.Core/McpSessionHandler.cs b/src/ModelContextProtocol.Core/McpSessionHandler.cs index b2f94fb28..389a7cf22 100644 --- a/src/ModelContextProtocol.Core/McpSessionHandler.cs +++ b/src/ModelContextProtocol.Core/McpSessionHandler.cs @@ -173,12 +173,19 @@ public Task ProcessMessagesAsync(CancellationToken cancellationToken) private async Task ProcessMessagesCoreAsync(CancellationToken cancellationToken) { + // Track in-flight message handlers so we can wait for them to complete before returning. + // Start at 1 to represent ProcessMessagesCoreAsync itself; it's decremented after the loop exits. + int inFlightCount = 1; + var allHandlersCompleted = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + try { await foreach (var message in _transport.MessageReader.ReadAllAsync(cancellationToken).ConfigureAwait(false)) { LogMessageRead(EndpointName, message.GetType().Name); + Interlocked.Increment(ref inFlightCount); + // Fire and forget the message handling to avoid blocking the transport. if (message.Context?.ExecutionContext is null) { @@ -286,6 +293,11 @@ ex is OperationCanceledException && _handlingRequests.TryRemove(messageWithId.Id, out _); combinedCts!.Dispose(); } + + if (Interlocked.Decrement(ref inFlightCount) == 0) + { + allHandlersCompleted.TrySetResult(true); + } } } } @@ -297,6 +309,12 @@ ex is OperationCanceledException && } finally { + // Decrement our own count. If all handlers have already completed, this will signal completion. + if (Interlocked.Decrement(ref inFlightCount) != 0) + { + await allHandlersCompleted.Task.ConfigureAwait(false); + } + // Fail any pending requests, as they'll never be satisfied. foreach (var entry in _pendingRequests) { diff --git a/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs b/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs index ea680ecf0..045cbe435 100644 --- a/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs +++ b/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs @@ -1003,6 +1003,60 @@ await transport.SendClientMessageAsync(new JsonRpcNotification await runTask; } + [Fact] + public async Task RunAsync_WaitsForInFlightHandlersBeforeReturning() + { + // Arrange: Create a tool handler that blocks until we release it. + var handlerStarted = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var releaseHandler = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + bool handlerCompleted = false; + + await using var transport = new TestServerTransport(); + var options = CreateOptions(new ServerCapabilities { Tools = new() }); + options.Handlers.CallToolHandler = async (request, ct) => + { + handlerStarted.SetResult(true); + await releaseHandler.Task; + handlerCompleted = true; + return new CallToolResult { Content = [new TextContentBlock { Text = "done" }] }; + }; + options.Handlers.ListToolsHandler = (request, ct) => throw new NotImplementedException(); + + await using var server = McpServer.Create(transport, options, LoggerFactory); + var runTask = server.RunAsync(TestContext.Current.CancellationToken); + + // Send a tool call request. + await transport.SendClientMessageAsync( + new JsonRpcRequest + { + Method = RequestMethods.ToolsCall, + Id = new RequestId(1) + }, + TestContext.Current.CancellationToken); + + // Wait for the handler to start executing. + await handlerStarted.Task.WaitAsync(TestConstants.DefaultTimeout, TestContext.Current.CancellationToken); + + // Dispose the transport to simulate client disconnect while the handler is still running. + await transport.DisposeAsync(); + + // Release the handler after a delay, giving ProcessMessagesCoreAsync time to notice the + // channel closed. Without the fix, RunAsync would return before the handler completes. + var ct = TestContext.Current.CancellationToken; + _ = Task.Run(async () => + { + await Task.Delay(200, ct); + releaseHandler.SetResult(true); + }, ct); + + // Wait for RunAsync to complete. + await runTask.WaitAsync(TestConstants.DefaultTimeout, TestContext.Current.CancellationToken); + + // With the fix, RunAsync waits for in-flight handlers. Without it, it returns immediately + // after the transport closes (before the 500ms delay releases the handler). + Assert.True(handlerCompleted, "RunAsync should wait for in-flight handlers to complete before returning."); + } + private static async Task InitializeServerAsync(TestServerTransport transport, ClientCapabilities capabilities, CancellationToken cancellationToken = default) { var initializeRequest = new JsonRpcRequest