Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 33 additions & 13 deletions src/strands/multiagent/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -622,6 +622,13 @@ def _activate_interrupt(self, node: GraphNode, interrupts: list[Interrupt]) -> M

self._interrupt_state.interrupts.update({interrupt.id: interrupt for interrupt in interrupts})
self._interrupt_state.activate()
if isinstance(node.executor, Agent):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agents are not allowed to use session managers in a graph execution. Consequently, we need to store some agent state in the graph interrupt state to help persist the interrupt workflow between shutdowns. Note, this is the same behavior we have in Swarm (src).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We check if executor is an Agent instance right now because MultiAgentBase executors may not have the same context. We will figure out how to handle that case in a separate (and final) PR for graph interrupt support.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this be AgentBase now?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not yet as class Agent does not yet derive from AgentBase (src).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gotcha, is anything blocking that @mkmeral ?

self._interrupt_state.context[node.node_id] = {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain how interrupts work in graph at a higher level?

I understand how this works for an interrupt that is raised in one Agent node by that Agent node.

But what happens if two nodes are executing and one of them raises an interrupt?

Copy link
Member Author

@pgrayy pgrayy Jan 21, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If two nodes are executing in parallel and one interrupts, the other node will be allowed finished. Once done, the call stack returns back to _execute_graph. Here, before preparing the next batch of nodes to execute, we check to see if an interrupt has been activated (any time any node interrupts, we immediately set graph state to INTERRUPTED). If we have an interrupt, we store the one completed node in the interrupt state context. Context about the interrupted node has already been stored in state as part of the _execute_node call.

Upon resuming, we unpack the interrupted node and the one already completed from interrupt state (src). Within _execute_nodes_parallel, we filter the batch down to just the interrupted node (src). This means the completed node does not get executed again. However, we have the reference so that we can identify its dependent nodes to execute next after we finish resuming the interrupted node.

For the interrupted node, we pass the user interrupt responses into node.executor.stream_async (src). The actual task that the agent node is meant to execute is already stored in the agent node message history.

From here, things proceed as normal.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For additional context, we explicitly test parallel node interrupts in the integ tests presented further down.

Copy link
Member

@dbschmigelski dbschmigelski Jan 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok that makes sense, but I do think this will make #1530 more important.I'm sure in a scenario like

    A
  /   \
 B    C
   \ /
    D

If C interrupts people will be similarly surprised that D still executed

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

D would not execute in this case. So A would be the entry point and execute first. On the next cycle, B and C would be picked up as the ready nodes to execute. They would do so concurrently. If B completes and C interrupts then we pause here and wait for the user to respond. Once the user responds, we finish executing C (and only C). Afterwards, we find the next batch of ready nodes. It is at this time that D executes.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Anything that prevents us from switching this to snapshots later?

"activated": node.executor._interrupt_state.activated,
"interrupt_state": node.executor._interrupt_state.to_dict(),
"state": node.executor.state.get(),
"messages": node.executor.messages,
}

return MultiAgentNodeInterruptEvent(node.node_id, interrupts)

Expand Down Expand Up @@ -920,16 +927,6 @@ async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any])
if agent_response is None:
raise ValueError(f"Node '{node.node_id}' did not produce a result event")

if agent_response.stop_reason == "interrupt":
node.executor.messages.pop() # remove interrupted tool use message
node.executor._interrupt_state.deactivate()

raise NotImplementedError(
f"node_id=<{node.node_id}>, "
"issue=<https://github.com/strands-agents/sdk-python/issues/204> "
"| user raised interrupt from an agent node"
)

# Extract metrics with defaults
response_metrics = getattr(agent_response, "metrics", None)
usage = getattr(
Expand All @@ -940,18 +937,24 @@ async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any])
node_result = NodeResult(
result=agent_response,
execution_time=round((time.time() - start_time) * 1000),
status=Status.COMPLETED,
status=Status.INTERRUPTED if agent_response.stop_reason == "interrupt" else Status.COMPLETED,
accumulated_usage=usage,
accumulated_metrics=metrics,
execution_count=1,
interrupts=agent_response.interrupts or [],
)
else:
raise ValueError(f"Node '{node.node_id}' of type '{type(node.executor)}' is not supported")

# Mark as completed
node.execution_status = Status.COMPLETED
node.result = node_result
node.execution_time = node_result.execution_time

if node_result.status == Status.INTERRUPTED:
yield self._activate_interrupt(node, node_result.interrupts)
return

# Mark as completed
node.execution_status = Status.COMPLETED
self.state.completed_nodes.add(node)
self.state.results[node.node_id] = node_result
self.state.execution_order.append(node)
Expand Down Expand Up @@ -1018,6 +1021,8 @@ def _accumulate_metrics(self, node_result: NodeResult) -> None:
def _build_node_input(self, node: GraphNode) -> list[ContentBlock]:
"""Build input text for a node based on dependency outputs.

If resuming from an interrupt, return user responses.

Example formatted output:
```
Original Task: Analyze the quarterly sales data and create a summary report
Expand All @@ -1032,6 +1037,21 @@ def _build_node_input(self, node: GraphNode) -> list[ContentBlock]:
- Agent: Data validation complete. All records verified, no anomalies detected.
```
"""
if self._interrupt_state.activated:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here is where we restore the agent node state upon resuming. We extract it from the graph interrupt state.

context = self._interrupt_state.context
if node.node_id in context and context[node.node_id]["activated"]:
agent_context = context[node.node_id]
agent = cast(Agent, node.executor)
agent.messages = agent_context["messages"]
agent.state = AgentState(agent_context["state"])
agent._interrupt_state = _InterruptState.from_dict(agent_context["interrupt_state"])

responses = context["responses"]
interrupts = agent._interrupt_state.interrupts
return [
response for response in responses if response["interruptResponse"]["interruptId"] in interrupts
]

# Get satisfied dependencies
dependency_results = {}
for edge in self.edges:
Expand Down
98 changes: 98 additions & 0 deletions tests/strands/multiagent/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ def create_mock_agent(name, response_text="Default response", metrics=None, agen
agent.id = agent_id or f"{name}_id"
agent._session_manager = None
agent.hooks = HookRegistry()
agent.state = AgentState()
agent.messages = []
agent._interrupt_state = _InterruptState()

if metrics is None:
metrics = Mock(
Expand Down Expand Up @@ -2153,3 +2156,98 @@ def test_graph_interrupt_on_before_node_call_event(interrupt_hook):
assert tru_message == exp_message

assert multiagent_result.execution_time >= first_execution_time


def test_graph_interrupt_on_agent(agenerator):
exp_interrupts = [
Interrupt(
id="test_id",
name="test_name",
reason="test_reason",
)
]

agent = create_mock_agent("test_agent", "Task completed")
agent.stream_async = Mock()
agent.stream_async.return_value = agenerator(
[
{
"result": AgentResult(
message={},
stop_reason="interrupt",
state={},
metrics=None,
interrupts=exp_interrupts,
),
},
],
)

builder = GraphBuilder()
builder.add_node(agent, "test_agent")
graph = builder.build()

multiagent_result = graph("Test task")

tru_result_status = multiagent_result.status
exp_result_status = Status.INTERRUPTED
assert tru_result_status == exp_result_status

tru_state_status = graph.state.status
exp_state_status = Status.INTERRUPTED
assert tru_state_status == exp_state_status

tru_node_ids = [node.node_id for node in graph.state.interrupted_nodes]
exp_node_ids = ["test_agent"]
assert tru_node_ids == exp_node_ids

tru_interrupts = multiagent_result.interrupts
assert tru_interrupts == exp_interrupts

interrupt = multiagent_result.interrupts[0]

agent.stream_async = Mock()
agent.stream_async.return_value = agenerator(
[
{
"result": AgentResult(
message={},
stop_reason="end_turn",
state={},
metrics=None,
),
},
],
)
graph._interrupt_state.context["test_agent"] = {
"activated": True,
"interrupt_state": {
"activated": True,
"context": {},
"interrupts": {interrupt.id: interrupt.to_dict()},
},
"messages": [],
"state": {},
}

responses = [
{
"interruptResponse": {
"interruptId": interrupt.id,
"response": "test_response",
},
},
]
multiagent_result = graph(responses)

tru_result_status = multiagent_result.status
exp_result_status = Status.COMPLETED
assert tru_result_status == exp_result_status

tru_state_status = graph.state.status
exp_state_status = Status.COMPLETED
assert tru_state_status == exp_state_status

assert len(multiagent_result.results) == 1

agent.stream_async.assert_called_once_with(responses, invocation_state={})
129 changes: 124 additions & 5 deletions tests_integ/interrupts/multiagent/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,28 +5,83 @@

from strands import Agent, tool
from strands.interrupt import Interrupt
from strands.multiagent import Swarm
from strands.multiagent import GraphBuilder, Swarm
from strands.multiagent.base import Status
from strands.types.tools import ToolContext


@pytest.fixture
def day_tool():
@tool(name="day_tool", context=True)
def func(tool_context: ToolContext) -> str:
response = tool_context.interrupt("day_interrupt", reason="need day")
return response

return func


@pytest.fixture
def time_tool():
@tool(name="time_tool")
def func():
return "12:01"

return func


@pytest.fixture
def weather_tool():
@tool(name="weather_tool", context=True)
def func(tool_context: ToolContext) -> str:
response = tool_context.interrupt("test_interrupt", reason="need weather")
response = tool_context.interrupt("weather_interrupt", reason="need weather")
return response

return func


@pytest.fixture
def swarm(weather_tool):
weather_agent = Agent(name="weather", tools=[weather_tool])
def info_agent():
return Agent(name="info")


@pytest.fixture
def day_agent(day_tool):
return Agent(name="day", tools=[day_tool])


@pytest.fixture
def time_agent(time_tool):
return Agent(name="time", tools=[time_tool])


@pytest.fixture
def weather_agent(weather_tool):
return Agent(name="weather", tools=[weather_tool])


@pytest.fixture
def swarm(weather_agent):
return Swarm([weather_agent])


@pytest.fixture
def graph(info_agent, day_agent, time_agent, weather_agent):
builder = GraphBuilder()

builder.add_node(info_agent, "info")
builder.add_node(day_agent, "day")
builder.add_node(time_agent, "time")
builder.add_node(weather_agent, "weather")

builder.add_edge("info", "day")
builder.add_edge("info", "time")
builder.add_edge("info", "weather")

builder.set_entry_point("info")

return builder.build()


def test_swarm_interrupt_agent(swarm):
multiagent_result = swarm("What is the weather?")

Expand All @@ -38,7 +93,7 @@ def test_swarm_interrupt_agent(swarm):
exp_interrupts = [
Interrupt(
id=ANY,
name="test_interrupt",
name="weather_interrupt",
reason="need weather",
),
]
Expand All @@ -65,3 +120,67 @@ def test_swarm_interrupt_agent(swarm):

weather_message = json.dumps(weather_result.result.message).lower()
assert "sunny" in weather_message


def test_graph_interrupt_agent(graph):
multiagent_result = graph("What is the day, time, and weather?")

tru_result_status = multiagent_result.status
exp_result_status = Status.INTERRUPTED
assert tru_result_status == exp_result_status

tru_state_status = graph.state.status
exp_state_status = Status.INTERRUPTED
assert tru_state_status == exp_state_status

tru_node_ids = sorted([node.node_id for node in graph.state.interrupted_nodes])
exp_node_ids = ["day", "weather"]
assert tru_node_ids == exp_node_ids

tru_interrupts = sorted(multiagent_result.interrupts, key=lambda interrupt: interrupt.name)
exp_interrupts = [
Interrupt(
id=ANY,
name="day_interrupt",
reason="need day",
),
Interrupt(
id=ANY,
name="weather_interrupt",
reason="need weather",
),
]
assert tru_interrupts == exp_interrupts

responses = [
{
"interruptResponse": {
"interruptId": tru_interrupts[0].id,
"response": "monday",
},
},
{
"interruptResponse": {
"interruptId": tru_interrupts[1].id,
"response": "sunny",
},
},
]
multiagent_result = graph(responses)

tru_result_status = multiagent_result.status
exp_result_status = Status.COMPLETED
assert tru_result_status == exp_result_status

tru_state_status = graph.state.status
exp_state_status = Status.COMPLETED
assert tru_state_status == exp_state_status

assert len(multiagent_result.results) == 4

day_message = json.dumps(multiagent_result.results["day"].result.message).lower()
time_message = json.dumps(multiagent_result.results["time"].result.message).lower()
weather_message = json.dumps(multiagent_result.results["weather"].result.message).lower()
assert "monday" in day_message
assert "12:01" in time_message
assert "sunny" in weather_message
Loading
Loading