diff --git a/src/ModelContextProtocol.Core/McpSessionHandler.cs b/src/ModelContextProtocol.Core/McpSessionHandler.cs index b2f94fb28..3968e4158 100644 --- a/src/ModelContextProtocol.Core/McpSessionHandler.cs +++ b/src/ModelContextProtocol.Core/McpSessionHandler.cs @@ -173,21 +173,42 @@ public Task ProcessMessagesAsync(CancellationToken cancellationToken) private async Task ProcessMessagesCoreAsync(CancellationToken cancellationToken) { + // Track handler tasks so we can await them during shutdown. This ensures + // that service scopes (e.g., from ASP.NET Core request services in stateless mode) + // are not disposed while handlers are still executing. + List pendingHandlerTasks = []; try { await foreach (var message in _transport.MessageReader.ReadAllAsync(cancellationToken).ConfigureAwait(false)) { LogMessageRead(EndpointName, message.GetType().Name); - // Fire and forget the message handling to avoid blocking the transport. + // Launch the message handler without blocking the transport read loop. + Task handlerTask; if (message.Context?.ExecutionContext is null) { - _ = ProcessMessageAsync(); + handlerTask = ProcessMessageAsync(); } else { // Flow the execution context from the HTTP request corresponding to this message if provided. - ExecutionContext.Run(message.Context.ExecutionContext, _ => _ = ProcessMessageAsync(), null); + Task? capturedTask = null; + ExecutionContext.Run(message.Context.ExecutionContext, _ => capturedTask = ProcessMessageAsync(), null); + handlerTask = capturedTask!; + } + + pendingHandlerTasks.Add(handlerTask); + + // Periodically prune completed tasks to avoid unbounded list growth. + if (pendingHandlerTasks.Count > 50) + { + for (int i = pendingHandlerTasks.Count - 1; i >= 0; i--) + { + if (pendingHandlerTasks[i].IsCompleted) + { + pendingHandlerTasks.RemoveAt(i); + } + } } async Task ProcessMessageAsync() @@ -297,6 +318,23 @@ ex is OperationCanceledException && } finally { + // Wait for all outstanding message handlers to complete before returning. + // This is critical in stateless HTTP mode where the service scope from the + // ASP.NET Core request is disposed after the message processing task completes. + // Without this, handlers could get ObjectDisposedException when trying to + // resolve scoped services. + if (pendingHandlerTasks.Count > 0) + { + try + { + await Task.WhenAll(pendingHandlerTasks).ConfigureAwait(false); + } + catch + { + // Exceptions from individual handlers are already logged within ProcessMessageAsync. + } + } + // Fail any pending requests, as they'll never be satisfied. foreach (var entry in _pendingRequests) { diff --git a/tests/ModelContextProtocol.Tests/Server/McpServerHandlerLifecycleTests.cs b/tests/ModelContextProtocol.Tests/Server/McpServerHandlerLifecycleTests.cs new file mode 100644 index 000000000..cf24cdf82 --- /dev/null +++ b/tests/ModelContextProtocol.Tests/Server/McpServerHandlerLifecycleTests.cs @@ -0,0 +1,94 @@ +using Microsoft.Extensions.DependencyInjection; +using ModelContextProtocol.Client; +using ModelContextProtocol.Protocol; +using ModelContextProtocol.Server; + +namespace ModelContextProtocol.Tests.Server; + +public class McpServerHandlerLifecycleTests : ClientServerTestBase +{ + public McpServerHandlerLifecycleTests(ITestOutputHelper testOutputHelper) + : base(testOutputHelper) + { + } + + protected override void ConfigureServices(ServiceCollection services, IMcpServerBuilder mcpServerBuilder) + { + mcpServerBuilder.WithTools(); + services.AddScoped(); + } + + [Fact] + public async Task ScopedServicesAreAccessibleThroughoutHandlerLifetime_EvenDuringShutdown() + { + // Arrange: create client and call the slow tool + await using McpClient client = await CreateMcpClientForServer(); + var tools = await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken); + var tool = tools.First(t => t.Name == "slow_tool"); + + TrackedService.Reset(); + + // Act: invoke the tool which delays, then accesses the scoped service + CallToolResult result = await tool.CallAsync(cancellationToken: TestContext.Current.CancellationToken); + + // Assert: the scoped service was successfully accessed after the delay. + // If the handler task were not awaited during shutdown, the service scope could + // be disposed before the handler finishes, causing ObjectDisposedException. + Assert.Equal(1, TrackedService.TotalConstructed); + var textContent = Assert.IsType(result.Content.First()); + Assert.Contains("service-ok", textContent.Text); + } + + [McpServerToolType] + public sealed class SlowTool + { + [McpServerTool] + public static async Task SlowToolAsync(TrackedService service, CancellationToken cancellationToken) + { + // Simulate a handler that takes some time, then accesses a scoped service. + await Task.Delay(100, cancellationToken); + + // Access the scoped service after the delay. If the scope were disposed + // prematurely, this would throw ObjectDisposedException. + service.DoWork(); + + return "service-ok"; + } + } + + public class TrackedService : IAsyncDisposable + { + private static int s_totalConstructed; + private static int s_totalDisposed; + private bool _disposed; + + public TrackedService() + { + Interlocked.Increment(ref s_totalConstructed); + } + + public static int TotalConstructed => Volatile.Read(ref s_totalConstructed); + public static int TotalDisposed => Volatile.Read(ref s_totalDisposed); + + public static void Reset() + { + Interlocked.Exchange(ref s_totalConstructed, 0); + Interlocked.Exchange(ref s_totalDisposed, 0); + } + + public void DoWork() + { + if (_disposed) + { + throw new ObjectDisposedException(nameof(TrackedService)); + } + } + + public ValueTask DisposeAsync() + { + _disposed = true; + Interlocked.Increment(ref s_totalDisposed); + return default; + } + } +}