diff --git a/src/strands/multiagent/graph.py b/src/strands/multiagent/graph.py index 32eca00ff..fdb2c00ea 100644 --- a/src/strands/multiagent/graph.py +++ b/src/strands/multiagent/graph.py @@ -622,6 +622,13 @@ def _activate_interrupt(self, node: GraphNode, interrupts: list[Interrupt]) -> M self._interrupt_state.interrupts.update({interrupt.id: interrupt for interrupt in interrupts}) self._interrupt_state.activate() + if isinstance(node.executor, Agent): + self._interrupt_state.context[node.node_id] = { + "activated": node.executor._interrupt_state.activated, + "interrupt_state": node.executor._interrupt_state.to_dict(), + "state": node.executor.state.get(), + "messages": node.executor.messages, + } return MultiAgentNodeInterruptEvent(node.node_id, interrupts) @@ -920,16 +927,6 @@ async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any]) if agent_response is None: raise ValueError(f"Node '{node.node_id}' did not produce a result event") - if agent_response.stop_reason == "interrupt": - node.executor.messages.pop() # remove interrupted tool use message - node.executor._interrupt_state.deactivate() - - raise NotImplementedError( - f"node_id=<{node.node_id}>, " - "issue= " - "| user raised interrupt from an agent node" - ) - # Extract metrics with defaults response_metrics = getattr(agent_response, "metrics", None) usage = getattr( @@ -940,18 +937,24 @@ async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any]) node_result = NodeResult( result=agent_response, execution_time=round((time.time() - start_time) * 1000), - status=Status.COMPLETED, + status=Status.INTERRUPTED if agent_response.stop_reason == "interrupt" else Status.COMPLETED, accumulated_usage=usage, accumulated_metrics=metrics, execution_count=1, + interrupts=agent_response.interrupts or [], ) else: raise ValueError(f"Node '{node.node_id}' of type '{type(node.executor)}' is not supported") - # Mark as completed - node.execution_status = Status.COMPLETED node.result = node_result node.execution_time = node_result.execution_time + + if node_result.status == Status.INTERRUPTED: + yield self._activate_interrupt(node, node_result.interrupts) + return + + # Mark as completed + node.execution_status = Status.COMPLETED self.state.completed_nodes.add(node) self.state.results[node.node_id] = node_result self.state.execution_order.append(node) @@ -1018,6 +1021,8 @@ def _accumulate_metrics(self, node_result: NodeResult) -> None: def _build_node_input(self, node: GraphNode) -> list[ContentBlock]: """Build input text for a node based on dependency outputs. + If resuming from an interrupt, return user responses. + Example formatted output: ``` Original Task: Analyze the quarterly sales data and create a summary report @@ -1032,6 +1037,21 @@ def _build_node_input(self, node: GraphNode) -> list[ContentBlock]: - Agent: Data validation complete. All records verified, no anomalies detected. ``` """ + if self._interrupt_state.activated: + context = self._interrupt_state.context + if node.node_id in context and context[node.node_id]["activated"]: + agent_context = context[node.node_id] + agent = cast(Agent, node.executor) + agent.messages = agent_context["messages"] + agent.state = AgentState(agent_context["state"]) + agent._interrupt_state = _InterruptState.from_dict(agent_context["interrupt_state"]) + + responses = context["responses"] + interrupts = agent._interrupt_state.interrupts + return [ + response for response in responses if response["interruptResponse"]["interruptId"] in interrupts + ] + # Get satisfied dependencies dependency_results = {} for edge in self.edges: diff --git a/tests/strands/multiagent/test_graph.py b/tests/strands/multiagent/test_graph.py index cd750865e..9e319d10f 100644 --- a/tests/strands/multiagent/test_graph.py +++ b/tests/strands/multiagent/test_graph.py @@ -23,6 +23,9 @@ def create_mock_agent(name, response_text="Default response", metrics=None, agen agent.id = agent_id or f"{name}_id" agent._session_manager = None agent.hooks = HookRegistry() + agent.state = AgentState() + agent.messages = [] + agent._interrupt_state = _InterruptState() if metrics is None: metrics = Mock( @@ -2153,3 +2156,98 @@ def test_graph_interrupt_on_before_node_call_event(interrupt_hook): assert tru_message == exp_message assert multiagent_result.execution_time >= first_execution_time + + +def test_graph_interrupt_on_agent(agenerator): + exp_interrupts = [ + Interrupt( + id="test_id", + name="test_name", + reason="test_reason", + ) + ] + + agent = create_mock_agent("test_agent", "Task completed") + agent.stream_async = Mock() + agent.stream_async.return_value = agenerator( + [ + { + "result": AgentResult( + message={}, + stop_reason="interrupt", + state={}, + metrics=None, + interrupts=exp_interrupts, + ), + }, + ], + ) + + builder = GraphBuilder() + builder.add_node(agent, "test_agent") + graph = builder.build() + + multiagent_result = graph("Test task") + + tru_result_status = multiagent_result.status + exp_result_status = Status.INTERRUPTED + assert tru_result_status == exp_result_status + + tru_state_status = graph.state.status + exp_state_status = Status.INTERRUPTED + assert tru_state_status == exp_state_status + + tru_node_ids = [node.node_id for node in graph.state.interrupted_nodes] + exp_node_ids = ["test_agent"] + assert tru_node_ids == exp_node_ids + + tru_interrupts = multiagent_result.interrupts + assert tru_interrupts == exp_interrupts + + interrupt = multiagent_result.interrupts[0] + + agent.stream_async = Mock() + agent.stream_async.return_value = agenerator( + [ + { + "result": AgentResult( + message={}, + stop_reason="end_turn", + state={}, + metrics=None, + ), + }, + ], + ) + graph._interrupt_state.context["test_agent"] = { + "activated": True, + "interrupt_state": { + "activated": True, + "context": {}, + "interrupts": {interrupt.id: interrupt.to_dict()}, + }, + "messages": [], + "state": {}, + } + + responses = [ + { + "interruptResponse": { + "interruptId": interrupt.id, + "response": "test_response", + }, + }, + ] + multiagent_result = graph(responses) + + tru_result_status = multiagent_result.status + exp_result_status = Status.COMPLETED + assert tru_result_status == exp_result_status + + tru_state_status = graph.state.status + exp_state_status = Status.COMPLETED + assert tru_state_status == exp_state_status + + assert len(multiagent_result.results) == 1 + + agent.stream_async.assert_called_once_with(responses, invocation_state={}) diff --git a/tests_integ/interrupts/multiagent/test_agent.py b/tests_integ/interrupts/multiagent/test_agent.py index 36fcfef27..1a6ad87c6 100644 --- a/tests_integ/interrupts/multiagent/test_agent.py +++ b/tests_integ/interrupts/multiagent/test_agent.py @@ -5,28 +5,83 @@ from strands import Agent, tool from strands.interrupt import Interrupt -from strands.multiagent import Swarm +from strands.multiagent import GraphBuilder, Swarm from strands.multiagent.base import Status from strands.types.tools import ToolContext +@pytest.fixture +def day_tool(): + @tool(name="day_tool", context=True) + def func(tool_context: ToolContext) -> str: + response = tool_context.interrupt("day_interrupt", reason="need day") + return response + + return func + + +@pytest.fixture +def time_tool(): + @tool(name="time_tool") + def func(): + return "12:01" + + return func + + @pytest.fixture def weather_tool(): @tool(name="weather_tool", context=True) def func(tool_context: ToolContext) -> str: - response = tool_context.interrupt("test_interrupt", reason="need weather") + response = tool_context.interrupt("weather_interrupt", reason="need weather") return response return func @pytest.fixture -def swarm(weather_tool): - weather_agent = Agent(name="weather", tools=[weather_tool]) +def info_agent(): + return Agent(name="info") + + +@pytest.fixture +def day_agent(day_tool): + return Agent(name="day", tools=[day_tool]) + + +@pytest.fixture +def time_agent(time_tool): + return Agent(name="time", tools=[time_tool]) + + +@pytest.fixture +def weather_agent(weather_tool): + return Agent(name="weather", tools=[weather_tool]) + +@pytest.fixture +def swarm(weather_agent): return Swarm([weather_agent]) +@pytest.fixture +def graph(info_agent, day_agent, time_agent, weather_agent): + builder = GraphBuilder() + + builder.add_node(info_agent, "info") + builder.add_node(day_agent, "day") + builder.add_node(time_agent, "time") + builder.add_node(weather_agent, "weather") + + builder.add_edge("info", "day") + builder.add_edge("info", "time") + builder.add_edge("info", "weather") + + builder.set_entry_point("info") + + return builder.build() + + def test_swarm_interrupt_agent(swarm): multiagent_result = swarm("What is the weather?") @@ -38,7 +93,7 @@ def test_swarm_interrupt_agent(swarm): exp_interrupts = [ Interrupt( id=ANY, - name="test_interrupt", + name="weather_interrupt", reason="need weather", ), ] @@ -65,3 +120,67 @@ def test_swarm_interrupt_agent(swarm): weather_message = json.dumps(weather_result.result.message).lower() assert "sunny" in weather_message + + +def test_graph_interrupt_agent(graph): + multiagent_result = graph("What is the day, time, and weather?") + + tru_result_status = multiagent_result.status + exp_result_status = Status.INTERRUPTED + assert tru_result_status == exp_result_status + + tru_state_status = graph.state.status + exp_state_status = Status.INTERRUPTED + assert tru_state_status == exp_state_status + + tru_node_ids = sorted([node.node_id for node in graph.state.interrupted_nodes]) + exp_node_ids = ["day", "weather"] + assert tru_node_ids == exp_node_ids + + tru_interrupts = sorted(multiagent_result.interrupts, key=lambda interrupt: interrupt.name) + exp_interrupts = [ + Interrupt( + id=ANY, + name="day_interrupt", + reason="need day", + ), + Interrupt( + id=ANY, + name="weather_interrupt", + reason="need weather", + ), + ] + assert tru_interrupts == exp_interrupts + + responses = [ + { + "interruptResponse": { + "interruptId": tru_interrupts[0].id, + "response": "monday", + }, + }, + { + "interruptResponse": { + "interruptId": tru_interrupts[1].id, + "response": "sunny", + }, + }, + ] + multiagent_result = graph(responses) + + tru_result_status = multiagent_result.status + exp_result_status = Status.COMPLETED + assert tru_result_status == exp_result_status + + tru_state_status = graph.state.status + exp_state_status = Status.COMPLETED + assert tru_state_status == exp_state_status + + assert len(multiagent_result.results) == 4 + + day_message = json.dumps(multiagent_result.results["day"].result.message).lower() + time_message = json.dumps(multiagent_result.results["time"].result.message).lower() + weather_message = json.dumps(multiagent_result.results["weather"].result.message).lower() + assert "monday" in day_message + assert "12:01" in time_message + assert "sunny" in weather_message diff --git a/tests_integ/interrupts/multiagent/test_session.py b/tests_integ/interrupts/multiagent/test_session.py index 2ccff2c12..96b9844bf 100644 --- a/tests_integ/interrupts/multiagent/test_session.py +++ b/tests_integ/interrupts/multiagent/test_session.py @@ -4,7 +4,6 @@ import pytest from strands import Agent, tool -from strands.hooks import BeforeNodeCallEvent, HookProvider from strands.interrupt import Interrupt from strands.multiagent import GraphBuilder, Swarm from strands.multiagent.base import Status @@ -12,21 +11,6 @@ from strands.types.tools import ToolContext -@pytest.fixture -def interrupt_hook(): - class Hook(HookProvider): - def register_hooks(self, registry): - registry.add_callback(BeforeNodeCallEvent, self.interrupt) - - def interrupt(self, event): - if event.node_id == "time": - response = event.interrupt("test_interrupt", reason="need approval") - if response != "APPROVE": - event.cancel_node = "node rejected" - - return Hook() - - @pytest.fixture def weather_tool(): @tool(name="weather_tool", context=True) @@ -37,15 +21,6 @@ def func(tool_context: ToolContext) -> str: return func -@pytest.fixture -def time_tool(): - @tool(name="time_tool") - def func(): - return "12:01" - - return func - - def test_swarm_interrupt_session(weather_tool, tmpdir): weather_agent = Agent(name="weather", tools=[weather_tool]) summarizer_agent = Agent(name="summarizer") @@ -96,20 +71,19 @@ def test_swarm_interrupt_session(weather_tool, tmpdir): assert "sunny" in summarizer_message -def test_graph_interrupt_session(interrupt_hook, time_tool, tmpdir): - time_agent = Agent(name="time", tools=[time_tool]) +def test_graph_interrupt_session(weather_tool, tmpdir): + weather_agent = Agent(name="weather", tools=[weather_tool]) summarizer_agent = Agent(name="summarizer") session_manager = FileSessionManager(session_id="strands-interrupt-test", storage_dir=tmpdir) builder = GraphBuilder() - builder.add_node(time_agent, "time") + builder.add_node(weather_agent, "weather") builder.add_node(summarizer_agent, "summarizer") - builder.add_edge("time", "summarizer") - builder.set_hook_providers([interrupt_hook]) + builder.add_edge("weather", "summarizer") builder.set_session_manager(session_manager) graph = builder.build() - multiagent_result = graph("Can you check the time and then summarize the results?") + multiagent_result = graph("Can you check the weather and then summarize the results?") tru_result_status = multiagent_result.status exp_result_status = Status.INTERRUPTED @@ -124,22 +98,21 @@ def test_graph_interrupt_session(interrupt_hook, time_tool, tmpdir): Interrupt( id=ANY, name="test_interrupt", - reason="need approval", + reason="need weather", ), ] assert tru_interrupts == exp_interrupts interrupt = multiagent_result.interrupts[0] - time_agent = Agent(name="time", tools=[time_tool]) + weather_agent = Agent(name="weather", tools=[weather_tool]) summarizer_agent = Agent(name="summarizer") session_manager = FileSessionManager(session_id="strands-interrupt-test", storage_dir=tmpdir) builder = GraphBuilder() - builder.add_node(time_agent, "time") + builder.add_node(weather_agent, "weather") builder.add_node(summarizer_agent, "summarizer") - builder.add_edge("time", "summarizer") - builder.set_hook_providers([interrupt_hook]) + builder.add_edge("weather", "summarizer") builder.set_session_manager(session_manager) graph = builder.build() @@ -147,7 +120,7 @@ def test_graph_interrupt_session(interrupt_hook, time_tool, tmpdir): { "interruptResponse": { "interruptId": interrupt.id, - "response": "APPROVE", + "response": "sunny", }, }, ] @@ -163,4 +136,4 @@ def test_graph_interrupt_session(interrupt_hook, time_tool, tmpdir): assert len(multiagent_result.results) == 2 summarizer_message = json.dumps(multiagent_result.results["summarizer"].result.message).lower() - assert "12:01" in summarizer_message + assert "sunny" in summarizer_message