From c3002e8d472ca99721a369c765750ae2cb72e72f Mon Sep 17 00:00:00 2001 From: Jack Date: Fri, 20 Mar 2026 11:22:41 +0000 Subject: [PATCH] feat(everyrow-mcp): reduce progress polling for widget-capable clients Replace the tight 12s polling loop with a single 25s long-poll for claude.ai and Claude Desktop. The MCP App widget already shows live progress via REST polling, so Claude doesn't need to keep checking. This reduces tool calls from unbounded (~20+ per task) to always 3, preventing the "tool-use limit reached" error on long-running tasks. - Add should_long_poll() to detect widget-capable clients - Add _progress_long_poll() with 25s timeout and 5s internal checks - Add widget_handoff message telling Claude to stop polling - Add "Ask Claude to show results" prompt in widget on completion - Increase GCPBackendPolicy timeoutSec to 120s for long-poll headroom - stdio (Claude Code) and internal clients keep the 12s loop unchanged Co-Authored-By: Claude Opus 4.6 --- .../chart/templates/gcpbackendpolicy.yaml | 1 + everyrow-mcp/deploy/chart/values.yaml | 1 + everyrow-mcp/src/everyrow_mcp/config.py | 10 + everyrow-mcp/src/everyrow_mcp/templates.py | 2 + everyrow-mcp/src/everyrow_mcp/tool_helpers.py | 29 +- everyrow-mcp/src/everyrow_mcp/tools.py | 81 +++- everyrow-mcp/tests/test_long_poll.py | 348 ++++++++++++++++++ 7 files changed, 470 insertions(+), 2 deletions(-) create mode 100644 everyrow-mcp/tests/test_long_poll.py diff --git a/everyrow-mcp/deploy/chart/templates/gcpbackendpolicy.yaml b/everyrow-mcp/deploy/chart/templates/gcpbackendpolicy.yaml index da3b835a..e1620707 100644 --- a/everyrow-mcp/deploy/chart/templates/gcpbackendpolicy.yaml +++ b/everyrow-mcp/deploy/chart/templates/gcpbackendpolicy.yaml @@ -8,6 +8,7 @@ spec: sessionAffinity: type: GENERATED_COOKIE cookieTtlSec: {{ .Values.sessionAffinity.gcpBackendPolicy.cookieTtlSec }} + timeoutSec: {{ .Values.sessionAffinity.gcpBackendPolicy.timeoutSec | default 30 }} targetRef: group: "" kind: Service diff --git a/everyrow-mcp/deploy/chart/values.yaml b/everyrow-mcp/deploy/chart/values.yaml index edc25a46..78895ed8 100644 --- a/everyrow-mcp/deploy/chart/values.yaml +++ b/everyrow-mcp/deploy/chart/values.yaml @@ -42,6 +42,7 @@ sessionAffinity: gcpBackendPolicy: enabled: true cookieTtlSec: 3600 + timeoutSec: 120 # Service-level ClientIP affinity: kube-proxy layer fallback. clientIP: enabled: true diff --git a/everyrow-mcp/src/everyrow_mcp/config.py b/everyrow-mcp/src/everyrow_mcp/config.py index 7d6dc0e0..1223b70e 100644 --- a/everyrow-mcp/src/everyrow_mcp/config.py +++ b/everyrow-mcp/src/everyrow_mcp/config.py @@ -92,6 +92,16 @@ class Settings(BaseSettings): default=100, description="If total rows <= this value, skip asking the user for page_size and load all rows directly.", ) + long_poll_timeout: int = Field( + default=25, + description="Max seconds to block in a single everyrow_progress call for widget-capable clients. " + "Must stay under both GKE backend timeout and MCP client tool timeout. " + "Set 0 to disable.", + ) + long_poll_interval: int = Field( + default=5, + description="Seconds between internal status checks during long-poll.", + ) # Upload settings (HTTP mode only) upload_secret: str = Field( diff --git a/everyrow-mcp/src/everyrow_mcp/templates.py b/everyrow-mcp/src/everyrow_mcp/templates.py index 50935e09..f26360b6 100644 --- a/everyrow-mcp/src/everyrow_mcp/templates.py +++ b/everyrow-mcp/src/everyrow_mcp/templates.py @@ -809,6 +809,7 @@ .l-fail::before{background:#e53935!important} .status-done{color:#4caf50;font-weight:600} .status-fail{color:#e53935;font-weight:600} +.prompt{margin-top:6px;font-size:12px;color:#666;font-style:italic} .eta{color:#888;font-size:11px} @keyframes flash{0%,100%{background:transparent}50%{background:rgba(76,175,80,.15)}} .flash{animation:flash 1s ease 3} @@ -852,6 +853,7 @@ h+=`${esc(d.status)}`; h+=`${comp}/${tot}${fail?` (${fail} failed)`:""}`; if(elapsed)h+=`${fmtTime(elapsed)}`; + if(d.status==="completed")h+=`
Ask Claude to show results
`; }else{ h+=`${comp}/${tot}`; const eta=comp>0&&elapsed>0?Math.round((tot-comp)/(comp/elapsed)):0; diff --git a/everyrow-mcp/src/everyrow_mcp/tool_helpers.py b/everyrow-mcp/src/everyrow_mcp/tool_helpers.py index 0a163a8c..0d387c41 100644 --- a/everyrow-mcp/src/everyrow_mcp/tool_helpers.py +++ b/everyrow-mcp/src/everyrow_mcp/tool_helpers.py @@ -192,6 +192,23 @@ def is_internal_client() -> bool: return "everyrow" in get_user_agent().lower() +def should_long_poll(ctx: EveryRowContext) -> bool: + """Return True if this client should use a single long-poll instead of the 12s loop. + + Widget-capable clients (claude.ai, Claude Desktop) get a single long-poll + followed by widget handoff, conserving their per-turn tool call quota. + stdio clients (Claude Code) and internal clients (everyrow-cc) keep the + 12s short-poll loop. + """ + if settings.is_stdio: + return False + if settings.long_poll_timeout <= 0: + return False + if is_internal_client(): + return False + return client_supports_widgets(ctx) + + def _submission_text( label: str, session_url: str, task_id: str, session_id: str = "" ) -> str: @@ -430,7 +447,7 @@ def write_file(self, task_id: str) -> None: started_at=self.started_at, ) - def progress_message(self, task_id: str) -> str: + def progress_message(self, task_id: str, *, widget_handoff: bool = False) -> str: if self.is_terminal: if self.error: return f"Task {self.status.value}: {self.error}" @@ -456,6 +473,16 @@ def progress_message(self, task_id: str) -> str: return f"{completed_msg}\n{next_call}" return f"Task {self.status.value}. Report the error to the user." + # ── Non-terminal: still running ────────────────────────────── + if widget_handoff: + # Widget-capable client: tell Claude to stop polling. + # The widget already shows live progress via REST polling. + fail_part = f", {self.failed} failed" if self.failed else "" + return dedent(f"""\ + Running: {self.completed}/{self.total} complete, {self.running} running{fail_part} ({self.elapsed_s}s elapsed). + Progress is live in the widget above. **Do NOT call everyrow_progress again.** + When the user is ready to review results, call everyrow_results(task_id='{task_id}').""") + if self.is_screen: return dedent(f"""\ Screen running ({self.elapsed_s}s elapsed). diff --git a/everyrow-mcp/src/everyrow_mcp/tools.py b/everyrow-mcp/src/everyrow_mcp/tools.py index f0dfb2b7..902f9e5e 100644 --- a/everyrow-mcp/src/everyrow_mcp/tools.py +++ b/everyrow-mcp/src/everyrow_mcp/tools.py @@ -69,6 +69,7 @@ create_tool_response, is_internal_client, log_client_info, + should_long_poll, write_initial_task_state, ) from everyrow_mcp.utils import fetch_csv_from_url, is_url, save_result_to_csv @@ -1006,7 +1007,6 @@ async def everyrow_progress( Do not add commentary between progress calls, just call again immediately. """ logger.debug("everyrow_progress: task_id=%s", params.task_id) - client = _get_client(ctx) task_id = params.task_id # ── Cross-user access check ────────────────────────────────── @@ -1022,6 +1022,15 @@ async def everyrow_progress( ) ] + if should_long_poll(ctx): + return await _progress_long_poll(ctx, task_id) + return await _progress_short_poll(ctx, task_id) + + +async def _progress_short_poll(ctx: EveryRowContext, task_id: str) -> list[TextContent]: + """Original 12s short-poll: sleep once, check status, return.""" + client = _get_client(ctx) + # Block server-side before polling — controls the cadence await asyncio.sleep(redis_store.PROGRESS_POLL_DELAY) @@ -1052,6 +1061,76 @@ async def everyrow_progress( return [TextContent(type="text", text=ts.progress_message(task_id))] +async def _progress_long_poll(ctx: EveryRowContext, task_id: str) -> list[TextContent]: + """Long-poll for widget-capable clients: block up to ``long_poll_timeout`` + seconds, checking status every ``long_poll_interval`` seconds. + + If the task completes within the timeout, return the normal completion + message so Claude can immediately call everyrow_results. + + If the task is still running, return a widget-handoff message telling + Claude to stop polling — the widget already shows live progress. + """ + client = _get_client(ctx) + timeout = settings.long_poll_timeout + interval = settings.long_poll_interval + elapsed = 0 + ts: TaskState | None = None + + while elapsed < timeout: + await asyncio.sleep(interval) + elapsed += interval + + try: + status_response = handle_response( + await get_task_status_tasks_task_id_status_get.asyncio( + task_id=UUID(task_id), + client=client, + ) + ) + except Exception: + logger.debug("Long-poll status check failed for %s, retrying", task_id) + continue # retry next iteration + + ts = TaskState(status_response) + + # Emit MCP progress notification (no-op if no progressToken) + try: + await ctx.report_progress( + progress=float(ts.completed), + total=float(ts.total or 1), + ) + except Exception: + pass # progress notifications are best-effort + + if ts.is_terminal: + break + + if ts is None: + return [ + TextContent( + type="text", + text=f"Unable to check status for task {task_id}. " + f"Call everyrow_progress(task_id='{task_id}') to retry.", + ) + ] + + if ts.is_terminal: + logger.info( + "everyrow_progress: task_id=%s status=%s (long-poll)", + task_id, + ts.status.value, + ) + + # Terminal → normal completion message; non-terminal → widget handoff + return [ + TextContent( + type="text", + text=ts.progress_message(task_id, widget_handoff=not ts.is_terminal), + ) + ] + + async def everyrow_results_stdio( params: StdioResultsInput, ctx: EveryRowContext ) -> list[TextContent]: diff --git a/everyrow-mcp/tests/test_long_poll.py b/everyrow-mcp/tests/test_long_poll.py new file mode 100644 index 00000000..55f2eff4 --- /dev/null +++ b/everyrow-mcp/tests/test_long_poll.py @@ -0,0 +1,348 @@ +"""Tests for long-poll progress and widget handoff.""" + +from __future__ import annotations + +from unittest.mock import AsyncMock, MagicMock, patch +from uuid import uuid4 + +import pytest +from everyrow.generated.models.public_task_type import PublicTaskType +from everyrow.generated.models.task_status import TaskStatus +from everyrow.generated.types import UNSET + +from everyrow_mcp.tool_helpers import TaskState, should_long_poll +from everyrow_mcp.tools import _progress_long_poll +from tests.conftest import override_settings + + +def _make_status_response( + *, + status: TaskStatus = TaskStatus.RUNNING, + task_type: PublicTaskType = PublicTaskType.AGENT, + completed: int = 2, + running: int = 3, + failed: int = 0, + total: int = 10, + artifact_id=UNSET, + error=UNSET, + created_at=None, + updated_at=None, +): + resp = MagicMock() + resp.status = status + resp.task_type = task_type + resp.artifact_id = artifact_id + resp.session_id = None + resp.error = error + resp.created_at = created_at + resp.updated_at = updated_at + + p = MagicMock() + p.completed = completed + p.running = running + p.failed = failed + p.total = total + resp.progress = p + + return resp + + +def _make_ctx(*, widget_capable: bool = True) -> MagicMock: + """Build a mock EveryRowContext.""" + ctx = MagicMock() + ctx.request_context.lifespan_context.client_factory = MagicMock() + + # Set up client_params for widget detection + if widget_capable: + cp = MagicMock() + cp.clientInfo.name = "Anthropic/ClaudeAI" + cp.clientInfo.version = "1.0.0" + cp.capabilities = MagicMock() + cp.capabilities.experimental = {} + ctx.session.client_params = cp + else: + cp = MagicMock() + cp.clientInfo.name = "claude-code" + cp.clientInfo.version = "2.0" + cp.capabilities = MagicMock() + cp.capabilities.experimental = {} + ctx.session.client_params = cp + + return ctx + + +def _mock_status_factory(responses: list, call_counter: dict): + """Create an async mock status function that returns responses in order.""" + + async def _mock_status(*_args, **_kwargs): + resp = responses[min(call_counter["n"], len(responses) - 1)] + call_counter["n"] += 1 + return resp + + return _mock_status + + +class TestShouldLongPoll: + """Tests for should_long_poll() routing logic.""" + + def test_false_for_stdio(self): + ctx = _make_ctx(widget_capable=True) + with override_settings(transport="stdio"): + assert should_long_poll(ctx) is False + + def test_false_when_timeout_zero(self): + ctx = _make_ctx(widget_capable=True) + with override_settings(transport="streamable-http", long_poll_timeout=0): + assert should_long_poll(ctx) is False + + def test_false_for_internal_client(self): + ctx = _make_ctx(widget_capable=True) + with ( + override_settings(transport="streamable-http"), + patch("everyrow_mcp.tool_helpers.is_internal_client", return_value=True), + ): + assert should_long_poll(ctx) is False + + def test_true_for_widget_capable_http_client(self): + ctx = _make_ctx(widget_capable=True) + with ( + override_settings(transport="streamable-http"), + patch("everyrow_mcp.tool_helpers.is_internal_client", return_value=False), + ): + assert should_long_poll(ctx) is True + + def test_false_for_non_widget_client(self): + ctx = _make_ctx(widget_capable=False) + with ( + override_settings(transport="streamable-http"), + patch("everyrow_mcp.tool_helpers.is_internal_client", return_value=False), + ): + assert should_long_poll(ctx) is False + + +class TestProgressMessageWidgetHandoff: + """Tests for progress_message(widget_handoff=True).""" + + def test_widget_handoff_message_stops_polling(self): + resp = _make_status_response( + status=TaskStatus.RUNNING, + completed=3, + running=2, + total=10, + ) + ts = TaskState(resp) + msg = ts.progress_message("task-abc", widget_handoff=True) + assert "Do NOT call everyrow_progress again" in msg + assert "everyrow_results" in msg + assert "task-abc" in msg + + def test_widget_handoff_includes_progress_stats(self): + resp = _make_status_response( + status=TaskStatus.RUNNING, + completed=5, + running=3, + failed=1, + total=10, + ) + ts = TaskState(resp) + msg = ts.progress_message("task-xyz", widget_handoff=True) + assert "5/10 complete" in msg + assert "3 running" in msg + assert "1 failed" in msg + + def test_widget_handoff_false_keeps_normal_polling(self): + resp = _make_status_response( + status=TaskStatus.RUNNING, + completed=3, + running=2, + total=10, + ) + ts = TaskState(resp) + msg = ts.progress_message("task-normal", widget_handoff=False) + assert "Immediately call everyrow_progress" in msg + assert "Do NOT" not in msg + + def test_widget_handoff_ignored_when_terminal(self): + """widget_handoff should have no effect on terminal messages.""" + resp = _make_status_response( + status=TaskStatus.COMPLETED, + completed=10, + total=10, + ) + ts = TaskState(resp) + # Even with widget_handoff=True, terminal messages are unchanged + msg_with = ts.progress_message("task-done", widget_handoff=True) + msg_without = ts.progress_message("task-done", widget_handoff=False) + assert msg_with == msg_without + + +class TestProgressLongPoll: + """Tests for _progress_long_poll() in tools.py.""" + + @pytest.mark.asyncio + async def test_completes_within_timeout(self): + """Task completes on the 2nd poll -> returns completion message.""" + task_id = str(uuid4()) + counter: dict = {"n": 0} + mock_status = _mock_status_factory( + [ + _make_status_response(status=TaskStatus.RUNNING, completed=3, total=10), + _make_status_response( + status=TaskStatus.COMPLETED, completed=10, total=10 + ), + ], + counter, + ) + + ctx = _make_ctx(widget_capable=True) + mock_client = MagicMock() + ctx.request_context.lifespan_context.client_factory = lambda: mock_client + ctx.report_progress = AsyncMock() + + with ( + override_settings( + transport="streamable-http", long_poll_timeout=15, long_poll_interval=1 + ), + patch( + "everyrow_mcp.tools.get_task_status_tasks_task_id_status_get.asyncio", + side_effect=mock_status, + ), + patch("everyrow_mcp.tools.handle_response", side_effect=lambda x: x), + ): + result = await _progress_long_poll(ctx, task_id) + + text = result[0].text + assert "Completed" in text + assert "Do NOT" not in text + + @pytest.mark.asyncio + async def test_times_out_with_widget_handoff(self): + """Task still running at timeout -> returns widget handoff message.""" + task_id = str(uuid4()) + running_resp = _make_status_response( + status=TaskStatus.RUNNING, completed=3, running=2, total=10 + ) + + ctx = _make_ctx(widget_capable=True) + mock_client = MagicMock() + ctx.request_context.lifespan_context.client_factory = lambda: mock_client + ctx.report_progress = AsyncMock() + + async def _mock_status(*_args, **_kwargs): + return running_resp + + with ( + override_settings( + transport="streamable-http", long_poll_timeout=3, long_poll_interval=1 + ), + patch( + "everyrow_mcp.tools.get_task_status_tasks_task_id_status_get.asyncio", + side_effect=_mock_status, + ), + patch("everyrow_mcp.tools.handle_response", side_effect=lambda x: x), + ): + result = await _progress_long_poll(ctx, task_id) + + text = result[0].text + assert "Do NOT call everyrow_progress again" in text + assert "everyrow_results" in text + + @pytest.mark.asyncio + async def test_resilient_to_intermittent_errors(self): + """Intermittent API errors don't crash -- retries next iteration.""" + task_id = str(uuid4()) + call_count = 0 + + async def _mock_status(*_args, **_kwargs): + nonlocal call_count + call_count += 1 + if call_count == 1: + raise ConnectionError("transient") + return _make_status_response( + status=TaskStatus.COMPLETED, completed=5, total=5 + ) + + ctx = _make_ctx(widget_capable=True) + mock_client = MagicMock() + ctx.request_context.lifespan_context.client_factory = lambda: mock_client + ctx.report_progress = AsyncMock() + + with ( + override_settings( + transport="streamable-http", long_poll_timeout=10, long_poll_interval=1 + ), + patch( + "everyrow_mcp.tools.get_task_status_tasks_task_id_status_get.asyncio", + side_effect=_mock_status, + ), + patch("everyrow_mcp.tools.handle_response", side_effect=lambda x: x), + ): + result = await _progress_long_poll(ctx, task_id) + + text = result[0].text + assert "Completed" in text + assert call_count == 2 # 1 error + 1 success + + @pytest.mark.asyncio + async def test_all_polls_fail_returns_retry_message(self): + """All status checks fail -> returns retry message.""" + task_id = str(uuid4()) + + async def _mock_status(*_args, **_kwargs): + raise ConnectionError("down") + + ctx = _make_ctx(widget_capable=True) + mock_client = MagicMock() + ctx.request_context.lifespan_context.client_factory = lambda: mock_client + ctx.report_progress = AsyncMock() + + with ( + override_settings( + transport="streamable-http", long_poll_timeout=3, long_poll_interval=1 + ), + patch( + "everyrow_mcp.tools.get_task_status_tasks_task_id_status_get.asyncio", + side_effect=_mock_status, + ), + patch("everyrow_mcp.tools.handle_response", side_effect=lambda x: x), + ): + result = await _progress_long_poll(ctx, task_id) + + text = result[0].text + assert "Unable to check status" in text + assert task_id in text + + @pytest.mark.asyncio + async def test_reports_progress_on_each_successful_poll(self): + """ctx.report_progress() called on each successful status check.""" + task_id = str(uuid4()) + counter: dict = {"n": 0} + mock_status = _mock_status_factory( + [ + _make_status_response(status=TaskStatus.RUNNING, completed=2, total=10), + _make_status_response(status=TaskStatus.RUNNING, completed=5, total=10), + _make_status_response( + status=TaskStatus.COMPLETED, completed=10, total=10 + ), + ], + counter, + ) + + ctx = _make_ctx(widget_capable=True) + mock_client = MagicMock() + ctx.request_context.lifespan_context.client_factory = lambda: mock_client + ctx.report_progress = AsyncMock() + + with ( + override_settings( + transport="streamable-http", long_poll_timeout=15, long_poll_interval=1 + ), + patch( + "everyrow_mcp.tools.get_task_status_tasks_task_id_status_get.asyncio", + side_effect=mock_status, + ), + patch("everyrow_mcp.tools.handle_response", side_effect=lambda x: x), + ): + await _progress_long_poll(ctx, task_id) + + assert ctx.report_progress.call_count == 3