diff --git a/pkg/tools/mcp/mcp.go b/pkg/tools/mcp/mcp.go index c6fd24ac6..48cf1f956 100644 --- a/pkg/tools/mcp/mcp.go +++ b/pkg/tools/mcp/mcp.go @@ -10,6 +10,7 @@ import ( "io" "iter" "log/slog" + "net" "net/url" "strings" "sync" @@ -454,12 +455,13 @@ func (ts *Toolset) callTool(ctx context.Context, toolCall tools.ToolCall) (*tool resp, err := ts.mcpClient.CallTool(ctx, request) - // If the server lost our session (e.g. it restarted), force a - // reconnection and retry the call once. - if errors.Is(err, mcp.ErrSessionMissing) { - slog.Warn("MCP session missing, forcing reconnect and retrying", "tool", toolCall.Function.Name, "server", ts.logID) + // If the call failed with a connection or session error (e.g. the + // server restarted), trigger or wait for a reconnection and retry + // the call once. + if err != nil && isConnectionError(err) && ctx.Err() == nil { + slog.Warn("MCP call failed, forcing reconnect and retrying", "tool", toolCall.Function.Name, "server", ts.logID, "error", err) if waitErr := ts.forceReconnectAndWait(ctx); waitErr != nil { - return nil, fmt.Errorf("failed to reconnect after session loss: %w", waitErr) + return nil, fmt.Errorf("failed to reconnect after call failure: %w", waitErr) } resp, err = ts.mcpClient.CallTool(ctx, request) } @@ -690,3 +692,27 @@ func (ts *Toolset) GetPrompt(ctx context.Context, name string, arguments map[str slog.Debug("Retrieved MCP prompt", "prompt", name, "messages_count", len(result.Messages)) return result, nil } + +// isConnectionError reports whether err is a connection or session error +// that warrants a reconnect-and-retry (as opposed to an application-level +// error that would fail again even after reconnecting). +func isConnectionError(err error) bool { + if errors.Is(err, mcp.ErrSessionMissing) || errors.Is(err, io.EOF) { + return true + } + var netErr net.Error + if errors.As(err, &netErr) { + return true + } + // The MCP SDK wraps transport failures (e.g. connection reset, EOF from + // client.Do) with its internal ErrRejected sentinel using %v, which + // drops the original error from the chain. Detect these by checking + // the error message for common transport-failure substrings. + if msg := err.Error(); strings.Contains(msg, "connection reset") || + strings.Contains(msg, "connection refused") || + strings.Contains(msg, "broken pipe") || + strings.Contains(msg, "EOF") { + return true + } + return false +}