From f636feeab6d1b1315ce89a8c78fd65663e1f1b44 Mon Sep 17 00:00:00 2001 From: Cristian Pufu Date: Fri, 20 Feb 2026 14:54:58 +0200 Subject: [PATCH 1/4] feat: replace session-based state with checkpoints, add breakpoint system - Replace SqliteSessionStore with checkpoint-based persistence using the Agent Framework's CheckpointStorage protocol (SqliteCheckpointStorage, ScopedCheckpointStorage) - Add executor-level breakpoints that work with all agent types (RawAgent, Agent) by wrapping executor.execute() instead of using middleware - Fix concurrent breakpoint infinite loop by accumulating skip_nodes set across resumes instead of tracking a single skip_node - Fix wildcard breakpoint resolution for ["*"] list format from debug bridge - Fix spurious start executor STARTED events on breakpoint resume - Add HITL interrupt support, resumable storage with KV store, and session persistence via runtime_kv table - Update samples: replace quickstart-agent/multi-agent with quickstart-workflow/hitl-workflow/concurrent - Rewrite tests for checkpoint storage, streaming, and breakpoints (80 tests) Co-Authored-By: Claude Opus 4.6 --- .../uipath-agent-framework/pyproject.toml | 2 +- .../uipath-agent-framework/samples/README.md | 4 +- .../samples/concurrent/.vscode/settings.json | 3 + .../samples/hitl-workflow/README.md | 43 + .../agent_framework.json | 0 .../samples/hitl-workflow/main.py | 83 ++ .../pyproject.toml | 6 +- .../samples/multi-agent/README.md | 46 - .../samples/multi-agent/agent.mermaid | 18 - .../samples/multi-agent/main.py | 79 -- .../samples/quickstart-agent/uipath.json | 14 - .../quickstart-workflow/.vscode/settings.json | 3 + .../README.md | 0 .../agent_framework.json | 0 .../main.py | 6 +- .../pyproject.toml | 6 +- .../uipath.json | 0 .../uipath_agent_framework/chat/__init__.py | 5 + .../src/uipath_agent_framework/chat/hitl.py | 70 ++ .../runtime/__init__.py | 11 + .../runtime/breakpoints.py | 168 ++++ .../uipath_agent_framework/runtime/factory.py | 76 +- .../runtime/interrupt.py | 131 +++ .../uipath_agent_framework/runtime/loader.py | 11 +- .../runtime/resumable_storage.py | 506 ++++++++++ .../uipath_agent_framework/runtime/runtime.py | 814 +++++++-------- .../uipath_agent_framework/runtime/schema.py | 230 +---- .../uipath_agent_framework/runtime/storage.py | 116 --- .../tests/test_breakpoints.py | 926 ++++++++++++++++++ .../tests/test_graph.py | 97 -- .../tests/test_schema.py | 93 -- .../tests/test_storage.py | 427 +++++--- .../tests/test_streaming.py | 658 +++++++++---- packages/uipath-agent-framework/uv.lock | 2 +- 34 files changed, 3223 insertions(+), 1431 deletions(-) create mode 100644 packages/uipath-agent-framework/samples/concurrent/.vscode/settings.json create mode 100644 packages/uipath-agent-framework/samples/hitl-workflow/README.md rename packages/uipath-agent-framework/samples/{multi-agent => hitl-workflow}/agent_framework.json (100%) create mode 100644 packages/uipath-agent-framework/samples/hitl-workflow/main.py rename packages/uipath-agent-framework/samples/{quickstart-agent => hitl-workflow}/pyproject.toml (74%) delete mode 100644 packages/uipath-agent-framework/samples/multi-agent/README.md delete mode 100644 packages/uipath-agent-framework/samples/multi-agent/agent.mermaid delete mode 100644 packages/uipath-agent-framework/samples/multi-agent/main.py delete mode 100644 packages/uipath-agent-framework/samples/quickstart-agent/uipath.json create mode 100644 packages/uipath-agent-framework/samples/quickstart-workflow/.vscode/settings.json rename packages/uipath-agent-framework/samples/{quickstart-agent => quickstart-workflow}/README.md (100%) rename packages/uipath-agent-framework/samples/{quickstart-agent => quickstart-workflow}/agent_framework.json (100%) rename packages/uipath-agent-framework/samples/{quickstart-agent => quickstart-workflow}/main.py (91%) rename packages/uipath-agent-framework/samples/{multi-agent => quickstart-workflow}/pyproject.toml (82%) rename packages/uipath-agent-framework/samples/{multi-agent => quickstart-workflow}/uipath.json (100%) create mode 100644 packages/uipath-agent-framework/src/uipath_agent_framework/chat/hitl.py create mode 100644 packages/uipath-agent-framework/src/uipath_agent_framework/runtime/breakpoints.py create mode 100644 packages/uipath-agent-framework/src/uipath_agent_framework/runtime/interrupt.py create mode 100644 packages/uipath-agent-framework/src/uipath_agent_framework/runtime/resumable_storage.py delete mode 100644 packages/uipath-agent-framework/src/uipath_agent_framework/runtime/storage.py create mode 100644 packages/uipath-agent-framework/tests/test_breakpoints.py diff --git a/packages/uipath-agent-framework/pyproject.toml b/packages/uipath-agent-framework/pyproject.toml index 2e8b4cb5..03b85298 100644 --- a/packages/uipath-agent-framework/pyproject.toml +++ b/packages/uipath-agent-framework/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "uipath-agent-framework" -version = "0.0.3" +version = "0.0.4" description = "Python SDK that enables developers to build and deploy Microsoft Agent Framework agents to the UiPath Cloud Platform" readme = "README.md" requires-python = ">=3.11" diff --git a/packages/uipath-agent-framework/samples/README.md b/packages/uipath-agent-framework/samples/README.md index 0fde9c07..4c8f8b38 100644 --- a/packages/uipath-agent-framework/samples/README.md +++ b/packages/uipath-agent-framework/samples/README.md @@ -6,8 +6,8 @@ Sample agents built with [Agent Framework](https://github.com/microsoft/agent-fr | Sample | Description | |--------|-------------| -| [quickstart-agent](./quickstart-agent/) | Single agent with tool calling: fetches live weather data for any location | -| [multi-agent](./multi-agent/) | Multi-agent coordinator: delegates research and code execution to specialist sub-agents via `as_tool()` | +| [quickstart-workflow](./quickstart-workflow/) | Single workflow agent with tool calling: fetches live weather data for any location | | [group-chat](./group-chat/) | Group chat orchestration: researcher, critic, and writer discuss a topic with an orchestrator picking speakers | | [concurrent](./concurrent/) | Concurrent orchestration: sentiment, topic extraction, and summarization agents analyze text in parallel | | [handoff](./handoff/) | Handoff orchestration: customer support agents transfer control to specialists with explicit routing rules | +| [hitl-workflow](./hitl-workflow/) | Human-in-the-loop workflow: customer support with approval-gated billing and refund operations | diff --git a/packages/uipath-agent-framework/samples/concurrent/.vscode/settings.json b/packages/uipath-agent-framework/samples/concurrent/.vscode/settings.json new file mode 100644 index 00000000..af690fcd --- /dev/null +++ b/packages/uipath-agent-framework/samples/concurrent/.vscode/settings.json @@ -0,0 +1,3 @@ +{ + "python-envs.pythonProjects": [] +} diff --git a/packages/uipath-agent-framework/samples/hitl-workflow/README.md b/packages/uipath-agent-framework/samples/hitl-workflow/README.md new file mode 100644 index 00000000..67f7e322 --- /dev/null +++ b/packages/uipath-agent-framework/samples/hitl-workflow/README.md @@ -0,0 +1,43 @@ +# HITL Workflow + +A customer support workflow with human-in-the-loop approval for sensitive operations. A triage agent routes requests to billing or returns specialists. Both `transfer_funds` and `issue_refund` tools require human approval before executing. + +## Agent Graph + +```mermaid +flowchart TB + __start__(__start__) + triage(triage) + billing_agent(billing_agent) + returns_agent(returns_agent) + __end__(__end__) + __start__ --> |input|triage + triage --> billing_agent + triage --> returns_agent + billing_agent --> returns_agent + billing_agent --> triage + returns_agent --> billing_agent + returns_agent --> triage + billing_agent --> |output|__end__ + returns_agent --> |output|__end__ +``` + +## Prerequisites + +Authenticate with UiPath to configure your `.env` file: + +```bash +uipath auth +``` + +## Run + +``` +uipath run agent '{"messages": [{"contentParts": [{"data": {"inline": "I need a refund for order #12345"}}], "role": "user"}]}' +``` + +## Debug + +``` +uipath dev web +``` diff --git a/packages/uipath-agent-framework/samples/multi-agent/agent_framework.json b/packages/uipath-agent-framework/samples/hitl-workflow/agent_framework.json similarity index 100% rename from packages/uipath-agent-framework/samples/multi-agent/agent_framework.json rename to packages/uipath-agent-framework/samples/hitl-workflow/agent_framework.json diff --git a/packages/uipath-agent-framework/samples/hitl-workflow/main.py b/packages/uipath-agent-framework/samples/hitl-workflow/main.py new file mode 100644 index 00000000..cd22c4fb --- /dev/null +++ b/packages/uipath-agent-framework/samples/hitl-workflow/main.py @@ -0,0 +1,83 @@ +from agent_framework.orchestrations import HandoffBuilder + +from uipath_agent_framework.chat import UiPathOpenAIChatClient, requires_approval + + +@requires_approval +def transfer_funds(from_account: str, to_account: str, amount: float) -> str: + """Transfer funds between accounts. Requires human approval. + + Args: + from_account: Source account ID + to_account: Destination account ID + amount: Amount to transfer + + Returns: + Confirmation message + """ + return f"Transferred ${amount:.2f} from {from_account} to {to_account}" + + +@requires_approval +def issue_refund(order_id: str, amount: float, reason: str) -> str: + """Issue a refund for an order. Requires human approval. + + Args: + order_id: The order ID to refund + amount: Refund amount + reason: Reason for the refund + + Returns: + Confirmation message + """ + return f"Refund of ${amount:.2f} issued for order {order_id}: {reason}" + + +client = UiPathOpenAIChatClient(model="gpt-5-mini-2025-08-07") + +triage = client.as_agent( + name="triage", + description="Routes customer requests to the right specialist.", + instructions=( + "You are a customer support triage agent. Determine what the " + "customer needs help with and hand off to the right agent:\n" + "- Billing issues (payments, transfers) -> billing_agent\n" + "- Returns and refunds -> returns_agent\n" + ), +) + +billing = client.as_agent( + name="billing_agent", + description="Handles billing, payments, and fund transfers.", + instructions=( + "You are a billing specialist. Help customers with payments " + "and transfers. Use the transfer_funds tool when needed — " + "it will require human approval before executing." + ), + tools=[transfer_funds], +) + +returns = client.as_agent( + name="returns_agent", + description="Handles product returns and refund requests.", + instructions=( + "You are a returns specialist. Help customers process returns " + "and issue refunds. Use the issue_refund tool — it will " + "require human approval before executing." + ), + tools=[issue_refund], +) + +workflow = ( + HandoffBuilder( + name="customer_support", + participants=[triage, billing, returns], + ) + .with_start_agent(triage) + .add_handoff(triage, [billing, returns]) + .add_handoff(billing, [returns, triage]) + .add_handoff(returns, [billing, triage]) + .build() +) + +agent = workflow.as_agent(name="customer_support") diff --git a/packages/uipath-agent-framework/samples/quickstart-agent/pyproject.toml b/packages/uipath-agent-framework/samples/hitl-workflow/pyproject.toml similarity index 74% rename from packages/uipath-agent-framework/samples/quickstart-agent/pyproject.toml rename to packages/uipath-agent-framework/samples/hitl-workflow/pyproject.toml index 52dbf51d..667ad564 100644 --- a/packages/uipath-agent-framework/samples/quickstart-agent/pyproject.toml +++ b/packages/uipath-agent-framework/samples/hitl-workflow/pyproject.toml @@ -1,7 +1,7 @@ [project] -name = "quickstart-agent" +name = "hitl-workflow" version = "0.0.1" -description = "Quickstart Agent Framework agent example" +description = "Agent Framework workflow with human-in-the-loop tool approval" authors = [{ name = "John Doe" }] readme = "README.md" requires-python = ">=3.11" @@ -18,5 +18,3 @@ dev = [ [tool.uv] prerelease = "allow" - - diff --git a/packages/uipath-agent-framework/samples/multi-agent/README.md b/packages/uipath-agent-framework/samples/multi-agent/README.md deleted file mode 100644 index cb67e3a5..00000000 --- a/packages/uipath-agent-framework/samples/multi-agent/README.md +++ /dev/null @@ -1,46 +0,0 @@ -# Multi-Agent - -A coordinator agent that delegates to specialist agents (research + code execution) using the Agent Framework's `as_tool()` pattern. - -## Agent Graph - -```mermaid -flowchart TB - __start__(__start__) - coordinator(coordinator) - research_agent(research_agent) - research_agent_tools(tools) - code_agent(code_agent) - code_agent_tools(tools) - __end__(__end__) - research_agent --> research_agent_tools - research_agent_tools --> research_agent - coordinator --> research_agent - research_agent --> coordinator - code_agent --> code_agent_tools - code_agent_tools --> code_agent - coordinator --> code_agent - code_agent --> coordinator - __start__ --> |input|coordinator - coordinator --> |output|__end__ -``` - -## Prerequisites - -Authenticate with UiPath to configure your `.env` file: - -```bash -uipath auth -``` - -## Run - -``` -uipath run agent '{"messages": [{"contentParts": [{"data": {"inline": "What is the population of France and calculate its square root?"}}], "role": "user"}]}' -``` - -## Debug - -``` -uipath dev web -``` diff --git a/packages/uipath-agent-framework/samples/multi-agent/agent.mermaid b/packages/uipath-agent-framework/samples/multi-agent/agent.mermaid deleted file mode 100644 index ea16e8f5..00000000 --- a/packages/uipath-agent-framework/samples/multi-agent/agent.mermaid +++ /dev/null @@ -1,18 +0,0 @@ -flowchart TB - __start__(__start__) - coordinator(coordinator) - research_agent(research_agent) - research_agent_tools(tools) - code_agent(code_agent) - code_agent_tools(tools) - __end__(__end__) - research_agent --> research_agent_tools - research_agent_tools --> research_agent - coordinator --> |research_agent|research_agent - research_agent --> coordinator - code_agent --> code_agent_tools - code_agent_tools --> code_agent - coordinator --> |code_agent|code_agent - code_agent --> coordinator - __start__ --> |input|coordinator - coordinator --> |output|__end__ diff --git a/packages/uipath-agent-framework/samples/multi-agent/main.py b/packages/uipath-agent-framework/samples/multi-agent/main.py deleted file mode 100644 index 94587a0c..00000000 --- a/packages/uipath-agent-framework/samples/multi-agent/main.py +++ /dev/null @@ -1,79 +0,0 @@ -import io -import sys - -import httpx - -from uipath_agent_framework.chat import UiPathOpenAIChatClient - - -def search_wikipedia(query: str) -> str: - """Search Wikipedia for a topic and return a summary. - - Args: - query: The search query, e.g. "Python programming language" - - Returns: - A summary of the Wikipedia article, or an error message. - """ - try: - resp = httpx.get( - "https://en.wikipedia.org/api/rest_v1/page/summary/" - + query.replace(" ", "_"), - headers={"User-Agent": "UiPathMultiAgent/1.0"}, - timeout=10, - follow_redirects=True, - ) - resp.raise_for_status() - data = resp.json() - return data.get("extract", "No summary available.") - except Exception as e: - return f"Wikipedia search failed for '{query}': {e}" - - -def run_python(code: str) -> str: - """Execute a Python code snippet and return its output. - - Args: - code: The Python code to execute. - - Returns: - The captured stdout output, or an error message. - """ - old_stdout = sys.stdout - sys.stdout = captured = io.StringIO() - try: - exec(code, {"__builtins__": __builtins__}) # noqa: B023 - return captured.getvalue() or "(no output)" - except Exception as e: - return f"Execution error: {e}" - finally: - sys.stdout = old_stdout - - -client = UiPathOpenAIChatClient(model="gpt-5-mini-2025-08-07") - -research_agent = client.as_agent( - name="research_agent", - description="Searches Wikipedia for information on any topic.", - instructions="You are a research assistant. Use the search_wikipedia tool to find information. Provide concise, factual summaries.", - tools=[search_wikipedia], -) - -code_agent = client.as_agent( - name="code_agent", - description="Executes Python code snippets and returns the output.", - instructions="You are a coding assistant. Use the run_python tool to execute Python code. Always validate code before running.", - tools=[run_python], -) - -# Coordinator delegates to specialists via as_tool() -agent = client.as_agent( - name="coordinator", - instructions=( - "You are a coordinator that delegates tasks to specialist agents.\n" - "- Use 'research_agent' for any research or factual questions.\n" - "- Use 'code_agent' for any coding, calculation, or data processing tasks.\n" - "Combine their results to give comprehensive answers." - ), - tools=[research_agent.as_tool(), code_agent.as_tool()], -) diff --git a/packages/uipath-agent-framework/samples/quickstart-agent/uipath.json b/packages/uipath-agent-framework/samples/quickstart-agent/uipath.json deleted file mode 100644 index 7969b8f0..00000000 --- a/packages/uipath-agent-framework/samples/quickstart-agent/uipath.json +++ /dev/null @@ -1,14 +0,0 @@ -{ - "$schema": "https://cloud.uipath.com/draft/2024-12/uipath", - "runtimeOptions": { - "isConversational": true - }, - "packOptions": { - "fileExtensionsIncluded": [], - "filesIncluded": [], - "filesExcluded": [], - "directoriesExcluded": [], - "includeUvLock": true - }, - "functions": {} -} diff --git a/packages/uipath-agent-framework/samples/quickstart-workflow/.vscode/settings.json b/packages/uipath-agent-framework/samples/quickstart-workflow/.vscode/settings.json new file mode 100644 index 00000000..af690fcd --- /dev/null +++ b/packages/uipath-agent-framework/samples/quickstart-workflow/.vscode/settings.json @@ -0,0 +1,3 @@ +{ + "python-envs.pythonProjects": [] +} diff --git a/packages/uipath-agent-framework/samples/quickstart-agent/README.md b/packages/uipath-agent-framework/samples/quickstart-workflow/README.md similarity index 100% rename from packages/uipath-agent-framework/samples/quickstart-agent/README.md rename to packages/uipath-agent-framework/samples/quickstart-workflow/README.md diff --git a/packages/uipath-agent-framework/samples/quickstart-agent/agent_framework.json b/packages/uipath-agent-framework/samples/quickstart-workflow/agent_framework.json similarity index 100% rename from packages/uipath-agent-framework/samples/quickstart-agent/agent_framework.json rename to packages/uipath-agent-framework/samples/quickstart-workflow/agent_framework.json diff --git a/packages/uipath-agent-framework/samples/quickstart-agent/main.py b/packages/uipath-agent-framework/samples/quickstart-workflow/main.py similarity index 91% rename from packages/uipath-agent-framework/samples/quickstart-agent/main.py rename to packages/uipath-agent-framework/samples/quickstart-workflow/main.py index b9df826f..f857088c 100644 --- a/packages/uipath-agent-framework/samples/quickstart-agent/main.py +++ b/packages/uipath-agent-framework/samples/quickstart-workflow/main.py @@ -1,4 +1,5 @@ import httpx +from agent_framework import WorkflowBuilder from uipath_agent_framework.chat import UiPathOpenAIChatClient @@ -55,8 +56,11 @@ def get_weather(location: str) -> str: client = UiPathOpenAIChatClient(model="gpt-5-mini-2025-08-07") -agent = client.as_agent( +weather_agent = client.as_agent( name="weather_agent", instructions="You are a helpful weather assistant. Use the get_weather tool to provide weather information.", tools=[get_weather], ) + +workflow = WorkflowBuilder(start_executor=weather_agent).build() +agent = workflow.as_agent(name="weather_workflow") diff --git a/packages/uipath-agent-framework/samples/multi-agent/pyproject.toml b/packages/uipath-agent-framework/samples/quickstart-workflow/pyproject.toml similarity index 82% rename from packages/uipath-agent-framework/samples/multi-agent/pyproject.toml rename to packages/uipath-agent-framework/samples/quickstart-workflow/pyproject.toml index d7e3ff9f..a7559704 100644 --- a/packages/uipath-agent-framework/samples/multi-agent/pyproject.toml +++ b/packages/uipath-agent-framework/samples/quickstart-workflow/pyproject.toml @@ -1,7 +1,7 @@ [project] -name = "multi-agent" +name = "quickstart-agent" version = "0.0.1" -description = "Multi-agent Agent Framework example with coordinator pattern" +description = "Quickstart Agent Framework agent example" authors = [{ name = "John Doe" }] readme = "README.md" requires-python = ">=3.11" @@ -20,5 +20,5 @@ dev = [ prerelease = "allow" [tool.uv.sources] -uipath-agent-framework = { path = "../../", editable = true } uipath-dev = { path = "../../../../../uipath-dev-python", editable = true } +uipath-agent-framework = { path = "../../", editable = true } diff --git a/packages/uipath-agent-framework/samples/multi-agent/uipath.json b/packages/uipath-agent-framework/samples/quickstart-workflow/uipath.json similarity index 100% rename from packages/uipath-agent-framework/samples/multi-agent/uipath.json rename to packages/uipath-agent-framework/samples/quickstart-workflow/uipath.json diff --git a/packages/uipath-agent-framework/src/uipath_agent_framework/chat/__init__.py b/packages/uipath-agent-framework/src/uipath_agent_framework/chat/__init__.py index ae2cdbd6..d088450c 100644 --- a/packages/uipath-agent-framework/src/uipath_agent_framework/chat/__init__.py +++ b/packages/uipath-agent-framework/src/uipath_agent_framework/chat/__init__.py @@ -18,10 +18,15 @@ def __getattr__(name): from .anthropic import UiPathAnthropicClient return UiPathAnthropicClient + if name == "requires_approval": + from .hitl import requires_approval + + return requires_approval raise AttributeError(f"module {__name__!r} has no attribute {name!r}") __all__ = [ "UiPathOpenAIChatClient", "UiPathAnthropicClient", + "requires_approval", ] diff --git a/packages/uipath-agent-framework/src/uipath_agent_framework/chat/hitl.py b/packages/uipath-agent-framework/src/uipath_agent_framework/chat/hitl.py new file mode 100644 index 00000000..944907f8 --- /dev/null +++ b/packages/uipath-agent-framework/src/uipath_agent_framework/chat/hitl.py @@ -0,0 +1,70 @@ +"""Human-in-the-loop (HITL) support for Agent Framework. + +Provides the ``requires_approval`` decorator for marking tool functions +that need human approval before execution. + +Example:: + + from uipath_agent_framework.chat import requires_approval + + @requires_approval + def transfer_funds(from_account: str, to_account: str, amount: float) -> str: + \"\"\"Transfer funds between accounts.\"\"\" + return f"Transferred ${amount:.2f} from {from_account} to {to_account}" +""" + +from __future__ import annotations + +from collections.abc import Callable +from typing import Any, overload + +from agent_framework import FunctionTool, tool + + +@overload +def requires_approval(func: Callable[..., Any]) -> FunctionTool[Any]: ... + + +@overload +def requires_approval( + func: None = None, +) -> Callable[[Callable[..., Any]], FunctionTool[Any]]: ... + + +def requires_approval( + func: Callable[..., Any] | None = None, +) -> FunctionTool[Any] | Callable[[Callable[..., Any]], FunctionTool[Any]]: + """Decorator that marks a tool function as requiring human approval. + + When the agent calls a tool decorated with ``@requires_approval``, + execution suspends and waits for a human to approve or reject + the tool call before proceeding. + + Can be used with or without parentheses:: + + @requires_approval + def my_tool(arg: str) -> str: ... + + @requires_approval() + def my_tool(arg: str) -> str: ... + + Under the hood, this sets ``approval_mode="always_require"`` on the + resulting ``FunctionTool``. The workflow runtime then intercepts the + call via ``request_info`` and suspends execution for human approval. + + Args: + func: The tool function to wrap. If None, returns a decorator. + + Returns: + A FunctionTool with approval_mode set to "always_require". + """ + if func is not None: + return tool(func, approval_mode="always_require") + + def decorator(fn: Callable[..., Any]) -> FunctionTool[Any]: + return tool(fn, approval_mode="always_require") + + return decorator + + +__all__ = ["requires_approval"] diff --git a/packages/uipath-agent-framework/src/uipath_agent_framework/runtime/__init__.py b/packages/uipath-agent-framework/src/uipath_agent_framework/runtime/__init__.py index 98c40fb1..667479ee 100644 --- a/packages/uipath-agent-framework/src/uipath_agent_framework/runtime/__init__.py +++ b/packages/uipath-agent-framework/src/uipath_agent_framework/runtime/__init__.py @@ -7,6 +7,12 @@ ) from .factory import UiPathAgentFrameworkRuntimeFactory +from .interrupt import AgentInterruptException, BreakpointMiddleware +from .resumable_storage import ( + ScopedCheckpointStorage, + SqliteCheckpointStorage, + SqliteResumableStorage, +) from .runtime import UiPathAgentFrameworkRuntime from .schema import get_agent_graph, get_entrypoints_schema @@ -32,4 +38,9 @@ def create_factory( "get_agent_graph", "UiPathAgentFrameworkRuntimeFactory", "UiPathAgentFrameworkRuntime", + "AgentInterruptException", + "BreakpointMiddleware", + "SqliteResumableStorage", + "SqliteCheckpointStorage", + "ScopedCheckpointStorage", ] diff --git a/packages/uipath-agent-framework/src/uipath_agent_framework/runtime/breakpoints.py b/packages/uipath-agent-framework/src/uipath_agent_framework/runtime/breakpoints.py new file mode 100644 index 00000000..8fc85ebd --- /dev/null +++ b/packages/uipath-agent-framework/src/uipath_agent_framework/runtime/breakpoints.py @@ -0,0 +1,168 @@ +"""Breakpoint management for the Agent Framework runtime. + +Implements breakpoints by wrapping executor.execute() methods so that +execution pauses BEFORE the executor runs. This works regardless of +the inner agent type (RawAgent, Agent, etc.) because interception +happens at the executor level, not via agent middleware. + +The debug UI sends graph node IDs which are resolved to executor IDs: +- ``"*"`` → all executors +- Executor IDs (e.g. ``"triage"``) → that executor +- Tools container IDs (e.g. ``"triage_tools"``) → the parent executor +- Tool names (e.g. ``"calculator"``) → the executor that owns that tool +""" + +from __future__ import annotations + +from typing import Any +from uuid import uuid4 + +from agent_framework import AgentExecutor, WorkflowAgent +from uipath.runtime.debug import UiPathBreakpointResult + +from .interrupt import AgentInterruptException +from .schema import get_agent_tools, get_tool_name + +_ORIGINAL_EXECUTE_ATTR = "_bp_original_execute" + + +def _build_executor_tool_map(agent: WorkflowAgent) -> dict[str, set[str]]: + """Build a mapping of executor_id -> set of tool names.""" + tool_map: dict[str, set[str]] = {} + for exec_id, executor in agent.workflow.executors.items(): + if isinstance(executor, AgentExecutor): + inner = getattr(executor, "_agent", None) + if inner is not None: + tools = get_agent_tools(inner) + names = {get_tool_name(t) for t in tools if get_tool_name(t)} + tool_map[exec_id] = names + return tool_map + + +def _resolve_to_executor_ids( + agent: WorkflowAgent, + breakpoints: list[str] | str, +) -> set[str]: + """Resolve graph node IDs to executor IDs. + + Maps breakpoint node IDs from the debug UI to the actual executor IDs + in the workflow so we know which executors to wrap. + """ + if breakpoints == "*" or (isinstance(breakpoints, list) and "*" in breakpoints): + return set(agent.workflow.executors.keys()) + + all_executors = set(agent.workflow.executors.keys()) + tool_map = _build_executor_tool_map(agent) + + # Reverse map: tool_name -> executor_id + tool_to_executor: dict[str, str] = {} + for exec_id, tool_names in tool_map.items(): + for name in tool_names: + tool_to_executor[name] = exec_id + + resolved: set[str] = set() + + for bp in breakpoints: + if bp in all_executors: + # Direct executor ID + resolved.add(bp) + elif bp.endswith("_tools"): + # Tools container node → parent executor + exec_id = bp[: -len("_tools")] + if exec_id in all_executors: + resolved.add(exec_id) + elif bp in tool_to_executor: + # Tool name → owning executor + resolved.add(tool_to_executor[bp]) + + return resolved + + +def inject_breakpoint_middleware( + agent: WorkflowAgent, + breakpoints: list[str] | str, + skip_nodes: set[str] | None = None, +) -> None: + """Wrap executor.execute() to pause before breakpointed executors run. + + Replaces each matching executor's execute() with a wrapper that raises + AgentInterruptException(is_breakpoint=True) before the executor runs. + + Args: + agent: The workflow agent whose executors to wrap. + breakpoints: ``"*"`` or a list of node IDs from the debug UI. + skip_nodes: Executor IDs to skip (for resume after breakpoint). + In concurrent workflows multiple executors may have been + breakpointed across sequential resumes within the same + superstep, so all of them must be skipped. + """ + executor_ids = _resolve_to_executor_ids(agent, breakpoints) + + for exec_id in executor_ids: + executor = agent.workflow.executors.get(exec_id) + if executor is None: + continue + + # Don't double-wrap + if hasattr(executor, _ORIGINAL_EXECUTE_ATTR): + continue + + # Skip executors already resumed past + if skip_nodes and exec_id in skip_nodes: + continue + + original = executor.execute + + async def wrapped_execute( + message: Any, + source_executor_ids: list[str], + state: Any, + runner_context: Any, + trace_contexts: list[dict[str, str]] | None = None, + source_span_ids: list[str] | None = None, + *, + _exec_id: str = exec_id, + ) -> None: + raise AgentInterruptException( + interrupt_id=str(uuid4()), + suspend_value={ + "type": "breakpoint", + "node_id": _exec_id, + }, + is_breakpoint=True, + ) + + setattr(executor, _ORIGINAL_EXECUTE_ATTR, original) + executor.execute = wrapped_execute # type: ignore[assignment] + + +def remove_breakpoint_middleware(agent: WorkflowAgent) -> None: + """Restore original execute methods on all wrapped executors.""" + for executor in agent.workflow.executors.values(): + original = getattr(executor, _ORIGINAL_EXECUTE_ATTR, None) + if original is not None: + executor.execute = original # type: ignore[assignment] + delattr(executor, _ORIGINAL_EXECUTE_ATTR) + + +def create_breakpoint_result( + exc: AgentInterruptException, +) -> UiPathBreakpointResult: + """Create a UiPathBreakpointResult from a breakpoint interrupt.""" + node_id = "" + if isinstance(exc.suspend_value, dict): + node_id = exc.suspend_value.get("node_id", "") + + return UiPathBreakpointResult( + breakpoint_node=node_id, + breakpoint_type="before", + current_state=exc.suspend_value, + next_nodes=[node_id] if node_id else [], + ) + + +__all__ = [ + "create_breakpoint_result", + "inject_breakpoint_middleware", + "remove_breakpoint_middleware", +] diff --git a/packages/uipath-agent-framework/src/uipath_agent_framework/runtime/factory.py b/packages/uipath-agent-framework/src/uipath_agent_framework/runtime/factory.py index ef4b72c3..611f8bae 100644 --- a/packages/uipath-agent-framework/src/uipath_agent_framework/runtime/factory.py +++ b/packages/uipath-agent-framework/src/uipath_agent_framework/runtime/factory.py @@ -4,14 +4,16 @@ import os from typing import Any -from agent_framework import BaseAgent +from agent_framework import WorkflowAgent from agent_framework.observability import enable_instrumentation from openinference.instrumentation.agent_framework import ( AgentFrameworkToOpenInferenceProcessor, ) from opentelemetry import trace from opentelemetry.sdk.trace import TracerProvider +from uipath.platform.resume_triggers import UiPathResumeTriggerHandler from uipath.runtime import ( + UiPathResumableRuntime, UiPathRuntimeContext, UiPathRuntimeFactorySettings, UiPathRuntimeProtocol, @@ -25,8 +27,11 @@ UiPathAgentFrameworkRuntimeError, ) from uipath_agent_framework.runtime.loader import AgentFrameworkAgentLoader +from uipath_agent_framework.runtime.resumable_storage import ( + ScopedCheckpointStorage, + SqliteResumableStorage, +) from uipath_agent_framework.runtime.runtime import UiPathAgentFrameworkRuntime -from uipath_agent_framework.runtime.storage import SqliteSessionStore class UiPathAgentFrameworkRuntimeFactory: @@ -47,8 +52,8 @@ def __init__( self._agent_loaders: dict[str, AgentFrameworkAgentLoader] = {} - self._session_store: SqliteSessionStore | None = None - self._session_store_lock = asyncio.Lock() + self._storage: SqliteResumableStorage | None = None + self._storage_lock = asyncio.Lock() self._setup_instrumentation() @@ -56,7 +61,6 @@ def _setup_instrumentation(self) -> None: """Setup tracing and instrumentation.""" enable_instrumentation() - # Add OpenInference span processor for Arize Phoenix compatibility tracer_provider = trace.get_tracer_provider() if isinstance(tracer_provider, TracerProvider): tracer_provider.add_span_processor(AgentFrameworkToOpenInferenceProcessor()) @@ -84,16 +88,16 @@ def _get_db_path(self) -> str: os.remove(path) return path - async def _get_session_store(self) -> SqliteSessionStore: - """Get or create the shared session store instance.""" - async with self._session_store_lock: - if self._session_store is None: + async def _get_storage(self) -> SqliteResumableStorage: + """Get or create the shared resumable storage instance.""" + async with self._storage_lock: + if self._storage is None: db_path = self._get_db_path() - self._session_store = SqliteSessionStore(db_path) - await self._session_store.setup() - return self._session_store + self._storage = SqliteResumableStorage(db_path) + await self._storage.setup() + return self._storage - async def _load_agent(self, entrypoint: str) -> BaseAgent: + async def _load_agent(self, entrypoint: str) -> WorkflowAgent: """ Load an agent for the given entrypoint. @@ -101,7 +105,7 @@ async def _load_agent(self, entrypoint: str) -> BaseAgent: entrypoint: Name of the agent to load Returns: - The loaded BaseAgent + The loaded WorkflowAgent Raises: UiPathAgentFrameworkRuntimeError: If agent cannot be loaded @@ -162,15 +166,14 @@ async def _load_agent(self, entrypoint: str) -> BaseAgent: UiPathErrorCategory.USER, ) from e - async def _resolve_agent(self, entrypoint: str) -> BaseAgent: + async def _resolve_agent(self, entrypoint: str) -> WorkflowAgent: """Load a fresh agent instance for the given entrypoint. Agents are NOT cached — each runtime gets its own instance. - Agent Framework agents (especially WorkflowAgents) hold internal - mutable state (e.g. Workflow._is_running) that prevents concurrent - executions on the same instance. Since the factory creates multiple - runtimes in parallel (one per request), sharing an agent instance - would cause "Workflow is already running" errors. + WorkflowAgents hold internal mutable state (e.g. Workflow._is_running) + that prevents concurrent executions on the same instance. Since the + factory creates multiple runtimes in parallel (one per request), + sharing an agent instance would cause "Workflow is already running" errors. """ return await self._load_agent(entrypoint) @@ -188,7 +191,7 @@ def discover_entrypoints(self) -> list[str]: async def get_storage(self) -> UiPathRuntimeStorageProtocol | None: """Get the shared storage instance.""" - return None + return await self._get_storage() async def get_settings(self) -> UiPathRuntimeFactorySettings | None: """Get the factory settings.""" @@ -196,23 +199,34 @@ async def get_settings(self) -> UiPathRuntimeFactorySettings | None: async def _create_runtime_instance( self, - agent: BaseAgent, + agent: WorkflowAgent, runtime_id: str, entrypoint: str, ) -> UiPathRuntimeProtocol: """Create a runtime instance from an agent. - Creates the runtime with a shared SqliteSessionStore for persistent - conversation history. Sessions are isolated by runtime_id — each - runtime instance gets its own conversation state. + Creates the runtime with a shared SqliteResumableStorage for persistent + conversation history and HITL trigger management. Wraps with + UiPathResumableRuntime for resume trigger lifecycle handling. """ - session_store = await self._get_session_store() + storage = await self._get_storage() + checkpoint_storage = ScopedCheckpointStorage( + storage.checkpoint_storage, runtime_id + ) - return UiPathAgentFrameworkRuntime( + base_runtime = UiPathAgentFrameworkRuntime( agent=agent, runtime_id=runtime_id, entrypoint=entrypoint, - session_store=session_store, + checkpoint_storage=checkpoint_storage, + resumable_storage=storage, + ) + + return UiPathResumableRuntime( + delegate=base_runtime, + storage=storage, + trigger_manager=UiPathResumeTriggerHandler(), + runtime_id=runtime_id, ) async def new_runtime( @@ -244,6 +258,6 @@ async def dispose(self) -> None: self._agent_loaders.clear() - if self._session_store: - await self._session_store.dispose() - self._session_store = None + if self._storage: + await self._storage.dispose() + self._storage = None diff --git a/packages/uipath-agent-framework/src/uipath_agent_framework/runtime/interrupt.py b/packages/uipath-agent-framework/src/uipath_agent_framework/runtime/interrupt.py new file mode 100644 index 00000000..c004eba0 --- /dev/null +++ b/packages/uipath-agent-framework/src/uipath_agent_framework/runtime/interrupt.py @@ -0,0 +1,131 @@ +"""Interrupt infrastructure for human-in-the-loop (HITL) support. + +Provides: +- AgentInterruptException: raised by middleware to suspend agent execution +- BreakpointMiddleware: intercepts tools matching breakpoint configuration +""" + +from __future__ import annotations + +from collections.abc import Awaitable, Callable +from typing import Any +from uuid import uuid4 + +from agent_framework._middleware import ( + FunctionInvocationContext, + FunctionMiddleware, +) + + +class AgentInterruptException(Exception): + """Raised by middleware to suspend agent execution for HITL. + + Carries an interrupt_id and suspend_value that the runtime uses + to create a UiPathRuntimeResult with SUSPENDED status. + When is_breakpoint is True, the runtime returns UiPathBreakpointResult + instead, which bypasses trigger management and is handled by the + debug runtime layer. + """ + + def __init__( + self, + interrupt_id: str, + suspend_value: Any, + *, + is_breakpoint: bool = False, + ) -> None: + self.interrupt_id = interrupt_id + self.suspend_value = suspend_value + self.is_breakpoint = is_breakpoint + super().__init__(f"Agent interrupted: {interrupt_id}") + + +class BreakpointMiddleware(FunctionMiddleware): + """Intercepts tools matching breakpoint configuration. + + Breakpoint flow (orchestrated by UiPathDebugRuntime): + + 1. UiPathDebugRuntime gets breakpoints from debug bridge and passes + them via ``options.breakpoints`` to the integration runtime. + 2. The integration runtime injects this middleware into the agent's + middleware chain with the breakpoint list. + 3. When the agent calls a matching tool, this middleware raises + ``AgentInterruptException(is_breakpoint=True)`` BEFORE the tool runs. + 4. The runtime catches the exception and returns + ``UiPathBreakpointResult`` (a SUSPENDED result subclass). + 5. ``UiPathResumableRuntime`` passes the breakpoint result through + (no trigger management — breakpoints bypass the trigger system). + 6. ``UiPathDebugRuntime`` sees ``UiPathBreakpointResult``, notifies + the debug bridge, and waits for a resume command. + 7. On resume, ``UiPathDebugRuntime`` re-invokes the runtime with + ``options.resume=True, input=None``. The runtime re-injects this + middleware with ``skip_tool`` set to the previously-interrupted + tool name so the first matching call is let through (one-shot). + 8. After the skipped call completes, subsequent breakpoint-matching + tool calls will pause again. + """ + + def __init__( + self, + breakpoints: list[str] | str, + skip_tool: str | None = None, + ) -> None: + self.breakpoints = breakpoints + self._skip_tool = skip_tool + + def _matches(self, tool_name: str) -> bool: + if self.breakpoints == "*": + return True + if isinstance(self.breakpoints, list): + return tool_name in self.breakpoints + return False + + async def process( + self, + context: FunctionInvocationContext, + call_next: Callable[[], Awaitable[None]], + ) -> None: + tool = context.function + tool_name = getattr(tool, "name", "") + + if not self._matches(tool_name): + await call_next() + return + + # One-shot skip for the tool we just resumed from + if self._skip_tool and tool_name == self._skip_tool: + self._skip_tool = None + await call_next() + return + + # Legacy metadata-based resume (kept for backward compatibility) + if context.metadata.get("_breakpoint_continue"): + await call_next() + return + + interrupt_id = str(uuid4()) + + input_value = None + if context.arguments is not None: + try: + input_value = context.arguments.model_dump() + except Exception: + input_value = str(context.arguments) + + suspend_value = { + "type": "breakpoint", + "tool_name": tool_name, + "input_value": input_value, + } + + raise AgentInterruptException( + interrupt_id=interrupt_id, + suspend_value=suspend_value, + is_breakpoint=True, + ) + + +__all__ = [ + "AgentInterruptException", + "BreakpointMiddleware", +] diff --git a/packages/uipath-agent-framework/src/uipath_agent_framework/runtime/loader.py b/packages/uipath-agent-framework/src/uipath_agent_framework/runtime/loader.py index e9908629..c52125df 100644 --- a/packages/uipath-agent-framework/src/uipath_agent_framework/runtime/loader.py +++ b/packages/uipath-agent-framework/src/uipath_agent_framework/runtime/loader.py @@ -5,7 +5,7 @@ from pathlib import Path from typing import Any, Self -from agent_framework import BaseAgent +from agent_framework import BaseAgent, WorkflowAgent from uipath.runtime.errors import UiPathErrorCategory from .errors import ( @@ -51,12 +51,12 @@ def from_path_string(cls, name: str, file_path: str) -> Self: file, variable = file_path.split(":", 1) return cls(name=name, file_path=file, variable_name=variable) - async def load(self) -> BaseAgent: + async def load(self) -> WorkflowAgent: """ Load and return the agent. Returns: - An instance of the loaded BaseAgent. + An instance of the loaded WorkflowAgent. Raises: UiPathAgentFrameworkRuntimeError: If loading fails @@ -97,11 +97,12 @@ async def load(self) -> BaseAgent: ) agent = await self._resolve_agent(agent_object) - if not isinstance(agent, BaseAgent): + if not isinstance(agent, WorkflowAgent): raise UiPathAgentFrameworkRuntimeError( code=UiPathAgentFrameworkErrorCode.AGENT_TYPE_ERROR, title="Invalid agent type", - detail=f"Expected BaseAgent, got '{type(agent).__name__}'.", + detail=f"Expected WorkflowAgent, got '{type(agent).__name__}'. " + "Use workflow.as_agent() to create a WorkflowAgent.", category=UiPathErrorCategory.USER, ) diff --git a/packages/uipath-agent-framework/src/uipath_agent_framework/runtime/resumable_storage.py b/packages/uipath-agent-framework/src/uipath_agent_framework/runtime/resumable_storage.py new file mode 100644 index 00000000..ddcb2dbd --- /dev/null +++ b/packages/uipath-agent-framework/src/uipath_agent_framework/runtime/resumable_storage.py @@ -0,0 +1,506 @@ +"""SQLite resumable storage for Agent Framework. + +Provides SqliteResumableStorage with resume trigger, key-value, and checkpoint +tables, plus SqliteCheckpointStorage and ScopedCheckpointStorage for the +Agent Framework checkpoint protocol. +""" + +from __future__ import annotations + +import asyncio +import json +import logging +import os +from typing import Any + +import aiosqlite +from agent_framework import WorkflowCheckpoint +from agent_framework._workflows._checkpoint_encoding import ( + decode_checkpoint_value, + encode_checkpoint_value, +) +from pydantic import BaseModel +from uipath.runtime import ( + UiPathApiTrigger, + UiPathResumeTrigger, + UiPathResumeTriggerName, + UiPathResumeTriggerType, +) + +logger = logging.getLogger(__name__) + + +class SqliteResumableStorage: + """SQLite storage with resume triggers, KV, and checkpoint tables. + + Tables: + - resume_triggers: interrupt trigger persistence for HITL + - runtime_kv: arbitrary key-value storage scoped by runtime_id + namespace + - checkpoints: workflow checkpoint persistence + """ + + def __init__(self, db_path: str) -> None: + self.db_path = db_path + self._conn: aiosqlite.Connection | None = None + self._lock = asyncio.Lock() + self.checkpoint_storage: SqliteCheckpointStorage | None = None + + async def setup(self) -> None: + """Create all tables and initialize checkpoint storage.""" + dir_name = os.path.dirname(self.db_path) + if dir_name: + os.makedirs(dir_name, exist_ok=True) + + conn = await self._get_conn() + async with self._lock: + await conn.execute( + """ + CREATE TABLE IF NOT EXISTS resume_triggers ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + runtime_id TEXT NOT NULL, + interrupt_id TEXT NOT NULL, + trigger_data TEXT NOT NULL, + timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP + ) + """ + ) + await conn.execute( + """ + CREATE INDEX IF NOT EXISTS idx_resume_triggers_runtime_id + ON resume_triggers(runtime_id) + """ + ) + await conn.execute( + """ + CREATE TABLE IF NOT EXISTS runtime_kv ( + runtime_id TEXT NOT NULL, + namespace TEXT NOT NULL, + key TEXT NOT NULL, + value TEXT, + timestamp DATETIME DEFAULT (strftime('%Y-%m-%d %H:%M:%S', 'now', 'utc')), + PRIMARY KEY (runtime_id, namespace, key) + ) + """ + ) + await conn.execute( + """ + CREATE TABLE IF NOT EXISTS checkpoints ( + checkpoint_id TEXT PRIMARY KEY, + workflow_name TEXT NOT NULL, + checkpoint_data TEXT NOT NULL, + timestamp TEXT NOT NULL + ) + """ + ) + await conn.execute( + """ + CREATE INDEX IF NOT EXISTS idx_checkpoints_workflow + ON checkpoints(workflow_name) + """ + ) + await conn.commit() + + self.checkpoint_storage = SqliteCheckpointStorage(self) + logger.debug("Resumable storage tables initialized at %s", self.db_path) + + async def _get_conn(self) -> aiosqlite.Connection: + """Get or create the database connection.""" + if self._conn is None: + self._conn = await aiosqlite.connect(self.db_path, timeout=30.0) + await self._conn.execute("PRAGMA journal_mode=WAL") + await self._conn.execute("PRAGMA busy_timeout=30000") + await self._conn.execute("PRAGMA synchronous=NORMAL") + await self._conn.commit() + return self._conn + + async def dispose(self) -> None: + """Close the database connection.""" + if self._conn: + await self._conn.close() + self._conn = None + + # ------------------------------------------------------------------ + # Resume trigger persistence + # ------------------------------------------------------------------ + + async def save_triggers( + self, runtime_id: str, triggers: list[UiPathResumeTrigger] + ) -> None: + """Save resume triggers, replacing any existing ones for this runtime_id.""" + conn = await self._get_conn() + async with self._lock: + await conn.execute( + "DELETE FROM resume_triggers WHERE runtime_id = ?", + (runtime_id,), + ) + for trigger in triggers: + trigger_dict = self._serialize_trigger(trigger) + trigger_json = json.dumps(trigger_dict) + await conn.execute( + "INSERT INTO resume_triggers (runtime_id, interrupt_id, trigger_data) VALUES (?, ?, ?)", + (runtime_id, trigger.interrupt_id, trigger_json), + ) + await conn.commit() + + logger.debug( + "Saved %d triggers for runtime_id=%s", len(triggers), runtime_id + ) + + async def get_triggers( + self, runtime_id: str + ) -> list[UiPathResumeTrigger] | None: + """Retrieve all resume triggers for this runtime_id.""" + conn = await self._get_conn() + async with self._lock: + cursor = await conn.execute( + "SELECT trigger_data FROM resume_triggers WHERE runtime_id = ? ORDER BY id ASC", + (runtime_id,), + ) + rows = await cursor.fetchall() + + if not rows: + return None + + triggers = [] + for row in rows: + trigger_dict = json.loads(row[0]) + triggers.append(self._deserialize_trigger(trigger_dict)) + return triggers + + async def delete_trigger( + self, runtime_id: str, trigger: UiPathResumeTrigger + ) -> None: + """Delete a specific resume trigger by runtime_id and interrupt_id.""" + conn = await self._get_conn() + async with self._lock: + await conn.execute( + "DELETE FROM resume_triggers WHERE runtime_id = ? AND interrupt_id = ?", + (runtime_id, trigger.interrupt_id), + ) + await conn.commit() + + logger.debug( + "Deleted trigger %s for runtime_id=%s", + trigger.interrupt_id, + runtime_id, + ) + + # ------------------------------------------------------------------ + # Key-value storage + # ------------------------------------------------------------------ + + async def set_value( + self, + runtime_id: str, + namespace: str, + key: str, + value: Any, + ) -> None: + """Save arbitrary key-value pair scoped by runtime_id + namespace.""" + value_text = self._dump_value(value) + + conn = await self._get_conn() + async with self._lock: + await conn.execute( + """ + INSERT INTO runtime_kv (runtime_id, namespace, key, value) + VALUES (?, ?, ?, ?) + ON CONFLICT(runtime_id, namespace, key) + DO UPDATE SET + value = excluded.value, + timestamp = (strftime('%Y-%m-%d %H:%M:%S', 'now', 'utc')) + """, + (runtime_id, namespace, key, value_text), + ) + await conn.commit() + + async def get_value( + self, runtime_id: str, namespace: str, key: str + ) -> Any: + """Get arbitrary key-value pair scoped by runtime_id + namespace.""" + conn = await self._get_conn() + async with self._lock: + cursor = await conn.execute( + """ + SELECT value FROM runtime_kv + WHERE runtime_id = ? AND namespace = ? AND key = ? + LIMIT 1 + """, + (runtime_id, namespace, key), + ) + row = await cursor.fetchone() + + if not row: + return None + + return self._load_value(row[0]) + + # ------------------------------------------------------------------ + # Serialization helpers + # ------------------------------------------------------------------ + + @staticmethod + def _serialize_trigger(trigger: UiPathResumeTrigger) -> dict[str, Any]: + """Serialize a resume trigger to a dictionary.""" + trigger_key = ( + trigger.api_resume.inbox_id if trigger.api_resume else trigger.item_key + ) + payload = ( + json.dumps(trigger.payload) + if isinstance(trigger.payload, dict) + else str(trigger.payload) + if trigger.payload + else None + ) + + return { + "type": trigger.trigger_type.value, + "key": trigger_key, + "name": trigger.trigger_name.value, + "payload": payload, + "interrupt_id": trigger.interrupt_id, + "folder_path": trigger.folder_path, + "folder_key": trigger.folder_key, + } + + @staticmethod + def _deserialize_trigger( + trigger_data: dict[str, Any], + ) -> UiPathResumeTrigger: + """Deserialize a resume trigger from a dictionary.""" + trigger_type = trigger_data["type"] + key = trigger_data["key"] + name = trigger_data["name"] + folder_path = trigger_data.get("folder_path") + folder_key = trigger_data.get("folder_key") + payload = trigger_data.get("payload") + interrupt_id = trigger_data.get("interrupt_id") + + resume_trigger = UiPathResumeTrigger( + trigger_type=UiPathResumeTriggerType(trigger_type), + trigger_name=UiPathResumeTriggerName(name), + item_key=key, + folder_path=folder_path, + folder_key=folder_key, + payload=payload, + interrupt_id=interrupt_id, + ) + + if resume_trigger.trigger_type == UiPathResumeTriggerType.API: + resume_trigger.api_resume = UiPathApiTrigger( + inbox_id=resume_trigger.item_key, + request=resume_trigger.payload, + ) + + return resume_trigger + + @staticmethod + def _dump_value( + value: str | dict[str, Any] | BaseModel | None, + ) -> str | None: + """Serialize a value for KV storage.""" + if value is None: + return None + if isinstance(value, BaseModel): + return "j:" + json.dumps(value.model_dump()) + if isinstance(value, dict): + return "j:" + json.dumps(value) + if isinstance(value, str): + return "s:" + value + raise TypeError("Value must be str, dict, BaseModel or None.") + + @staticmethod + def _load_value(raw: str | None) -> Any: + """Deserialize a value from KV storage.""" + if raw is None: + return None + if raw.startswith("s:"): + return raw[2:] + if raw.startswith("j:"): + return json.loads(raw[2:]) + return raw + + +class SqliteCheckpointStorage: + """SQLite-backed CheckpointStorage implementation. + + Implements the agent_framework CheckpointStorage protocol using the + checkpoints table managed by SqliteResumableStorage. + """ + + def __init__(self, storage: SqliteResumableStorage) -> None: + self._storage = storage + + async def save(self, checkpoint: WorkflowCheckpoint) -> str: + """Save a checkpoint and return its ID.""" + checkpoint_dict = checkpoint.to_dict() + encoded = encode_checkpoint_value(checkpoint_dict) + checkpoint_data = json.dumps(encoded, ensure_ascii=False) + + conn = await self._storage._get_conn() + async with self._storage._lock: + await conn.execute( + """ + INSERT OR REPLACE INTO checkpoints + (checkpoint_id, workflow_name, checkpoint_data, timestamp) + VALUES (?, ?, ?, ?) + """, + ( + checkpoint.checkpoint_id, + checkpoint.workflow_name, + checkpoint_data, + checkpoint.timestamp, + ), + ) + await conn.commit() + + logger.debug("Saved checkpoint %s", checkpoint.checkpoint_id) + return checkpoint.checkpoint_id + + async def load(self, checkpoint_id: str) -> WorkflowCheckpoint: + """Load a checkpoint by ID.""" + conn = await self._storage._get_conn() + async with self._storage._lock: + cursor = await conn.execute( + "SELECT checkpoint_data FROM checkpoints WHERE checkpoint_id = ?", + (checkpoint_id,), + ) + row = await cursor.fetchone() + + if not row: + from agent_framework._workflows._checkpoint import ( + WorkflowCheckpointException, + ) + + raise WorkflowCheckpointException( + f"No checkpoint found with ID {checkpoint_id}" + ) + + encoded = json.loads(row[0]) + decoded = decode_checkpoint_value(encoded) + return WorkflowCheckpoint.from_dict(decoded) + + async def list_checkpoints( + self, *, workflow_name: str + ) -> list[WorkflowCheckpoint]: + """List checkpoint objects for a given workflow name.""" + conn = await self._storage._get_conn() + async with self._storage._lock: + cursor = await conn.execute( + "SELECT checkpoint_data FROM checkpoints WHERE workflow_name = ? ORDER BY timestamp ASC", + (workflow_name,), + ) + rows = await cursor.fetchall() + + checkpoints = [] + for row in rows: + encoded = json.loads(row[0]) + decoded = decode_checkpoint_value(encoded) + checkpoints.append(WorkflowCheckpoint.from_dict(decoded)) + return checkpoints + + async def delete(self, checkpoint_id: str) -> bool: + """Delete a checkpoint by ID.""" + conn = await self._storage._get_conn() + async with self._storage._lock: + cursor = await conn.execute( + "DELETE FROM checkpoints WHERE checkpoint_id = ?", + (checkpoint_id,), + ) + await conn.commit() + return cursor.rowcount > 0 + + async def get_latest( + self, *, workflow_name: str + ) -> WorkflowCheckpoint | None: + """Get the latest checkpoint for a given workflow name.""" + conn = await self._storage._get_conn() + async with self._storage._lock: + cursor = await conn.execute( + "SELECT checkpoint_data FROM checkpoints WHERE workflow_name = ? ORDER BY timestamp DESC LIMIT 1", + (workflow_name,), + ) + row = await cursor.fetchone() + + if not row: + return None + + encoded = json.loads(row[0]) + decoded = decode_checkpoint_value(encoded) + return WorkflowCheckpoint.from_dict(decoded) + + async def list_checkpoint_ids( + self, *, workflow_name: str + ) -> list[str]: + """List checkpoint IDs for a given workflow name.""" + conn = await self._storage._get_conn() + async with self._storage._lock: + cursor = await conn.execute( + "SELECT checkpoint_id FROM checkpoints WHERE workflow_name = ? ORDER BY timestamp ASC", + (workflow_name,), + ) + rows = await cursor.fetchall() + + return [row[0] for row in rows] + + +class ScopedCheckpointStorage: + """Thin wrapper that prefixes workflow_name with a runtime scope. + + When multiple runtimes share the same SqliteResumableStorage, this + ensures checkpoint isolation by prefixing workflow_name queries with + ``{runtime_id}::``. + """ + + def __init__( + self, delegate: SqliteCheckpointStorage, runtime_id: str + ) -> None: + self._delegate = delegate + self._scope = f"{runtime_id}::" + + def _scoped_name(self, workflow_name: str) -> str: + return self._scope + workflow_name + + async def save(self, checkpoint: WorkflowCheckpoint) -> str: + """Save with scoped workflow_name.""" + checkpoint.workflow_name = self._scoped_name(checkpoint.workflow_name) + return await self._delegate.save(checkpoint) + + async def load(self, checkpoint_id: str) -> WorkflowCheckpoint: + """Load by checkpoint_id (globally unique).""" + return await self._delegate.load(checkpoint_id) + + async def list_checkpoints( + self, *, workflow_name: str + ) -> list[WorkflowCheckpoint]: + """List checkpoints with scoped workflow_name.""" + return await self._delegate.list_checkpoints( + workflow_name=self._scoped_name(workflow_name) + ) + + async def delete(self, checkpoint_id: str) -> bool: + """Delete by checkpoint_id (globally unique).""" + return await self._delegate.delete(checkpoint_id) + + async def get_latest( + self, *, workflow_name: str + ) -> WorkflowCheckpoint | None: + """Get latest checkpoint with scoped workflow_name.""" + return await self._delegate.get_latest( + workflow_name=self._scoped_name(workflow_name) + ) + + async def list_checkpoint_ids( + self, *, workflow_name: str + ) -> list[str]: + """List checkpoint IDs with scoped workflow_name.""" + return await self._delegate.list_checkpoint_ids( + workflow_name=self._scoped_name(workflow_name) + ) + + +__all__ = [ + "SqliteResumableStorage", + "SqliteCheckpointStorage", + "ScopedCheckpointStorage", +] diff --git a/packages/uipath-agent-framework/src/uipath_agent_framework/runtime/runtime.py b/packages/uipath-agent-framework/src/uipath_agent_framework/runtime/runtime.py index 41979302..0a960e10 100644 --- a/packages/uipath-agent-framework/src/uipath_agent_framework/runtime/runtime.py +++ b/packages/uipath-agent-framework/src/uipath_agent_framework/runtime/runtime.py @@ -5,15 +5,16 @@ from uuid import uuid4 from agent_framework import ( + AgentExecutor, AgentResponse, AgentResponseUpdate, AgentSession, - BaseAgent, Content, - FunctionTool, Message, WorkflowAgent, + WorkflowRunResult, ) +from pydantic import BaseModel from uipath.core.serialization import serialize_json from uipath.runtime import ( UiPathExecuteOptions, @@ -30,50 +31,16 @@ ) from uipath.runtime.schema import UiPathRuntimeSchema +from .breakpoints import ( + create_breakpoint_result, + inject_breakpoint_middleware, + remove_breakpoint_middleware, +) from .errors import UiPathAgentFrameworkErrorCode, UiPathAgentFrameworkRuntimeError +from .interrupt import AgentInterruptException from .messages import AgentFrameworkChatMessagesMapper -from .schema import ( - extract_agent_from_tool, - get_agent_graph, - get_agent_tools, - get_entrypoints_schema, -) -from .storage import SqliteSessionStore - - -class _StreamState: - """Mutable state tracker for agent streaming. - - Holds the sub-agent metadata (computed once) and the active node - state that changes as function_call / function_result events arrive. - """ - - __slots__ = ( - "root_agent", - "active_agent", - "active_tools", - "call_ids", - "agent_tool_names", - "tool_name_to_agent", - "sub_agents_with_tools", - ) - - def __init__( - self, - root_agent: str, - agent_tool_names: set[str], - tool_name_to_agent: dict[str, str], - sub_agents_with_tools: set[str], - ) -> None: - self.root_agent = root_agent - self.active_agent: str = root_agent - self.active_tools: str | None = None - # call_id → sub-agent name (content.name on function_result - # may be empty for as_tool() wrappers, so we match by call_id). - self.call_ids: dict[str, str] = {} - self.agent_tool_names = agent_tool_names - self.tool_name_to_agent = tool_name_to_agent - self.sub_agents_with_tools = sub_agents_with_tools +from .resumable_storage import ScopedCheckpointStorage, SqliteResumableStorage +from .schema import get_agent_graph, get_entrypoints_schema class UiPathAgentFrameworkRuntime: @@ -81,72 +48,106 @@ class UiPathAgentFrameworkRuntime: def __init__( self, - agent: BaseAgent, + agent: WorkflowAgent, runtime_id: str | None = None, entrypoint: str | None = None, - session_store: SqliteSessionStore | None = None, + checkpoint_storage: ScopedCheckpointStorage | None = None, + resumable_storage: SqliteResumableStorage | None = None, ): - self.agent: BaseAgent = agent + self.agent: WorkflowAgent = agent self.runtime_id: str = runtime_id or "default" self.entrypoint: str | None = entrypoint self.chat = AgentFrameworkChatMessagesMapper() - self._session_store = session_store + self._checkpoint_storage = checkpoint_storage + self._resumable_storage = resumable_storage + self._resume_responses: dict[str, Any] | None = None + self._breakpoint_skip_nodes: set[str] = set() + self._last_checkpoint_id: str | None = None # ------------------------------------------------------------------ - # Sub-agent introspection + # Checkpoint helpers # ------------------------------------------------------------------ - @staticmethod - def _build_sub_agent_info( - agent: BaseAgent, - ) -> tuple[set[str], dict[str, str], set[str]]: - """Inspect the agent's tools once to extract all sub-agent metadata. - - Returns: - agent_tool_names: tool names that wrap sub-agents - tool_name_to_agent: mapping from tool name → sub-agent node name - sub_agents_with_tools: sub-agent names that own tools - """ - agent_tool_names: set[str] = set() - tool_name_to_agent: dict[str, str] = {} - sub_agents_with_tools: set[str] = set() - - for tool in get_agent_tools(agent): - inner_agent = extract_agent_from_tool(tool) - if inner_agent is None or not isinstance(tool, FunctionTool): - continue + async def _get_latest_checkpoint_id(self) -> str | None: + """Get the latest checkpoint ID for this workflow.""" + if not self._checkpoint_storage: + return None + workflow_name = self.agent.workflow.name + checkpoint = await self._checkpoint_storage.get_latest( + workflow_name=workflow_name + ) + return checkpoint.checkpoint_id if checkpoint else None - agent_tool_names.add(tool.name) - inner_name = inner_agent.name or "agent" - tool_name_to_agent[tool.name] = inner_name + async def _save_breakpoint_state(self, original_input: str) -> None: + """Persist breakpoint state to KV storage for resume. - if get_agent_tools(inner_agent): - sub_agents_with_tools.add(inner_name) + The skip_nodes set accumulates across resumes so that concurrent + executors breakpointed in the same superstep are all skipped on + subsequent resumes (prevents the infinite-cycle bug). + """ + if not self._resumable_storage: + return + checkpoint_id = await self._get_latest_checkpoint_id() + state = { + "skip_nodes": list(self._breakpoint_skip_nodes), + "checkpoint_id": checkpoint_id, + "original_input": original_input, + } + await self._resumable_storage.set_value( + self.runtime_id, "breakpoint", "state", state + ) - return agent_tool_names, tool_name_to_agent, sub_agents_with_tools + async def _load_breakpoint_state(self) -> dict[str, Any] | None: + """Load breakpoint state from KV storage.""" + if not self._resumable_storage: + return None + state = await self._resumable_storage.get_value( + self.runtime_id, "breakpoint", "state" + ) + if state and isinstance(state, dict): + self._breakpoint_skip_nodes = set(state.get("skip_nodes", [])) + self._last_checkpoint_id = state.get("checkpoint_id") + return state + return None # ------------------------------------------------------------------ - # Session helpers + # Session helpers (multi-turn conversation history) # ------------------------------------------------------------------ async def _load_session(self) -> AgentSession: - """Load or create an AgentSession for this runtime_id.""" - if self._session_store: - session_data = await self._session_store.load_session(self.runtime_id) - if session_data is not None: - return AgentSession.from_dict(session_data) # type: ignore[attr-defined] + """Load or create an AgentSession for this runtime_id. - return self.agent.create_session(session_id=self.runtime_id) # type: ignore[attr-defined] + Sessions maintain conversation history across turns. This is separate + from checkpoints which handle workflow interruption/resume. + """ + if self._resumable_storage: + session_data = await self._resumable_storage.get_value( + self.runtime_id, "session", "data" + ) + if session_data is not None and isinstance(session_data, dict): + return AgentSession.from_dict(session_data) + + return self.agent.create_session(session_id=self.runtime_id) async def _save_session(self, session: AgentSession) -> None: """Persist the session state after execution.""" - if self._session_store: - session_data = session.to_dict() # type: ignore[attr-defined] - await self._session_store.save_session(self.runtime_id, session_data) + if self._resumable_storage: + session_data = session.to_dict() + await self._resumable_storage.set_value( + self.runtime_id, "session", "data", session_data + ) - # ------------------------------------------------------------------ - # Execute (non-streaming) - # ------------------------------------------------------------------ + def _apply_session_to_executors(self, session: AgentSession) -> None: + """Propagate the loaded session to all AgentExecutors in the workflow. + + Each AgentExecutor uses a unique source_id key inside session.state, + so sharing one session across all executors is safe and ensures + conversation history is preserved across turns. + """ + workflow = self.agent.workflow + for executor in workflow.executors.values(): + if isinstance(executor, AgentExecutor): + executor._session = session async def execute( self, @@ -154,57 +155,159 @@ async def execute( options: UiPathExecuteOptions | None = None, ) -> UiPathRuntimeResult: """Execute the agent with the provided input and return the result.""" + session = None try: - user_input = self._prepare_input(input) - session = await self._load_session() - response = await self.agent.run(user_input, session=session) # type: ignore[attr-defined] - await self._save_session(session) - output = self._extract_output(response) + is_resuming = bool(options and options.resume) + + workflow = self.agent.workflow + + if is_resuming and input is not None: + # HITL resume: checkpoint restores executor state (including session) + self._resume_responses = input + + # Inject breakpoints (no skip needed for HITL resume) + if options and options.breakpoints: + inject_breakpoint_middleware(self.agent, options.breakpoints) + + if self._resume_responses: + checkpoint_id = await self._get_latest_checkpoint_id() + result = await workflow.run( + responses=self._resume_responses, + checkpoint_id=checkpoint_id, + checkpoint_storage=self._checkpoint_storage, + ) + self._resume_responses = None + else: + result = await workflow.run( + message="", + checkpoint_storage=self._checkpoint_storage, + ) + elif is_resuming: + # Breakpoint resume: restore from checkpoint + bp_state = await self._load_breakpoint_state() + checkpoint_id = self._last_checkpoint_id + original_input = bp_state.get("original_input", "") if bp_state else "" + + # Inject breakpoints, skipping all previously-resumed executors + if options and options.breakpoints: + inject_breakpoint_middleware( + self.agent, options.breakpoints, self._breakpoint_skip_nodes + ) + + if checkpoint_id: + result = await workflow.run( + checkpoint_id=checkpoint_id, + checkpoint_storage=self._checkpoint_storage, + ) + else: + result = await workflow.run( + message=original_input, + checkpoint_storage=self._checkpoint_storage, + ) + else: + # Fresh run: load session for multi-turn conversation history + session = await self._load_session() + self._apply_session_to_executors(session) + + # Inject breakpoints for fresh runs + if options and options.breakpoints: + inject_breakpoint_middleware(self.agent, options.breakpoints) + + user_input = self._prepare_input(input) + result = await workflow.run( + message=user_input, + checkpoint_storage=self._checkpoint_storage, + ) + + if session is not None: + await self._save_session(session) + output = self._extract_workflow_output(result) return self._create_success_result(output) + except AgentInterruptException as e: + if session is not None: + await self._save_session(session) + if e.is_breakpoint: + node_id = ( + e.suspend_value.get("node_id", "") + if isinstance(e.suspend_value, dict) + else "" + ) + self._breakpoint_skip_nodes.add(node_id) + original_input = self._prepare_input(input) if not is_resuming else "" + await self._save_breakpoint_state(original_input) + return create_breakpoint_result(e) + return self._create_suspended_result(e) except Exception as e: raise self._create_runtime_error(e) from e - - # ------------------------------------------------------------------ - # Stream (main entry) - # ------------------------------------------------------------------ + finally: + remove_breakpoint_middleware(self.agent) async def stream( self, input: dict[str, Any] | None = None, options: UiPathStreamOptions | None = None, ) -> AsyncGenerator[UiPathRuntimeEvent, None]: - """Stream agent execution events in real-time. - - Two streaming paths: - - WorkflowAgent: raw workflow events (executor_invoked/completed). - - Regular BaseAgent: function_call/function_result content tracking. - """ + """Stream workflow execution events in real-time.""" try: - user_input = self._prepare_input(input) - session = await self._load_session() - agent_name = self.agent.name or "agent" + is_resuming = bool(options and options.resume) + session = None + + if is_resuming and input is not None: + # HITL resume: input contains response data + self._resume_responses = input + user_input = self._prepare_input(None) + + # Inject breakpoints (no skip needed for HITL resume) + if options and options.breakpoints: + inject_breakpoint_middleware(self.agent, options.breakpoints) + + elif is_resuming: + # Breakpoint resume: restore original_input and session + self._resume_responses = None + bp_state = await self._load_breakpoint_state() + user_input = bp_state.get("original_input", "") if bp_state else "" + + # Load session for context preservation across the breakpoint + session = await self._load_session() + self._apply_session_to_executors(session) + + # Inject breakpoints, skipping all previously-resumed executors + if options and options.breakpoints: + inject_breakpoint_middleware( + self.agent, options.breakpoints, self._breakpoint_skip_nodes + ) - if isinstance(self.agent, WorkflowAgent): - async for event in self._stream_workflow( - user_input, session, agent_name - ): - yield event else: - async for event in self._stream_agent(user_input, session, agent_name): - yield event + # Fresh run + self._resume_responses = None + user_input = self._prepare_input(input) + + # Load session for multi-turn conversation history + session = await self._load_session() + self._apply_session_to_executors(session) + + # Inject breakpoints for fresh runs + if options and options.breakpoints: + inject_breakpoint_middleware(self.agent, options.breakpoints) + + agent_name = self.agent.name or "agent" + + async for event in self._stream_workflow( + user_input, agent_name, is_resuming, session + ): + yield event except Exception as e: raise self._create_runtime_error(e) from e - - # ------------------------------------------------------------------ - # Workflow streaming - # ------------------------------------------------------------------ + finally: + remove_breakpoint_middleware(self.agent) async def _stream_workflow( self, user_input: str, - session: AgentSession, agent_name: str, + is_resuming: bool = False, + session: AgentSession | None = None, ) -> AsyncGenerator[UiPathRuntimeEvent, None]: """Stream workflow execution with real-time executor lifecycle events.""" assert isinstance(self.agent, WorkflowAgent) @@ -216,289 +319,148 @@ async def _stream_workflow( phase=UiPathRuntimeStatePhase.STARTED, ) - response_stream = workflow.run(message=user_input, stream=True) - - async for event in response_stream: - if event.type == "executor_invoked": - yield UiPathRuntimeStateEvent( - payload=self._serialize_event_data(event.data), - node_name=event.executor_id, - phase=UiPathRuntimeStatePhase.STARTED, - ) - elif event.type == "executor_completed": - yield UiPathRuntimeStateEvent( - payload=self._serialize_event_data(event.data), - node_name=event.executor_id, - phase=UiPathRuntimeStatePhase.COMPLETED, - ) - elif event.type == "output": - for msg_event in self._extract_workflow_messages(event.data): - yield UiPathRuntimeMessageEvent(payload=msg_event) - - yield UiPathRuntimeStateEvent( - payload={}, - node_name=agent_name, - phase=UiPathRuntimeStatePhase.COMPLETED, - ) - - for msg_event in self.chat.close_message(): - yield UiPathRuntimeMessageEvent(payload=msg_event) - - await self._save_session(session) - - final_result = await response_stream.get_final_response() - output = self._extract_workflow_output(final_result) - yield self._create_success_result(output) - - # ------------------------------------------------------------------ - # Agent streaming - # ------------------------------------------------------------------ - - async def _stream_agent( - self, - user_input: str, - session: AgentSession, - agent_name: str, - ) -> AsyncGenerator[UiPathRuntimeEvent, None]: - """Stream regular BaseAgent execution with tool/sub-agent tracking.""" - state = _StreamState(agent_name, *self._build_sub_agent_info(self.agent)) - - yield UiPathRuntimeStateEvent( - payload={}, - node_name=agent_name, - phase=UiPathRuntimeStatePhase.STARTED, - ) - - response_stream = self.agent.run(user_input, stream=True, session=session) # type: ignore[attr-defined] - async for update in response_stream: - if not isinstance(update, AgentResponseUpdate): - continue - - for content in update.contents or []: - if not isinstance(content, Content): - continue - - for event in self._process_agent_content(state, content): - yield event + # Choose workflow.run() mode based on resume type + if self._resume_responses: + # HITL resume: pass responses to workflow with checkpoint + checkpoint_id = await self._get_latest_checkpoint_id() + response_stream = workflow.run( + responses=self._resume_responses, + checkpoint_id=checkpoint_id, + checkpoint_storage=self._checkpoint_storage, + stream=True, + ) + self._resume_responses = None + elif self._last_checkpoint_id: + # Breakpoint resume with checkpoint: restore and continue + checkpoint_id = self._last_checkpoint_id + self._last_checkpoint_id = None + response_stream = workflow.run( + checkpoint_id=checkpoint_id, + checkpoint_storage=self._checkpoint_storage, + stream=True, + ) + else: + # Fresh run (or breakpoint resume without checkpoint — uses original_input) + response_stream = workflow.run( + message=user_input, + checkpoint_storage=self._checkpoint_storage, + stream=True, + ) - for msg in self.chat.map_streaming_content(content): - yield UiPathRuntimeMessageEvent(payload=msg) + request_info_map: dict[str, Any] = {} + is_suspended = False + + # Emit an early STARTED event for the start executor so the graph + # visualization shows it immediately rather than after it finishes. + # The framework's _run_workflow_with_tracing awaits the entire start + # executor before yielding any executor events, which means the real + # executor_invoked arrives only after execution completes. + pre_emitted_executor: str | None = None + if not is_resuming: + start_id = workflow.start_executor_id + yield UiPathRuntimeStateEvent( + payload={}, + node_name=start_id, + phase=UiPathRuntimeStatePhase.STARTED, + ) + pre_emitted_executor = start_id - # Teardown: close remaining nodes - if state.active_tools: + try: + async for event in response_stream: + if event.type == "request_info": + request_info_map[event.request_id] = event.data + elif event.type == "executor_invoked": + # Skip the duplicate for the start executor we already emitted + if pre_emitted_executor and event.executor_id == pre_emitted_executor: + pre_emitted_executor = None + continue + yield UiPathRuntimeStateEvent( + payload=self._serialize_event_data(event.data), + node_name=event.executor_id, + phase=UiPathRuntimeStatePhase.STARTED, + ) + elif event.type == "executor_completed": + yield UiPathRuntimeStateEvent( + payload=self._serialize_event_data( + self._filter_completed_data(event.data) + ), + node_name=event.executor_id, + phase=UiPathRuntimeStatePhase.COMPLETED, + ) + elif event.type == "output": + executor_id = getattr(event, "executor_id", None) or "" + for tool_event in self._extract_tool_state_events( + event.data, executor_id + ): + yield tool_event + for msg_event in self._extract_workflow_messages(event.data): + yield UiPathRuntimeMessageEvent(payload=msg_event) + + # Detect workflow suspension via state + if event.type == "status" and str(event.state) == "IDLE_WITH_PENDING_REQUESTS": + is_suspended = True + except AgentInterruptException as e: + # Breakpoint or HITL interrupt fired inside an inner agent yield UiPathRuntimeStateEvent( payload={}, - node_name=state.active_tools, + node_name=agent_name, phase=UiPathRuntimeStatePhase.COMPLETED, ) + for msg_event in self.chat.close_message(): + yield UiPathRuntimeMessageEvent(payload=msg_event) + + if session is not None: + await self._save_session(session) + + if e.is_breakpoint: + node_id = ( + e.suspend_value.get("node_id", "") + if isinstance(e.suspend_value, dict) + else "" + ) + self._breakpoint_skip_nodes.add(node_id) + await self._save_breakpoint_state(user_input) + yield create_breakpoint_result(e) + else: + yield self._create_suspended_result(e) + return + yield UiPathRuntimeStateEvent( payload={}, node_name=agent_name, phase=UiPathRuntimeStatePhase.COMPLETED, ) - for msg in self.chat.close_message(): - yield UiPathRuntimeMessageEvent(payload=msg) - - await self._save_session(session) - - final_response = await response_stream.get_final_response() - yield self._create_success_result(self._extract_output(final_response)) - - # ------------------------------------------------------------------ - # Agent content event handlers - # ------------------------------------------------------------------ - - def _process_agent_content( - self, s: _StreamState, content: Content - ) -> list[UiPathRuntimeStateEvent]: - """Dispatch a streaming Content to the appropriate handler.""" - if content.type == "function_call": - if not content.name: - return [] - if content.name in s.agent_tool_names: - return self._on_sub_agent_call(s, content) - return self._on_tool_call(s, content) - - if content.type == "function_result": - return self._on_function_result(s, content) - - return [] - - def _on_sub_agent_call( - self, s: _StreamState, content: Content - ) -> list[UiPathRuntimeStateEvent]: - """Handle a function_call that invokes a sub-agent via as_tool().""" - call_name = content.name or "" - sub_agent = s.tool_name_to_agent.get(call_name, call_name) - events: list[UiPathRuntimeStateEvent] = [] - - if content.call_id: - s.call_ids[content.call_id] = sub_agent - - # Close any active tools node - if s.active_tools: - events.append( - UiPathRuntimeStateEvent( - payload={}, - node_name=s.active_tools, - phase=UiPathRuntimeStatePhase.COMPLETED, - ) - ) - s.active_tools = None + for msg_event in self.chat.close_message(): + yield UiPathRuntimeMessageEvent(payload=msg_event) - payload = {"function_name": call_name} + if session is not None: + await self._save_session(session) - # Start sub-agent node - events.append( - UiPathRuntimeStateEvent( - payload=payload, - node_name=sub_agent, - phase=UiPathRuntimeStatePhase.STARTED, - ) - ) - s.active_agent = sub_agent - - # Sub-agent's internal tool calls are opaque in the as_tool() - # stream — emit a synthetic STARTED on its tools node. - if sub_agent in s.sub_agents_with_tools: - tools_node = f"{sub_agent}_tools" - events.append( - UiPathRuntimeStateEvent( - payload=payload, - node_name=tools_node, - phase=UiPathRuntimeStatePhase.STARTED, - ) + if is_suspended and request_info_map: + yield UiPathRuntimeResult( + output=request_info_map, + status=UiPathRuntimeStatus.SUSPENDED, ) - s.active_tools = tools_node else: - events.append( - UiPathRuntimeStateEvent( - payload=payload, - node_name=sub_agent, - metadata={"event_type": "function_call"}, - ) - ) - - return events - - def _on_tool_call( - self, s: _StreamState, content: Content - ) -> list[UiPathRuntimeStateEvent]: - """Handle a regular (non-agent) function_call.""" - call_name = content.name or "" - tools_node = f"{s.active_agent}_tools" - events: list[UiPathRuntimeStateEvent] = [] - - if s.active_tools != tools_node: - if s.active_tools: - events.append( - UiPathRuntimeStateEvent( - payload={}, - node_name=s.active_tools, - phase=UiPathRuntimeStatePhase.COMPLETED, - ) - ) - s.active_tools = tools_node - events.append( - UiPathRuntimeStateEvent( - payload={}, - node_name=tools_node, - phase=UiPathRuntimeStatePhase.STARTED, - ) - ) - - events.append( - UiPathRuntimeStateEvent( - payload={"function_name": call_name}, - node_name=tools_node, - metadata={"event_type": "function_call"}, - ) - ) - return events - - def _on_function_result( - self, s: _StreamState, content: Content - ) -> list[UiPathRuntimeStateEvent]: - """Handle a function_result for either a sub-agent or regular tool.""" - call_id = content.call_id or "" - result_name = content.name or "" - events: list[UiPathRuntimeStateEvent] = [] - - # Match sub-agent by call_id first (reliable), fall back to name - matched = s.call_ids.pop(call_id, None) - if matched is None and result_name in s.agent_tool_names: - matched = s.tool_name_to_agent.get(result_name, result_name) - - result_payload = self._build_result_payload(content) - - if matched: - # Sub-agent completed — close tools, then agent, re-start root - if s.active_tools and s.active_tools == f"{matched}_tools": - events.append( - UiPathRuntimeStateEvent( - payload=result_payload, - node_name=s.active_tools, - phase=UiPathRuntimeStatePhase.COMPLETED, - ) - ) - s.active_tools = None - - events.append( - UiPathRuntimeStateEvent( - payload=result_payload, - node_name=matched, - phase=UiPathRuntimeStatePhase.COMPLETED, - ) - ) - s.active_agent = s.root_agent - events.append( - UiPathRuntimeStateEvent( - payload={}, - node_name=s.root_agent, - phase=UiPathRuntimeStatePhase.STARTED, - ) - ) - elif s.active_tools: - # Regular tool completed - events.append( - UiPathRuntimeStateEvent( - payload=result_payload, - node_name=s.active_tools, - phase=UiPathRuntimeStatePhase.COMPLETED, - ) - ) - s.active_tools = None - if s.active_agent: - events.append( - UiPathRuntimeStateEvent( - payload={}, - node_name=s.active_agent, - phase=UiPathRuntimeStatePhase.STARTED, - ) - ) - - return events - - # ------------------------------------------------------------------ - # Payload / serialization helpers - # ------------------------------------------------------------------ + final_result = await response_stream.get_final_response() + output = self._extract_workflow_output(final_result) + yield self._create_success_result(output) @staticmethod - def _build_result_payload(content: Content) -> dict[str, Any]: - """Build a payload dict from a function_result Content.""" - payload: dict[str, Any] = {} - if content.name: - payload["function_name"] = content.name - if content.result is not None: - try: - payload["function_response"] = json.loads( - serialize_json(content.result) - ) - except Exception: - payload["function_response"] = str(content.result) - return payload + def _filter_completed_data(data: Any) -> Any: + """Strip streaming AgentResponseUpdate chunks from executor_completed data. + + The framework packs sent_messages + yielded_outputs into the + executor_completed event. In streaming mode the yielded_outputs are + individual AgentResponseUpdate token chunks which bloat the payload. + Keep only the non-update items (e.g. AgentExecutorResponse). + """ + if not isinstance(data, list): + return data + filtered = [item for item in data if not isinstance(item, AgentResponseUpdate)] + return filtered if filtered else None @staticmethod def _serialize_event_data(data: Any) -> dict[str, Any]: @@ -513,9 +475,56 @@ def _serialize_event_data(data: Any) -> dict[str, Any]: except Exception: return {"data": str(data)} - # ------------------------------------------------------------------ - # Workflow message / output extraction - # ------------------------------------------------------------------ + @staticmethod + def _extract_tool_state_events( + data: Any, executor_id: str + ) -> list[UiPathRuntimeStateEvent]: + """Extract tool-node state events from output data containing function calls/results. + + Looks for Content objects with type 'function_call' (tool start) and + 'function_result' (tool end) and emits STARTED/COMPLETED StateEvents + for the '{executor_id}_tools' node. + """ + contents: list[Any] = [] + + if isinstance(data, AgentResponseUpdate): + contents = list(data.contents or []) + elif isinstance(data, AgentResponse): + for message in data.messages or []: + contents.extend(message.contents or []) + elif isinstance(data, Message): + contents = list(data.contents or []) + elif isinstance(data, list): + events: list[UiPathRuntimeStateEvent] = [] + for item in data: + events.extend( + UiPathAgentFrameworkRuntime._extract_tool_state_events( + item, executor_id + ) + ) + return events + + tool_node = f"{executor_id}_tools" + tool_events: list[UiPathRuntimeStateEvent] = [] + for content in contents: + if isinstance(content, Content): + if content.type == "function_call" and content.name: + tool_events.append( + UiPathRuntimeStateEvent( + payload={"tool_name": content.name}, + node_name=tool_node, + phase=UiPathRuntimeStatePhase.STARTED, + ) + ) + elif content.type == "function_result": + tool_events.append( + UiPathRuntimeStateEvent( + payload={}, + node_name=tool_node, + phase=UiPathRuntimeStatePhase.COMPLETED, + ) + ) + return tool_events def _extract_workflow_messages(self, data: Any) -> list[Any]: """Extract UiPath conversation message events from workflow output data.""" @@ -540,11 +549,9 @@ def _extract_workflow_messages(self, data: Any) -> list[Any]: return events - def _extract_workflow_output(self, result: Any) -> Any: + def _extract_workflow_output(self, result: WorkflowRunResult) -> Any: """Extract output from WorkflowRunResult.""" - outputs: list[Any] = [] - if hasattr(result, "get_outputs"): - outputs = result.get_outputs() + outputs = result.get_outputs() if not outputs: return "" @@ -600,10 +607,6 @@ def _extract_text_from_data(data: Any) -> str: return "\n\n".join(parts) return "" - # ------------------------------------------------------------------ - # Input / output / result helpers - # ------------------------------------------------------------------ - def _prepare_input(self, input: dict[str, Any] | None) -> str: """Prepare input string from UiPath input dictionary.""" if not input: @@ -614,12 +617,6 @@ def _prepare_input(self, input: dict[str, Any] | None) -> str: return json.dumps(input) - def _extract_output(self, response: AgentResponse) -> Any: - """Extract output from agent response.""" - if response.text: - return response.text - return str(response) if response else "" - def _create_success_result(self, output: Any) -> UiPathRuntimeResult: """Create result for successful completion.""" serialized_output = json.loads(serialize_json(output)) @@ -639,11 +636,28 @@ def _create_success_result(self, output: Any) -> UiPathRuntimeResult: status=UiPathRuntimeStatus.SUCCESSFUL, ) + def _create_suspended_result( + self, exc: AgentInterruptException + ) -> UiPathRuntimeResult: + """Create a SUSPENDED result from an AgentInterruptException.""" + interrupt_value = exc.suspend_value + if isinstance(interrupt_value, BaseModel): + interrupt_value = interrupt_value.model_dump(by_alias=True) + + return UiPathRuntimeResult( + output={exc.interrupt_id: interrupt_value}, + status=UiPathRuntimeStatus.SUSPENDED, + ) + def _create_runtime_error(self, e: Exception) -> UiPathAgentFrameworkRuntimeError: """Handle execution errors and create appropriate runtime error.""" if isinstance(e, UiPathAgentFrameworkRuntimeError): return e + # Let AgentInterruptException propagate (handled by caller) + if isinstance(e, AgentInterruptException): + raise e + detail = f"Error: {str(e)}" if isinstance(e, json.JSONDecodeError): @@ -669,10 +683,6 @@ def _create_runtime_error(self, e: Exception) -> UiPathAgentFrameworkRuntimeErro UiPathErrorCategory.USER, ) - # ------------------------------------------------------------------ - # Schema - # ------------------------------------------------------------------ - async def get_schema(self) -> UiPathRuntimeSchema: """Get schema for this Agent Framework runtime.""" entrypoints_schema = get_entrypoints_schema(self.agent) diff --git a/packages/uipath-agent-framework/src/uipath_agent_framework/runtime/schema.py b/packages/uipath-agent-framework/src/uipath_agent_framework/runtime/schema.py index f46f02b9..03324976 100644 --- a/packages/uipath-agent-framework/src/uipath_agent_framework/runtime/schema.py +++ b/packages/uipath-agent-framework/src/uipath_agent_framework/runtime/schema.py @@ -85,22 +85,24 @@ def _default_messages_schema() -> dict[str, Any]: def get_agent_graph(agent: BaseAgent) -> UiPathRuntimeGraph: """Extract graph structure from an Agent Framework agent. - Handles two cases: - 1. WorkflowAgent (from orchestrations): extracts the underlying Workflow's - executors and edge_groups to build a proper multi-agent graph. - 2. Regular BaseAgent: traverses the agent tree, inspecting tools for - agent-as-tool instances (created via BaseAgent.as_tool()). + Only WorkflowAgent is supported. Extracts the underlying Workflow's + executors and edge_groups to build a proper multi-agent graph. Args: - agent: An Agent Framework BaseAgent instance + agent: An Agent Framework BaseAgent instance (must be WorkflowAgent) Returns: UiPathRuntimeGraph with nodes and edges representing the agent structure + + Raises: + TypeError: If agent is not a WorkflowAgent """ if isinstance(agent, WorkflowAgent): return _build_workflow_graph(agent.workflow) - return _build_agent_graph(agent) + raise TypeError( + f"Only WorkflowAgent is supported for graph extraction, got {type(agent).__name__}" + ) def _build_workflow_graph(workflow: Workflow) -> UiPathRuntimeGraph: @@ -204,11 +206,7 @@ def _add_executor_tool_nodes( if not tools: return - regular_tools = [t for t in tools if extract_agent_from_tool(t) is None] - if not regular_tools: - return - - tool_names = [_get_tool_name(t) for t in regular_tools] + tool_names = [get_tool_name(t) for t in tools] tool_names = [n for n in tool_names if n] if tool_names: @@ -233,210 +231,6 @@ def _add_executor_tool_nodes( ) -def _build_agent_graph(agent: BaseAgent) -> UiPathRuntimeGraph: - """Build graph from a regular BaseAgent with tools. - - Traverses the agent tree, inspecting tools for agent-as-tool instances - (created via BaseAgent.as_tool()). For each agent-as-tool, creates a - separate node and recursively processes its own tools. - """ - nodes: list[UiPathRuntimeNode] = [] - edges: list[UiPathRuntimeEdge] = [] - visited: set[str] = set() - - def _add_agent_and_tools(current_agent: BaseAgent) -> None: - """Recursively add agent, its tools, and nested agents to the graph.""" - agent_name = current_agent.name or "agent" - - if agent_name in visited: - return - visited.add(agent_name) - - # Add agent node - nodes.append( - UiPathRuntimeNode( - id=agent_name, - name=agent_name, - type="node", - subgraph=None, - metadata=None, - ) - ) - - # Process tools: separate agent-as-tool from regular tools - _process_tools(current_agent, agent_name, nodes, edges, visited) - - # Add __start__ node - nodes.append( - UiPathRuntimeNode( - id="__start__", - name="__start__", - type="__start__", - subgraph=None, - metadata=None, - ) - ) - - _add_agent_and_tools(agent) - - agent_name = agent.name or "agent" - - # Add __end__ node - nodes.append( - UiPathRuntimeNode( - id="__end__", - name="__end__", - type="__end__", - subgraph=None, - metadata=None, - ) - ) - - # Connect start → agent → end - edges.append( - UiPathRuntimeEdge(source="__start__", target=agent_name, label="input") - ) - edges.append(UiPathRuntimeEdge(source=agent_name, target="__end__", label="output")) - - return UiPathRuntimeGraph(nodes=nodes, edges=edges) - - -def _process_tools( - agent: BaseAgent, - agent_name: str, - nodes: list[UiPathRuntimeNode], - edges: list[UiPathRuntimeEdge], - visited: set[str], -) -> None: - """Process an agent's tools list, separating agent-as-tools from regular tools.""" - tools = get_agent_tools(agent) - - agent_tools: list[tuple[str, BaseAgent]] = [] - regular_tools: list[Any] = [] - - for tool in tools: - inner_agent = extract_agent_from_tool(tool) - if inner_agent is not None: - tool_name = _get_tool_name(tool) or (inner_agent.name or "agent") - agent_tools.append((tool_name, inner_agent)) - else: - regular_tools.append(tool) - - # Agent-as-tool: add the wrapped agent as a node and recurse - for tool_name, tool_agent in agent_tools: - tool_agent_name = tool_agent.name or "agent" - if tool_agent_name not in visited: - # Recursively add the sub-agent and its own tools - _add_agent_node(tool_agent, nodes, edges, visited) - - edges.append( - UiPathRuntimeEdge( - source=agent_name, target=tool_agent_name, label=tool_name - ) - ) - edges.append( - UiPathRuntimeEdge(source=tool_agent_name, target=agent_name, label=None) - ) - - # Regular tools — aggregate into single tools node - if regular_tools: - tool_names = [_get_tool_name(t) for t in regular_tools] - tool_names = [n for n in tool_names if n] - - if tool_names: - tools_node_id = f"{agent_name}_tools" - nodes.append( - UiPathRuntimeNode( - id=tools_node_id, - name="tools", - type="tool", - subgraph=None, - metadata={ - "tool_names": tool_names, - "tool_count": len(tool_names), - }, - ) - ) - edges.append( - UiPathRuntimeEdge(source=agent_name, target=tools_node_id, label=None) - ) - edges.append( - UiPathRuntimeEdge(source=tools_node_id, target=agent_name, label=None) - ) - - -def _add_agent_node( - agent: BaseAgent, - nodes: list[UiPathRuntimeNode], - edges: list[UiPathRuntimeEdge], - visited: set[str], -) -> None: - """Add an agent node and recursively process its tools.""" - agent_name = agent.name or "agent" - - if agent_name in visited: - return - visited.add(agent_name) - - nodes.append( - UiPathRuntimeNode( - id=agent_name, - name=agent_name, - type="node", - subgraph=None, - metadata=None, - ) - ) - - _process_tools(agent, agent_name, nodes, edges, visited) - - -_extract_cache: dict[int, BaseAgent | None] = {} - - -def extract_agent_from_tool( - tool: FunctionTool | Callable[..., Any], -) -> BaseAgent | None: - """Extract a BaseAgent from a tool created via BaseAgent.as_tool(). - - The as_tool() method creates an async agent_wrapper closure that captures - `self` (the BaseAgent instance). We inspect the closure cells to find it. - Results are cached by tool identity to avoid repeated introspection. - """ - tool_id = id(tool) - if tool_id in _extract_cache: - return _extract_cache[tool_id] - - result = _extract_agent_from_closure(tool) - _extract_cache[tool_id] = result - return result - - -def _extract_agent_from_closure( - tool: FunctionTool | Callable[..., Any], -) -> BaseAgent | None: - if not isinstance(tool, FunctionTool): - return None - - func = getattr(tool, "func", None) - if func is None: - return None - - closure = getattr(func, "__closure__", None) - if not closure: - return None - - for cell in closure: - try: - content = cell.cell_contents - if isinstance(content, BaseAgent): - return content - except ValueError: - continue - - return None - - def get_agent_tools(agent: BaseAgent) -> list[Any]: """Extract tools list from an Agent Framework agent. @@ -445,7 +239,7 @@ def get_agent_tools(agent: BaseAgent) -> list[Any]: return getattr(agent, "default_options", {}).get("tools", []) -def _get_tool_name(tool: FunctionTool | Callable[..., Any]) -> str | None: +def get_tool_name(tool: FunctionTool | Callable[..., Any]) -> str | None: """Extract the name of a tool. Tools in Agent Framework are either FunctionTool instances or plain callables. @@ -460,6 +254,6 @@ def _get_tool_name(tool: FunctionTool | Callable[..., Any]) -> str | None: __all__ = [ "get_entrypoints_schema", "get_agent_graph", - "extract_agent_from_tool", "get_agent_tools", + "get_tool_name", ] diff --git a/packages/uipath-agent-framework/src/uipath_agent_framework/runtime/storage.py b/packages/uipath-agent-framework/src/uipath_agent_framework/runtime/storage.py deleted file mode 100644 index f56bbb8b..00000000 --- a/packages/uipath-agent-framework/src/uipath_agent_framework/runtime/storage.py +++ /dev/null @@ -1,116 +0,0 @@ -"""SQLite session store for Agent Framework agents. - -Persists AgentSession state between turns using SQLite, keyed by runtime_id. -Each runtime_id maps to an isolated session — conversation history accumulates -across calls via the InMemoryHistoryProvider that Agent Framework auto-injects. -""" - -from __future__ import annotations - -import asyncio -import json -import logging -import os -from typing import Any - -import aiosqlite - -logger = logging.getLogger(__name__) - - -class SqliteSessionStore: - """SQLite-backed store for Agent Framework session state. - - Stores serialized AgentSession dicts (via to_dict/from_dict) in a single - table, keyed by runtime_id. Thread-safe via asyncio lock. - """ - - def __init__(self, db_path: str) -> None: - self.db_path = db_path - self._conn: aiosqlite.Connection | None = None - self._lock = asyncio.Lock() - self._initialized = False - - async def setup(self) -> None: - """Ensure storage directory and database table exist.""" - dir_name = os.path.dirname(self.db_path) - if dir_name: - os.makedirs(dir_name, exist_ok=True) - - conn = await self._get_conn() - async with self._lock: - await conn.execute( - """ - CREATE TABLE IF NOT EXISTS sessions ( - runtime_id TEXT PRIMARY KEY, - session_data TEXT NOT NULL - ) - """ - ) - await conn.commit() - self._initialized = True - logger.debug("Session store initialized at %s", self.db_path) - - async def _get_conn(self) -> aiosqlite.Connection: - """Get or create the database connection.""" - if self._conn is None: - self._conn = await aiosqlite.connect(self.db_path, timeout=30.0) - await self._conn.execute("PRAGMA journal_mode=WAL") - await self._conn.execute("PRAGMA busy_timeout=30000") - await self._conn.execute("PRAGMA synchronous=NORMAL") - await self._conn.commit() - return self._conn - - async def load_session(self, runtime_id: str) -> dict[str, Any] | None: - """Load a serialized session dict for the given runtime_id. - - Returns None if no session exists for this runtime_id. - """ - if not self._initialized: - await self.setup() - - conn = await self._get_conn() - async with self._lock: - cursor = await conn.execute( - "SELECT session_data FROM sessions WHERE runtime_id = ?", - (runtime_id,), - ) - row = await cursor.fetchone() - - if not row: - logger.debug("No session found for runtime_id=%s", runtime_id) - return None - - logger.debug("Loaded session for runtime_id=%s", runtime_id) - return json.loads(row[0]) - - async def save_session(self, runtime_id: str, session_data: dict[str, Any]) -> None: - """Save a serialized session dict for the given runtime_id.""" - if not self._initialized: - await self.setup() - - data_json = json.dumps(session_data) - conn = await self._get_conn() - async with self._lock: - await conn.execute( - """ - INSERT INTO sessions (runtime_id, session_data) - VALUES (?, ?) - ON CONFLICT(runtime_id) DO UPDATE SET - session_data = excluded.session_data - """, - (runtime_id, data_json), - ) - await conn.commit() - - logger.debug("Saved session for runtime_id=%s", runtime_id) - - async def dispose(self) -> None: - """Close the database connection.""" - if self._conn: - await self._conn.close() - self._conn = None - self._initialized = False - - -__all__ = ["SqliteSessionStore"] diff --git a/packages/uipath-agent-framework/tests/test_breakpoints.py b/packages/uipath-agent-framework/tests/test_breakpoints.py new file mode 100644 index 00000000..6c38e87a --- /dev/null +++ b/packages/uipath-agent-framework/tests/test_breakpoints.py @@ -0,0 +1,926 @@ +"""Tests for executor-level breakpoints. + +Verifies that breakpoints pause execution before the targeted executor runs +and that resume correctly skips the breakpointed executor. + +Includes integration tests with UiPathDebugRuntime to simulate the full +debug flow: breakpoints → pause → resume → continue. +""" + +from typing import Any, cast +from unittest.mock import AsyncMock, MagicMock, Mock + +from agent_framework import RawAgent, WorkflowAgent, WorkflowBuilder +from uipath.runtime.debug import ( + UiPathBreakpointResult, + UiPathDebugProtocol, + UiPathDebugQuitError, + UiPathDebugRuntime, +) +from uipath.runtime.events import UiPathRuntimeStateEvent +from uipath.runtime.result import UiPathRuntimeResult, UiPathRuntimeStatus + +from uipath_agent_framework.runtime.breakpoints import ( + _resolve_to_executor_ids, + create_breakpoint_result, + inject_breakpoint_middleware, + remove_breakpoint_middleware, +) +from uipath_agent_framework.runtime.interrupt import AgentInterruptException +from uipath_agent_framework.runtime.runtime import UiPathAgentFrameworkRuntime + + +_mock_client = MagicMock() + + +class _MockWorkflowStream: + """Fake workflow stream for testing runtime orchestration without LLM. + + Simulates the async iterable returned by ``workflow.run(stream=True)`` + so we can control when breakpoint exceptions fire and what the final + response looks like. + """ + + def __init__( + self, + events: list[Any] | None = None, + exception: Exception | None = None, + final_output: str = "done", + ): + self._events = events or [] + self._exception = exception + self._final_output = final_output + + def __aiter__(self): + return self._aiter_impl() + + async def _aiter_impl(self): + for event in self._events: + yield event + if self._exception: + raise self._exception + + async def get_final_response(self): + mock_result = MagicMock() + mock_result.get_outputs.return_value = [self._final_output] + return mock_result + + +def _make_debug_bridge(**overrides: Any) -> UiPathDebugProtocol: + """Create a mock debug bridge with sensible defaults.""" + bridge: Mock = Mock(spec=UiPathDebugProtocol) + bridge.connect = AsyncMock() + bridge.disconnect = AsyncMock() + bridge.emit_execution_started = AsyncMock() + bridge.emit_execution_completed = AsyncMock() + bridge.emit_execution_error = AsyncMock() + bridge.emit_execution_suspended = AsyncMock() + bridge.emit_breakpoint_hit = AsyncMock() + bridge.emit_state_update = AsyncMock() + bridge.emit_execution_resumed = AsyncMock() + bridge.wait_for_resume = AsyncMock(return_value=None) + bridge.wait_for_terminate = AsyncMock() + bridge.get_breakpoints = Mock(return_value=[]) + for k, v in overrides.items(): + setattr(bridge, k, v) + return cast(UiPathDebugProtocol, bridge) + + +def calculator(expression: str) -> str: + """Evaluate a math expression.""" + return str(eval(expression)) + + +def search_web(query: str) -> str: + """Search the web.""" + return f"Results for: {query}" + + +# --------------------------------------------------------------------------- +# Resolution tests +# --------------------------------------------------------------------------- + + +class TestResolveBreakpoints: + """Verify graph node IDs are correctly resolved to executor IDs.""" + + def test_wildcard_resolves_to_all_executors(self): + a = RawAgent(_mock_client, name="agent_a") + b = RawAgent(_mock_client, name="agent_b") + workflow = WorkflowBuilder(start_executor=a).add_edge(a, b).build() + agent = WorkflowAgent(workflow=workflow, name="wf") + + result = _resolve_to_executor_ids(agent, "*") + assert result == set(workflow.executors.keys()) + + def test_executor_id_resolves_directly(self): + a = RawAgent(_mock_client, name="agent_a") + b = RawAgent(_mock_client, name="agent_b") + workflow = WorkflowBuilder(start_executor=a).add_edge(a, b).build() + agent = WorkflowAgent(workflow=workflow, name="wf") + + result = _resolve_to_executor_ids(agent, ["agent_a"]) + assert result == {"agent_a"} + + def test_tools_suffix_resolves_to_parent_executor(self): + a = RawAgent(_mock_client, name="agent_a", tools=[calculator]) + workflow = WorkflowBuilder(start_executor=a).build() + agent = WorkflowAgent(workflow=workflow, name="wf") + + result = _resolve_to_executor_ids(agent, ["agent_a_tools"]) + assert result == {"agent_a"} + + def test_tool_name_resolves_to_owning_executor(self): + a = RawAgent(_mock_client, name="agent_a", tools=[calculator]) + b = RawAgent(_mock_client, name="agent_b", tools=[search_web]) + workflow = WorkflowBuilder(start_executor=a).add_edge(a, b).build() + agent = WorkflowAgent(workflow=workflow, name="wf") + + result = _resolve_to_executor_ids(agent, ["calculator"]) + assert result == {"agent_a"} + + result = _resolve_to_executor_ids(agent, ["search_web"]) + assert result == {"agent_b"} + + def test_wildcard_in_list_resolves_to_all(self): + """Wildcard passed as ["*"] (list) also resolves to all executors.""" + a = RawAgent(_mock_client, name="agent_a") + b = RawAgent(_mock_client, name="agent_b") + workflow = WorkflowBuilder(start_executor=a).add_edge(a, b).build() + agent = WorkflowAgent(workflow=workflow, name="wf") + + result = _resolve_to_executor_ids(agent, ["*"]) + assert result == set(workflow.executors.keys()) + + def test_unknown_node_id_ignored(self): + a = RawAgent(_mock_client, name="agent_a") + workflow = WorkflowBuilder(start_executor=a).build() + agent = WorkflowAgent(workflow=workflow, name="wf") + + result = _resolve_to_executor_ids(agent, ["nonexistent"]) + assert result == set() + + def test_mixed_breakpoints(self): + a = RawAgent(_mock_client, name="agent_a", tools=[calculator]) + b = RawAgent(_mock_client, name="agent_b") + workflow = WorkflowBuilder(start_executor=a).add_edge(a, b).build() + agent = WorkflowAgent(workflow=workflow, name="wf") + + result = _resolve_to_executor_ids(agent, ["agent_b", "calculator"]) + assert result == {"agent_a", "agent_b"} + + +# --------------------------------------------------------------------------- +# Injection tests +# --------------------------------------------------------------------------- + + +class TestInjectBreakpoints: + """Verify executor wrapping for breakpoint injection.""" + + def test_inject_wraps_executor_execute(self): + """Injecting breakpoints replaces executor.execute with a wrapper.""" + a = RawAgent(_mock_client, name="agent_a") + workflow = WorkflowBuilder(start_executor=a).build() + agent = WorkflowAgent(workflow=workflow, name="wf") + + original = workflow.executors["agent_a"].execute + inject_breakpoint_middleware(agent, ["agent_a"]) + + assert workflow.executors["agent_a"].execute is not original + assert hasattr(workflow.executors["agent_a"], "_bp_original_execute") + + remove_breakpoint_middleware(agent) + + def test_remove_restores_original_execute(self): + """Removing breakpoints restores the original execute method.""" + a = RawAgent(_mock_client, name="agent_a") + workflow = WorkflowBuilder(start_executor=a).build() + agent = WorkflowAgent(workflow=workflow, name="wf") + + inject_breakpoint_middleware(agent, ["agent_a"]) + assert hasattr(workflow.executors["agent_a"], "_bp_original_execute") + + remove_breakpoint_middleware(agent) + assert not hasattr(workflow.executors["agent_a"], "_bp_original_execute") + + async def test_wrapped_execute_raises_interrupt(self): + """Wrapped executor raises AgentInterruptException on execute.""" + a = RawAgent(_mock_client, name="agent_a") + workflow = WorkflowBuilder(start_executor=a).build() + agent = WorkflowAgent(workflow=workflow, name="wf") + + inject_breakpoint_middleware(agent, ["agent_a"]) + + executor = workflow.executors["agent_a"] + try: + await executor.execute("msg", [], MagicMock(), MagicMock()) + assert False, "Should have raised AgentInterruptException" + except AgentInterruptException as e: + assert e.is_breakpoint is True + assert e.suspend_value["type"] == "breakpoint" + assert e.suspend_value["node_id"] == "agent_a" + finally: + remove_breakpoint_middleware(agent) + + def test_skip_nodes_not_wrapped(self): + """Executors in skip_nodes should not be wrapped (resume scenario).""" + a = RawAgent(_mock_client, name="agent_a") + b = RawAgent(_mock_client, name="agent_b") + workflow = WorkflowBuilder(start_executor=a).add_edge(a, b).build() + agent = WorkflowAgent(workflow=workflow, name="wf") + + inject_breakpoint_middleware( + agent, ["agent_a", "agent_b"], skip_nodes={"agent_a"} + ) + + # agent_a should NOT be wrapped (it's in skip_nodes) + assert not hasattr(workflow.executors["agent_a"], "_bp_original_execute") + # agent_b SHOULD be wrapped + assert hasattr(workflow.executors["agent_b"], "_bp_original_execute") + + remove_breakpoint_middleware(agent) + + def test_skip_nodes_multiple(self): + """Multiple skip_nodes are all excluded from wrapping.""" + a = RawAgent(_mock_client, name="agent_a") + b = RawAgent(_mock_client, name="agent_b") + c = RawAgent(_mock_client, name="agent_c") + workflow = ( + WorkflowBuilder(start_executor=a) + .add_edge(a, b) + .add_edge(a, c) + .build() + ) + agent = WorkflowAgent(workflow=workflow, name="wf") + + inject_breakpoint_middleware( + agent, "*", skip_nodes={"agent_a", "agent_b"} + ) + + assert not hasattr(workflow.executors["agent_a"], "_bp_original_execute") + assert not hasattr(workflow.executors["agent_b"], "_bp_original_execute") + assert hasattr(workflow.executors["agent_c"], "_bp_original_execute") + + remove_breakpoint_middleware(agent) + + def test_no_double_wrap(self): + """Calling inject twice doesn't double-wrap executors.""" + a = RawAgent(_mock_client, name="agent_a") + workflow = WorkflowBuilder(start_executor=a).build() + agent = WorkflowAgent(workflow=workflow, name="wf") + + inject_breakpoint_middleware(agent, ["agent_a"]) + first_wrapped = workflow.executors["agent_a"].execute + + inject_breakpoint_middleware(agent, ["agent_a"]) + second_wrapped = workflow.executors["agent_a"].execute + + # Should be the same wrapper, not double-wrapped + assert first_wrapped is second_wrapped + + remove_breakpoint_middleware(agent) + + def test_wildcard_wraps_all_executors(self): + """Wildcard breakpoint wraps every executor.""" + a = RawAgent(_mock_client, name="agent_a") + b = RawAgent(_mock_client, name="agent_b") + workflow = WorkflowBuilder(start_executor=a).add_edge(a, b).build() + agent = WorkflowAgent(workflow=workflow, name="wf") + + inject_breakpoint_middleware(agent, "*") + + for exec_id in workflow.executors: + assert hasattr(workflow.executors[exec_id], "_bp_original_execute") + + remove_breakpoint_middleware(agent) + + def test_agents_without_tools_can_be_breakpointed(self): + """Executors with no tools (pure chat agents) can be breakpointed.""" + # This was the original bug: pure chat agents had no tools, + # so FunctionMiddleware never fired. + a = RawAgent(_mock_client, name="chat_agent") + workflow = WorkflowBuilder(start_executor=a).build() + agent = WorkflowAgent(workflow=workflow, name="wf") + + inject_breakpoint_middleware(agent, ["chat_agent"]) + assert hasattr(workflow.executors["chat_agent"], "_bp_original_execute") + + remove_breakpoint_middleware(agent) + + +# --------------------------------------------------------------------------- +# Result creation tests +# --------------------------------------------------------------------------- + + +class TestBreakpointResult: + """Verify breakpoint result creation.""" + + def test_create_breakpoint_result_with_node_id(self): + exc = AgentInterruptException( + interrupt_id="int-1", + suspend_value={"type": "breakpoint", "node_id": "my_agent"}, + is_breakpoint=True, + ) + result = create_breakpoint_result(exc) + assert isinstance(result, UiPathBreakpointResult) + assert result.breakpoint_node == "my_agent" + assert result.breakpoint_type == "before" + assert result.next_nodes == ["my_agent"] + + def test_create_breakpoint_result_empty_suspend_value(self): + exc = AgentInterruptException( + interrupt_id="int-2", + suspend_value="unexpected", + is_breakpoint=True, + ) + result = create_breakpoint_result(exc) + assert result.breakpoint_node == "" + assert result.next_nodes == [] + + +# --------------------------------------------------------------------------- +# Integration tests: UiPathDebugRuntime ← UiPathAgentFrameworkRuntime +# --------------------------------------------------------------------------- + + +def _make_agent_runtime(agent: WorkflowAgent) -> UiPathAgentFrameworkRuntime: + """Create a runtime with mocked chat mapper (no LLM needed).""" + runtime = UiPathAgentFrameworkRuntime(agent=agent) + runtime.chat = MagicMock() + runtime.chat.map_messages_to_input.return_value = "hello" + runtime.chat.map_streaming_content.return_value = [] + runtime.chat.close_message.return_value = [] + return runtime + + +class TestDebugRuntimeBreakpointIntegration: + """Integration tests: UiPathDebugRuntime wrapping our runtime with real workflows. + + These tests verify the full breakpoint flow that the debug UI exercises: + debug bridge sends breakpoint node IDs → UiPathDebugRuntime passes them + as options.breakpoints → our runtime wraps executor.execute() → workflow + runs → wrapped executor raises → our runtime yields UiPathBreakpointResult + → UiPathDebugRuntime sees it and notifies the bridge. + """ + + async def test_breakpoint_fires_on_start_executor(self): + """Breakpoint on the start executor pauses before it runs.""" + worker = RawAgent(_mock_client, name="worker") + workflow = WorkflowBuilder(start_executor=worker).build() + agent = WorkflowAgent(workflow=workflow, name="test_wf") + + runtime = _make_agent_runtime(agent) + + bridge = _make_debug_bridge() + cast(Mock, bridge.get_breakpoints).return_value = ["worker"] + # Initial resume + quit after breakpoint (don't try to actually resume) + cast(AsyncMock, bridge.wait_for_resume).side_effect = [ + None, + UiPathDebugQuitError("quit"), + ] + + debug_runtime = UiPathDebugRuntime( + delegate=runtime, debug_bridge=bridge + ) + + result = await debug_runtime.execute({"messages": []}) + + # Breakpoint should have been hit + cast(AsyncMock, bridge.emit_breakpoint_hit).assert_awaited_once() + bp_result = cast(AsyncMock, bridge.emit_breakpoint_hit).call_args[0][0] + assert isinstance(bp_result, UiPathBreakpointResult) + assert bp_result.breakpoint_node == "worker" + + # Quit produces a successful result + assert result.status == UiPathRuntimeStatus.SUCCESSFUL + + async def test_breakpoint_fires_on_toolless_agent(self): + """Breakpoint works on agents with no tools (the original bug).""" + # This is the concurrent sample scenario: pure chat agents, no tools + chat_agent = RawAgent(_mock_client, name="sentiment") + workflow = WorkflowBuilder(start_executor=chat_agent).build() + agent = WorkflowAgent(workflow=workflow, name="concurrent_wf") + + runtime = _make_agent_runtime(agent) + + bridge = _make_debug_bridge() + cast(Mock, bridge.get_breakpoints).return_value = ["sentiment"] + cast(AsyncMock, bridge.wait_for_resume).side_effect = [ + None, + UiPathDebugQuitError("quit"), + ] + + debug_runtime = UiPathDebugRuntime( + delegate=runtime, debug_bridge=bridge + ) + + result = await debug_runtime.execute({"messages": []}) + + cast(AsyncMock, bridge.emit_breakpoint_hit).assert_awaited_once() + bp_result = cast(AsyncMock, bridge.emit_breakpoint_hit).call_args[0][0] + assert bp_result.breakpoint_node == "sentiment" + + async def test_state_events_emitted_before_breakpoint(self): + """Debug bridge should receive state events (STARTED) before the breakpoint.""" + worker = RawAgent(_mock_client, name="agent_x") + workflow = WorkflowBuilder(start_executor=worker).build() + agent = WorkflowAgent(workflow=workflow, name="wf") + + runtime = _make_agent_runtime(agent) + + bridge = _make_debug_bridge() + cast(Mock, bridge.get_breakpoints).return_value = ["agent_x"] + cast(AsyncMock, bridge.wait_for_resume).side_effect = [ + None, + UiPathDebugQuitError("quit"), + ] + + debug_runtime = UiPathDebugRuntime( + delegate=runtime, debug_bridge=bridge + ) + + # Collect all events from the stream + events: list[Any] = [] + async for event in debug_runtime.stream({"messages": []}): + events.append(event) + + # Should have state events (STARTED for wf and agent_x) before the breakpoint + state_events = [e for e in events if isinstance(e, UiPathRuntimeStateEvent)] + bp_events = [e for e in events if isinstance(e, UiPathBreakpointResult)] + + assert len(state_events) >= 1, "Should have at least one state event" + assert len(bp_events) >= 1, "Should have a breakpoint result" + + # The workflow STARTED should come before the breakpoint + wf_started = [e for e in state_events if e.node_name == "wf"] + assert len(wf_started) >= 1, "Workflow STARTED event should be emitted" + + async def test_no_breakpoints_runs_to_completion(self): + """With no breakpoints set, the workflow should run normally (or fail + trying to call LLM, but not hit any breakpoint).""" + worker = RawAgent(_mock_client, name="worker") + workflow = WorkflowBuilder(start_executor=worker).build() + agent = WorkflowAgent(workflow=workflow, name="wf") + + runtime = _make_agent_runtime(agent) + + bridge = _make_debug_bridge() + cast(Mock, bridge.get_breakpoints).return_value = [] # no breakpoints + cast(AsyncMock, bridge.wait_for_resume).return_value = None + + debug_runtime = UiPathDebugRuntime( + delegate=runtime, debug_bridge=bridge + ) + + # Without breakpoints, the workflow tries to actually execute the agent. + # Since we have a mock client, this will fail — but NOT as a breakpoint. + try: + await debug_runtime.execute({"messages": []}) + except Exception: + pass # Expected — no LLM to call + + # No breakpoint should have been hit + cast(AsyncMock, bridge.emit_breakpoint_hit).assert_not_awaited() + + async def test_breakpoint_resume_preserves_original_input_and_session(self): + """After breakpoint → continue, the original user input and session + must be restored so the agent doesn't lose context. + + This was a real bug: on resume, stream() was passing message="" to + workflow.run() and not loading the session, so the agent acted like + it never received the user's message. + """ + worker = RawAgent(_mock_client, name="weather_agent") + workflow = WorkflowBuilder(start_executor=worker).build() + agent = WorkflowAgent(workflow=workflow, name="wf") + + # Track what workflow.run() receives on each call + captured_run_kwargs: list[dict[str, Any]] = [] + original_run = workflow.run + + def tracking_run(**kwargs: Any) -> Any: + captured_run_kwargs.append(kwargs) + return original_run(**kwargs) + + workflow.run = tracking_run # type: ignore[assignment] + + # Mock resumable storage for session + breakpoint state persistence + kv_store: dict[str, Any] = {} + + mock_storage = AsyncMock() + + async def mock_set_value( + runtime_id: str, namespace: str, key: str, value: Any + ) -> None: + kv_store[f"{runtime_id}:{namespace}:{key}"] = value + + async def mock_get_value( + runtime_id: str, namespace: str, key: str + ) -> Any: + return kv_store.get(f"{runtime_id}:{namespace}:{key}") + + mock_storage.set_value = mock_set_value + mock_storage.get_value = mock_get_value + + runtime = UiPathAgentFrameworkRuntime( + agent=agent, + runtime_id="test-bp-resume", + resumable_storage=mock_storage, + ) + runtime.chat = MagicMock() + runtime.chat.map_messages_to_input.return_value = ( + "what's the weather in San Francisco" + ) + runtime.chat.map_streaming_content.return_value = [] + runtime.chat.close_message.return_value = [] + + bridge = _make_debug_bridge() + cast(Mock, bridge.get_breakpoints).return_value = ["weather_agent"] + # Initial resume → None; after breakpoint → resume (None = continue) + cast(AsyncMock, bridge.wait_for_resume).side_effect = [ + None, # initial resume + None, # continue after breakpoint + UiPathDebugQuitError("quit"), # quit after second run fails (no LLM) + ] + + debug_runtime = UiPathDebugRuntime( + delegate=runtime, debug_bridge=bridge + ) + + # Execute — first run hits breakpoint, resume continues + try: + await debug_runtime.execute({"messages": []}) + except Exception: + pass # Expected: no real LLM on the resume run + + # First call: fresh run that hit the breakpoint + assert len(captured_run_kwargs) >= 1 + first_call = captured_run_kwargs[0] + assert first_call.get("message") == "what's the weather in San Francisco" + assert first_call.get("stream") is True + + # Second call: breakpoint resume — must use original input, NOT "" + assert len(captured_run_kwargs) >= 2 + resume_call = captured_run_kwargs[1] + assert resume_call.get("message") == "what's the weather in San Francisco" + assert resume_call.get("stream") is True + + # Breakpoint state should have been persisted with original input + bp_state = kv_store.get("test-bp-resume:breakpoint:state") + assert bp_state is not None + assert bp_state["original_input"] == "what's the weather in San Francisco" + assert "weather_agent" in bp_state["skip_nodes"] + + # Session should have been saved + session_data = kv_store.get("test-bp-resume:session:data") + assert session_data is not None + + async def test_two_sequential_breakpoints_with_resumes(self): + """Two agents (agent_a → agent_b), both breakpointed. Verifies: + - Fresh run: BP fires on agent_a, both executors wrapped + - Resume 1: agent_a skipped, BP fires on agent_b + - Resume 2: both skipped, completes normally + """ + agent_a = RawAgent(_mock_client, name="agent_a") + agent_b = RawAgent(_mock_client, name="agent_b") + workflow = ( + WorkflowBuilder(start_executor=agent_a) + .add_edge(agent_a, agent_b) + .build() + ) + agent = WorkflowAgent(workflow=workflow, name="wf") + + call_log: list[dict[str, Any]] = [] + call_count = 0 + checkpoint_counter = [0] + + def mock_run(**kwargs: Any) -> _MockWorkflowStream: + nonlocal call_count + call_count += 1 + call_log.append({ + "call_number": call_count, + "kwargs": dict(kwargs), + "agent_a_wrapped": hasattr( + workflow.executors["agent_a"], "_bp_original_execute" + ), + "agent_b_wrapped": hasattr( + workflow.executors["agent_b"], "_bp_original_execute" + ), + }) + if call_count == 1: + # Fresh run → breakpoint on agent_a + checkpoint_counter[0] = 1 + return _MockWorkflowStream( + exception=AgentInterruptException( + interrupt_id="bp-1", + suspend_value={ + "type": "breakpoint", + "node_id": "agent_a", + }, + is_breakpoint=True, + ) + ) + elif call_count == 2: + # Resume 1 → breakpoint on agent_b + checkpoint_counter[0] = 2 + return _MockWorkflowStream( + exception=AgentInterruptException( + interrupt_id="bp-2", + suspend_value={ + "type": "breakpoint", + "node_id": "agent_b", + }, + is_breakpoint=True, + ) + ) + else: + return _MockWorkflowStream(final_output="completed") + + workflow.run = mock_run # type: ignore[assignment] + + # KV store + kv_store: dict[str, Any] = {} + mock_storage = AsyncMock() + + async def mock_set_value( + runtime_id: str, namespace: str, key: str, value: Any + ) -> None: + kv_store[f"{runtime_id}:{namespace}:{key}"] = value + + async def mock_get_value( + runtime_id: str, namespace: str, key: str + ) -> Any: + return kv_store.get(f"{runtime_id}:{namespace}:{key}") + + mock_storage.set_value = mock_set_value + mock_storage.get_value = mock_get_value + + # Mock checkpoint storage + mock_cs = AsyncMock() + + async def mock_get_latest(**kwargs: Any) -> Any: + if checkpoint_counter[0] == 0: + return None + cp = MagicMock() + cp.checkpoint_id = f"cp-{checkpoint_counter[0]}" + return cp + + mock_cs.get_latest = mock_get_latest + + runtime = UiPathAgentFrameworkRuntime( + agent=agent, + runtime_id="test-2bp", + checkpoint_storage=mock_cs, + resumable_storage=mock_storage, + ) + runtime.chat = MagicMock() + runtime.chat.map_messages_to_input.return_value = "hello world" + runtime.chat.map_streaming_content.return_value = [] + runtime.chat.close_message.return_value = [] + + bridge = _make_debug_bridge() + cast(Mock, bridge.get_breakpoints).return_value = ["agent_a", "agent_b"] + cast(AsyncMock, bridge.wait_for_resume).side_effect = [ + None, # initial resume + None, # continue after BP on agent_a + None, # continue after BP on agent_b + ] + + debug_runtime = UiPathDebugRuntime( + delegate=runtime, debug_bridge=bridge + ) + result = await debug_runtime.execute({"messages": []}) + + # --- 3 workflow.run() calls --- + assert len(call_log) == 3 + + # Call 1: Fresh run — both executors wrapped + c1 = call_log[0] + assert c1["kwargs"].get("message") == "hello world" + assert c1["kwargs"].get("stream") is True + assert c1["agent_a_wrapped"] is True + assert c1["agent_b_wrapped"] is True + + # Call 2: Resume after agent_a breakpoint — checkpoint-based + c2 = call_log[1] + assert c2["kwargs"].get("checkpoint_id") == "cp-1" + assert c2["kwargs"].get("stream") is True + assert "message" not in c2["kwargs"] + assert c2["agent_a_wrapped"] is False # Skipped + assert c2["agent_b_wrapped"] is True # Still breakpointed + + # Call 3: Resume after agent_b breakpoint — checkpoint-based + c3 = call_log[2] + assert c3["kwargs"].get("checkpoint_id") == "cp-2" + assert c3["kwargs"].get("stream") is True + assert "message" not in c3["kwargs"] + assert c3["agent_a_wrapped"] is False # Still skipped (accumulated) + assert c3["agent_b_wrapped"] is False # Now also skipped + + # --- Debug bridge interactions --- + assert cast(AsyncMock, bridge.emit_breakpoint_hit).await_count == 2 + bp_calls = cast( + AsyncMock, bridge.emit_breakpoint_hit + ).call_args_list + assert bp_calls[0].args[0].breakpoint_node == "agent_a" + assert bp_calls[1].args[0].breakpoint_node == "agent_b" + + assert result.status == UiPathRuntimeStatus.SUCCESSFUL + + # --- KV state --- + bp_state = kv_store.get("test-2bp:breakpoint:state") + assert bp_state is not None + assert set(bp_state["skip_nodes"]) == {"agent_a", "agent_b"} + assert bp_state["checkpoint_id"] == "cp-2" + assert bp_state["original_input"] == "hello world" + + async def test_concurrent_breakpoints_accumulate_skip_nodes(self): + """Concurrent executors: breakpoints fire one-at-a-time, skip_nodes + accumulate so previously-resumed executors aren't re-breakpointed. + + Graph: dispatcher → [worker_a, worker_b] (fan-out) + With breakpoints on all nodes: + - Fresh: BP on dispatcher + - Resume 1: dispatcher skipped, BP on worker_a + - Resume 2: dispatcher+worker_a skipped, BP on worker_b + - Resume 3: all skipped, completes + + This was the infinite loop bug: without accumulating skip_nodes, + worker_a and worker_b kept trading breakpoints forever. + """ + dispatcher = RawAgent(_mock_client, name="dispatcher") + worker_a = RawAgent(_mock_client, name="worker_a") + worker_b = RawAgent(_mock_client, name="worker_b") + workflow = ( + WorkflowBuilder(start_executor=dispatcher) + .add_edge(dispatcher, worker_a) + .add_edge(dispatcher, worker_b) + .build() + ) + agent = WorkflowAgent(workflow=workflow, name="concurrent_wf") + + call_log: list[dict[str, Any]] = [] + call_count = 0 + checkpoint_counter = [0] + + def mock_run(**kwargs: Any) -> _MockWorkflowStream: + nonlocal call_count + call_count += 1 + call_log.append({ + "call_number": call_count, + "kwargs": dict(kwargs), + "dispatcher_wrapped": hasattr( + workflow.executors["dispatcher"], "_bp_original_execute" + ), + "worker_a_wrapped": hasattr( + workflow.executors["worker_a"], "_bp_original_execute" + ), + "worker_b_wrapped": hasattr( + workflow.executors["worker_b"], "_bp_original_execute" + ), + }) + if call_count == 1: + checkpoint_counter[0] = 1 + return _MockWorkflowStream( + exception=AgentInterruptException( + interrupt_id="bp-1", + suspend_value={ + "type": "breakpoint", + "node_id": "dispatcher", + }, + is_breakpoint=True, + ) + ) + elif call_count == 2: + # After dispatcher skipped, worker_a fires + checkpoint_counter[0] = 2 + return _MockWorkflowStream( + exception=AgentInterruptException( + interrupt_id="bp-2", + suspend_value={ + "type": "breakpoint", + "node_id": "worker_a", + }, + is_breakpoint=True, + ) + ) + elif call_count == 3: + # After dispatcher+worker_a skipped, worker_b fires + checkpoint_counter[0] = 3 + return _MockWorkflowStream( + exception=AgentInterruptException( + interrupt_id="bp-3", + suspend_value={ + "type": "breakpoint", + "node_id": "worker_b", + }, + is_breakpoint=True, + ) + ) + else: + # All skipped, completes + return _MockWorkflowStream(final_output="done") + + workflow.run = mock_run # type: ignore[assignment] + + kv_store: dict[str, Any] = {} + mock_storage = AsyncMock() + + async def mock_set_value( + runtime_id: str, namespace: str, key: str, value: Any + ) -> None: + kv_store[f"{runtime_id}:{namespace}:{key}"] = value + + async def mock_get_value( + runtime_id: str, namespace: str, key: str + ) -> Any: + return kv_store.get(f"{runtime_id}:{namespace}:{key}") + + mock_storage.set_value = mock_set_value + mock_storage.get_value = mock_get_value + + mock_cs = AsyncMock() + + async def mock_get_latest(**kwargs: Any) -> Any: + if checkpoint_counter[0] == 0: + return None + cp = MagicMock() + cp.checkpoint_id = f"cp-{checkpoint_counter[0]}" + return cp + + mock_cs.get_latest = mock_get_latest + + runtime = UiPathAgentFrameworkRuntime( + agent=agent, + runtime_id="test-concurrent-bp", + checkpoint_storage=mock_cs, + resumable_storage=mock_storage, + ) + runtime.chat = MagicMock() + runtime.chat.map_messages_to_input.return_value = "analyze this text" + runtime.chat.map_streaming_content.return_value = [] + runtime.chat.close_message.return_value = [] + + bridge = _make_debug_bridge() + cast(Mock, bridge.get_breakpoints).return_value = [ + "dispatcher", "worker_a", "worker_b", + ] + cast(AsyncMock, bridge.wait_for_resume).side_effect = [ + None, # initial + None, # continue after dispatcher BP + None, # continue after worker_a BP + None, # continue after worker_b BP + ] + + debug_runtime = UiPathDebugRuntime( + delegate=runtime, debug_bridge=bridge + ) + result = await debug_runtime.execute({"messages": []}) + + # --- 4 workflow.run() calls --- + assert len(call_log) == 4 + + # Call 1: Fresh — all wrapped + c1 = call_log[0] + assert c1["dispatcher_wrapped"] is True + assert c1["worker_a_wrapped"] is True + assert c1["worker_b_wrapped"] is True + + # Call 2: dispatcher skipped, workers wrapped + c2 = call_log[1] + assert c2["dispatcher_wrapped"] is False + assert c2["worker_a_wrapped"] is True + assert c2["worker_b_wrapped"] is True + + # Call 3: dispatcher+worker_a skipped, worker_b wrapped + c3 = call_log[2] + assert c3["dispatcher_wrapped"] is False + assert c3["worker_a_wrapped"] is False + assert c3["worker_b_wrapped"] is True + + # Call 4: all skipped — completes + c4 = call_log[3] + assert c4["dispatcher_wrapped"] is False + assert c4["worker_a_wrapped"] is False + assert c4["worker_b_wrapped"] is False + + # 3 breakpoints hit + assert cast(AsyncMock, bridge.emit_breakpoint_hit).await_count == 3 + bp_nodes = [ + call.args[0].breakpoint_node + for call in cast( + AsyncMock, bridge.emit_breakpoint_hit + ).call_args_list + ] + assert bp_nodes == ["dispatcher", "worker_a", "worker_b"] + + assert result.status == UiPathRuntimeStatus.SUCCESSFUL + + # skip_nodes accumulated all three + bp_state = kv_store.get("test-concurrent-bp:breakpoint:state") + assert bp_state is not None + assert set(bp_state["skip_nodes"]) == { + "dispatcher", "worker_a", "worker_b", + } diff --git a/packages/uipath-agent-framework/tests/test_graph.py b/packages/uipath-agent-framework/tests/test_graph.py index 976c94f7..7096523e 100644 --- a/packages/uipath-agent-framework/tests/test_graph.py +++ b/packages/uipath-agent-framework/tests/test_graph.py @@ -23,103 +23,6 @@ def _make_agent(name="test_agent", tools=None) -> BaseAgent: return agent -class TestGraphStructure: - """Tests for graph structure correctness.""" - - def test_start_and_end_nodes_present(self): - """Graph always has __start__ and __end__ nodes.""" - agent = _make_agent(name="my_agent") - graph = get_agent_graph(agent) - - node_types = {n.type for n in graph.nodes} - assert "__start__" in node_types - assert "__end__" in node_types - - def test_agent_node_type_is_node(self): - """Agent nodes have type 'node'.""" - agent = _make_agent(name="my_agent") - graph = get_agent_graph(agent) - - agent_node = next(n for n in graph.nodes if n.id == "my_agent") - assert agent_node.type == "node" - - def test_tools_node_type_is_tool(self): - """Tools nodes have type 'tool'.""" - - def search(): - pass - - search.__name__ = "search" - tool = search - - agent = _make_agent(name="my_agent", tools=[tool]) - graph = get_agent_graph(agent) - - tools_node = next(n for n in graph.nodes if n.id == "my_agent_tools") - assert tools_node.type == "tool" - - def test_start_connects_to_agent(self): - """__start__ connects to the agent with 'input' label.""" - agent = _make_agent(name="my_agent") - graph = get_agent_graph(agent) - - start_edge = next(e for e in graph.edges if e.source == "__start__") - assert start_edge.target == "my_agent" - assert start_edge.label == "input" - - def test_agent_connects_to_end(self): - """Agent connects to __end__ with 'output' label.""" - agent = _make_agent(name="my_agent") - graph = get_agent_graph(agent) - - end_edge = next(e for e in graph.edges if e.target == "__end__") - assert end_edge.source == "my_agent" - assert end_edge.label == "output" - - def test_tools_metadata_contains_names(self): - """Tools node metadata includes tool names and count.""" - - def search(): - pass - - search.__name__ = "search" - - def calculator(): - pass - - calculator.__name__ = "calculator" - tool1 = search - tool2 = calculator - - agent = _make_agent(name="agent", tools=[tool1, tool2]) - graph = get_agent_graph(agent) - - tools_node = next(n for n in graph.nodes if n.id == "agent_tools") - assert tools_node.metadata is not None - assert tools_node.metadata["tool_names"] == ["search", "calculator"] - assert tools_node.metadata["tool_count"] == 2 - - def test_agent_name_fallback(self): - """Agent without name falls back to 'agent'.""" - agent = MagicMock(spec=BaseAgent) - agent.name = None - agent.default_options = {"tools": []} - - graph = get_agent_graph(agent) - node_ids = [n.id for n in graph.nodes] - assert "agent" in node_ids - - def test_no_subgraph_or_metadata_on_simple_nodes(self): - """Simple agent/start/end nodes have no subgraph or metadata.""" - agent = _make_agent(name="test") - graph = get_agent_graph(agent) - - for node in graph.nodes: - assert node.subgraph is None - if node.type != "tool": - assert node.metadata is None - - def _make_edge( source_id: str, target_id: str, condition_name: str | None = None ) -> Edge: diff --git a/packages/uipath-agent-framework/tests/test_schema.py b/packages/uipath-agent-framework/tests/test_schema.py index 84e42eb4..0c95d2bc 100644 --- a/packages/uipath-agent-framework/tests/test_schema.py +++ b/packages/uipath-agent-framework/tests/test_schema.py @@ -5,7 +5,6 @@ from agent_framework import BaseAgent from uipath_agent_framework.runtime.schema import ( - get_agent_graph, get_entrypoints_schema, ) @@ -76,95 +75,3 @@ def test_input_output_schema_match(self): schema = get_entrypoints_schema(agent) assert schema["input"] == schema["output"] - - -class TestGetAgentGraph: - """Tests for get_agent_graph function.""" - - def test_single_agent_graph(self): - """Test graph for a single agent with no tools.""" - agent = _make_agent(name="root") - graph = get_agent_graph(agent) - - node_ids = [n.id for n in graph.nodes] - assert "__start__" in node_ids - assert "__end__" in node_ids - assert "root" in node_ids - - # Check start->root and root->end edges - edge_pairs = [(e.source, e.target) for e in graph.edges] - assert ("__start__", "root") in edge_pairs - assert ("root", "__end__") in edge_pairs - - def test_agent_with_tools(self): - """Test graph for agent with regular tools.""" - - def search(): - pass - - search.__name__ = "search" - - def calculator(): - pass - - calculator.__name__ = "calculator" - - tool1 = search - tool2 = calculator - - agent = _make_agent(name="root", tools=[tool1, tool2]) - graph = get_agent_graph(agent) - - node_ids = [n.id for n in graph.nodes] - assert "root_tools" in node_ids - - # Find tools node and check metadata - tools_node = next(n for n in graph.nodes if n.id == "root_tools") - assert tools_node.type == "tool" - assert tools_node.metadata is not None - assert tools_node.metadata["tool_count"] == 2 - assert "search" in tools_node.metadata["tool_names"] - assert "calculator" in tools_node.metadata["tool_names"] - - def test_agent_without_tools_has_no_tools_node(self): - """Agent without tools has no tools node.""" - agent = _make_agent(name="root", tools=[]) - graph = get_agent_graph(agent) - - node_ids = [n.id for n in graph.nodes] - assert "root_tools" not in node_ids - - def test_graph_edges_are_bidirectional_for_tools(self): - """Tools node has bidirectional edges with agent.""" - - def my_tool(): - pass - - my_tool.__name__ = "my_tool" - tool = my_tool - - agent = _make_agent(name="root", tools=[tool]) - graph = get_agent_graph(agent) - - edge_pairs = [(e.source, e.target) for e in graph.edges] - assert ("root", "root_tools") in edge_pairs - assert ("root_tools", "root") in edge_pairs - - def test_graph_has_correct_node_count_no_tools(self): - """Graph with no tools has 3 nodes: __start__, agent, __end__.""" - agent = _make_agent(name="root") - graph = get_agent_graph(agent) - assert len(graph.nodes) == 3 - - def test_graph_has_correct_node_count_with_tools(self): - """Graph with tools has 4 nodes: __start__, agent, tools, __end__.""" - - def my_tool(): - pass - - my_tool.__name__ = "my_tool" - tool = my_tool - - agent = _make_agent(name="root", tools=[tool]) - graph = get_agent_graph(agent) - assert len(graph.nodes) == 4 diff --git a/packages/uipath-agent-framework/tests/test_storage.py b/packages/uipath-agent-framework/tests/test_storage.py index 72a46590..a5acac30 100644 --- a/packages/uipath-agent-framework/tests/test_storage.py +++ b/packages/uipath-agent-framework/tests/test_storage.py @@ -1,170 +1,361 @@ -"""Tests for SQLite session store.""" +"""Tests for SQLite checkpoint storage.""" import os import tempfile -from uipath_agent_framework.runtime.storage import SqliteSessionStore - - -class TestSqliteSessionStore: - """Tests for SqliteSessionStore.""" +from agent_framework import WorkflowCheckpoint + +from uipath_agent_framework.runtime.resumable_storage import ( + ScopedCheckpointStorage, + SqliteCheckpointStorage, + SqliteResumableStorage, +) + + +def _make_checkpoint( + workflow_name: str = "test_workflow", + checkpoint_id: str | None = None, + timestamp: str | None = None, +) -> WorkflowCheckpoint: + """Create a WorkflowCheckpoint for testing.""" + return WorkflowCheckpoint( + workflow_name=workflow_name, + graph_signature_hash="abc123", + checkpoint_id=checkpoint_id or f"cp-{id(object())}", + timestamp=timestamp or "2026-01-01T00:00:00+00:00", + state={"key": "value"}, + messages={}, + pending_request_info_events={}, + iteration_count=1, + metadata={"test": True}, + ) + + +class TestSqliteCheckpointStorage: + """Tests for SqliteCheckpointStorage via SqliteResumableStorage.""" async def test_setup_creates_db_file(self): - """Setup creates the SQLite database file.""" + """Setup creates the SQLite database file with checkpoints table.""" with tempfile.TemporaryDirectory() as tmpdir: db_path = os.path.join(tmpdir, "test.db") - store = SqliteSessionStore(db_path) - await store.setup() + storage = SqliteResumableStorage(db_path) + await storage.setup() assert os.path.exists(db_path) - await store.dispose() + assert storage.checkpoint_storage is not None + await storage.dispose() async def test_setup_creates_nested_directories(self): """Setup creates parent directories if they don't exist.""" with tempfile.TemporaryDirectory() as tmpdir: db_path = os.path.join(tmpdir, "nested", "dir", "test.db") - store = SqliteSessionStore(db_path) - await store.setup() + storage = SqliteResumableStorage(db_path) + await storage.setup() assert os.path.exists(db_path) - await store.dispose() + await storage.dispose() - async def test_load_returns_none_for_missing_session(self): - """Loading a non-existent session returns None.""" + async def test_save_and_load_checkpoint(self): + """Saved checkpoint can be loaded back.""" with tempfile.TemporaryDirectory() as tmpdir: db_path = os.path.join(tmpdir, "test.db") - store = SqliteSessionStore(db_path) - await store.setup() - - result = await store.load_session("nonexistent") - assert result is None - await store.dispose() - - async def test_save_and_load_session(self): - """Saved session data can be loaded back.""" + storage = SqliteResumableStorage(db_path) + await storage.setup() + cs = storage.checkpoint_storage + + checkpoint = _make_checkpoint(checkpoint_id="cp-1") + await cs.save(checkpoint) + loaded = await cs.load("cp-1") + + assert loaded.checkpoint_id == "cp-1" + assert loaded.workflow_name == "test_workflow" + assert loaded.graph_signature_hash == "abc123" + assert loaded.state == {"key": "value"} + assert loaded.metadata == {"test": True} + await storage.dispose() + + async def test_load_nonexistent_raises(self): + """Loading a non-existent checkpoint raises an exception.""" with tempfile.TemporaryDirectory() as tmpdir: db_path = os.path.join(tmpdir, "test.db") - store = SqliteSessionStore(db_path) - await store.setup() - - session_data = { - "session_id": "runtime-123", - "state": {"memory": {"messages": [{"role": "user", "content": "hi"}]}}, - } - await store.save_session("runtime-123", session_data) - loaded = await store.load_session("runtime-123") + storage = SqliteResumableStorage(db_path) + await storage.setup() + cs = storage.checkpoint_storage + + try: + await cs.load("nonexistent") + assert False, "Should have raised" + except Exception: + pass + await storage.dispose() + + async def test_get_latest_returns_most_recent(self): + """get_latest returns the checkpoint with the most recent timestamp.""" + with tempfile.TemporaryDirectory() as tmpdir: + db_path = os.path.join(tmpdir, "test.db") + storage = SqliteResumableStorage(db_path) + await storage.setup() + cs = storage.checkpoint_storage + + cp1 = _make_checkpoint( + checkpoint_id="cp-old", timestamp="2026-01-01T00:00:00+00:00" + ) + cp2 = _make_checkpoint( + checkpoint_id="cp-new", timestamp="2026-01-02T00:00:00+00:00" + ) + await cs.save(cp1) + await cs.save(cp2) + + latest = await cs.get_latest(workflow_name="test_workflow") + assert latest is not None + assert latest.checkpoint_id == "cp-new" + await storage.dispose() + + async def test_get_latest_returns_none_for_empty(self): + """get_latest returns None when no checkpoints exist.""" + with tempfile.TemporaryDirectory() as tmpdir: + db_path = os.path.join(tmpdir, "test.db") + storage = SqliteResumableStorage(db_path) + await storage.setup() + cs = storage.checkpoint_storage - assert loaded == session_data - await store.dispose() + latest = await cs.get_latest(workflow_name="nonexistent") + assert latest is None + await storage.dispose() - async def test_save_overwrites_existing_session(self): - """Saving with the same runtime_id overwrites the previous data.""" + async def test_delete_checkpoint(self): + """Deleted checkpoint is no longer loadable.""" with tempfile.TemporaryDirectory() as tmpdir: db_path = os.path.join(tmpdir, "test.db") - store = SqliteSessionStore(db_path) - await store.setup() + storage = SqliteResumableStorage(db_path) + await storage.setup() + cs = storage.checkpoint_storage + + cp = _make_checkpoint(checkpoint_id="cp-del") + await cs.save(cp) - await store.save_session("rt-1", {"version": 1}) - await store.save_session("rt-1", {"version": 2}) - loaded = await store.load_session("rt-1") + result = await cs.delete("cp-del") + assert result is True - assert loaded == {"version": 2} - await store.dispose() + # Verify it's gone + result = await cs.delete("cp-del") + assert result is False + await storage.dispose() - async def test_sessions_isolated_by_runtime_id(self): - """Different runtime_ids have independent sessions.""" + async def test_list_checkpoints_filtered_by_workflow_name(self): + """list_checkpoints only returns checkpoints for the given workflow.""" + with tempfile.TemporaryDirectory() as tmpdir: + db_path = os.path.join(tmpdir, "test.db") + storage = SqliteResumableStorage(db_path) + await storage.setup() + cs = storage.checkpoint_storage + + cp1 = _make_checkpoint( + workflow_name="wf_a", checkpoint_id="cp-a1" + ) + cp2 = _make_checkpoint( + workflow_name="wf_a", checkpoint_id="cp-a2" + ) + cp3 = _make_checkpoint( + workflow_name="wf_b", checkpoint_id="cp-b1" + ) + await cs.save(cp1) + await cs.save(cp2) + await cs.save(cp3) + + wf_a_cps = await cs.list_checkpoints(workflow_name="wf_a") + assert len(wf_a_cps) == 2 + assert {cp.checkpoint_id for cp in wf_a_cps} == {"cp-a1", "cp-a2"} + + wf_b_cps = await cs.list_checkpoints(workflow_name="wf_b") + assert len(wf_b_cps) == 1 + assert wf_b_cps[0].checkpoint_id == "cp-b1" + await storage.dispose() + + async def test_list_checkpoint_ids(self): + """list_checkpoint_ids returns IDs filtered by workflow_name.""" with tempfile.TemporaryDirectory() as tmpdir: db_path = os.path.join(tmpdir, "test.db") - store = SqliteSessionStore(db_path) - await store.setup() + storage = SqliteResumableStorage(db_path) + await storage.setup() + cs = storage.checkpoint_storage + + cp1 = _make_checkpoint( + workflow_name="wf_x", checkpoint_id="cp-x1" + ) + cp2 = _make_checkpoint( + workflow_name="wf_x", checkpoint_id="cp-x2" + ) + cp3 = _make_checkpoint( + workflow_name="wf_y", checkpoint_id="cp-y1" + ) + await cs.save(cp1) + await cs.save(cp2) + await cs.save(cp3) + + ids = await cs.list_checkpoint_ids(workflow_name="wf_x") + assert set(ids) == {"cp-x1", "cp-x2"} + await storage.dispose() + + async def test_save_overwrites_existing_checkpoint(self): + """Saving with the same checkpoint_id overwrites the previous data.""" + with tempfile.TemporaryDirectory() as tmpdir: + db_path = os.path.join(tmpdir, "test.db") + storage = SqliteResumableStorage(db_path) + await storage.setup() + cs = storage.checkpoint_storage - await store.save_session("rt-a", {"agent": "alpha"}) - await store.save_session("rt-b", {"agent": "beta"}) + cp1 = _make_checkpoint(checkpoint_id="cp-ow") + cp1.state = {"version": 1} + await cs.save(cp1) - assert await store.load_session("rt-a") == {"agent": "alpha"} - assert await store.load_session("rt-b") == {"agent": "beta"} - assert await store.load_session("rt-c") is None - await store.dispose() + cp2 = _make_checkpoint(checkpoint_id="cp-ow") + cp2.state = {"version": 2} + await cs.save(cp2) + + loaded = await cs.load("cp-ow") + assert loaded.state == {"version": 2} + await storage.dispose() async def test_dispose_allows_reconnect(self): """After dispose, the store can be set up again and data persists.""" with tempfile.TemporaryDirectory() as tmpdir: db_path = os.path.join(tmpdir, "test.db") - store = SqliteSessionStore(db_path) - await store.setup() - await store.save_session("rt-1", {"data": "persisted"}) - await store.dispose() + storage = SqliteResumableStorage(db_path) + await storage.setup() + cs = storage.checkpoint_storage + + cp = _make_checkpoint(checkpoint_id="cp-persist") + await cs.save(cp) + await storage.dispose() # Reconnect to the same DB - store2 = SqliteSessionStore(db_path) - await store2.setup() - loaded = await store2.load_session("rt-1") + storage2 = SqliteResumableStorage(db_path) + await storage2.setup() + cs2 = storage2.checkpoint_storage - assert loaded == {"data": "persisted"} - await store2.dispose() + loaded = await cs2.load("cp-persist") + assert loaded.checkpoint_id == "cp-persist" + await storage2.dispose() - async def test_auto_setup_on_load(self): - """Loading without explicit setup triggers auto-setup.""" - with tempfile.TemporaryDirectory() as tmpdir: - db_path = os.path.join(tmpdir, "test.db") - store = SqliteSessionStore(db_path) - # No explicit setup() call - result = await store.load_session("any-id") - assert result is None - await store.dispose() +class TestScopedCheckpointStorage: + """Tests for ScopedCheckpointStorage prefix isolation.""" - async def test_auto_setup_on_save(self): - """Saving without explicit setup triggers auto-setup.""" + async def test_scoped_storage_isolates_by_runtime_id(self): + """Different runtime scopes produce isolated checkpoint namespaces.""" with tempfile.TemporaryDirectory() as tmpdir: db_path = os.path.join(tmpdir, "test.db") - store = SqliteSessionStore(db_path) + storage = SqliteResumableStorage(db_path) + await storage.setup() + cs = storage.checkpoint_storage + + scoped_a = ScopedCheckpointStorage(cs, "runtime-a") + scoped_b = ScopedCheckpointStorage(cs, "runtime-b") + + cp_a = _make_checkpoint( + workflow_name="my_wf", checkpoint_id="cp-a" + ) + cp_b = _make_checkpoint( + workflow_name="my_wf", checkpoint_id="cp-b" + ) - # No explicit setup() call - await store.save_session("rt-1", {"key": "value"}) - loaded = await store.load_session("rt-1") + await scoped_a.save(cp_a) + await scoped_b.save(cp_b) - assert loaded == {"key": "value"} - await store.dispose() + # Each scope sees only its own checkpoints + a_cps = await scoped_a.list_checkpoints(workflow_name="my_wf") + b_cps = await scoped_b.list_checkpoints(workflow_name="my_wf") - async def test_complex_session_data_roundtrip(self): - """Complex nested session data survives serialization roundtrip.""" + assert len(a_cps) == 1 + assert a_cps[0].checkpoint_id == "cp-a" + + assert len(b_cps) == 1 + assert b_cps[0].checkpoint_id == "cp-b" + await storage.dispose() + + async def test_scoped_get_latest_respects_scope(self): + """get_latest only returns checkpoints within the runtime scope.""" + with tempfile.TemporaryDirectory() as tmpdir: + db_path = os.path.join(tmpdir, "test.db") + storage = SqliteResumableStorage(db_path) + await storage.setup() + cs = storage.checkpoint_storage + + scoped_a = ScopedCheckpointStorage(cs, "rt-a") + scoped_b = ScopedCheckpointStorage(cs, "rt-b") + + cp_a = _make_checkpoint( + workflow_name="wf", + checkpoint_id="cp-a", + timestamp="2026-01-01T00:00:00+00:00", + ) + cp_b = _make_checkpoint( + workflow_name="wf", + checkpoint_id="cp-b", + timestamp="2026-01-02T00:00:00+00:00", + ) + + await scoped_a.save(cp_a) + await scoped_b.save(cp_b) + + latest_a = await scoped_a.get_latest(workflow_name="wf") + assert latest_a is not None + assert latest_a.checkpoint_id == "cp-a" + + latest_b = await scoped_b.get_latest(workflow_name="wf") + assert latest_b is not None + assert latest_b.checkpoint_id == "cp-b" + await storage.dispose() + + async def test_scoped_load_and_delete_are_global(self): + """load and delete operate on checkpoint_id which is globally unique.""" with tempfile.TemporaryDirectory() as tmpdir: db_path = os.path.join(tmpdir, "test.db") - store = SqliteSessionStore(db_path) - await store.setup() - - session_data = { - "session_id": "abc-123", - "state": { - "memory": { - "messages": [ - { - "role": "user", - "content": "What is the weather?", - "metadata": {"timestamp": "2025-01-01T00:00:00Z"}, - }, - { - "role": "assistant", - "content": "It's sunny!", - "tool_calls": [ - { - "id": "call_1", - "name": "get_weather", - "arguments": {"city": "SF"}, - } - ], - }, - ] - }, - "custom_key": [1, 2, 3], - "nested": {"a": {"b": {"c": True}}}, - }, - } - - await store.save_session("rt-complex", session_data) - loaded = await store.load_session("rt-complex") - - assert loaded == session_data - await store.dispose() + storage = SqliteResumableStorage(db_path) + await storage.setup() + cs = storage.checkpoint_storage + + scoped = ScopedCheckpointStorage(cs, "rt-x") + cp = _make_checkpoint( + workflow_name="wf", checkpoint_id="cp-global" + ) + await scoped.save(cp) + + # Load from any scope + loaded = await scoped.load("cp-global") + assert loaded.checkpoint_id == "cp-global" + + # Delete from any scope + result = await scoped.delete("cp-global") + assert result is True + await storage.dispose() + + async def test_scoped_list_checkpoint_ids(self): + """list_checkpoint_ids respects scope.""" + with tempfile.TemporaryDirectory() as tmpdir: + db_path = os.path.join(tmpdir, "test.db") + storage = SqliteResumableStorage(db_path) + await storage.setup() + cs = storage.checkpoint_storage + + scoped_a = ScopedCheckpointStorage(cs, "rt-a") + scoped_b = ScopedCheckpointStorage(cs, "rt-b") + + cp_a = _make_checkpoint( + workflow_name="wf", checkpoint_id="cp-a" + ) + cp_b = _make_checkpoint( + workflow_name="wf", checkpoint_id="cp-b" + ) + + await scoped_a.save(cp_a) + await scoped_b.save(cp_b) + + ids_a = await scoped_a.list_checkpoint_ids(workflow_name="wf") + ids_b = await scoped_b.list_checkpoint_ids(workflow_name="wf") + + assert ids_a == ["cp-a"] + assert ids_b == ["cp-b"] + await storage.dispose() diff --git a/packages/uipath-agent-framework/tests/test_streaming.py b/packages/uipath-agent-framework/tests/test_streaming.py index 1b061682..4b0c4e0b 100644 --- a/packages/uipath-agent-framework/tests/test_streaming.py +++ b/packages/uipath-agent-framework/tests/test_streaming.py @@ -10,7 +10,9 @@ from unittest.mock import AsyncMock, MagicMock, patch from agent_framework import ( + AgentExecutor, AgentResponseUpdate, + AgentSession, BaseAgent, Content, RawAgent, @@ -59,27 +61,11 @@ async def get_final_response(self): # --------------------------------------------------------------------------- -def _update(*contents: Content) -> AgentResponseUpdate: - return AgentResponseUpdate(contents=list(contents)) - - -def _fc(name: str, call_id: str = "c1") -> Content: - return Content(type="function_call", name=name, call_id=call_id) - - -def _fr(name: str = "", call_id: str = "c1", result: Any = "ok") -> Content: - return Content(type="function_result", name=name, call_id=call_id, result=result) - - -def _text(text: str = "hi") -> Content: - return Content(type="text", text=text) - - -def _wf_event(event_type: str, executor_id: str) -> MagicMock: +def _wf_event(event_type: str, executor_id: str, data: Any = None) -> MagicMock: evt = MagicMock() evt.type = event_type evt.executor_id = executor_id - evt.data = None + evt.data = data return evt @@ -162,168 +148,6 @@ def _assert_started_before_completed( assert first_started < first_completed, f"{node}: COMPLETED before STARTED" -# =========================================================================== -# Agent streaming tests -# =========================================================================== - - -class TestAgentStreamingEvents: - """Verify STARTED/COMPLETED pairing for agent streaming.""" - - async def test_simple_agent_no_tools(self): - """Agent with no tools: root STARTED then COMPLETED.""" - agent = RawAgent(_mock_client, name="root") - agent.run = MagicMock(return_value=_MockAsyncStream([_update(_text())])) # type: ignore[method-assign] - agent.create_session = MagicMock(return_value=MagicMock()) # type: ignore[method-assign] - - runtime = _make_runtime(agent) - events = await _collect_events(runtime) - - se = _state_events(events) - _assert_all_completed(se) - _assert_started_before_completed(se, "root") - assert isinstance(events[-1], UiPathRuntimeResult) - - async def test_agent_with_regular_tools(self): - """Agent with regular tools: tools node gets STARTED/COMPLETED.""" - agent = RawAgent(_mock_client, name="researcher", tools=[search_wikipedia]) - agent.run = MagicMock( # type: ignore[method-assign] - return_value=_MockAsyncStream( - [ - _update(_fc("search_wikipedia", "c1")), - _update(_fr("search_wikipedia", "c1", "wiki result")), - _update(_text("here's what I found")), - ] - ) - ) - agent.create_session = MagicMock(return_value=MagicMock()) # type: ignore[method-assign] - - runtime = _make_runtime(agent) - events = await _collect_events(runtime) - - se = _state_events(events) - _assert_all_completed(se) - _assert_started_before_completed(se, "researcher") - _assert_started_before_completed(se, "researcher_tools") - - async def test_multi_agent_with_sub_agents(self): - """Coordinator with sub-agents via as_tool(): all nodes paired.""" - research = RawAgent( - _mock_client, - name="research_agent", - tools=[search_wikipedia], - ) - coder = RawAgent( - _mock_client, - name="code_agent", - tools=[run_python], - ) - coordinator = RawAgent( - _mock_client, - name="coordinator", - tools=[research.as_tool(), coder.as_tool()], - ) - - # Get actual tool names assigned by as_tool() - tools = coordinator.default_options.get("tools", []) - research_tool_name = tools[0].name - code_tool_name = tools[1].name - - coordinator.run = MagicMock( # type: ignore[method-assign] - return_value=_MockAsyncStream( - [ - _update(_fc(research_tool_name, "c1")), - _update(_fr("", "c1", "research done")), - _update(_fc(code_tool_name, "c2")), - _update(_fr("", "c2", "code done")), - _update(_text("final answer")), - ] - ) - ) - coordinator.create_session = MagicMock(return_value=MagicMock()) # type: ignore[method-assign] - - runtime = _make_runtime(coordinator) - events = await _collect_events(runtime) - - se = _state_events(events) - _assert_all_completed(se) - _assert_started_before_completed(se, "coordinator") - _assert_started_before_completed(se, "research_agent") - _assert_started_before_completed(se, "research_agent_tools") - _assert_started_before_completed(se, "code_agent") - _assert_started_before_completed(se, "code_agent_tools") - - async def test_sub_agent_completed_via_call_id(self): - """Sub-agent COMPLETED even when function_result has empty name. - - The original bug: as_tool() wrappers produce function_result with - empty content.name. We match by call_id instead. - """ - inner = RawAgent(_mock_client, name="inner_agent", tools=[calculator]) - outer = RawAgent( - _mock_client, - name="outer", - tools=[inner.as_tool()], - ) - - tool_name = outer.default_options["tools"][0].name - - outer.run = MagicMock( # type: ignore[method-assign] - return_value=_MockAsyncStream( - [ - _update(_fc(tool_name, "call_xyz")), - # empty name on result — must still complete inner_agent - _update(_fr("", "call_xyz", "42")), - _update(_text("done")), - ] - ) - ) - outer.create_session = MagicMock(return_value=MagicMock()) # type: ignore[method-assign] - - runtime = _make_runtime(outer) - events = await _collect_events(runtime) - - se = _state_events(events) - _assert_all_completed(se) - _assert_started_before_completed(se, "inner_agent") - _assert_started_before_completed(se, "inner_agent_tools") - - async def test_mixed_regular_tools_and_sub_agents(self): - """Agent with both regular tools and agent-as-tool.""" - inner = RawAgent(_mock_client, name="helper") - agent = RawAgent( - _mock_client, - name="main", - tools=[search_wikipedia, inner.as_tool()], - ) - - agent_tool_name = next( - t.name for t in agent.default_options["tools"] if hasattr(t, "func") - ) - - agent.run = MagicMock( # type: ignore[method-assign] - return_value=_MockAsyncStream( - [ - _update(_fc("search_wikipedia", "c1")), - _update(_fr("search_wikipedia", "c1", "wiki")), - _update(_fc(agent_tool_name, "c2")), - _update(_fr("", "c2", "helped")), - _update(_text("done")), - ] - ) - ) - agent.create_session = MagicMock(return_value=MagicMock()) # type: ignore[method-assign] - - runtime = _make_runtime(agent) - events = await _collect_events(runtime) - - se = _state_events(events) - _assert_all_completed(se) - _assert_started_before_completed(se, "main") - _assert_started_before_completed(se, "main_tools") - _assert_started_before_completed(se, "helper") - - # =========================================================================== # Workflow streaming tests # =========================================================================== @@ -451,7 +275,7 @@ async def _fake_load_agent(entrypoint: str) -> BaseAgent: return agent with patch.object(factory, "_load_agent", side_effect=_fake_load_agent): - with patch.object(factory, "_get_session_store", new_callable=AsyncMock): + with patch.object(factory, "_get_storage", new_callable=AsyncMock): runtimes = await asyncio.gather( factory.new_runtime("agent", "runtime_1"), factory.new_runtime("agent", "runtime_2"), @@ -460,5 +284,475 @@ async def _fake_load_agent(entrypoint: str) -> BaseAgent: # Each runtime must have gotten a separate agent instance assert len(loaded_agents) == 3 - agents = [r.agent for r in runtimes] # type: ignore[attr-defined] + # Factory wraps in UiPathResumableRuntime; access delegate.agent + agents = [r.delegate.agent for r in runtimes] # type: ignore[attr-defined] assert len(set(id(a) for a in agents)) == 3, "Runtimes share agent instances!" + + +# =========================================================================== +# Tool state event tests +# =========================================================================== + + +class TestToolStateEvents: + """Verify that output events with function_call/function_result Content + emit STARTED/COMPLETED state events for tool nodes.""" + + async def test_tool_call_emits_state_events(self): + """Output event with function_call + function_result Content should + produce STARTED and COMPLETED state events for '{executor}_tools'.""" + worker = RawAgent(_mock_client, name="weather_agent", tools=[calculator]) + workflow = WorkflowBuilder(start_executor=worker).build() # type: ignore[arg-type] + agent = WorkflowAgent(workflow=workflow, name="wf") + + # Simulate a tool call cycle: function_call then function_result + call_content = Content( + type="function_call", + name="calculator", + call_id="call_1", + arguments='{"expression": "2+2"}', + ) + result_content = Content( + type="function_result", + call_id="call_1", + result="4", + ) + + call_update = AgentResponseUpdate(contents=[call_content]) + result_update = AgentResponseUpdate(contents=[result_content]) + + final = MagicMock() + final.get_outputs.return_value = [] + workflow.run = MagicMock( # type: ignore[method-assign] + return_value=_MockAsyncStream( + [ + _wf_event("executor_invoked", "weather_agent"), + _wf_event("output", "weather_agent", data=call_update), + _wf_event("output", "weather_agent", data=result_update), + _wf_event("executor_completed", "weather_agent"), + ], + final, + ) + ) + agent.create_session = MagicMock(return_value=MagicMock()) # type: ignore[method-assign] + + runtime = _make_runtime(agent) + events = await _collect_events(runtime) + + se = _state_events(events) + + # Tool node should have STARTED and COMPLETED + tool_events = [(n, p) for n, p in se if n == "weather_agent_tools"] + assert ("weather_agent_tools", STARTED) in tool_events + assert ("weather_agent_tools", COMPLETED) in tool_events + _assert_started_before_completed(se, "weather_agent_tools") + + # The STARTED event should carry the tool_name payload + tool_started = [ + e + for e in events + if isinstance(e, UiPathRuntimeStateEvent) + and e.node_name == "weather_agent_tools" + and e.phase == STARTED + ] + assert len(tool_started) == 1 + assert tool_started[0].payload == {"tool_name": "calculator"} + + async def test_multiple_tool_calls_emit_paired_events(self): + """Multiple tool call cycles should each produce a STARTED/COMPLETED pair.""" + worker = RawAgent( + _mock_client, + name="multi_tool_agent", + tools=[calculator, search_wikipedia], + ) + workflow = WorkflowBuilder(start_executor=worker).build() # type: ignore[arg-type] + agent = WorkflowAgent(workflow=workflow, name="wf") + + updates = [ + AgentResponseUpdate( + contents=[ + Content( + type="function_call", + name="calculator", + call_id="c1", + arguments="{}", + ) + ] + ), + AgentResponseUpdate( + contents=[ + Content(type="function_result", call_id="c1", result="42") + ] + ), + AgentResponseUpdate( + contents=[ + Content( + type="function_call", + name="search_wikipedia", + call_id="c2", + arguments="{}", + ) + ] + ), + AgentResponseUpdate( + contents=[ + Content(type="function_result", call_id="c2", result="found") + ] + ), + ] + + wf_events = [_wf_event("executor_invoked", "multi_tool_agent")] + for upd in updates: + wf_events.append(_wf_event("output", "multi_tool_agent", data=upd)) + wf_events.append(_wf_event("executor_completed", "multi_tool_agent")) + + final = MagicMock() + final.get_outputs.return_value = [] + workflow.run = MagicMock(return_value=_MockAsyncStream(wf_events, final)) # type: ignore[method-assign] + agent.create_session = MagicMock(return_value=MagicMock()) # type: ignore[method-assign] + + runtime = _make_runtime(agent) + events = await _collect_events(runtime) + + se = _state_events(events) + tool_events = [(n, p) for n, p in se if n == "multi_tool_agent_tools"] + + # Two STARTED + two COMPLETED + assert tool_events.count(("multi_tool_agent_tools", STARTED)) == 2 + assert tool_events.count(("multi_tool_agent_tools", COMPLETED)) == 2 + + async def test_no_tool_events_for_text_content(self): + """Output events with only text Content should not emit tool state events.""" + worker = RawAgent(_mock_client, name="text_agent") + workflow = WorkflowBuilder(start_executor=worker).build() # type: ignore[arg-type] + agent = WorkflowAgent(workflow=workflow, name="wf") + + text_update = AgentResponseUpdate( + contents=[Content(type="text", text="Hello world")] + ) + + final = MagicMock() + final.get_outputs.return_value = [] + workflow.run = MagicMock( # type: ignore[method-assign] + return_value=_MockAsyncStream( + [ + _wf_event("executor_invoked", "text_agent"), + _wf_event("output", "text_agent", data=text_update), + _wf_event("executor_completed", "text_agent"), + ], + final, + ) + ) + agent.create_session = MagicMock(return_value=MagicMock()) # type: ignore[method-assign] + + runtime = _make_runtime(agent) + events = await _collect_events(runtime) + + se = _state_events(events) + tool_events = [(n, p) for n, p in se if "_tools" in n] + assert tool_events == [], "Text-only output should not emit tool state events" + + async def test_executor_completed_payload_excludes_streaming_updates(self): + """executor_completed StateEvent payload must NOT contain + AgentResponseUpdate streaming chunks — only the summary data.""" + worker = RawAgent(_mock_client, name="agent_x", tools=[calculator]) + workflow = WorkflowBuilder(start_executor=worker).build() # type: ignore[arg-type] + agent = WorkflowAgent(workflow=workflow, name="wf") + + # Simulate streaming: many AgentResponseUpdate chunks arrive as output + # events, then executor_completed carries sent_messages + yielded_outputs + text_chunks = [ + AgentResponseUpdate(contents=[Content(type="text", text=tok)]) + for tok in ["Hello", " ", "world"] + ] + call_chunk = AgentResponseUpdate( + contents=[ + Content( + type="function_call", + name="calculator", + call_id="c1", + arguments="{}", + ) + ] + ) + result_chunk = AgentResponseUpdate( + contents=[ + Content(type="function_result", call_id="c1", result="42") + ] + ) + + # The framework packs sent_messages + yielded_outputs into completed data + summary = MagicMock() # represents AgentExecutorResponse (not an update) + completed_data = [summary] + text_chunks + [call_chunk, result_chunk] + + final = MagicMock() + final.get_outputs.return_value = [] + + stream_events = [ + _wf_event("executor_invoked", "agent_x"), + ] + for chunk in text_chunks + [call_chunk, result_chunk]: + stream_events.append(_wf_event("output", "agent_x", data=chunk)) + stream_events.append( + _wf_event("executor_completed", "agent_x", data=completed_data) + ) + + workflow.run = MagicMock( # type: ignore[method-assign] + return_value=_MockAsyncStream(stream_events, final) + ) + agent.create_session = MagicMock(return_value=MagicMock()) # type: ignore[method-assign] + + runtime = _make_runtime(agent) + events = await _collect_events(runtime) + + # Find the executor_completed state event for agent_x + completed_events = [ + e + for e in events + if isinstance(e, UiPathRuntimeStateEvent) + and e.node_name == "agent_x" + and e.phase == COMPLETED + ] + assert len(completed_events) == 1 + + # The payload must NOT contain any AgentResponseUpdate data + payload = completed_events[0].payload + payload_str = str(payload) + assert "agent_response_update" not in payload_str.lower() + # Text token chunks should not appear in the completed payload + assert "Hello" not in payload_str or "world" not in payload_str + + +# =========================================================================== +# Checkpoint propagation tests +# =========================================================================== + + +class TestCheckpointPropagation: + """Verify that checkpoint_storage is passed to workflow.run().""" + + async def test_checkpoint_storage_passed_to_workflow_run_stream(self): + """Streaming: workflow.run() should receive checkpoint_storage parameter.""" + worker = RawAgent(_mock_client, name="assistant") + workflow = WorkflowBuilder(start_executor=worker).build() # type: ignore[arg-type] + agent = WorkflowAgent(workflow=workflow, name="chat_wf") + + mock_checkpoint_storage = MagicMock() + captured_kwargs: list[dict[str, Any]] = [] + + def mock_run(**kwargs): + captured_kwargs.append(kwargs) + final = MagicMock() + final.get_outputs.return_value = [] + return _MockAsyncStream( + [ + _wf_event("executor_invoked", "assistant"), + _wf_event("executor_completed", "assistant"), + ], + final, + ) + + workflow.run = mock_run # type: ignore[method-assign] + + runtime = UiPathAgentFrameworkRuntime( + agent=agent, + runtime_id="test-session", + checkpoint_storage=mock_checkpoint_storage, + ) + runtime.chat = MagicMock() + runtime.chat.map_messages_to_input.return_value = "hello" + runtime.chat.map_streaming_content.return_value = [] + runtime.chat.close_message.return_value = [] + + await _collect_events(runtime) + + assert len(captured_kwargs) == 1 + assert captured_kwargs[0]["checkpoint_storage"] is mock_checkpoint_storage + assert captured_kwargs[0]["stream"] is True + + async def test_checkpoint_storage_passed_to_workflow_run_execute(self): + """Non-streaming execute() should pass checkpoint_storage to workflow.run().""" + worker = RawAgent(_mock_client, name="assistant") + workflow = WorkflowBuilder(start_executor=worker).build() # type: ignore[arg-type] + agent = WorkflowAgent(workflow=workflow, name="exec_wf") + + mock_checkpoint_storage = MagicMock() + captured_kwargs: list[dict[str, Any]] = [] + + async def mock_run(**kwargs): + captured_kwargs.append(kwargs) + result = MagicMock() + result.get_outputs.return_value = ["done"] + return result + + workflow.run = mock_run # type: ignore[method-assign] + + runtime = UiPathAgentFrameworkRuntime( + agent=agent, + runtime_id="exec-session", + checkpoint_storage=mock_checkpoint_storage, + ) + + await runtime.execute(input={"messages": []}) + + assert len(captured_kwargs) == 1 + assert captured_kwargs[0]["checkpoint_storage"] is mock_checkpoint_storage + + async def test_hitl_resume_passes_checkpoint_id_and_responses(self): + """HITL resume: workflow.run() should receive responses, checkpoint_id, and checkpoint_storage.""" + worker = RawAgent(_mock_client, name="assistant") + workflow = WorkflowBuilder(start_executor=worker).build() # type: ignore[arg-type] + agent = WorkflowAgent(workflow=workflow, name="resume_wf") + workflow.name = "resume_wf" + + mock_checkpoint_storage = MagicMock() + mock_checkpoint_storage.get_latest = AsyncMock(return_value=MagicMock(checkpoint_id="cp-123")) + captured_kwargs: list[dict[str, Any]] = [] + + def mock_run(**kwargs): + captured_kwargs.append(kwargs) + final = MagicMock() + final.get_outputs.return_value = [] + return _MockAsyncStream( + [ + _wf_event("executor_invoked", "assistant"), + _wf_event("executor_completed", "assistant"), + ], + final, + ) + + workflow.run = mock_run # type: ignore[method-assign] + + runtime = UiPathAgentFrameworkRuntime( + agent=agent, + runtime_id="resume-session", + checkpoint_storage=mock_checkpoint_storage, + ) + runtime.chat = MagicMock() + runtime.chat.map_messages_to_input.return_value = "" + runtime.chat.map_streaming_content.return_value = [] + runtime.chat.close_message.return_value = [] + + options = MagicMock() + options.resume = True + options.breakpoints = None + + responses = {"req-1": "approved"} + await _collect_events(runtime) + + # Trigger a resume via stream with responses + runtime._resume_responses = responses + events = [] + async for event in runtime._stream_workflow("", "resume_wf"): + events.append(event) + + # The second call (resume) should have responses and checkpoint_id + assert len(captured_kwargs) == 2 + resume_call = captured_kwargs[1] + assert resume_call["responses"] == responses + assert resume_call["checkpoint_id"] == "cp-123" + assert resume_call["checkpoint_storage"] is mock_checkpoint_storage + + +# =========================================================================== +# Session propagation tests (multi-turn conversation history) +# =========================================================================== + + +class TestSessionPropagation: + """Verify that sessions are loaded from storage and propagated to executors.""" + + async def test_session_propagated_to_executors_in_stream(self): + """Session loaded from resumable_storage should be set on each + AgentExecutor before workflow.run(), so inner agents see conversation history.""" + worker = RawAgent(_mock_client, name="assistant") + workflow = WorkflowBuilder(start_executor=worker).build() # type: ignore[arg-type] + agent = WorkflowAgent(workflow=workflow, name="chat_wf") + + # Create a session with pre-existing state (simulating a prior turn) + session = agent.create_session(session_id="test-session") + session.state["prior_turn_data"] = "previous conversation" + + # Mock resumable_storage that returns our pre-populated session via KV + mock_storage = AsyncMock() + mock_storage.get_value = AsyncMock(return_value=session.to_dict()) + mock_storage.set_value = AsyncMock() + + runtime = UiPathAgentFrameworkRuntime( + agent=agent, + runtime_id="test-session", + resumable_storage=mock_storage, + ) + runtime.chat = MagicMock() + runtime.chat.map_messages_to_input.return_value = "hello" + runtime.chat.map_streaming_content.return_value = [] + runtime.chat.close_message.return_value = [] + + # Track which session gets set on the executor + captured_sessions: list[AgentSession] = [] + + def mock_run(**kwargs): + # At the point workflow.run() is called, capture the executor's session + executor = workflow.executors["assistant"] + if isinstance(executor, AgentExecutor): + captured_sessions.append(executor._session) + final = MagicMock() + final.get_outputs.return_value = [] + return _MockAsyncStream( + [ + _wf_event("executor_invoked", "assistant"), + _wf_event("executor_completed", "assistant"), + ], + final, + ) + + workflow.run = mock_run # type: ignore[method-assign] + + await _collect_events(runtime) + + # Session should have been captured with prior turn data + assert len(captured_sessions) == 1 + assert captured_sessions[0].state.get("prior_turn_data") == "previous conversation" + + # Session should have been loaded from KV storage + mock_storage.get_value.assert_called_once_with("test-session", "session", "data") + + # Session should have been saved after execution + mock_storage.set_value.assert_called_once() + + async def test_session_propagated_in_execute_path(self): + """Non-streaming execute() should also propagate session to executors.""" + worker = RawAgent(_mock_client, name="assistant") + workflow = WorkflowBuilder(start_executor=worker).build() # type: ignore[arg-type] + agent = WorkflowAgent(workflow=workflow, name="exec_wf") + + session = agent.create_session(session_id="exec-session") + session.state["history_key"] = "turn1_data" + + mock_storage = AsyncMock() + mock_storage.get_value = AsyncMock(return_value=session.to_dict()) + mock_storage.set_value = AsyncMock() + + runtime = UiPathAgentFrameworkRuntime( + agent=agent, + runtime_id="exec-session", + resumable_storage=mock_storage, + ) + + captured_sessions: list[AgentSession] = [] + + async def mock_run(**kwargs): + executor = workflow.executors["assistant"] + if isinstance(executor, AgentExecutor): + captured_sessions.append(executor._session) + result = MagicMock() + result.get_outputs.return_value = ["done"] + return result + + workflow.run = mock_run # type: ignore[method-assign] + + await runtime.execute(input={"messages": []}) + + assert len(captured_sessions) == 1 + assert captured_sessions[0].state.get("history_key") == "turn1_data" + mock_storage.set_value.assert_called_once() diff --git a/packages/uipath-agent-framework/uv.lock b/packages/uipath-agent-framework/uv.lock index cda770e3..f28f5001 100644 --- a/packages/uipath-agent-framework/uv.lock +++ b/packages/uipath-agent-framework/uv.lock @@ -2448,7 +2448,7 @@ wheels = [ [[package]] name = "uipath-agent-framework" -version = "0.0.3" +version = "0.0.4" source = { editable = "." } dependencies = [ { name = "agent-framework-core" }, From 899bd22414246386d9ee0d2834ae0276516c9e8a Mon Sep 17 00:00:00 2001 From: Cristian Pufu Date: Fri, 20 Feb 2026 15:06:01 +0200 Subject: [PATCH 2/4] feat: exclude handoff tools from graph tool nodes Handoff tools (handoff_to_) are already represented as edges between agent nodes, so showing them as tool nodes is redundant. Filter them out so the graph only shows non-handoff tools in tool nodes. Co-Authored-By: Claude Opus 4.6 --- .../uipath_agent_framework/runtime/schema.py | 23 ++++++- .../tests/test_graph.py | 68 +++++++++++++++++++ 2 files changed, 89 insertions(+), 2 deletions(-) diff --git a/packages/uipath-agent-framework/src/uipath_agent_framework/runtime/schema.py b/packages/uipath-agent-framework/src/uipath_agent_framework/runtime/schema.py index 03324976..73072135 100644 --- a/packages/uipath-agent-framework/src/uipath_agent_framework/runtime/schema.py +++ b/packages/uipath-agent-framework/src/uipath_agent_framework/runtime/schema.py @@ -137,6 +137,7 @@ def _build_workflow_graph(workflow: Workflow) -> UiPathRuntimeGraph: executors: dict[str, Executor] = workflow.executors start_id: str = workflow.start_executor_id + executor_ids: set[str] = set(executors.keys()) # Add a node for each executor for exec_id, executor in executors.items(): @@ -154,7 +155,9 @@ def _build_workflow_graph(workflow: Workflow) -> UiPathRuntimeGraph: if isinstance(executor, AgentExecutor): inner_agent: BaseAgent | None = getattr(executor, "_agent", None) if inner_agent is not None: - _add_executor_tool_nodes(exec_id, inner_agent, nodes, edges) + _add_executor_tool_nodes( + exec_id, inner_agent, nodes, edges, executor_ids + ) # Connect __start__ → start executor edges.append(UiPathRuntimeEdge(source="__start__", target=start_id, label="input")) @@ -195,13 +198,26 @@ def _build_workflow_graph(workflow: Workflow) -> UiPathRuntimeGraph: return UiPathRuntimeGraph(nodes=nodes, edges=edges) +def _is_handoff_tool(tool_name: str, executor_ids: set[str]) -> bool: + """Check if a tool is a handoff tool by matching against executor IDs.""" + if not tool_name.startswith("handoff_to_"): + return False + target = tool_name[len("handoff_to_"):] + return target in executor_ids + + def _add_executor_tool_nodes( executor_id: str, agent: BaseAgent, nodes: list[UiPathRuntimeNode], edges: list[UiPathRuntimeEdge], + executor_ids: set[str], ) -> None: - """Add tool nodes for an executor's wrapped agent's tools.""" + """Add tool nodes for an executor's wrapped agent's tools. + + Handoff tools (tools that transfer control to another executor) are + excluded since they are already represented as edges between nodes. + """ tools = get_agent_tools(agent) if not tools: return @@ -209,6 +225,9 @@ def _add_executor_tool_nodes( tool_names = [get_tool_name(t) for t in tools] tool_names = [n for n in tool_names if n] + # Filter out handoff tools — they are represented as edges, not tool nodes + tool_names = [n for n in tool_names if not _is_handoff_tool(n, executor_ids)] + if tool_names: tools_node_id = f"{executor_id}_tools" nodes.append( diff --git a/packages/uipath-agent-framework/tests/test_graph.py b/packages/uipath-agent-framework/tests/test_graph.py index 7096523e..dda86e7b 100644 --- a/packages/uipath-agent-framework/tests/test_graph.py +++ b/packages/uipath-agent-framework/tests/test_graph.py @@ -300,6 +300,74 @@ def search_wikipedia(): assert ("researcher", "researcher_tools") in edge_pairs assert ("researcher_tools", "researcher") in edge_pairs + def test_handoff_tools_excluded_from_tool_nodes(self): + """Handoff tools are not shown as tool nodes — they are edges.""" + + def handoff_to_billing(): + pass + + handoff_to_billing.__name__ = "handoff_to_billing" + + def handoff_to_tech(): + pass + + handoff_to_tech.__name__ = "handoff_to_tech" + + triage_agent = _make_agent( + name="triage", tools=[handoff_to_billing, handoff_to_tech] + ) + executors = { + "triage": _make_executor("triage", agent=triage_agent), + "billing": _make_executor("billing"), + "tech": _make_executor("tech"), + } + workflow = _make_workflow( + executors=executors, + edge_groups=[], + start_executor_id="triage", + ) + agent = _make_workflow_agent(workflow) + graph = get_agent_graph(agent) + + node_ids = {n.id for n in graph.nodes} + assert "triage_tools" not in node_ids + + def test_mixed_handoff_and_regular_tools(self): + """Only non-handoff tools appear in tool nodes when mixed with handoffs.""" + + def handoff_to_billing(): + pass + + handoff_to_billing.__name__ = "handoff_to_billing" + + def search_docs(): + pass + + search_docs.__name__ = "search_docs" + + triage_agent = _make_agent( + name="triage", tools=[handoff_to_billing, search_docs] + ) + executors = { + "triage": _make_executor("triage", agent=triage_agent), + "billing": _make_executor("billing"), + } + workflow = _make_workflow( + executors=executors, + edge_groups=[], + start_executor_id="triage", + ) + agent = _make_workflow_agent(workflow) + graph = get_agent_graph(agent) + + node_ids = {n.id for n in graph.nodes} + assert "triage_tools" in node_ids + + tools_node = next(n for n in graph.nodes if n.id == "triage_tools") + assert tools_node.metadata is not None + assert tools_node.metadata["tool_names"] == ["search_docs"] + assert tools_node.metadata["tool_count"] == 1 + def test_workflow_edge_condition_labels(self): """Conditional edges include condition_name as label.""" executors = { From 807feed90f94e37e785762a53bf9c95484037fb5 Mon Sep 17 00:00:00 2001 From: Cristian Pufu Date: Fri, 20 Feb 2026 15:25:35 +0200 Subject: [PATCH 3/4] fix: extract tool state events from executor_completed when output events are filtered MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The agent framework's Workflow filters output events to only include those from "output executors". In multi-agent workflows like GroupChat, participant agents are not output executors, so their tool call output events get dropped — causing missing tool state events in the runtime. Fall back to extracting tool events from executor_completed data when no output events were seen for that executor, with dedup tracking to avoid double-emitting in the normal (unfiltered) case. Co-Authored-By: Claude Opus 4.6 --- .../uipath_agent_framework/runtime/runtime.py | 19 ++- .../tests/test_streaming.py | 112 ++++++++++++++++++ 2 files changed, 129 insertions(+), 2 deletions(-) diff --git a/packages/uipath-agent-framework/src/uipath_agent_framework/runtime/runtime.py b/packages/uipath-agent-framework/src/uipath_agent_framework/runtime/runtime.py index 0a960e10..20f3062c 100644 --- a/packages/uipath-agent-framework/src/uipath_agent_framework/runtime/runtime.py +++ b/packages/uipath-agent-framework/src/uipath_agent_framework/runtime/runtime.py @@ -349,6 +349,10 @@ async def _stream_workflow( request_info_map: dict[str, Any] = {} is_suspended = False + # Track executors whose tool events were emitted via output events. + # When the workflow filters output events (e.g. GroupChat), tool events + # are instead extracted from executor_completed data as a fallback. + executors_with_tool_outputs: set[str] = set() # Emit an early STARTED event for the start executor so the graph # visualization shows it immediately rather than after it finishes. @@ -380,6 +384,14 @@ async def _stream_workflow( phase=UiPathRuntimeStatePhase.STARTED, ) elif event.type == "executor_completed": + # When output events were filtered by the workflow (e.g. + # GroupChat where participants are not output executors), + # extract tool state events from the completed data instead. + if event.executor_id not in executors_with_tool_outputs: + for tool_event in self._extract_tool_state_events( + event.data, event.executor_id + ): + yield tool_event yield UiPathRuntimeStateEvent( payload=self._serialize_event_data( self._filter_completed_data(event.data) @@ -389,9 +401,12 @@ async def _stream_workflow( ) elif event.type == "output": executor_id = getattr(event, "executor_id", None) or "" - for tool_event in self._extract_tool_state_events( + tool_events = self._extract_tool_state_events( event.data, executor_id - ): + ) + if tool_events: + executors_with_tool_outputs.add(executor_id) + for tool_event in tool_events: yield tool_event for msg_event in self._extract_workflow_messages(event.data): yield UiPathRuntimeMessageEvent(payload=msg_event) diff --git a/packages/uipath-agent-framework/tests/test_streaming.py b/packages/uipath-agent-framework/tests/test_streaming.py index 4b0c4e0b..97f38873 100644 --- a/packages/uipath-agent-framework/tests/test_streaming.py +++ b/packages/uipath-agent-framework/tests/test_streaming.py @@ -522,6 +522,118 @@ async def test_executor_completed_payload_excludes_streaming_updates(self): # Text token chunks should not appear in the completed payload assert "Hello" not in payload_str or "world" not in payload_str + async def test_tool_events_from_executor_completed_when_output_filtered(self): + """When the workflow filters output events (e.g. GroupChat), tool state + events should still appear — extracted from executor_completed data.""" + worker = RawAgent(_mock_client, name="researcher", tools=[search_wikipedia]) + workflow = WorkflowBuilder(start_executor=worker).build() # type: ignore[arg-type] + agent = WorkflowAgent(workflow=workflow, name="wf") + + # Simulate a workflow that does NOT emit output events (they were + # filtered by _should_yield_output_event), but executor_completed + # carries the full data including AgentResponseUpdate chunks. + call_update = AgentResponseUpdate( + contents=[ + Content( + type="function_call", + name="search_wikipedia", + call_id="c1", + arguments='{"query": "test"}', + ) + ] + ) + result_update = AgentResponseUpdate( + contents=[ + Content(type="function_result", call_id="c1", result="found") + ] + ) + summary = MagicMock() # AgentExecutorResponse + completed_data = [summary, call_update, result_update] + + final = MagicMock() + final.get_outputs.return_value = [] + workflow.run = MagicMock( # type: ignore[method-assign] + return_value=_MockAsyncStream( + [ + # No "output" events — simulating the filter + _wf_event("executor_invoked", "researcher"), + _wf_event("executor_completed", "researcher", data=completed_data), + ], + final, + ) + ) + agent.create_session = MagicMock(return_value=MagicMock()) # type: ignore[method-assign] + + runtime = _make_runtime(agent) + events = await _collect_events(runtime) + + se = _state_events(events) + tool_events = [(n, p) for n, p in se if n == "researcher_tools"] + assert ("researcher_tools", STARTED) in tool_events + assert ("researcher_tools", COMPLETED) in tool_events + + # Tool STARTED should carry the tool_name + tool_started = [ + e + for e in events + if isinstance(e, UiPathRuntimeStateEvent) + and e.node_name == "researcher_tools" + and e.phase == STARTED + ] + assert len(tool_started) == 1 + assert tool_started[0].payload == {"tool_name": "search_wikipedia"} + + async def test_no_duplicate_tool_events_when_output_present(self): + """When output events ARE emitted (normal case), tool events should NOT + be extracted again from executor_completed to avoid duplicates.""" + worker = RawAgent(_mock_client, name="agent_y", tools=[calculator]) + workflow = WorkflowBuilder(start_executor=worker).build() # type: ignore[arg-type] + agent = WorkflowAgent(workflow=workflow, name="wf") + + call_update = AgentResponseUpdate( + contents=[ + Content( + type="function_call", + name="calculator", + call_id="c1", + arguments="{}", + ) + ] + ) + result_update = AgentResponseUpdate( + contents=[ + Content(type="function_result", call_id="c1", result="42") + ] + ) + summary = MagicMock() + completed_data = [summary, call_update, result_update] + + final = MagicMock() + final.get_outputs.return_value = [] + workflow.run = MagicMock( # type: ignore[method-assign] + return_value=_MockAsyncStream( + [ + _wf_event("executor_invoked", "agent_y"), + # Output events ARE present (normal path) + _wf_event("output", "agent_y", data=call_update), + _wf_event("output", "agent_y", data=result_update), + _wf_event("executor_completed", "agent_y", data=completed_data), + ], + final, + ) + ) + agent.create_session = MagicMock(return_value=MagicMock()) # type: ignore[method-assign] + + runtime = _make_runtime(agent) + events = await _collect_events(runtime) + + se = _state_events(events) + tool_events = [(n, p) for n, p in se if n == "agent_y_tools"] + + # Should have exactly 1 STARTED + 1 COMPLETED (not duplicated) + assert tool_events.count(("agent_y_tools", STARTED)) == 1 + assert tool_events.count(("agent_y_tools", COMPLETED)) == 1 + # =========================================================================== # Checkpoint propagation tests From 99f40862adea73a1418a34e295be51122fbd112c Mon Sep 17 00:00:00 2001 From: Cristian Pufu Date: Fri, 20 Feb 2026 18:30:44 +0200 Subject: [PATCH 4/4] fix: resolve breakpoint infinite loop in debug streaming with cyclic topologies - Add checkpoint advancement detection to _get_breakpoint_skip() to correctly handle cyclic graphs (GroupChat, handoffs) where executors are visited on every cycle - Merge interrupt.py contents into breakpoints.py (AgentInterruptException, BreakpointMiddleware) - Add breakpoint integration tests for all sample topologies (group-chat, quickstart-workflow, concurrent, handoff, hitl-workflow) - Move agent-framework-orchestrations to main dependencies - Bump version to 0.0.5 - Fix all mypy errors across src and tests - Add LangChain integration to root README.md - Add .vscode/ to .gitignore Co-Authored-By: Claude Opus 4.6 --- .gitignore | 3 +- README.md | 5 +- .../uipath-agent-framework/pyproject.toml | 3 +- .../samples/concurrent/.vscode/settings.json | 3 - .../samples/concurrent/pyproject.toml | 4 - .../samples/group-chat/agent.mermaid | 18 + .../samples/handoff/agent.mermaid | 24 + .../quickstart-workflow/.vscode/settings.json | 3 - .../samples/quickstart-workflow/agent.mermaid | 9 + .../quickstart-workflow/pyproject.toml | 4 - .../runtime/__init__.py | 2 +- .../runtime/breakpoints.py | 163 ++- .../uipath_agent_framework/runtime/factory.py | 1 + .../runtime/interrupt.py | 131 -- .../runtime/resumable_storage.py | 42 +- .../uipath_agent_framework/runtime/runtime.py | 102 +- .../uipath_agent_framework/runtime/schema.py | 12 +- .../tests/test_breakpoints.py | 681 +++++++-- .../tests/test_group_chat_breakpoints.py | 1220 +++++++++++++++++ .../tests/test_storage.py | 61 +- .../tests/test_streaming.py | 80 +- packages/uipath-agent-framework/uv.lock | 16 +- 22 files changed, 2157 insertions(+), 430 deletions(-) delete mode 100644 packages/uipath-agent-framework/samples/concurrent/.vscode/settings.json create mode 100644 packages/uipath-agent-framework/samples/group-chat/agent.mermaid create mode 100644 packages/uipath-agent-framework/samples/handoff/agent.mermaid delete mode 100644 packages/uipath-agent-framework/samples/quickstart-workflow/.vscode/settings.json create mode 100644 packages/uipath-agent-framework/samples/quickstart-workflow/agent.mermaid delete mode 100644 packages/uipath-agent-framework/src/uipath_agent_framework/runtime/interrupt.py create mode 100644 packages/uipath-agent-framework/tests/test_group_chat_breakpoints.py diff --git a/.gitignore b/.gitignore index fc1b8ec3..559b6b3a 100644 --- a/.gitignore +++ b/.gitignore @@ -179,13 +179,14 @@ cython_debug/ **/__uipath/ .claude/settings.local.json -/.vscode/launch.json +.vscode/ playground.py # Samples generated files **/samples/**/.agent/ **/samples/**/.claude/ +**/samples/**/.vscode/ **/samples/**/AGENTS.md **/samples/**/CLAUDE.md **/samples/**/bindings.json diff --git a/README.md b/README.md index d51a8750..6f1fa8b4 100644 --- a/README.md +++ b/README.md @@ -8,10 +8,11 @@ All packages extend the [UiPath Python SDK](https://github.com/UiPath/uipath-pyt | Framework | Version | Downloads | Links | |---|---|---|---| -| [LlamaIndex](https://www.llamaindex.ai/) | [![PyPI](https://img.shields.io/pypi/v/uipath-llamaindex)](https://pypi.org/project/uipath-llamaindex/) | [![Downloads](https://img.shields.io/pypi/dm/uipath-llamaindex.svg)](https://pypi.org/project/uipath-llamaindex/) | [README](packages/uipath-llamaindex/README.md) · [Docs](https://uipath.github.io/uipath-python/llamaindex/quick_start/) · [Samples](packages/uipath-llamaindex/samples/) | -| [OpenAI Agents](https://github.com/openai/openai-agents-python) | [![PyPI](https://img.shields.io/pypi/v/uipath-openai-agents)](https://pypi.org/project/uipath-openai-agents/) | [![Downloads](https://img.shields.io/pypi/dm/uipath-openai-agents.svg)](https://pypi.org/project/uipath-openai-agents/) | [README](packages/uipath-openai-agents/README.md) · [Docs](https://uipath.github.io/uipath-python/openai-agents/quick_start/) · [Samples](packages/uipath-openai-agents/samples/) | | [Google ADK](https://github.com/google/adk-python) | [![PyPI](https://img.shields.io/pypi/v/uipath-google-adk)](https://pypi.org/project/uipath-google-adk/) | [![Downloads](https://img.shields.io/pypi/dm/uipath-google-adk.svg)](https://pypi.org/project/uipath-google-adk/) | [README](packages/uipath-google-adk/README.md) · [Samples](packages/uipath-google-adk/samples/) | +| [LangChain](https://github.com/langchain-ai/langchain) | [![PyPI](https://img.shields.io/pypi/v/uipath-langchain)](https://pypi.org/project/uipath-langchain/) | [![Downloads](https://img.shields.io/pypi/dm/uipath-langchain.svg)](https://pypi.org/project/uipath-langchain/) | [README](https://github.com/UiPath/uipath-langchain-python#readme) · [Docs](https://uipath.github.io/uipath-python/langchain/quick_start/) · [Samples](https://github.com/UiPath/uipath-langchain-python/tree/main/samples) | +| [LlamaIndex](https://www.llamaindex.ai/) | [![PyPI](https://img.shields.io/pypi/v/uipath-llamaindex)](https://pypi.org/project/uipath-llamaindex/) | [![Downloads](https://img.shields.io/pypi/dm/uipath-llamaindex.svg)](https://pypi.org/project/uipath-llamaindex/) | [README](packages/uipath-llamaindex/README.md) · [Docs](https://uipath.github.io/uipath-python/llamaindex/quick_start/) · [Samples](packages/uipath-llamaindex/samples/) | | [Microsoft Agent Framework](https://github.com/microsoft/agent-framework) | [![PyPI](https://img.shields.io/pypi/v/uipath-agent-framework)](https://pypi.org/project/uipath-agent-framework/) | [![Downloads](https://img.shields.io/pypi/dm/uipath-agent-framework.svg)](https://pypi.org/project/uipath-agent-framework/) | [README](packages/uipath-agent-framework/README.md) · [Samples](packages/uipath-agent-framework/samples/) | +| [OpenAI Agents](https://github.com/openai/openai-agents-python) | [![PyPI](https://img.shields.io/pypi/v/uipath-openai-agents)](https://pypi.org/project/uipath-openai-agents/) | [![Downloads](https://img.shields.io/pypi/dm/uipath-openai-agents.svg)](https://pypi.org/project/uipath-openai-agents/) | [README](packages/uipath-openai-agents/README.md) · [Docs](https://uipath.github.io/uipath-python/openai-agents/quick_start/) · [Samples](packages/uipath-openai-agents/samples/) | ## Structure diff --git a/packages/uipath-agent-framework/pyproject.toml b/packages/uipath-agent-framework/pyproject.toml index 03b85298..33320645 100644 --- a/packages/uipath-agent-framework/pyproject.toml +++ b/packages/uipath-agent-framework/pyproject.toml @@ -1,11 +1,12 @@ [project] name = "uipath-agent-framework" -version = "0.0.4" +version = "0.0.5" description = "Python SDK that enables developers to build and deploy Microsoft Agent Framework agents to the UiPath Cloud Platform" readme = "README.md" requires-python = ">=3.11" dependencies = [ "agent-framework-core>=1.0.0b260212", + "agent-framework-orchestrations>=1.0.0b260212", "aiosqlite>=0.20.0", "openinference-instrumentation-agent-framework>=0.1.0", "uipath>=2.8.41, <2.9.0", diff --git a/packages/uipath-agent-framework/samples/concurrent/.vscode/settings.json b/packages/uipath-agent-framework/samples/concurrent/.vscode/settings.json deleted file mode 100644 index af690fcd..00000000 --- a/packages/uipath-agent-framework/samples/concurrent/.vscode/settings.json +++ /dev/null @@ -1,3 +0,0 @@ -{ - "python-envs.pythonProjects": [] -} diff --git a/packages/uipath-agent-framework/samples/concurrent/pyproject.toml b/packages/uipath-agent-framework/samples/concurrent/pyproject.toml index a3252cf0..44ae5b2b 100644 --- a/packages/uipath-agent-framework/samples/concurrent/pyproject.toml +++ b/packages/uipath-agent-framework/samples/concurrent/pyproject.toml @@ -19,7 +19,3 @@ dev = [ [tool.uv] prerelease = "allow" - -[tool.uv.sources] -uipath-dev = { path = "../../../../../uipath-dev-python", editable = true } -uipath-agent-framework = { path = "../../", editable = true } diff --git a/packages/uipath-agent-framework/samples/group-chat/agent.mermaid b/packages/uipath-agent-framework/samples/group-chat/agent.mermaid new file mode 100644 index 00000000..94f17d7d --- /dev/null +++ b/packages/uipath-agent-framework/samples/group-chat/agent.mermaid @@ -0,0 +1,18 @@ +flowchart TB + __start__(__start__) + __end__(__end__) + orchestrator(orchestrator) + researcher(researcher) + researcher_tools(tools) + critic(critic) + writer(writer) + researcher --> researcher_tools + researcher_tools --> researcher + __start__ --> |input|orchestrator + orchestrator --> researcher + researcher --> orchestrator + orchestrator --> critic + critic --> orchestrator + orchestrator --> writer + writer --> orchestrator + orchestrator --> |output|__end__ diff --git a/packages/uipath-agent-framework/samples/handoff/agent.mermaid b/packages/uipath-agent-framework/samples/handoff/agent.mermaid new file mode 100644 index 00000000..e2466d18 --- /dev/null +++ b/packages/uipath-agent-framework/samples/handoff/agent.mermaid @@ -0,0 +1,24 @@ +flowchart TB + __start__(__start__) + __end__(__end__) + triage(triage) + billing_agent(billing_agent) + tech_agent(tech_agent) + returns_agent(returns_agent) + __start__ --> |input|triage + triage --> billing_agent + triage --> tech_agent + triage --> returns_agent + billing_agent --> triage + billing_agent --> tech_agent + billing_agent --> returns_agent + tech_agent --> triage + tech_agent --> billing_agent + tech_agent --> returns_agent + returns_agent --> triage + returns_agent --> billing_agent + returns_agent --> tech_agent + triage --> |output|__end__ + billing_agent --> |output|__end__ + tech_agent --> |output|__end__ + returns_agent --> |output|__end__ diff --git a/packages/uipath-agent-framework/samples/quickstart-workflow/.vscode/settings.json b/packages/uipath-agent-framework/samples/quickstart-workflow/.vscode/settings.json deleted file mode 100644 index af690fcd..00000000 --- a/packages/uipath-agent-framework/samples/quickstart-workflow/.vscode/settings.json +++ /dev/null @@ -1,3 +0,0 @@ -{ - "python-envs.pythonProjects": [] -} diff --git a/packages/uipath-agent-framework/samples/quickstart-workflow/agent.mermaid b/packages/uipath-agent-framework/samples/quickstart-workflow/agent.mermaid new file mode 100644 index 00000000..cbb6b288 --- /dev/null +++ b/packages/uipath-agent-framework/samples/quickstart-workflow/agent.mermaid @@ -0,0 +1,9 @@ +flowchart TB + __start__(__start__) + __end__(__end__) + weather_agent(weather_agent) + weather_agent_tools(tools) + weather_agent --> weather_agent_tools + weather_agent_tools --> weather_agent + __start__ --> |input|weather_agent + weather_agent --> |output|__end__ diff --git a/packages/uipath-agent-framework/samples/quickstart-workflow/pyproject.toml b/packages/uipath-agent-framework/samples/quickstart-workflow/pyproject.toml index a7559704..a42ca93b 100644 --- a/packages/uipath-agent-framework/samples/quickstart-workflow/pyproject.toml +++ b/packages/uipath-agent-framework/samples/quickstart-workflow/pyproject.toml @@ -18,7 +18,3 @@ dev = [ [tool.uv] prerelease = "allow" - -[tool.uv.sources] -uipath-dev = { path = "../../../../../uipath-dev-python", editable = true } -uipath-agent-framework = { path = "../../", editable = true } diff --git a/packages/uipath-agent-framework/src/uipath_agent_framework/runtime/__init__.py b/packages/uipath-agent-framework/src/uipath_agent_framework/runtime/__init__.py index 667479ee..99d5b632 100644 --- a/packages/uipath-agent-framework/src/uipath_agent_framework/runtime/__init__.py +++ b/packages/uipath-agent-framework/src/uipath_agent_framework/runtime/__init__.py @@ -6,8 +6,8 @@ UiPathRuntimeFactoryRegistry, ) +from .breakpoints import AgentInterruptException, BreakpointMiddleware from .factory import UiPathAgentFrameworkRuntimeFactory -from .interrupt import AgentInterruptException, BreakpointMiddleware from .resumable_storage import ( ScopedCheckpointStorage, SqliteCheckpointStorage, diff --git a/packages/uipath-agent-framework/src/uipath_agent_framework/runtime/breakpoints.py b/packages/uipath-agent-framework/src/uipath_agent_framework/runtime/breakpoints.py index 8fc85ebd..f5b6306c 100644 --- a/packages/uipath-agent-framework/src/uipath_agent_framework/runtime/breakpoints.py +++ b/packages/uipath-agent-framework/src/uipath_agent_framework/runtime/breakpoints.py @@ -1,5 +1,9 @@ """Breakpoint management for the Agent Framework runtime. +Provides: +- AgentInterruptException: raised by middleware to suspend agent execution +- BreakpointMiddleware: intercepts tools matching breakpoint configuration + Implements breakpoints by wrapping executor.execute() methods so that execution pauses BEFORE the executor runs. This works regardless of the inner agent type (RawAgent, Agent, etc.) because interception @@ -14,15 +18,128 @@ from __future__ import annotations +from collections.abc import Awaitable, Callable from typing import Any from uuid import uuid4 from agent_framework import AgentExecutor, WorkflowAgent +from agent_framework._middleware import ( + FunctionInvocationContext, + FunctionMiddleware, +) from uipath.runtime.debug import UiPathBreakpointResult -from .interrupt import AgentInterruptException from .schema import get_agent_tools, get_tool_name + +class AgentInterruptException(Exception): + """Raised by middleware to suspend agent execution for HITL. + + Carries an interrupt_id and suspend_value that the runtime uses + to create a UiPathRuntimeResult with SUSPENDED status. + When is_breakpoint is True, the runtime returns UiPathBreakpointResult + instead, which bypasses trigger management and is handled by the + debug runtime layer. + """ + + def __init__( + self, + interrupt_id: str, + suspend_value: Any, + *, + is_breakpoint: bool = False, + ) -> None: + self.interrupt_id = interrupt_id + self.suspend_value = suspend_value + self.is_breakpoint = is_breakpoint + super().__init__(f"Agent interrupted: {interrupt_id}") + + +class BreakpointMiddleware(FunctionMiddleware): + """Intercepts tools matching breakpoint configuration. + + Breakpoint flow (orchestrated by UiPathDebugRuntime): + + 1. UiPathDebugRuntime gets breakpoints from debug bridge and passes + them via ``options.breakpoints`` to the integration runtime. + 2. The integration runtime injects this middleware into the agent's + middleware chain with the breakpoint list. + 3. When the agent calls a matching tool, this middleware raises + ``AgentInterruptException(is_breakpoint=True)`` BEFORE the tool runs. + 4. The runtime catches the exception and returns + ``UiPathBreakpointResult`` (a SUSPENDED result subclass). + 5. ``UiPathResumableRuntime`` passes the breakpoint result through + (no trigger management — breakpoints bypass the trigger system). + 6. ``UiPathDebugRuntime`` sees ``UiPathBreakpointResult``, notifies + the debug bridge, and waits for a resume command. + 7. On resume, ``UiPathDebugRuntime`` re-invokes the runtime with + ``options.resume=True, input=None``. The runtime re-injects this + middleware with ``skip_tool`` set to the previously-interrupted + tool name so the first matching call is let through (one-shot). + 8. After the skipped call completes, subsequent breakpoint-matching + tool calls will pause again. + """ + + def __init__( + self, + breakpoints: list[str] | str, + skip_tool: str | None = None, + ) -> None: + self.breakpoints = breakpoints + self._skip_tool = skip_tool + + def _matches(self, tool_name: str) -> bool: + if self.breakpoints == "*": + return True + if isinstance(self.breakpoints, list): + return tool_name in self.breakpoints + return False + + async def process( + self, + context: FunctionInvocationContext, + call_next: Callable[[], Awaitable[None]], + ) -> None: + tool = context.function + tool_name = getattr(tool, "name", "") + + if not self._matches(tool_name): + await call_next() + return + + # One-shot skip for the tool we just resumed from + if self._skip_tool and tool_name == self._skip_tool: + self._skip_tool = None + await call_next() + return + + # Legacy metadata-based resume (kept for backward compatibility) + if context.metadata.get("_breakpoint_continue"): + await call_next() + return + + interrupt_id = str(uuid4()) + + input_value: Any = None + if context.arguments is not None: + try: + input_value = context.arguments.model_dump() + except Exception: + input_value = str(context.arguments) + + suspend_value = { + "type": "breakpoint", + "tool_name": tool_name, + "input_value": input_value, + } + + raise AgentInterruptException( + interrupt_id=interrupt_id, + suspend_value=suspend_value, + is_breakpoint=True, + ) + + _ORIGINAL_EXECUTE_ATTR = "_bp_original_execute" @@ -34,7 +151,7 @@ def _build_executor_tool_map(agent: WorkflowAgent) -> dict[str, set[str]]: inner = getattr(executor, "_agent", None) if inner is not None: tools = get_agent_tools(inner) - names = {get_tool_name(t) for t in tools if get_tool_name(t)} + names = {n for t in tools if (n := get_tool_name(t)) is not None} tool_map[exec_id] = names return tool_map @@ -81,22 +198,35 @@ def _resolve_to_executor_ids( def inject_breakpoint_middleware( agent: WorkflowAgent, breakpoints: list[str] | str, - skip_nodes: set[str] | None = None, + skip_nodes: dict[str, int] | None = None, ) -> None: """Wrap executor.execute() to pause before breakpointed executors run. Replaces each matching executor's execute() with a wrapper that raises AgentInterruptException(is_breakpoint=True) before the executor runs. + For executors in *skip_nodes*, the wrapper allows *N* pass-through calls + (running the original execute) before re-arming the breakpoint. The + count *N* equals the number of times that executor has previously been + breakpointed and resumed — this correctly handles both: + + * **GroupChat star topology** where the orchestrator is called multiple + times per workflow run (initial + once per participant response). + * **Cyclic graphs** where each executor is visited on every cycle. + + Each resume increments the count so the executor passes through all the + calls that happened *before* the breakpoint, then breaks on the next new + call. + Args: agent: The workflow agent whose executors to wrap. breakpoints: ``"*"`` or a list of node IDs from the debug UI. - skip_nodes: Executor IDs to skip (for resume after breakpoint). - In concurrent workflows multiple executors may have been - breakpointed across sequential resumes within the same - superstep, so all of them must be skipped. + skip_nodes: Mapping of executor_id → pass-through count. + Each value is the number of calls to let through before + re-arming the breakpoint on that executor. """ executor_ids = _resolve_to_executor_ids(agent, breakpoints) + skip = skip_nodes or {} for exec_id in executor_ids: executor = agent.workflow.executors.get(exec_id) @@ -107,11 +237,8 @@ def inject_breakpoint_middleware( if hasattr(executor, _ORIGINAL_EXECUTE_ATTR): continue - # Skip executors already resumed past - if skip_nodes and exec_id in skip_nodes: - continue - original = executor.execute + pass_count = skip.get(exec_id, 0) async def wrapped_execute( message: Any, @@ -121,8 +248,20 @@ async def wrapped_execute( trace_contexts: list[dict[str, str]] | None = None, source_span_ids: list[str] | None = None, *, + _original: Any = original, _exec_id: str = exec_id, + _remaining: list[int] = [pass_count], # noqa: B006 ) -> None: + if _remaining[0] > 0: + _remaining[0] -= 1 + return await _original( + message, + source_executor_ids, + state, + runner_context, + trace_contexts, + source_span_ids, + ) raise AgentInterruptException( interrupt_id=str(uuid4()), suspend_value={ @@ -162,6 +301,8 @@ def create_breakpoint_result( __all__ = [ + "AgentInterruptException", + "BreakpointMiddleware", "create_breakpoint_result", "inject_breakpoint_middleware", "remove_breakpoint_middleware", diff --git a/packages/uipath-agent-framework/src/uipath_agent_framework/runtime/factory.py b/packages/uipath-agent-framework/src/uipath_agent_framework/runtime/factory.py index 611f8bae..1b3bfb42 100644 --- a/packages/uipath-agent-framework/src/uipath_agent_framework/runtime/factory.py +++ b/packages/uipath-agent-framework/src/uipath_agent_framework/runtime/factory.py @@ -210,6 +210,7 @@ async def _create_runtime_instance( UiPathResumableRuntime for resume trigger lifecycle handling. """ storage = await self._get_storage() + assert storage.checkpoint_storage is not None checkpoint_storage = ScopedCheckpointStorage( storage.checkpoint_storage, runtime_id ) diff --git a/packages/uipath-agent-framework/src/uipath_agent_framework/runtime/interrupt.py b/packages/uipath-agent-framework/src/uipath_agent_framework/runtime/interrupt.py deleted file mode 100644 index c004eba0..00000000 --- a/packages/uipath-agent-framework/src/uipath_agent_framework/runtime/interrupt.py +++ /dev/null @@ -1,131 +0,0 @@ -"""Interrupt infrastructure for human-in-the-loop (HITL) support. - -Provides: -- AgentInterruptException: raised by middleware to suspend agent execution -- BreakpointMiddleware: intercepts tools matching breakpoint configuration -""" - -from __future__ import annotations - -from collections.abc import Awaitable, Callable -from typing import Any -from uuid import uuid4 - -from agent_framework._middleware import ( - FunctionInvocationContext, - FunctionMiddleware, -) - - -class AgentInterruptException(Exception): - """Raised by middleware to suspend agent execution for HITL. - - Carries an interrupt_id and suspend_value that the runtime uses - to create a UiPathRuntimeResult with SUSPENDED status. - When is_breakpoint is True, the runtime returns UiPathBreakpointResult - instead, which bypasses trigger management and is handled by the - debug runtime layer. - """ - - def __init__( - self, - interrupt_id: str, - suspend_value: Any, - *, - is_breakpoint: bool = False, - ) -> None: - self.interrupt_id = interrupt_id - self.suspend_value = suspend_value - self.is_breakpoint = is_breakpoint - super().__init__(f"Agent interrupted: {interrupt_id}") - - -class BreakpointMiddleware(FunctionMiddleware): - """Intercepts tools matching breakpoint configuration. - - Breakpoint flow (orchestrated by UiPathDebugRuntime): - - 1. UiPathDebugRuntime gets breakpoints from debug bridge and passes - them via ``options.breakpoints`` to the integration runtime. - 2. The integration runtime injects this middleware into the agent's - middleware chain with the breakpoint list. - 3. When the agent calls a matching tool, this middleware raises - ``AgentInterruptException(is_breakpoint=True)`` BEFORE the tool runs. - 4. The runtime catches the exception and returns - ``UiPathBreakpointResult`` (a SUSPENDED result subclass). - 5. ``UiPathResumableRuntime`` passes the breakpoint result through - (no trigger management — breakpoints bypass the trigger system). - 6. ``UiPathDebugRuntime`` sees ``UiPathBreakpointResult``, notifies - the debug bridge, and waits for a resume command. - 7. On resume, ``UiPathDebugRuntime`` re-invokes the runtime with - ``options.resume=True, input=None``. The runtime re-injects this - middleware with ``skip_tool`` set to the previously-interrupted - tool name so the first matching call is let through (one-shot). - 8. After the skipped call completes, subsequent breakpoint-matching - tool calls will pause again. - """ - - def __init__( - self, - breakpoints: list[str] | str, - skip_tool: str | None = None, - ) -> None: - self.breakpoints = breakpoints - self._skip_tool = skip_tool - - def _matches(self, tool_name: str) -> bool: - if self.breakpoints == "*": - return True - if isinstance(self.breakpoints, list): - return tool_name in self.breakpoints - return False - - async def process( - self, - context: FunctionInvocationContext, - call_next: Callable[[], Awaitable[None]], - ) -> None: - tool = context.function - tool_name = getattr(tool, "name", "") - - if not self._matches(tool_name): - await call_next() - return - - # One-shot skip for the tool we just resumed from - if self._skip_tool and tool_name == self._skip_tool: - self._skip_tool = None - await call_next() - return - - # Legacy metadata-based resume (kept for backward compatibility) - if context.metadata.get("_breakpoint_continue"): - await call_next() - return - - interrupt_id = str(uuid4()) - - input_value = None - if context.arguments is not None: - try: - input_value = context.arguments.model_dump() - except Exception: - input_value = str(context.arguments) - - suspend_value = { - "type": "breakpoint", - "tool_name": tool_name, - "input_value": input_value, - } - - raise AgentInterruptException( - interrupt_id=interrupt_id, - suspend_value=suspend_value, - is_breakpoint=True, - ) - - -__all__ = [ - "AgentInterruptException", - "BreakpointMiddleware", -] diff --git a/packages/uipath-agent-framework/src/uipath_agent_framework/runtime/resumable_storage.py b/packages/uipath-agent-framework/src/uipath_agent_framework/runtime/resumable_storage.py index ddcb2dbd..24a87d65 100644 --- a/packages/uipath-agent-framework/src/uipath_agent_framework/runtime/resumable_storage.py +++ b/packages/uipath-agent-framework/src/uipath_agent_framework/runtime/resumable_storage.py @@ -142,13 +142,9 @@ async def save_triggers( ) await conn.commit() - logger.debug( - "Saved %d triggers for runtime_id=%s", len(triggers), runtime_id - ) + logger.debug("Saved %d triggers for runtime_id=%s", len(triggers), runtime_id) - async def get_triggers( - self, runtime_id: str - ) -> list[UiPathResumeTrigger] | None: + async def get_triggers(self, runtime_id: str) -> list[UiPathResumeTrigger] | None: """Retrieve all resume triggers for this runtime_id.""" conn = await self._get_conn() async with self._lock: @@ -214,9 +210,7 @@ async def set_value( ) await conn.commit() - async def get_value( - self, runtime_id: str, namespace: str, key: str - ) -> Any: + async def get_value(self, runtime_id: str, namespace: str, key: str) -> Any: """Get arbitrary key-value pair scoped by runtime_id + namespace.""" conn = await self._get_conn() async with self._lock: @@ -368,7 +362,7 @@ async def load(self, checkpoint_id: str) -> WorkflowCheckpoint: row = await cursor.fetchone() if not row: - from agent_framework._workflows._checkpoint import ( + from agent_framework._workflows._checkpoint import ( # type: ignore[attr-defined] WorkflowCheckpointException, ) @@ -380,9 +374,7 @@ async def load(self, checkpoint_id: str) -> WorkflowCheckpoint: decoded = decode_checkpoint_value(encoded) return WorkflowCheckpoint.from_dict(decoded) - async def list_checkpoints( - self, *, workflow_name: str - ) -> list[WorkflowCheckpoint]: + async def list_checkpoints(self, *, workflow_name: str) -> list[WorkflowCheckpoint]: """List checkpoint objects for a given workflow name.""" conn = await self._storage._get_conn() async with self._storage._lock: @@ -410,9 +402,7 @@ async def delete(self, checkpoint_id: str) -> bool: await conn.commit() return cursor.rowcount > 0 - async def get_latest( - self, *, workflow_name: str - ) -> WorkflowCheckpoint | None: + async def get_latest(self, *, workflow_name: str) -> WorkflowCheckpoint | None: """Get the latest checkpoint for a given workflow name.""" conn = await self._storage._get_conn() async with self._storage._lock: @@ -429,9 +419,7 @@ async def get_latest( decoded = decode_checkpoint_value(encoded) return WorkflowCheckpoint.from_dict(decoded) - async def list_checkpoint_ids( - self, *, workflow_name: str - ) -> list[str]: + async def list_checkpoint_ids(self, *, workflow_name: str) -> list[str]: """List checkpoint IDs for a given workflow name.""" conn = await self._storage._get_conn() async with self._storage._lock: @@ -452,9 +440,7 @@ class ScopedCheckpointStorage: ``{runtime_id}::``. """ - def __init__( - self, delegate: SqliteCheckpointStorage, runtime_id: str - ) -> None: + def __init__(self, delegate: SqliteCheckpointStorage, runtime_id: str) -> None: self._delegate = delegate self._scope = f"{runtime_id}::" @@ -470,9 +456,7 @@ async def load(self, checkpoint_id: str) -> WorkflowCheckpoint: """Load by checkpoint_id (globally unique).""" return await self._delegate.load(checkpoint_id) - async def list_checkpoints( - self, *, workflow_name: str - ) -> list[WorkflowCheckpoint]: + async def list_checkpoints(self, *, workflow_name: str) -> list[WorkflowCheckpoint]: """List checkpoints with scoped workflow_name.""" return await self._delegate.list_checkpoints( workflow_name=self._scoped_name(workflow_name) @@ -482,17 +466,13 @@ async def delete(self, checkpoint_id: str) -> bool: """Delete by checkpoint_id (globally unique).""" return await self._delegate.delete(checkpoint_id) - async def get_latest( - self, *, workflow_name: str - ) -> WorkflowCheckpoint | None: + async def get_latest(self, *, workflow_name: str) -> WorkflowCheckpoint | None: """Get latest checkpoint with scoped workflow_name.""" return await self._delegate.get_latest( workflow_name=self._scoped_name(workflow_name) ) - async def list_checkpoint_ids( - self, *, workflow_name: str - ) -> list[str]: + async def list_checkpoint_ids(self, *, workflow_name: str) -> list[str]: """List checkpoint IDs with scoped workflow_name.""" return await self._delegate.list_checkpoint_ids( workflow_name=self._scoped_name(workflow_name) diff --git a/packages/uipath-agent-framework/src/uipath_agent_framework/runtime/runtime.py b/packages/uipath-agent-framework/src/uipath_agent_framework/runtime/runtime.py index 20f3062c..84a8dd77 100644 --- a/packages/uipath-agent-framework/src/uipath_agent_framework/runtime/runtime.py +++ b/packages/uipath-agent-framework/src/uipath_agent_framework/runtime/runtime.py @@ -32,12 +32,12 @@ from uipath.runtime.schema import UiPathRuntimeSchema from .breakpoints import ( + AgentInterruptException, create_breakpoint_result, inject_breakpoint_middleware, remove_breakpoint_middleware, ) from .errors import UiPathAgentFrameworkErrorCode, UiPathAgentFrameworkRuntimeError -from .interrupt import AgentInterruptException from .messages import AgentFrameworkChatMessagesMapper from .resumable_storage import ScopedCheckpointStorage, SqliteResumableStorage from .schema import get_agent_graph, get_entrypoints_schema @@ -61,8 +61,10 @@ def __init__( self._checkpoint_storage = checkpoint_storage self._resumable_storage = resumable_storage self._resume_responses: dict[str, Any] | None = None - self._breakpoint_skip_nodes: set[str] = set() + self._breakpoint_skip_nodes: dict[str, int] = {} + self._last_breakpoint_node: str | None = None self._last_checkpoint_id: str | None = None + self._resumed_from_checkpoint_id: str | None = None # ------------------------------------------------------------------ # Checkpoint helpers @@ -78,18 +80,24 @@ async def _get_latest_checkpoint_id(self) -> str | None: ) return checkpoint.checkpoint_id if checkpoint else None - async def _save_breakpoint_state(self, original_input: str) -> None: + async def _save_breakpoint_state( + self, original_input: str, checkpoint_id: str | None = None + ) -> None: """Persist breakpoint state to KV storage for resume. - The skip_nodes set accumulates across resumes so that concurrent - executors breakpointed in the same superstep are all skipped on - subsequent resumes (prevents the infinite-cycle bug). + skip_nodes is a dict mapping executor_id → pass-through count. + Each count records how many times the executor must be allowed + to run before re-arming its breakpoint. The count is incremented + every time the same executor hits a breakpoint again (cyclic + graphs, GroupChat orchestrators). """ if not self._resumable_storage: return - checkpoint_id = await self._get_latest_checkpoint_id() + if checkpoint_id is None: + checkpoint_id = await self._get_latest_checkpoint_id() state = { - "skip_nodes": list(self._breakpoint_skip_nodes), + "skip_nodes": dict(self._breakpoint_skip_nodes), + "last_breakpoint_node": self._last_breakpoint_node, "checkpoint_id": checkpoint_id, "original_input": original_input, } @@ -105,11 +113,23 @@ async def _load_breakpoint_state(self) -> dict[str, Any] | None: self.runtime_id, "breakpoint", "state" ) if state and isinstance(state, dict): - self._breakpoint_skip_nodes = set(state.get("skip_nodes", [])) + self._breakpoint_skip_nodes = dict(state.get("skip_nodes", {})) + self._last_breakpoint_node = state.get("last_breakpoint_node") self._last_checkpoint_id = state.get("checkpoint_id") return state return None + def _get_breakpoint_skip(self) -> dict[str, int]: + """Get the skip_nodes dict for breakpoint injection. + + Returns accumulated skip counts. The counts are reset whenever + a checkpoint advancement is detected (see the breakpoint handlers + in execute / _stream_workflow) so they stay correct regardless + of whether checkpoints advance per-executor or stay at a coarser + granularity. + """ + return dict(self._breakpoint_skip_nodes) + # ------------------------------------------------------------------ # Session helpers (multi-turn conversation history) # ------------------------------------------------------------------ @@ -171,6 +191,7 @@ async def execute( if self._resume_responses: checkpoint_id = await self._get_latest_checkpoint_id() + self._resumed_from_checkpoint_id = checkpoint_id result = await workflow.run( responses=self._resume_responses, checkpoint_id=checkpoint_id, @@ -178,6 +199,7 @@ async def execute( ) self._resume_responses = None else: + self._resumed_from_checkpoint_id = None result = await workflow.run( message="", checkpoint_storage=self._checkpoint_storage, @@ -186,12 +208,13 @@ async def execute( # Breakpoint resume: restore from checkpoint bp_state = await self._load_breakpoint_state() checkpoint_id = self._last_checkpoint_id + self._resumed_from_checkpoint_id = checkpoint_id original_input = bp_state.get("original_input", "") if bp_state else "" - # Inject breakpoints, skipping all previously-resumed executors + # Inject breakpoints with accumulated skip counts if options and options.breakpoints: inject_breakpoint_middleware( - self.agent, options.breakpoints, self._breakpoint_skip_nodes + self.agent, options.breakpoints, self._get_breakpoint_skip() ) if checkpoint_id: @@ -206,6 +229,7 @@ async def execute( ) else: # Fresh run: load session for multi-turn conversation history + self._resumed_from_checkpoint_id = None session = await self._load_session() self._apply_session_to_executors(session) @@ -232,9 +256,22 @@ async def execute( if isinstance(e.suspend_value, dict) else "" ) - self._breakpoint_skip_nodes.add(node_id) + # Detect checkpoint advancement and reset counts if needed + latest_checkpoint = await self._get_latest_checkpoint_id() + if ( + latest_checkpoint + and self._resumed_from_checkpoint_id + and latest_checkpoint != self._resumed_from_checkpoint_id + ): + self._breakpoint_skip_nodes = {} + self._breakpoint_skip_nodes[node_id] = ( + self._breakpoint_skip_nodes.get(node_id, 0) + 1 + ) + self._last_breakpoint_node = node_id original_input = self._prepare_input(input) if not is_resuming else "" - await self._save_breakpoint_state(original_input) + await self._save_breakpoint_state( + original_input, checkpoint_id=latest_checkpoint + ) return create_breakpoint_result(e) return self._create_suspended_result(e) except Exception as e: @@ -271,10 +308,10 @@ async def stream( session = await self._load_session() self._apply_session_to_executors(session) - # Inject breakpoints, skipping all previously-resumed executors + # Inject breakpoints — skip strategy depends on resume mode if options and options.breakpoints: inject_breakpoint_middleware( - self.agent, options.breakpoints, self._breakpoint_skip_nodes + self.agent, options.breakpoints, self._get_breakpoint_skip() ) else: @@ -323,6 +360,7 @@ async def _stream_workflow( if self._resume_responses: # HITL resume: pass responses to workflow with checkpoint checkpoint_id = await self._get_latest_checkpoint_id() + self._resumed_from_checkpoint_id = checkpoint_id response_stream = workflow.run( responses=self._resume_responses, checkpoint_id=checkpoint_id, @@ -333,6 +371,7 @@ async def _stream_workflow( elif self._last_checkpoint_id: # Breakpoint resume with checkpoint: restore and continue checkpoint_id = self._last_checkpoint_id + self._resumed_from_checkpoint_id = checkpoint_id self._last_checkpoint_id = None response_stream = workflow.run( checkpoint_id=checkpoint_id, @@ -341,6 +380,7 @@ async def _stream_workflow( ) else: # Fresh run (or breakpoint resume without checkpoint — uses original_input) + self._resumed_from_checkpoint_id = None response_stream = workflow.run( message=user_input, checkpoint_storage=self._checkpoint_storage, @@ -375,7 +415,10 @@ async def _stream_workflow( request_info_map[event.request_id] = event.data elif event.type == "executor_invoked": # Skip the duplicate for the start executor we already emitted - if pre_emitted_executor and event.executor_id == pre_emitted_executor: + if ( + pre_emitted_executor + and event.executor_id == pre_emitted_executor + ): pre_emitted_executor = None continue yield UiPathRuntimeStateEvent( @@ -387,7 +430,10 @@ async def _stream_workflow( # When output events were filtered by the workflow (e.g. # GroupChat where participants are not output executors), # extract tool state events from the completed data instead. - if event.executor_id not in executors_with_tool_outputs: + if ( + event.executor_id + and event.executor_id not in executors_with_tool_outputs + ): for tool_event in self._extract_tool_state_events( event.data, event.executor_id ): @@ -412,7 +458,10 @@ async def _stream_workflow( yield UiPathRuntimeMessageEvent(payload=msg_event) # Detect workflow suspension via state - if event.type == "status" and str(event.state) == "IDLE_WITH_PENDING_REQUESTS": + if ( + event.type == "status" + and str(event.state) == "IDLE_WITH_PENDING_REQUESTS" + ): is_suspended = True except AgentInterruptException as e: # Breakpoint or HITL interrupt fired inside an inner agent @@ -434,8 +483,21 @@ async def _stream_workflow( if isinstance(e.suspend_value, dict) else "" ) - self._breakpoint_skip_nodes.add(node_id) - await self._save_breakpoint_state(user_input) + # Detect checkpoint advancement and reset counts if needed + latest_checkpoint = await self._get_latest_checkpoint_id() + if ( + latest_checkpoint + and self._resumed_from_checkpoint_id + and latest_checkpoint != self._resumed_from_checkpoint_id + ): + self._breakpoint_skip_nodes = {} + self._breakpoint_skip_nodes[node_id] = ( + self._breakpoint_skip_nodes.get(node_id, 0) + 1 + ) + self._last_breakpoint_node = node_id + await self._save_breakpoint_state( + user_input, checkpoint_id=latest_checkpoint + ) yield create_breakpoint_result(e) else: yield self._create_suspended_result(e) diff --git a/packages/uipath-agent-framework/src/uipath_agent_framework/runtime/schema.py b/packages/uipath-agent-framework/src/uipath_agent_framework/runtime/schema.py index 73072135..642e1b42 100644 --- a/packages/uipath-agent-framework/src/uipath_agent_framework/runtime/schema.py +++ b/packages/uipath-agent-framework/src/uipath_agent_framework/runtime/schema.py @@ -202,7 +202,7 @@ def _is_handoff_tool(tool_name: str, executor_ids: set[str]) -> bool: """Check if a tool is a handoff tool by matching against executor IDs.""" if not tool_name.startswith("handoff_to_"): return False - target = tool_name[len("handoff_to_"):] + target = tool_name[len("handoff_to_") :] return target in executor_ids @@ -222,11 +222,11 @@ def _add_executor_tool_nodes( if not tools: return - tool_names = [get_tool_name(t) for t in tools] - tool_names = [n for n in tool_names if n] - - # Filter out handoff tools — they are represented as edges, not tool nodes - tool_names = [n for n in tool_names if not _is_handoff_tool(n, executor_ids)] + tool_names: list[str] = [ + n + for t in tools + if (n := get_tool_name(t)) is not None and not _is_handoff_tool(n, executor_ids) + ] if tool_names: tools_node_id = f"{executor_id}_tools" diff --git a/packages/uipath-agent-framework/tests/test_breakpoints.py b/packages/uipath-agent-framework/tests/test_breakpoints.py index 6c38e87a..da56c625 100644 --- a/packages/uipath-agent-framework/tests/test_breakpoints.py +++ b/packages/uipath-agent-framework/tests/test_breakpoints.py @@ -18,19 +18,23 @@ UiPathDebugRuntime, ) from uipath.runtime.events import UiPathRuntimeStateEvent -from uipath.runtime.result import UiPathRuntimeResult, UiPathRuntimeStatus +from uipath.runtime.result import UiPathRuntimeStatus from uipath_agent_framework.runtime.breakpoints import ( + AgentInterruptException, _resolve_to_executor_ids, create_breakpoint_result, inject_breakpoint_middleware, remove_breakpoint_middleware, ) -from uipath_agent_framework.runtime.interrupt import AgentInterruptException from uipath_agent_framework.runtime.runtime import UiPathAgentFrameworkRuntime +_mock_client: Any = MagicMock() -_mock_client = MagicMock() + +def _agent(name: str, **kwargs: Any) -> Any: + """Create a RawAgent typed as Any to avoid protocol mismatch with WorkflowBuilder.""" + return RawAgent(_mock_client, name=name, **kwargs) class _MockWorkflowStream: @@ -105,8 +109,8 @@ class TestResolveBreakpoints: """Verify graph node IDs are correctly resolved to executor IDs.""" def test_wildcard_resolves_to_all_executors(self): - a = RawAgent(_mock_client, name="agent_a") - b = RawAgent(_mock_client, name="agent_b") + a = _agent(name="agent_a") + b = _agent(name="agent_b") workflow = WorkflowBuilder(start_executor=a).add_edge(a, b).build() agent = WorkflowAgent(workflow=workflow, name="wf") @@ -114,8 +118,8 @@ def test_wildcard_resolves_to_all_executors(self): assert result == set(workflow.executors.keys()) def test_executor_id_resolves_directly(self): - a = RawAgent(_mock_client, name="agent_a") - b = RawAgent(_mock_client, name="agent_b") + a = _agent(name="agent_a") + b = _agent(name="agent_b") workflow = WorkflowBuilder(start_executor=a).add_edge(a, b).build() agent = WorkflowAgent(workflow=workflow, name="wf") @@ -123,7 +127,7 @@ def test_executor_id_resolves_directly(self): assert result == {"agent_a"} def test_tools_suffix_resolves_to_parent_executor(self): - a = RawAgent(_mock_client, name="agent_a", tools=[calculator]) + a = _agent(name="agent_a", tools=[calculator]) workflow = WorkflowBuilder(start_executor=a).build() agent = WorkflowAgent(workflow=workflow, name="wf") @@ -131,8 +135,8 @@ def test_tools_suffix_resolves_to_parent_executor(self): assert result == {"agent_a"} def test_tool_name_resolves_to_owning_executor(self): - a = RawAgent(_mock_client, name="agent_a", tools=[calculator]) - b = RawAgent(_mock_client, name="agent_b", tools=[search_web]) + a = _agent(name="agent_a", tools=[calculator]) + b = _agent(name="agent_b", tools=[search_web]) workflow = WorkflowBuilder(start_executor=a).add_edge(a, b).build() agent = WorkflowAgent(workflow=workflow, name="wf") @@ -144,8 +148,8 @@ def test_tool_name_resolves_to_owning_executor(self): def test_wildcard_in_list_resolves_to_all(self): """Wildcard passed as ["*"] (list) also resolves to all executors.""" - a = RawAgent(_mock_client, name="agent_a") - b = RawAgent(_mock_client, name="agent_b") + a = _agent(name="agent_a") + b = _agent(name="agent_b") workflow = WorkflowBuilder(start_executor=a).add_edge(a, b).build() agent = WorkflowAgent(workflow=workflow, name="wf") @@ -153,7 +157,7 @@ def test_wildcard_in_list_resolves_to_all(self): assert result == set(workflow.executors.keys()) def test_unknown_node_id_ignored(self): - a = RawAgent(_mock_client, name="agent_a") + a = _agent(name="agent_a") workflow = WorkflowBuilder(start_executor=a).build() agent = WorkflowAgent(workflow=workflow, name="wf") @@ -161,8 +165,8 @@ def test_unknown_node_id_ignored(self): assert result == set() def test_mixed_breakpoints(self): - a = RawAgent(_mock_client, name="agent_a", tools=[calculator]) - b = RawAgent(_mock_client, name="agent_b") + a = _agent(name="agent_a", tools=[calculator]) + b = _agent(name="agent_b") workflow = WorkflowBuilder(start_executor=a).add_edge(a, b).build() agent = WorkflowAgent(workflow=workflow, name="wf") @@ -180,7 +184,7 @@ class TestInjectBreakpoints: def test_inject_wraps_executor_execute(self): """Injecting breakpoints replaces executor.execute with a wrapper.""" - a = RawAgent(_mock_client, name="agent_a") + a = _agent(name="agent_a") workflow = WorkflowBuilder(start_executor=a).build() agent = WorkflowAgent(workflow=workflow, name="wf") @@ -194,7 +198,7 @@ def test_inject_wraps_executor_execute(self): def test_remove_restores_original_execute(self): """Removing breakpoints restores the original execute method.""" - a = RawAgent(_mock_client, name="agent_a") + a = _agent(name="agent_a") workflow = WorkflowBuilder(start_executor=a).build() agent = WorkflowAgent(workflow=workflow, name="wf") @@ -206,7 +210,7 @@ def test_remove_restores_original_execute(self): async def test_wrapped_execute_raises_interrupt(self): """Wrapped executor raises AgentInterruptException on execute.""" - a = RawAgent(_mock_client, name="agent_a") + a = _agent(name="agent_a") workflow = WorkflowBuilder(start_executor=a).build() agent = WorkflowAgent(workflow=workflow, name="wf") @@ -215,7 +219,7 @@ async def test_wrapped_execute_raises_interrupt(self): executor = workflow.executors["agent_a"] try: await executor.execute("msg", [], MagicMock(), MagicMock()) - assert False, "Should have raised AgentInterruptException" + raise AssertionError("Should have raised AgentInterruptException") except AgentInterruptException as e: assert e.is_breakpoint is True assert e.suspend_value["type"] == "breakpoint" @@ -223,50 +227,108 @@ async def test_wrapped_execute_raises_interrupt(self): finally: remove_breakpoint_middleware(agent) - def test_skip_nodes_not_wrapped(self): - """Executors in skip_nodes should not be wrapped (resume scenario).""" - a = RawAgent(_mock_client, name="agent_a") - b = RawAgent(_mock_client, name="agent_b") + def test_skip_nodes_wrapped_with_pass_through(self): + """Executors in skip_nodes are wrapped with pass-through-N. + + On resume, skip_nodes executors get a wrapper that allows N calls + (the count) before re-arming. This ensures breakpoints fire on + every visit in cyclic graphs while handling GroupChat orchestrators + that are called multiple times per workflow run. + """ + a = _agent(name="agent_a") + b = _agent(name="agent_b") workflow = WorkflowBuilder(start_executor=a).add_edge(a, b).build() agent = WorkflowAgent(workflow=workflow, name="wf") inject_breakpoint_middleware( - agent, ["agent_a", "agent_b"], skip_nodes={"agent_a"} + agent, ["agent_a", "agent_b"], skip_nodes={"agent_a": 1} ) - # agent_a should NOT be wrapped (it's in skip_nodes) - assert not hasattr(workflow.executors["agent_a"], "_bp_original_execute") - # agent_b SHOULD be wrapped + # Both are wrapped (skip_nodes get pass-through wrappers) + assert hasattr(workflow.executors["agent_a"], "_bp_original_execute") assert hasattr(workflow.executors["agent_b"], "_bp_original_execute") remove_breakpoint_middleware(agent) + async def test_skip_node_passes_through_then_breaks(self): + """A skip_node with count=1 passes through one call, then breaks.""" + a = _agent(name="agent_a") + workflow = WorkflowBuilder(start_executor=a).build() + agent = WorkflowAgent(workflow=workflow, name="wf") + + # Replace real execute with a trackable mock + original_mock = AsyncMock(return_value=None) + workflow.executors["agent_a"].execute = original_mock # type: ignore[method-assign] + + inject_breakpoint_middleware(agent, ["agent_a"], skip_nodes={"agent_a": 1}) + + executor = workflow.executors["agent_a"] + + # First call: passes through to original + await executor.execute("msg", [], MagicMock(), MagicMock()) + assert original_mock.await_count == 1 + + # Second call: breakpoint fires + try: + await executor.execute("msg", [], MagicMock(), MagicMock()) + raise AssertionError("Should have raised AgentInterruptException") + except AgentInterruptException as e: + assert e.is_breakpoint is True + assert e.suspend_value["node_id"] == "agent_a" + + remove_breakpoint_middleware(agent) + + async def test_skip_node_count_two_passes_through_twice(self): + """A skip_node with count=2 passes through two calls, then breaks.""" + a = _agent(name="agent_a") + workflow = WorkflowBuilder(start_executor=a).build() + agent = WorkflowAgent(workflow=workflow, name="wf") + + original_mock = AsyncMock(return_value=None) + workflow.executors["agent_a"].execute = original_mock # type: ignore[method-assign] + + inject_breakpoint_middleware(agent, ["agent_a"], skip_nodes={"agent_a": 2}) + + executor = workflow.executors["agent_a"] + + # First two calls pass through + await executor.execute("msg", [], MagicMock(), MagicMock()) + await executor.execute("msg", [], MagicMock(), MagicMock()) + assert original_mock.await_count == 2 + + # Third call: breakpoint fires + try: + await executor.execute("msg", [], MagicMock(), MagicMock()) + raise AssertionError("Should have raised AgentInterruptException") + except AgentInterruptException as e: + assert e.is_breakpoint is True + + remove_breakpoint_middleware(agent) + def test_skip_nodes_multiple(self): - """Multiple skip_nodes are all excluded from wrapping.""" - a = RawAgent(_mock_client, name="agent_a") - b = RawAgent(_mock_client, name="agent_b") - c = RawAgent(_mock_client, name="agent_c") + """Multiple skip_nodes with different counts are all wrapped.""" + a = _agent(name="agent_a") + b = _agent(name="agent_b") + c = _agent(name="agent_c") workflow = ( - WorkflowBuilder(start_executor=a) - .add_edge(a, b) - .add_edge(a, c) - .build() + WorkflowBuilder(start_executor=a).add_edge(a, b).add_edge(a, c).build() ) agent = WorkflowAgent(workflow=workflow, name="wf") inject_breakpoint_middleware( - agent, "*", skip_nodes={"agent_a", "agent_b"} + agent, "*", skip_nodes={"agent_a": 1, "agent_b": 2} ) - assert not hasattr(workflow.executors["agent_a"], "_bp_original_execute") - assert not hasattr(workflow.executors["agent_b"], "_bp_original_execute") + # All are wrapped (skip_nodes get pass-through wrappers, others get BP wrappers) + assert hasattr(workflow.executors["agent_a"], "_bp_original_execute") + assert hasattr(workflow.executors["agent_b"], "_bp_original_execute") assert hasattr(workflow.executors["agent_c"], "_bp_original_execute") remove_breakpoint_middleware(agent) def test_no_double_wrap(self): """Calling inject twice doesn't double-wrap executors.""" - a = RawAgent(_mock_client, name="agent_a") + a = _agent(name="agent_a") workflow = WorkflowBuilder(start_executor=a).build() agent = WorkflowAgent(workflow=workflow, name="wf") @@ -283,8 +345,8 @@ def test_no_double_wrap(self): def test_wildcard_wraps_all_executors(self): """Wildcard breakpoint wraps every executor.""" - a = RawAgent(_mock_client, name="agent_a") - b = RawAgent(_mock_client, name="agent_b") + a = _agent(name="agent_a") + b = _agent(name="agent_b") workflow = WorkflowBuilder(start_executor=a).add_edge(a, b).build() agent = WorkflowAgent(workflow=workflow, name="wf") @@ -299,7 +361,7 @@ def test_agents_without_tools_can_be_breakpointed(self): """Executors with no tools (pure chat agents) can be breakpointed.""" # This was the original bug: pure chat agents had no tools, # so FunctionMiddleware never fired. - a = RawAgent(_mock_client, name="chat_agent") + a = _agent(name="chat_agent") workflow = WorkflowBuilder(start_executor=a).build() agent = WorkflowAgent(workflow=workflow, name="wf") @@ -367,7 +429,7 @@ class TestDebugRuntimeBreakpointIntegration: async def test_breakpoint_fires_on_start_executor(self): """Breakpoint on the start executor pauses before it runs.""" - worker = RawAgent(_mock_client, name="worker") + worker = _agent(name="worker") workflow = WorkflowBuilder(start_executor=worker).build() agent = WorkflowAgent(workflow=workflow, name="test_wf") @@ -381,9 +443,7 @@ async def test_breakpoint_fires_on_start_executor(self): UiPathDebugQuitError("quit"), ] - debug_runtime = UiPathDebugRuntime( - delegate=runtime, debug_bridge=bridge - ) + debug_runtime = UiPathDebugRuntime(delegate=runtime, debug_bridge=bridge) result = await debug_runtime.execute({"messages": []}) @@ -399,7 +459,7 @@ async def test_breakpoint_fires_on_start_executor(self): async def test_breakpoint_fires_on_toolless_agent(self): """Breakpoint works on agents with no tools (the original bug).""" # This is the concurrent sample scenario: pure chat agents, no tools - chat_agent = RawAgent(_mock_client, name="sentiment") + chat_agent = _agent(name="sentiment") workflow = WorkflowBuilder(start_executor=chat_agent).build() agent = WorkflowAgent(workflow=workflow, name="concurrent_wf") @@ -412,11 +472,9 @@ async def test_breakpoint_fires_on_toolless_agent(self): UiPathDebugQuitError("quit"), ] - debug_runtime = UiPathDebugRuntime( - delegate=runtime, debug_bridge=bridge - ) + debug_runtime = UiPathDebugRuntime(delegate=runtime, debug_bridge=bridge) - result = await debug_runtime.execute({"messages": []}) + await debug_runtime.execute({"messages": []}) cast(AsyncMock, bridge.emit_breakpoint_hit).assert_awaited_once() bp_result = cast(AsyncMock, bridge.emit_breakpoint_hit).call_args[0][0] @@ -424,7 +482,7 @@ async def test_breakpoint_fires_on_toolless_agent(self): async def test_state_events_emitted_before_breakpoint(self): """Debug bridge should receive state events (STARTED) before the breakpoint.""" - worker = RawAgent(_mock_client, name="agent_x") + worker = _agent(name="agent_x") workflow = WorkflowBuilder(start_executor=worker).build() agent = WorkflowAgent(workflow=workflow, name="wf") @@ -437,9 +495,7 @@ async def test_state_events_emitted_before_breakpoint(self): UiPathDebugQuitError("quit"), ] - debug_runtime = UiPathDebugRuntime( - delegate=runtime, debug_bridge=bridge - ) + debug_runtime = UiPathDebugRuntime(delegate=runtime, debug_bridge=bridge) # Collect all events from the stream events: list[Any] = [] @@ -460,7 +516,7 @@ async def test_state_events_emitted_before_breakpoint(self): async def test_no_breakpoints_runs_to_completion(self): """With no breakpoints set, the workflow should run normally (or fail trying to call LLM, but not hit any breakpoint).""" - worker = RawAgent(_mock_client, name="worker") + worker = _agent(name="worker") workflow = WorkflowBuilder(start_executor=worker).build() agent = WorkflowAgent(workflow=workflow, name="wf") @@ -470,9 +526,7 @@ async def test_no_breakpoints_runs_to_completion(self): cast(Mock, bridge.get_breakpoints).return_value = [] # no breakpoints cast(AsyncMock, bridge.wait_for_resume).return_value = None - debug_runtime = UiPathDebugRuntime( - delegate=runtime, debug_bridge=bridge - ) + debug_runtime = UiPathDebugRuntime(delegate=runtime, debug_bridge=bridge) # Without breakpoints, the workflow tries to actually execute the agent. # Since we have a mock client, this will fail — but NOT as a breakpoint. @@ -492,7 +546,7 @@ async def test_breakpoint_resume_preserves_original_input_and_session(self): workflow.run() and not loading the session, so the agent acted like it never received the user's message. """ - worker = RawAgent(_mock_client, name="weather_agent") + worker = _agent(name="weather_agent") workflow = WorkflowBuilder(start_executor=worker).build() agent = WorkflowAgent(workflow=workflow, name="wf") @@ -516,9 +570,7 @@ async def mock_set_value( ) -> None: kv_store[f"{runtime_id}:{namespace}:{key}"] = value - async def mock_get_value( - runtime_id: str, namespace: str, key: str - ) -> Any: + async def mock_get_value(runtime_id: str, namespace: str, key: str) -> Any: return kv_store.get(f"{runtime_id}:{namespace}:{key}") mock_storage.set_value = mock_set_value @@ -545,9 +597,7 @@ async def mock_get_value( UiPathDebugQuitError("quit"), # quit after second run fails (no LLM) ] - debug_runtime = UiPathDebugRuntime( - delegate=runtime, debug_bridge=bridge - ) + debug_runtime = UiPathDebugRuntime(delegate=runtime, debug_bridge=bridge) # Execute — first run hits breakpoint, resume continues try: @@ -583,12 +633,10 @@ async def test_two_sequential_breakpoints_with_resumes(self): - Resume 1: agent_a skipped, BP fires on agent_b - Resume 2: both skipped, completes normally """ - agent_a = RawAgent(_mock_client, name="agent_a") - agent_b = RawAgent(_mock_client, name="agent_b") + agent_a = _agent(name="agent_a") + agent_b = _agent(name="agent_b") workflow = ( - WorkflowBuilder(start_executor=agent_a) - .add_edge(agent_a, agent_b) - .build() + WorkflowBuilder(start_executor=agent_a).add_edge(agent_a, agent_b).build() ) agent = WorkflowAgent(workflow=workflow, name="wf") @@ -599,16 +647,18 @@ async def test_two_sequential_breakpoints_with_resumes(self): def mock_run(**kwargs: Any) -> _MockWorkflowStream: nonlocal call_count call_count += 1 - call_log.append({ - "call_number": call_count, - "kwargs": dict(kwargs), - "agent_a_wrapped": hasattr( - workflow.executors["agent_a"], "_bp_original_execute" - ), - "agent_b_wrapped": hasattr( - workflow.executors["agent_b"], "_bp_original_execute" - ), - }) + call_log.append( + { + "call_number": call_count, + "kwargs": dict(kwargs), + "agent_a_wrapped": hasattr( + workflow.executors["agent_a"], "_bp_original_execute" + ), + "agent_b_wrapped": hasattr( + workflow.executors["agent_b"], "_bp_original_execute" + ), + } + ) if call_count == 1: # Fresh run → breakpoint on agent_a checkpoint_counter[0] = 1 @@ -649,9 +699,7 @@ async def mock_set_value( ) -> None: kv_store[f"{runtime_id}:{namespace}:{key}"] = value - async def mock_get_value( - runtime_id: str, namespace: str, key: str - ) -> Any: + async def mock_get_value(runtime_id: str, namespace: str, key: str) -> Any: return kv_store.get(f"{runtime_id}:{namespace}:{key}") mock_storage.set_value = mock_set_value @@ -688,9 +736,7 @@ async def mock_get_latest(**kwargs: Any) -> Any: None, # continue after BP on agent_b ] - debug_runtime = UiPathDebugRuntime( - delegate=runtime, debug_bridge=bridge - ) + debug_runtime = UiPathDebugRuntime(delegate=runtime, debug_bridge=bridge) result = await debug_runtime.execute({"messages": []}) # --- 3 workflow.run() calls --- @@ -708,22 +754,21 @@ async def mock_get_latest(**kwargs: Any) -> Any: assert c2["kwargs"].get("checkpoint_id") == "cp-1" assert c2["kwargs"].get("stream") is True assert "message" not in c2["kwargs"] - assert c2["agent_a_wrapped"] is False # Skipped - assert c2["agent_b_wrapped"] is True # Still breakpointed + # All executors are wrapped (skip_nodes get pass-through wrappers) + assert c2["agent_a_wrapped"] is True # Pass-through wrapper (count=1) + assert c2["agent_b_wrapped"] is True # BP wrapper # Call 3: Resume after agent_b breakpoint — checkpoint-based c3 = call_log[2] assert c3["kwargs"].get("checkpoint_id") == "cp-2" assert c3["kwargs"].get("stream") is True assert "message" not in c3["kwargs"] - assert c3["agent_a_wrapped"] is False # Still skipped (accumulated) - assert c3["agent_b_wrapped"] is False # Now also skipped + assert c3["agent_a_wrapped"] is True # Pass-through wrapper (count=1) + assert c3["agent_b_wrapped"] is True # Pass-through wrapper (count=1) # --- Debug bridge interactions --- assert cast(AsyncMock, bridge.emit_breakpoint_hit).await_count == 2 - bp_calls = cast( - AsyncMock, bridge.emit_breakpoint_hit - ).call_args_list + bp_calls = cast(AsyncMock, bridge.emit_breakpoint_hit).call_args_list assert bp_calls[0].args[0].breakpoint_node == "agent_a" assert bp_calls[1].args[0].breakpoint_node == "agent_b" @@ -732,7 +777,9 @@ async def mock_get_latest(**kwargs: Any) -> Any: # --- KV state --- bp_state = kv_store.get("test-2bp:breakpoint:state") assert bp_state is not None - assert set(bp_state["skip_nodes"]) == {"agent_a", "agent_b"} + # Checkpoint advanced (cp-1 → cp-2), so skip_nodes were reset: + # only the last breakpointed node remains. + assert bp_state["skip_nodes"] == {"agent_b": 1} assert bp_state["checkpoint_id"] == "cp-2" assert bp_state["original_input"] == "hello world" @@ -750,9 +797,9 @@ async def test_concurrent_breakpoints_accumulate_skip_nodes(self): This was the infinite loop bug: without accumulating skip_nodes, worker_a and worker_b kept trading breakpoints forever. """ - dispatcher = RawAgent(_mock_client, name="dispatcher") - worker_a = RawAgent(_mock_client, name="worker_a") - worker_b = RawAgent(_mock_client, name="worker_b") + dispatcher = _agent(name="dispatcher") + worker_a = _agent(name="worker_a") + worker_b = _agent(name="worker_b") workflow = ( WorkflowBuilder(start_executor=dispatcher) .add_edge(dispatcher, worker_a) @@ -768,19 +815,21 @@ async def test_concurrent_breakpoints_accumulate_skip_nodes(self): def mock_run(**kwargs: Any) -> _MockWorkflowStream: nonlocal call_count call_count += 1 - call_log.append({ - "call_number": call_count, - "kwargs": dict(kwargs), - "dispatcher_wrapped": hasattr( - workflow.executors["dispatcher"], "_bp_original_execute" - ), - "worker_a_wrapped": hasattr( - workflow.executors["worker_a"], "_bp_original_execute" - ), - "worker_b_wrapped": hasattr( - workflow.executors["worker_b"], "_bp_original_execute" - ), - }) + call_log.append( + { + "call_number": call_count, + "kwargs": dict(kwargs), + "dispatcher_wrapped": hasattr( + workflow.executors["dispatcher"], "_bp_original_execute" + ), + "worker_a_wrapped": hasattr( + workflow.executors["worker_a"], "_bp_original_execute" + ), + "worker_b_wrapped": hasattr( + workflow.executors["worker_b"], "_bp_original_execute" + ), + } + ) if call_count == 1: checkpoint_counter[0] = 1 return _MockWorkflowStream( @@ -833,9 +882,7 @@ async def mock_set_value( ) -> None: kv_store[f"{runtime_id}:{namespace}:{key}"] = value - async def mock_get_value( - runtime_id: str, namespace: str, key: str - ) -> Any: + async def mock_get_value(runtime_id: str, namespace: str, key: str) -> Any: return kv_store.get(f"{runtime_id}:{namespace}:{key}") mock_storage.set_value = mock_set_value @@ -865,7 +912,9 @@ async def mock_get_latest(**kwargs: Any) -> Any: bridge = _make_debug_bridge() cast(Mock, bridge.get_breakpoints).return_value = [ - "dispatcher", "worker_a", "worker_b", + "dispatcher", + "worker_a", + "worker_b", ] cast(AsyncMock, bridge.wait_for_resume).side_effect = [ None, # initial @@ -874,9 +923,7 @@ async def mock_get_latest(**kwargs: Any) -> Any: None, # continue after worker_b BP ] - debug_runtime = UiPathDebugRuntime( - delegate=runtime, debug_bridge=bridge - ) + debug_runtime = UiPathDebugRuntime(delegate=runtime, debug_bridge=bridge) result = await debug_runtime.execute({"messages": []}) # --- 4 workflow.run() calls --- @@ -888,39 +935,405 @@ async def mock_get_latest(**kwargs: Any) -> Any: assert c1["worker_a_wrapped"] is True assert c1["worker_b_wrapped"] is True - # Call 2: dispatcher skipped, workers wrapped + # Call 2: dispatcher has pass-through, workers breakpointed c2 = call_log[1] - assert c2["dispatcher_wrapped"] is False + assert c2["dispatcher_wrapped"] is True # Pass-through (count=1) assert c2["worker_a_wrapped"] is True assert c2["worker_b_wrapped"] is True - # Call 3: dispatcher+worker_a skipped, worker_b wrapped + # Call 3: dispatcher+worker_a have pass-through, worker_b breakpointed c3 = call_log[2] - assert c3["dispatcher_wrapped"] is False - assert c3["worker_a_wrapped"] is False + assert c3["dispatcher_wrapped"] is True # Pass-through (count=1) + assert c3["worker_a_wrapped"] is True # Pass-through (count=1) assert c3["worker_b_wrapped"] is True - # Call 4: all skipped — completes + # Call 4: all have pass-through — completes c4 = call_log[3] - assert c4["dispatcher_wrapped"] is False - assert c4["worker_a_wrapped"] is False - assert c4["worker_b_wrapped"] is False + assert c4["dispatcher_wrapped"] is True # Pass-through (count=1) + assert c4["worker_a_wrapped"] is True # Pass-through (count=1) + assert c4["worker_b_wrapped"] is True # Pass-through (count=1) # 3 breakpoints hit assert cast(AsyncMock, bridge.emit_breakpoint_hit).await_count == 3 bp_nodes = [ call.args[0].breakpoint_node - for call in cast( - AsyncMock, bridge.emit_breakpoint_hit - ).call_args_list + for call in cast(AsyncMock, bridge.emit_breakpoint_hit).call_args_list ] assert bp_nodes == ["dispatcher", "worker_a", "worker_b"] assert result.status == UiPathRuntimeStatus.SUCCESSFUL - # skip_nodes accumulated all three + # Checkpoints advanced (cp-1 → cp-2 → cp-3), so skip_nodes were + # reset on each advancement — only the last BP node remains. bp_state = kv_store.get("test-concurrent-bp:breakpoint:state") assert bp_state is not None - assert set(bp_state["skip_nodes"]) == { - "dispatcher", "worker_a", "worker_b", - } + assert bp_state["skip_nodes"] == {"worker_b": 1} + + async def test_group_chat_style_breakall_completes(self): + """GroupChat-style star topology with break-all: every visit breaks. + + Graph: orchestrator ↔ participant_a, orchestrator ↔ participant_b, + orchestrator ↔ participant_c (star topology with cycles) + Breakpoints: ``"*"`` (break all) + + Simulates a GroupChat where the orchestrator is called multiple times + per workflow run (initial + receiving each participant's response). + + The mock simulates real checkpoint-resume: on resume, the stream + starts from the BPed executor (not from the beginning), matching + how workflow.run(checkpoint_id=...) works in production. + + With checkpoint-resume + ``_get_breakpoint_skip()`` returning + ``{last_bp: 1}``, every node visit triggers a breakpoint: + + Execution order: orch → A → orch → B → orch → C → orch + + Expected 7 BPs (one per execution), 8 workflow.run() calls. + """ + orchestrator = _agent(name="orchestrator") + participant_a = _agent(name="participant_a") + participant_b = _agent(name="participant_b") + participant_c = _agent(name="participant_c") + workflow = ( + WorkflowBuilder(start_executor=orchestrator) + .add_edge(orchestrator, participant_a) + .add_edge(participant_a, orchestrator) + .add_edge(orchestrator, participant_b) + .add_edge(participant_b, orchestrator) + .add_edge(orchestrator, participant_c) + .add_edge(participant_c, orchestrator) + .build() + ) + agent = WorkflowAgent(workflow=workflow, name="group_chat") + + # Replace real execute methods with async mocks (simulates LLM calls) + for executor in workflow.executors.values(): + executor.execute = AsyncMock(return_value=None) # type: ignore[method-assign] + + # GroupChat execution order: orchestrator is called after each + # participant response, simulating the selector/response loop. + execution_order = [ + "orchestrator", # 0: initial + "participant_a", # 1 + "orchestrator", # 2: response from A + "participant_b", # 3 + "orchestrator", # 4: response from B + "participant_c", # 5 + "orchestrator", # 6: terminate + ] + + call_log: list[dict[str, Any]] = [] + call_count = 0 + checkpoint_counter = [0] + bp_position = [0] # Index in execution_order where last BP fired + MAX_CALLS = 20 # Safety limit to detect infinite loops + + class _GroupChatMockStream: + """Walks execution_order from a start position (checkpoint-resume). + + On fresh run starts from 0; on checkpoint-resume starts from the + position of the last breakpoint, simulating how workflow.run() + with a checkpoint_id resumes from the saved state. + """ + + def __init__(self, start: int) -> None: + self._start = start + self._final_output = "group chat done" + + def __aiter__(self): + return self._aiter_impl() + + async def _aiter_impl(self): + for i in range(self._start, len(execution_order)): + bp_position[0] = i + exec_id = execution_order[i] + executor = workflow.executors[exec_id] + await executor.execute("msg", [], MagicMock(), MagicMock()) + # Async generator (needs yield even if unreachable) + return + yield # noqa: B901 + + async def get_final_response(self): + mock_result = MagicMock() + mock_result.get_outputs.return_value = [self._final_output] + return mock_result + + def mock_run(**kwargs: Any) -> _GroupChatMockStream: + nonlocal call_count + call_count += 1 + assert call_count <= MAX_CALLS, ( + f"Exceeded {MAX_CALLS} workflow.run() calls — infinite loop!" + ) + call_log.append( + { + "call_number": call_count, + "kwargs": dict(kwargs), + } + ) + checkpoint_counter[0] = call_count + # Checkpoint-resume: start from last BP position + # Fresh run: start from 0 + if "checkpoint_id" in kwargs and kwargs.get("checkpoint_id"): + return _GroupChatMockStream(start=bp_position[0]) + return _GroupChatMockStream(start=0) + + workflow.run = mock_run # type: ignore[assignment] + + # KV store + kv_store: dict[str, Any] = {} + mock_storage = AsyncMock() + + async def mock_set_value( + runtime_id: str, namespace: str, key: str, value: Any + ) -> None: + kv_store[f"{runtime_id}:{namespace}:{key}"] = value + + async def mock_get_value(runtime_id: str, namespace: str, key: str) -> Any: + return kv_store.get(f"{runtime_id}:{namespace}:{key}") + + mock_storage.set_value = mock_set_value + mock_storage.get_value = mock_get_value + + # Mock checkpoint storage + mock_cs = AsyncMock() + + async def mock_get_latest(**kwargs: Any) -> Any: + if checkpoint_counter[0] == 0: + return None + cp = MagicMock() + cp.checkpoint_id = f"cp-{checkpoint_counter[0]}" + return cp + + mock_cs.get_latest = mock_get_latest + + runtime = UiPathAgentFrameworkRuntime( + agent=agent, + runtime_id="test-groupchat-bp", + checkpoint_storage=mock_cs, + resumable_storage=mock_storage, + ) + runtime.chat = MagicMock() + runtime.chat.map_messages_to_input.return_value = "discuss AI safety" + runtime.chat.map_streaming_content.return_value = [] + runtime.chat.close_message.return_value = [] + + bridge = _make_debug_bridge() + # Break ALL nodes + cast(Mock, bridge.get_breakpoints).return_value = "*" + # 1 initial + 7 breakpoint resumes + cast(AsyncMock, bridge.wait_for_resume).side_effect = [ + None, # initial + None, # continue after orchestrator BP #1 + None, # continue after participant_a BP + None, # continue after orchestrator BP #2 + None, # continue after participant_b BP + None, # continue after orchestrator BP #3 + None, # continue after participant_c BP + None, # continue after orchestrator BP #4 (terminate) + ] + + debug_runtime = UiPathDebugRuntime(delegate=runtime, debug_bridge=bridge) + result = await debug_runtime.execute({"messages": []}) + + # --- 8 workflow.run() calls (7 BPs + 1 completion) --- + assert len(call_log) == 8, ( + f"Expected 8 calls but got {len(call_log)} — breakpoints may be looping" + ) + + # --- 7 breakpoints: every node visit --- + assert cast(AsyncMock, bridge.emit_breakpoint_hit).await_count == 7 + bp_nodes = [ + call.args[0].breakpoint_node + for call in cast(AsyncMock, bridge.emit_breakpoint_hit).call_args_list + ] + assert bp_nodes == [ + "orchestrator", # 0: initial + "participant_a", # 1: after orch passes through + "orchestrator", # 2: response from A + "participant_b", # 3: after orch passes through + "orchestrator", # 4: response from B + "participant_c", # 5: after orch passes through + "orchestrator", # 6: terminate + ] + + assert result.status == UiPathRuntimeStatus.SUCCESSFUL + + # --- KV state --- + bp_state = kv_store.get("test-groupchat-bp:breakpoint:state") + assert bp_state is not None + # Checkpoints advanced on every call, so skip_nodes were reset + # each time — only the last BP node remains. + assert bp_state["skip_nodes"] == {"orchestrator": 1} + assert bp_state["last_breakpoint_node"] == "orchestrator" + + # --- Debug bridge lifecycle --- + cast(AsyncMock, bridge.connect).assert_awaited_once() + cast(AsyncMock, bridge.emit_execution_started).assert_awaited_once() + assert cast(AsyncMock, bridge.wait_for_resume).await_count == 8 + cast(AsyncMock, bridge.emit_execution_completed).assert_awaited_once() + + async def test_cyclic_workflow_breakpoints_fire_on_every_visit(self): + """Cyclic graph: breakpoints fire on every executor visit. + + Graph: planner → worker → reviewer → planner (cycle) + Breakpoints: ``"*"`` (break all) + + With pass-through-N, each executor's skip count increments on + each breakpoint. This test uses mock_run to simulate the runtime + seeing BPs on the first cycle; all executors are always wrapped + (skip_nodes get pass-through wrappers). + + Expected (first cycle): + 1. Fresh: BP on planner + 2. Resume (planner:1): planner passes → BP on worker + 3. Resume (planner:1, worker:1): both pass → BP on reviewer + 4. Resume (planner:1, worker:1, reviewer:1): all pass → completes + """ + planner = _agent(name="planner") + worker = _agent(name="worker") + reviewer = _agent(name="reviewer") + workflow = ( + WorkflowBuilder(start_executor=planner) + .add_edge(planner, worker) + .add_edge(worker, reviewer) + .add_edge(reviewer, planner) # cycle back + .build() + ) + agent = WorkflowAgent(workflow=workflow, name="cyclic_wf") + + call_log: list[dict[str, Any]] = [] + call_count = 0 + checkpoint_counter = [0] + + def mock_run(**kwargs: Any) -> _MockWorkflowStream: + nonlocal call_count + call_count += 1 + call_log.append( + { + "call_number": call_count, + "kwargs": dict(kwargs), + "planner_wrapped": hasattr( + workflow.executors["planner"], "_bp_original_execute" + ), + "worker_wrapped": hasattr( + workflow.executors["worker"], "_bp_original_execute" + ), + "reviewer_wrapped": hasattr( + workflow.executors["reviewer"], "_bp_original_execute" + ), + } + ) + if call_count == 1: + checkpoint_counter[0] = 1 + return _MockWorkflowStream( + exception=AgentInterruptException( + interrupt_id="bp-1", + suspend_value={ + "type": "breakpoint", + "node_id": "planner", + }, + is_breakpoint=True, + ) + ) + elif call_count == 2: + checkpoint_counter[0] = 2 + return _MockWorkflowStream( + exception=AgentInterruptException( + interrupt_id="bp-2", + suspend_value={ + "type": "breakpoint", + "node_id": "worker", + }, + is_breakpoint=True, + ) + ) + elif call_count == 3: + checkpoint_counter[0] = 3 + return _MockWorkflowStream( + exception=AgentInterruptException( + interrupt_id="bp-3", + suspend_value={ + "type": "breakpoint", + "node_id": "reviewer", + }, + is_breakpoint=True, + ) + ) + else: + # All excluded → completes (cycle runs freely) + return _MockWorkflowStream(final_output="review complete") + + workflow.run = mock_run # type: ignore[assignment] + + kv_store: dict[str, Any] = {} + mock_storage = AsyncMock() + + async def mock_set_value( + runtime_id: str, namespace: str, key: str, value: Any + ) -> None: + kv_store[f"{runtime_id}:{namespace}:{key}"] = value + + async def mock_get_value(runtime_id: str, namespace: str, key: str) -> Any: + return kv_store.get(f"{runtime_id}:{namespace}:{key}") + + mock_storage.set_value = mock_set_value + mock_storage.get_value = mock_get_value + + mock_cs = AsyncMock() + + async def mock_get_latest(**kwargs: Any) -> Any: + if checkpoint_counter[0] == 0: + return None + cp = MagicMock() + cp.checkpoint_id = f"cp-{checkpoint_counter[0]}" + return cp + + mock_cs.get_latest = mock_get_latest + + runtime = UiPathAgentFrameworkRuntime( + agent=agent, + runtime_id="test-cyclic-bp", + checkpoint_storage=mock_cs, + resumable_storage=mock_storage, + ) + runtime.chat = MagicMock() + runtime.chat.map_messages_to_input.return_value = "plan and review" + runtime.chat.map_streaming_content.return_value = [] + runtime.chat.close_message.return_value = [] + + bridge = _make_debug_bridge() + cast(Mock, bridge.get_breakpoints).return_value = "*" + cast(AsyncMock, bridge.wait_for_resume).side_effect = [ + None, # initial + None, # continue after planner BP + None, # continue after worker BP + None, # continue after reviewer BP + ] + + debug_runtime = UiPathDebugRuntime(delegate=runtime, debug_bridge=bridge) + result = await debug_runtime.execute({"messages": []}) + + # --- 4 workflow.run() calls (3 BPs + 1 completion) --- + assert len(call_log) == 4 + + # All executors are always wrapped (skip_nodes get pass-through wrappers) + for log_entry in call_log: + assert log_entry["planner_wrapped"] is True + assert log_entry["worker_wrapped"] is True + assert log_entry["reviewer_wrapped"] is True + + # --- 3 breakpoints: planner, worker, reviewer --- + assert cast(AsyncMock, bridge.emit_breakpoint_hit).await_count == 3 + bp_nodes = [ + call.args[0].breakpoint_node + for call in cast(AsyncMock, bridge.emit_breakpoint_hit).call_args_list + ] + assert bp_nodes == ["planner", "worker", "reviewer"] + + assert result.status == UiPathRuntimeStatus.SUCCESSFUL + + # Checkpoints advanced (cp-1 → cp-2 → cp-3), so skip_nodes were + # reset each time — only the last BP node remains. + bp_state = kv_store.get("test-cyclic-bp:breakpoint:state") + assert bp_state is not None + assert bp_state["skip_nodes"] == {"reviewer": 1} diff --git a/packages/uipath-agent-framework/tests/test_group_chat_breakpoints.py b/packages/uipath-agent-framework/tests/test_group_chat_breakpoints.py new file mode 100644 index 00000000..610472f2 --- /dev/null +++ b/packages/uipath-agent-framework/tests/test_group_chat_breakpoints.py @@ -0,0 +1,1220 @@ +"""Integration tests: sample topologies with breakpoints on all nodes. + +Reproduces the infinite loop bug when running samples with breakpoints +enabled on all nodes via UiPathDebugRuntime. + +Uses REAL builders, real checkpoint storage, real breakpoint injection — +only LLM calls are mocked. + +Covers all sample topologies: +- group-chat (GroupChatBuilder): orchestrator ↔ participants +- quickstart-workflow (WorkflowBuilder): single agent +- concurrent (ConcurrentBuilder): dispatcher → parallel agents → aggregator +- handoff (HandoffBuilder): triage → specialists +- hitl-workflow (HandoffBuilder): triage → specialists with HITL tools +""" + +import json +import os +import tempfile +from typing import Any, cast +from unittest.mock import AsyncMock, MagicMock, Mock + +from agent_framework import WorkflowBuilder +from agent_framework.openai import OpenAIChatClient +from agent_framework.orchestrations import ( + ConcurrentBuilder, + GroupChatBuilder, + HandoffBuilder, +) +from openai.types import CompletionUsage +from openai.types.chat.chat_completion import ChatCompletion, Choice +from openai.types.chat.chat_completion_chunk import ( + ChatCompletionChunk, + ChoiceDelta, + ChoiceDeltaToolCall, + ChoiceDeltaToolCallFunction, +) +from openai.types.chat.chat_completion_chunk import ( + Choice as ChunkChoice, +) +from openai.types.chat.chat_completion_message import ChatCompletionMessage +from openai.types.chat.chat_completion_message_tool_call import ( + ChatCompletionMessageToolCall, + Function, +) +from uipath.runtime.debug import ( + UiPathDebugProtocol, + UiPathDebugRuntime, +) +from uipath.runtime.result import UiPathRuntimeStatus + +from uipath_agent_framework.runtime.resumable_storage import ( + ScopedCheckpointStorage, + SqliteResumableStorage, +) +from uipath_agent_framework.runtime.runtime import UiPathAgentFrameworkRuntime + +# Safety limit: if the debug loop exceeds this many resume calls, +# the test fails — this means breakpoints are stuck in a loop. +MAX_RESUME_CALLS = 50 + + +def _extract_system_text(messages: list[dict[str, Any]]) -> str: + """Extract system/developer message text from OpenAI-format messages. + + The OpenAI client sends content as a list of content parts: + [{"type": "text", "text": "..."}] + """ + for msg in messages: + if not isinstance(msg, dict): + continue + if msg.get("role") not in ("system", "developer"): + continue + content = msg.get("content", "") + if isinstance(content, str): + return content + if isinstance(content, list): + parts = [] + for part in content: + if isinstance(part, dict) and part.get("type") == "text": + parts.append(part.get("text", "")) + return " ".join(parts) + return "" + + +def _make_chat_completion(text: str) -> ChatCompletion: + """Create a mock OpenAI ChatCompletion response.""" + return ChatCompletion( + id="test-completion", + choices=[ + Choice( + index=0, + message=ChatCompletionMessage(role="assistant", content=text), + finish_reason="stop", + ) + ], + created=0, + model="mock-model", + object="chat.completion", + usage=CompletionUsage(prompt_tokens=10, completion_tokens=10, total_tokens=20), + ) + + +async def _make_streaming_response(text: str): + """Create an async iterable of ChatCompletionChunks for streaming.""" + # Yield the content in a single chunk + yield ChatCompletionChunk( + id="test-chunk", + choices=[ + ChunkChoice( + index=0, + delta=ChoiceDelta(role="assistant", content=text), + finish_reason=None, + ) + ], + created=0, + model="mock-model", + object="chat.completion.chunk", + ) + # Yield the stop chunk + yield ChatCompletionChunk( + id="test-chunk", + choices=[ + ChunkChoice( + index=0, + delta=ChoiceDelta(), + finish_reason="stop", + ) + ], + created=0, + model="mock-model", + object="chat.completion.chunk", + usage=CompletionUsage(prompt_tokens=10, completion_tokens=10, total_tokens=20), + ) + + +def _make_mock_response(text: str, stream: bool = False): + """Return either a ChatCompletion or a streaming async iterable.""" + if stream: + return _make_streaming_response(text) + return _make_chat_completion(text) + + +def _make_tool_call_completion(tool_name: str, arguments: str = "{}") -> ChatCompletion: + """Create a mock ChatCompletion with a tool call.""" + return ChatCompletion( + id="test-completion", + choices=[ + Choice( + index=0, + message=ChatCompletionMessage( + role="assistant", + content=None, + tool_calls=[ + ChatCompletionMessageToolCall( + id=f"call_{tool_name}", + function=Function(name=tool_name, arguments=arguments), + type="function", + ) + ], + ), + finish_reason="tool_calls", + ) + ], + created=0, + model="mock-model", + object="chat.completion", + usage=CompletionUsage(prompt_tokens=10, completion_tokens=10, total_tokens=20), + ) + + +async def _make_streaming_tool_call(tool_name: str, arguments: str = "{}"): + """Create streaming chunks for a tool call response.""" + yield ChatCompletionChunk( + id="test-chunk", + choices=[ + ChunkChoice( + index=0, + delta=ChoiceDelta( + role="assistant", + tool_calls=[ + ChoiceDeltaToolCall( + index=0, + id=f"call_{tool_name}", + function=ChoiceDeltaToolCallFunction( + name=tool_name, + arguments=arguments, + ), + type="function", + ) + ], + ), + finish_reason=None, + ) + ], + created=0, + model="mock-model", + object="chat.completion.chunk", + ) + yield ChatCompletionChunk( + id="test-chunk", + choices=[ + ChunkChoice( + index=0, + delta=ChoiceDelta(), + finish_reason="tool_calls", + ) + ], + created=0, + model="mock-model", + object="chat.completion.chunk", + usage=CompletionUsage(prompt_tokens=10, completion_tokens=10, total_tokens=20), + ) + + +def _make_tool_call_response( + tool_name: str, arguments: str = "{}", stream: bool = False +): + """Return either a tool call ChatCompletion or streaming chunks.""" + if stream: + return _make_streaming_tool_call(tool_name, arguments) + return _make_tool_call_completion(tool_name, arguments) + + +def _make_debug_bridge(**overrides: Any) -> UiPathDebugProtocol: + """Create a mock debug bridge with sensible defaults.""" + bridge: Mock = Mock(spec=UiPathDebugProtocol) + bridge.connect = AsyncMock() + bridge.disconnect = AsyncMock() + bridge.emit_execution_started = AsyncMock() + bridge.emit_execution_completed = AsyncMock() + bridge.emit_execution_error = AsyncMock() + bridge.emit_execution_suspended = AsyncMock() + bridge.emit_breakpoint_hit = AsyncMock() + bridge.emit_state_update = AsyncMock() + bridge.emit_execution_resumed = AsyncMock() + bridge.wait_for_resume = AsyncMock(return_value=None) + bridge.wait_for_terminate = AsyncMock() + bridge.get_breakpoints = Mock(return_value=[]) + for k, v in overrides.items(): + setattr(bridge, k, v) + return cast(UiPathDebugProtocol, bridge) + + +class TestGroupChatBreakpoints: + """Integration test: GroupChat sample with breakpoints on all nodes. + + Uses the exact same topology as the group-chat sample: + - 3 participants: researcher, critic, writer + - 1 orchestrator agent (agent-based) + - max_rounds=6 + - breakpoints="*" (all nodes) + + Only LLM calls are mocked via a fake AsyncOpenAI client. + Everything else (GroupChatBuilder, checkpoint storage, breakpoint + injection, UiPathDebugRuntime) uses real code. + """ + + async def test_group_chat_breakall_completes_without_loop(self): + """GroupChat with breakpoints on ALL nodes must eventually complete. + + This is the exact scenario that causes an infinite loop in production: + the debug UI sets breakpoints="*", and the runtime should pause at + each executor, resume, and eventually finish the workflow. + + The test fails if we exceed MAX_RESUME_CALLS, proving the loop exists. + """ + # --- LLM mock --- + # Track which participant the orchestrator selects on each call. + # The orchestrator cycles: researcher → critic → writer, then repeats. + participant_names = ["researcher", "critic", "writer"] + orchestrator_call_count = [0] + llm_call_log: list[dict[str, str]] = [] + + async def mock_chat_completions_create(**kwargs: Any): + """Mock LLM: orchestrator returns structured JSON, participants return text.""" + messages = kwargs.get("messages", []) + is_stream = kwargs.get("stream", False) + system_msg = _extract_system_text(messages) + + if ( + "coordinate" in system_msg.lower() + or "next speaker" in system_msg.lower() + ): + # Orchestrator: select next participant (structured JSON output) + idx = orchestrator_call_count[0] % len(participant_names) + orchestrator_call_count[0] += 1 + selected = participant_names[idx] + response_text = json.dumps( + { + "terminate": False, + "reason": f"Selecting {selected} for the discussion.", + "next_speaker": selected, + "final_message": None, + } + ) + llm_call_log.append({"agent": "orchestrator", "response": selected}) + elif "research" in system_msg.lower(): + response_text = ( + "Based on my research, AI safety involves alignment, " + "interpretability, and robustness." + ) + llm_call_log.append( + {"agent": "researcher", "response": response_text[:50]} + ) + elif "critical" in system_msg.lower(): + response_text = ( + "I challenge the assumption that current approaches " + "are sufficient for AI safety." + ) + llm_call_log.append({"agent": "critic", "response": response_text[:50]}) + elif "writer" in system_msg.lower() or "synthesize" in system_msg.lower(): + response_text = ( + "In summary, the discussion revealed important nuances " + "in AI safety research." + ) + llm_call_log.append({"agent": "writer", "response": response_text[:50]}) + else: + response_text = "OK" + llm_call_log.append({"agent": "unknown", "response": response_text}) + + return _make_mock_response(response_text, stream=is_stream) + + # --- Build agents exactly like the group-chat sample --- + mock_openai = AsyncMock() + mock_openai.chat.completions.create = mock_chat_completions_create + + client = OpenAIChatClient(model_id="mock-model", async_client=mock_openai) + + researcher = client.as_agent( + name="researcher", + description="Expert at finding facts and data using Wikipedia.", + instructions=( + "You are a research specialist. Use the search_wikipedia tool " + "to find factual information. Provide concise, well-sourced " + "responses." + ), + ) + + critic = client.as_agent( + name="critic", + description="Challenges assumptions and evaluates claims critically.", + instructions=( + "You are a critical thinker. Evaluate the claims made by other " + "participants. Point out gaps, biases, or missing context. " + "Ask probing questions to deepen the discussion." + ), + ) + + writer = client.as_agent( + name="writer", + description="Synthesizes group discussion into clear, structured prose.", + instructions=( + "You are a skilled writer. Synthesize the group discussion into " + "a clear, well-organized summary. Incorporate the researcher's " + "facts and address the critic's concerns." + ), + ) + + orchestrator = client.as_agent( + name="orchestrator", + description="Coordinates the group discussion by selecting the next speaker.", + instructions=( + "You coordinate a team of researcher, critic, and writer. " + "Select the next speaker based on the conversation flow:\n" + "- Pick 'researcher' when facts or data are needed.\n" + "- Pick 'critic' to challenge or evaluate claims.\n" + "- Pick 'writer' to synthesize when enough discussion has happened.\n" + "Respond with ONLY the agent name, nothing else." + ), + ) + + # --- Build workflow exactly like the sample --- + workflow = GroupChatBuilder( + participants=[researcher, critic, writer], + orchestrator_agent=orchestrator, + max_rounds=6, + ).build() + + agent = workflow.as_agent(name="group_chat") + + # --- Real SQLite storage (temp file) --- + tmp_fd, tmp_path = tempfile.mkstemp(suffix=".db") + os.close(tmp_fd) + try: + storage = SqliteResumableStorage(tmp_path) + await storage.setup() + assert storage.checkpoint_storage is not None + + scoped_cs = ScopedCheckpointStorage( + storage.checkpoint_storage, "test-gc-bp" + ) + + # --- Create runtime --- + runtime = UiPathAgentFrameworkRuntime( + agent=agent, + runtime_id="test-gc-bp", + checkpoint_storage=scoped_cs, + resumable_storage=storage, + ) + runtime.chat = MagicMock() + runtime.chat.map_messages_to_input.return_value = "Discuss AI safety" + runtime.chat.map_streaming_content.return_value = [] + runtime.chat.close_message.return_value = [] + + # --- Debug bridge: breakpoints on ALL nodes --- + resume_count = [0] + + async def mock_wait_for_resume(*args: Any, **kwargs: Any) -> None: + resume_count[0] += 1 + if resume_count[0] > MAX_RESUME_CALLS: + raise AssertionError( + f"INFINITE LOOP DETECTED: exceeded {MAX_RESUME_CALLS} " + f"resume calls.\n" + f"LLM calls: {len(llm_call_log)}\n" + f"Orchestrator selections: {orchestrator_call_count[0]}\n" + f"Call log: {llm_call_log[-10:]}" + ) + return None # Continue (resume after breakpoint) + + bridge = _make_debug_bridge() + cast(Mock, bridge.get_breakpoints).return_value = "*" + cast(AsyncMock, bridge.wait_for_resume).side_effect = mock_wait_for_resume + + debug_runtime = UiPathDebugRuntime(delegate=runtime, debug_bridge=bridge) + + # --- Execute --- + result = await debug_runtime.execute({"messages": []}) + + # --- Assertions --- + bp_count = cast(AsyncMock, bridge.emit_breakpoint_hit).await_count + + # Must complete successfully + assert result.status == UiPathRuntimeStatus.SUCCESSFUL, ( + f"Expected SUCCESSFUL but got {result.status}. " + f"Resumes: {resume_count[0]}, BPs: {bp_count}, " + f"LLM calls: {len(llm_call_log)}" + ) + + # Must have hit at least one breakpoint + assert bp_count > 0, "Should have hit at least one breakpoint" + + # Must have made LLM calls (agents actually ran) + assert len(llm_call_log) > 0, "LLM should have been called" + + # Verify breakpoint nodes were reported + bp_nodes = [ + call.args[0].breakpoint_node + for call in cast(AsyncMock, bridge.emit_breakpoint_hit).call_args_list + ] + assert all( + n in ("orchestrator", "researcher", "critic", "writer") + for n in bp_nodes + ), f"Unexpected breakpoint nodes: {bp_nodes}" + + finally: + await storage.dispose() + os.unlink(tmp_path) + + async def test_group_chat_single_breakpoint_completes(self): + """GroupChat with a breakpoint on only the orchestrator completes.""" + orchestrator_call_count = [0] + participant_names = ["researcher", "critic", "writer"] + + async def mock_create(**kwargs: Any): + messages = kwargs.get("messages", []) + is_stream = kwargs.get("stream", False) + system_msg = _extract_system_text(messages) + + if ( + "coordinate" in system_msg.lower() + or "next speaker" in system_msg.lower() + ): + idx = orchestrator_call_count[0] % len(participant_names) + orchestrator_call_count[0] += 1 + text = json.dumps( + { + "terminate": False, + "reason": f"Selecting {participant_names[idx]}.", + "next_speaker": participant_names[idx], + "final_message": None, + } + ) + return _make_mock_response(text, stream=is_stream) + return _make_mock_response("Some response text.", stream=is_stream) + + mock_openai = AsyncMock() + mock_openai.chat.completions.create = mock_create + + client = OpenAIChatClient(model_id="mock-model", async_client=mock_openai) + + researcher = client.as_agent( + name="researcher", + description="Research expert.", + instructions="You are a research specialist.", + ) + critic = client.as_agent( + name="critic", + description="Critical thinker.", + instructions="You are a critical thinker.", + ) + writer = client.as_agent( + name="writer", + description="Writer.", + instructions="You are a skilled writer.", + ) + orchestrator = client.as_agent( + name="orchestrator", + description="Coordinator.", + instructions=( + "You coordinate a team of researcher, critic, and writer. " + "Select the next speaker. Respond with ONLY the agent name." + ), + ) + + workflow = GroupChatBuilder( + participants=[researcher, critic, writer], + orchestrator_agent=orchestrator, + max_rounds=3, + ).build() + agent = workflow.as_agent(name="group_chat") + + tmp_fd, tmp_path = tempfile.mkstemp(suffix=".db") + os.close(tmp_fd) + try: + storage = SqliteResumableStorage(tmp_path) + await storage.setup() + assert storage.checkpoint_storage is not None + + scoped_cs = ScopedCheckpointStorage( + storage.checkpoint_storage, "test-gc-single-bp" + ) + + runtime = UiPathAgentFrameworkRuntime( + agent=agent, + runtime_id="test-gc-single-bp", + checkpoint_storage=scoped_cs, + resumable_storage=storage, + ) + runtime.chat = MagicMock() + runtime.chat.map_messages_to_input.return_value = "Discuss AI" + runtime.chat.map_streaming_content.return_value = [] + runtime.chat.close_message.return_value = [] + + resume_count = [0] + + async def mock_wait_for_resume(*args: Any, **kwargs: Any) -> None: + resume_count[0] += 1 + if resume_count[0] > MAX_RESUME_CALLS: + raise AssertionError(f"Loop detected: {resume_count[0]} resumes") + return None + + bridge = _make_debug_bridge() + # Breakpoint ONLY on orchestrator + cast(Mock, bridge.get_breakpoints).return_value = ["orchestrator"] + cast(AsyncMock, bridge.wait_for_resume).side_effect = mock_wait_for_resume + + debug_runtime = UiPathDebugRuntime(delegate=runtime, debug_bridge=bridge) + + result = await debug_runtime.execute({"messages": []}) + + bp_count = cast(AsyncMock, bridge.emit_breakpoint_hit).await_count + + assert result.status == UiPathRuntimeStatus.SUCCESSFUL, ( + f"Expected SUCCESSFUL, got {result.status}. " + f"Resumes: {resume_count[0]}, BPs: {bp_count}" + ) + assert bp_count > 0 + + # All BPs should be on orchestrator + bp_nodes = [ + call.args[0].breakpoint_node + for call in cast(AsyncMock, bridge.emit_breakpoint_hit).call_args_list + ] + assert all(n == "orchestrator" for n in bp_nodes) + + finally: + await storage.dispose() + os.unlink(tmp_path) + + async def test_group_chat_no_breakpoints_completes(self): + """GroupChat without breakpoints runs to completion (baseline).""" + orchestrator_call_count = [0] + participant_names = ["researcher", "critic", "writer"] + + async def mock_create(**kwargs: Any): + messages = kwargs.get("messages", []) + is_stream = kwargs.get("stream", False) + system_msg = _extract_system_text(messages) + + if ( + "coordinate" in system_msg.lower() + or "next speaker" in system_msg.lower() + ): + idx = orchestrator_call_count[0] % len(participant_names) + orchestrator_call_count[0] += 1 + text = json.dumps( + { + "terminate": False, + "reason": f"Selecting {participant_names[idx]}.", + "next_speaker": participant_names[idx], + "final_message": None, + } + ) + return _make_mock_response(text, stream=is_stream) + return _make_mock_response("Some response text.", stream=is_stream) + + mock_openai = AsyncMock() + mock_openai.chat.completions.create = mock_create + + client = OpenAIChatClient(model_id="mock-model", async_client=mock_openai) + + researcher = client.as_agent( + name="researcher", + description="Research expert.", + instructions="You are a research specialist.", + ) + critic = client.as_agent( + name="critic", + description="Critical thinker.", + instructions="You are a critical thinker.", + ) + writer = client.as_agent( + name="writer", + description="Writer.", + instructions="You are a skilled writer.", + ) + orchestrator = client.as_agent( + name="orchestrator", + description="Coordinator.", + instructions=( + "You coordinate a team of researcher, critic, and writer. " + "Select the next speaker. Respond with ONLY the agent name." + ), + ) + + workflow = GroupChatBuilder( + participants=[researcher, critic, writer], + orchestrator_agent=orchestrator, + max_rounds=3, + ).build() + agent = workflow.as_agent(name="group_chat") + + tmp_fd, tmp_path = tempfile.mkstemp(suffix=".db") + os.close(tmp_fd) + try: + storage = SqliteResumableStorage(tmp_path) + await storage.setup() + assert storage.checkpoint_storage is not None + + scoped_cs = ScopedCheckpointStorage( + storage.checkpoint_storage, "test-gc-no-bp" + ) + + runtime = UiPathAgentFrameworkRuntime( + agent=agent, + runtime_id="test-gc-no-bp", + checkpoint_storage=scoped_cs, + resumable_storage=storage, + ) + runtime.chat = MagicMock() + runtime.chat.map_messages_to_input.return_value = "Discuss AI" + runtime.chat.map_streaming_content.return_value = [] + runtime.chat.close_message.return_value = [] + + bridge = _make_debug_bridge() + # No breakpoints + cast(Mock, bridge.get_breakpoints).return_value = [] + cast(AsyncMock, bridge.wait_for_resume).return_value = None + + debug_runtime = UiPathDebugRuntime(delegate=runtime, debug_bridge=bridge) + + result = await debug_runtime.execute({"messages": []}) + + assert result.status == UiPathRuntimeStatus.SUCCESSFUL + cast(AsyncMock, bridge.emit_breakpoint_hit).assert_not_awaited() + + finally: + await storage.dispose() + os.unlink(tmp_path) + + +class TestQuickstartWorkflowBreakpoints: + """Integration test: quickstart-workflow sample with breakpoints. + + Topology: single weather_agent in a WorkflowBuilder. + Breakpoints="*" → breaks on the single agent, then completes. + """ + + async def test_quickstart_breakall_completes_without_loop(self): + """Single-agent workflow with breakpoints="*" completes.""" + llm_call_log: list[str] = [] + + async def mock_create(**kwargs: Any): + is_stream = kwargs.get("stream", False) + llm_call_log.append("weather_agent") + return _make_mock_response( + "The weather in New York is 72°F and sunny.", + stream=is_stream, + ) + + mock_openai = AsyncMock() + mock_openai.chat.completions.create = mock_create + + client = OpenAIChatClient(model_id="mock-model", async_client=mock_openai) + + weather_agent = client.as_agent( + name="weather_agent", + description="Provides weather information.", + instructions="You are a weather assistant.", + ) + + workflow = WorkflowBuilder(start_executor=weather_agent).build() + agent = workflow.as_agent(name="weather_assistant") + + tmp_fd, tmp_path = tempfile.mkstemp(suffix=".db") + os.close(tmp_fd) + try: + storage = SqliteResumableStorage(tmp_path) + await storage.setup() + assert storage.checkpoint_storage is not None + + scoped_cs = ScopedCheckpointStorage( + storage.checkpoint_storage, "test-qs-bp" + ) + + runtime = UiPathAgentFrameworkRuntime( + agent=agent, + runtime_id="test-qs-bp", + checkpoint_storage=scoped_cs, + resumable_storage=storage, + ) + runtime.chat = MagicMock() + runtime.chat.map_messages_to_input.return_value = ( + "What is the weather in New York?" + ) + runtime.chat.map_streaming_content.return_value = [] + runtime.chat.close_message.return_value = [] + + resume_count = [0] + + async def mock_wait_for_resume(*args: Any, **kwargs: Any) -> None: + resume_count[0] += 1 + if resume_count[0] > MAX_RESUME_CALLS: + raise AssertionError(f"Loop detected: {resume_count[0]} resumes") + return None + + bridge = _make_debug_bridge() + cast(Mock, bridge.get_breakpoints).return_value = "*" + cast(AsyncMock, bridge.wait_for_resume).side_effect = mock_wait_for_resume + + debug_runtime = UiPathDebugRuntime(delegate=runtime, debug_bridge=bridge) + + result = await debug_runtime.execute({"messages": []}) + + bp_count = cast(AsyncMock, bridge.emit_breakpoint_hit).await_count + + assert result.status == UiPathRuntimeStatus.SUCCESSFUL, ( + f"Expected SUCCESSFUL, got {result.status}. " + f"Resumes: {resume_count[0]}, BPs: {bp_count}" + ) + + assert bp_count >= 1, "Should have hit at least one breakpoint" + assert len(llm_call_log) > 0, "LLM should have been called" + + bp_nodes = [ + call.args[0].breakpoint_node + for call in cast(AsyncMock, bridge.emit_breakpoint_hit).call_args_list + ] + assert "weather_agent" in bp_nodes + + finally: + await storage.dispose() + os.unlink(tmp_path) + + +class TestConcurrentBreakpoints: + """Integration test: concurrent sample with breakpoints. + + Topology: dispatcher → [sentiment, topic, summarizer] → aggregator + (ConcurrentBuilder fan-out / fan-in). + Breakpoints="*" → breaks on each executor, then completes. + """ + + async def test_concurrent_breakall_completes_without_loop(self): + """Concurrent workflow with breakpoints="*" completes.""" + llm_call_log: list[str] = [] + + async def mock_create(**kwargs: Any): + messages = kwargs.get("messages", []) + is_stream = kwargs.get("stream", False) + system_msg = _extract_system_text(messages) + + if "sentiment" in system_msg.lower(): + llm_call_log.append("sentiment") + return _make_mock_response( + "Sentiment: positive (0.85)", stream=is_stream + ) + elif "topic" in system_msg.lower() or "entit" in system_msg.lower(): + llm_call_log.append("topic") + return _make_mock_response( + "Topics: AI, safety, alignment", stream=is_stream + ) + elif "summar" in system_msg.lower(): + llm_call_log.append("summarizer") + return _make_mock_response( + "Summary: A discussion about AI safety.", + stream=is_stream, + ) + else: + llm_call_log.append("unknown") + return _make_mock_response("OK", stream=is_stream) + + mock_openai = AsyncMock() + mock_openai.chat.completions.create = mock_create + + client = OpenAIChatClient(model_id="mock-model", async_client=mock_openai) + + sentiment_agent = client.as_agent( + name="sentiment", + description="Analyzes text sentiment.", + instructions="You analyze sentiment of the given text.", + ) + topic_agent = client.as_agent( + name="topic", + description="Extracts topics and entities.", + instructions="You extract topics and entities from text.", + ) + summarizer = client.as_agent( + name="summarizer", + description="Creates concise summaries.", + instructions="You summarize the given text concisely.", + ) + + workflow = ConcurrentBuilder( + participants=[sentiment_agent, topic_agent, summarizer], + ).build() + agent = workflow.as_agent(name="multi_analyzer") + + tmp_fd, tmp_path = tempfile.mkstemp(suffix=".db") + os.close(tmp_fd) + try: + storage = SqliteResumableStorage(tmp_path) + await storage.setup() + assert storage.checkpoint_storage is not None + + scoped_cs = ScopedCheckpointStorage( + storage.checkpoint_storage, "test-conc-bp" + ) + + runtime = UiPathAgentFrameworkRuntime( + agent=agent, + runtime_id="test-conc-bp", + checkpoint_storage=scoped_cs, + resumable_storage=storage, + ) + runtime.chat = MagicMock() + runtime.chat.map_messages_to_input.return_value = ( + "Analyze this text about AI safety" + ) + runtime.chat.map_streaming_content.return_value = [] + runtime.chat.close_message.return_value = [] + + resume_count = [0] + + async def mock_wait_for_resume(*args: Any, **kwargs: Any) -> None: + resume_count[0] += 1 + if resume_count[0] > MAX_RESUME_CALLS: + raise AssertionError(f"Loop detected: {resume_count[0]} resumes") + return None + + bridge = _make_debug_bridge() + cast(Mock, bridge.get_breakpoints).return_value = "*" + cast(AsyncMock, bridge.wait_for_resume).side_effect = mock_wait_for_resume + + debug_runtime = UiPathDebugRuntime(delegate=runtime, debug_bridge=bridge) + + result = await debug_runtime.execute({"messages": []}) + + bp_count = cast(AsyncMock, bridge.emit_breakpoint_hit).await_count + + assert result.status == UiPathRuntimeStatus.SUCCESSFUL, ( + f"Expected SUCCESSFUL, got {result.status}. " + f"Resumes: {resume_count[0]}, BPs: {bp_count}" + ) + + # At least the 3 agents + dispatcher + aggregator + assert bp_count >= 3, f"Expected at least 3 breakpoints, got {bp_count}" + + bp_nodes = [ + call.args[0].breakpoint_node + for call in cast(AsyncMock, bridge.emit_breakpoint_hit).call_args_list + ] + # All 3 agent executors should appear in breakpoint nodes + for agent_name in ("sentiment", "topic", "summarizer"): + assert agent_name in bp_nodes, ( + f"{agent_name} not in breakpoint nodes: {bp_nodes}" + ) + + finally: + await storage.dispose() + os.unlink(tmp_path) + + +class TestHandoffBreakpoints: + """Integration test: handoff sample with breakpoints. + + Topology: triage → [billing_agent, tech_agent, returns_agent] + (HandoffBuilder with handoff tools). + Breakpoints="*" → breaks on triage, then on the specialist it + hands off to, then completes. + """ + + async def test_handoff_breakall_completes_without_loop(self): + """Handoff workflow with breakpoints="*" completes.""" + llm_call_log: list[str] = [] + + async def mock_create(**kwargs: Any): + messages = kwargs.get("messages", []) + is_stream = kwargs.get("stream", False) + system_msg = _extract_system_text(messages) + + if "route" in system_msg.lower() or "triage" in system_msg.lower(): + llm_call_log.append("triage") + # Triage hands off to billing_agent via tool call + return _make_tool_call_response( + "handoff_to_billing_agent", stream=is_stream + ) + elif "billing" in system_msg.lower(): + llm_call_log.append("billing_agent") + return _make_mock_response( + "I've resolved your billing issue. Your account " + "has been credited $50.", + stream=is_stream, + ) + elif "tech" in system_msg.lower(): + llm_call_log.append("tech_agent") + return _make_mock_response( + "I can help with your technical issue.", + stream=is_stream, + ) + elif "return" in system_msg.lower() or "refund" in system_msg.lower(): + llm_call_log.append("returns_agent") + return _make_mock_response( + "I'll process your return right away.", + stream=is_stream, + ) + else: + llm_call_log.append("unknown") + return _make_mock_response("How can I help you?", stream=is_stream) + + mock_openai = AsyncMock() + mock_openai.chat.completions.create = mock_create + + client = OpenAIChatClient(model_id="mock-model", async_client=mock_openai) + + triage = client.as_agent( + name="triage", + description="Routes customers to the right specialist.", + instructions=( + "You are a triage agent. Route the customer to the " + "appropriate specialist based on their issue." + ), + ) + billing_agent = client.as_agent( + name="billing_agent", + description="Handles billing and payment issues.", + instructions=( + "You are a billing specialist. Help customers with " + "billing inquiries and payment issues." + ), + ) + tech_agent = client.as_agent( + name="tech_agent", + description="Handles technical support.", + instructions=( + "You are a tech support specialist. Help customers " + "with technical issues." + ), + ) + returns_agent = client.as_agent( + name="returns_agent", + description="Handles returns and refunds.", + instructions=( + "You are a returns specialist. Help customers with " + "returns and refund requests." + ), + ) + + workflow = ( + HandoffBuilder( + name="customer_support", + participants=[triage, billing_agent, tech_agent, returns_agent], + ) + .with_start_agent(triage) + .add_handoff(triage, [billing_agent, tech_agent, returns_agent]) + .add_handoff(billing_agent, [triage]) + .add_handoff(tech_agent, [triage]) + .add_handoff(returns_agent, [triage]) + .build() + ) + agent = workflow.as_agent(name="customer_support") + + tmp_fd, tmp_path = tempfile.mkstemp(suffix=".db") + os.close(tmp_fd) + try: + storage = SqliteResumableStorage(tmp_path) + await storage.setup() + assert storage.checkpoint_storage is not None + + scoped_cs = ScopedCheckpointStorage( + storage.checkpoint_storage, "test-handoff-bp" + ) + + runtime = UiPathAgentFrameworkRuntime( + agent=agent, + runtime_id="test-handoff-bp", + checkpoint_storage=scoped_cs, + resumable_storage=storage, + ) + runtime.chat = MagicMock() + runtime.chat.map_messages_to_input.return_value = "I have a billing issue" + runtime.chat.map_streaming_content.return_value = [] + runtime.chat.close_message.return_value = [] + + resume_count = [0] + + async def mock_wait_for_resume(*args: Any, **kwargs: Any) -> None: + resume_count[0] += 1 + if resume_count[0] > MAX_RESUME_CALLS: + raise AssertionError( + f"Loop detected: {resume_count[0]} resumes. " + f"LLM calls: {llm_call_log}" + ) + return None + + bridge = _make_debug_bridge() + cast(Mock, bridge.get_breakpoints).return_value = "*" + cast(AsyncMock, bridge.wait_for_resume).side_effect = mock_wait_for_resume + + debug_runtime = UiPathDebugRuntime(delegate=runtime, debug_bridge=bridge) + + result = await debug_runtime.execute({"messages": []}) + + bp_count = cast(AsyncMock, bridge.emit_breakpoint_hit).await_count + + assert result.status == UiPathRuntimeStatus.SUCCESSFUL, ( + f"Expected SUCCESSFUL, got {result.status}. " + f"Resumes: {resume_count[0]}, BPs: {bp_count}, " + f"LLM calls: {llm_call_log}" + ) + + # At least triage and billing_agent should have been breakpointed + assert bp_count >= 2, f"Expected at least 2 breakpoints, got {bp_count}" + + bp_nodes = [ + call.args[0].breakpoint_node + for call in cast(AsyncMock, bridge.emit_breakpoint_hit).call_args_list + ] + assert "triage" in bp_nodes, f"triage not in breakpoint nodes: {bp_nodes}" + assert "billing_agent" in bp_nodes, ( + f"billing_agent not in breakpoint nodes: {bp_nodes}" + ) + + finally: + await storage.dispose() + os.unlink(tmp_path) + + +class TestHitlWorkflowBreakpoints: + """Integration test: hitl-workflow sample with breakpoints. + + Topology: triage → [billing_agent, returns_agent] + (HandoffBuilder, same as handoff but specialists have HITL tools). + Breakpoints="*" → breaks on each executor independently of HITL. + """ + + async def test_hitl_breakall_completes_without_loop(self): + """HITL handoff workflow with breakpoints="*" completes. + + The HITL tools (@requires_approval) are present on specialist agents + but the mock LLM does not trigger them. This test verifies that + breakpoints work correctly on the HITL topology. + """ + llm_call_log: list[str] = [] + + async def mock_create(**kwargs: Any): + messages = kwargs.get("messages", []) + is_stream = kwargs.get("stream", False) + system_msg = _extract_system_text(messages) + + if "route" in system_msg.lower() or "triage" in system_msg.lower(): + llm_call_log.append("triage") + return _make_tool_call_response( + "handoff_to_billing_agent", stream=is_stream + ) + elif "billing" in system_msg.lower(): + llm_call_log.append("billing_agent") + return _make_mock_response( + "Your billing issue has been resolved.", + stream=is_stream, + ) + elif "return" in system_msg.lower() or "refund" in system_msg.lower(): + llm_call_log.append("returns_agent") + return _make_mock_response( + "Your refund has been processed.", + stream=is_stream, + ) + else: + llm_call_log.append("unknown") + return _make_mock_response("How can I help you?", stream=is_stream) + + mock_openai = AsyncMock() + mock_openai.chat.completions.create = mock_create + + client = OpenAIChatClient(model_id="mock-model", async_client=mock_openai) + + triage = client.as_agent( + name="triage", + description="Routes customers to the right specialist.", + instructions=( + "You are a triage agent. Route the customer to the " + "appropriate specialist." + ), + ) + billing_agent = client.as_agent( + name="billing_agent", + description="Handles billing with approval-required tools.", + instructions=( + "You are a billing specialist. Use transfer_funds " + "when needed (requires approval)." + ), + ) + returns_agent = client.as_agent( + name="returns_agent", + description="Handles returns with approval-required tools.", + instructions=( + "You are a returns specialist. Use issue_refund " + "when needed (requires approval)." + ), + ) + + workflow = ( + HandoffBuilder( + name="hitl_support", + participants=[triage, billing_agent, returns_agent], + ) + .with_start_agent(triage) + .add_handoff(triage, [billing_agent, returns_agent]) + .add_handoff(billing_agent, [triage]) + .add_handoff(returns_agent, [triage]) + .build() + ) + agent = workflow.as_agent(name="hitl_support") + + tmp_fd, tmp_path = tempfile.mkstemp(suffix=".db") + os.close(tmp_fd) + try: + storage = SqliteResumableStorage(tmp_path) + await storage.setup() + assert storage.checkpoint_storage is not None + + scoped_cs = ScopedCheckpointStorage( + storage.checkpoint_storage, "test-hitl-bp" + ) + + runtime = UiPathAgentFrameworkRuntime( + agent=agent, + runtime_id="test-hitl-bp", + checkpoint_storage=scoped_cs, + resumable_storage=storage, + ) + runtime.chat = MagicMock() + runtime.chat.map_messages_to_input.return_value = "I need to transfer funds" + runtime.chat.map_streaming_content.return_value = [] + runtime.chat.close_message.return_value = [] + + resume_count = [0] + + async def mock_wait_for_resume(*args: Any, **kwargs: Any) -> None: + resume_count[0] += 1 + if resume_count[0] > MAX_RESUME_CALLS: + raise AssertionError( + f"Loop detected: {resume_count[0]} resumes. " + f"LLM calls: {llm_call_log}" + ) + return None + + bridge = _make_debug_bridge() + cast(Mock, bridge.get_breakpoints).return_value = "*" + cast(AsyncMock, bridge.wait_for_resume).side_effect = mock_wait_for_resume + + debug_runtime = UiPathDebugRuntime(delegate=runtime, debug_bridge=bridge) + + result = await debug_runtime.execute({"messages": []}) + + bp_count = cast(AsyncMock, bridge.emit_breakpoint_hit).await_count + + assert result.status == UiPathRuntimeStatus.SUCCESSFUL, ( + f"Expected SUCCESSFUL, got {result.status}. " + f"Resumes: {resume_count[0]}, BPs: {bp_count}, " + f"LLM calls: {llm_call_log}" + ) + + assert bp_count >= 2, f"Expected at least 2 breakpoints, got {bp_count}" + + bp_nodes = [ + call.args[0].breakpoint_node + for call in cast(AsyncMock, bridge.emit_breakpoint_hit).call_args_list + ] + assert "triage" in bp_nodes, ( + f"triage not in breakpoint nodes (hitl): {bp_nodes}" + ) + assert "billing_agent" in bp_nodes, ( + f"billing_agent not in breakpoint nodes (hitl): {bp_nodes}" + ) + + finally: + await storage.dispose() + os.unlink(tmp_path) diff --git a/packages/uipath-agent-framework/tests/test_storage.py b/packages/uipath-agent-framework/tests/test_storage.py index a5acac30..d3c38329 100644 --- a/packages/uipath-agent-framework/tests/test_storage.py +++ b/packages/uipath-agent-framework/tests/test_storage.py @@ -7,7 +7,6 @@ from uipath_agent_framework.runtime.resumable_storage import ( ScopedCheckpointStorage, - SqliteCheckpointStorage, SqliteResumableStorage, ) @@ -62,6 +61,7 @@ async def test_save_and_load_checkpoint(self): storage = SqliteResumableStorage(db_path) await storage.setup() cs = storage.checkpoint_storage + assert cs is not None checkpoint = _make_checkpoint(checkpoint_id="cp-1") await cs.save(checkpoint) @@ -81,10 +81,11 @@ async def test_load_nonexistent_raises(self): storage = SqliteResumableStorage(db_path) await storage.setup() cs = storage.checkpoint_storage + assert cs is not None try: await cs.load("nonexistent") - assert False, "Should have raised" + raise AssertionError("Should have raised") except Exception: pass await storage.dispose() @@ -96,6 +97,7 @@ async def test_get_latest_returns_most_recent(self): storage = SqliteResumableStorage(db_path) await storage.setup() cs = storage.checkpoint_storage + assert cs is not None cp1 = _make_checkpoint( checkpoint_id="cp-old", timestamp="2026-01-01T00:00:00+00:00" @@ -118,6 +120,7 @@ async def test_get_latest_returns_none_for_empty(self): storage = SqliteResumableStorage(db_path) await storage.setup() cs = storage.checkpoint_storage + assert cs is not None latest = await cs.get_latest(workflow_name="nonexistent") assert latest is None @@ -130,6 +133,7 @@ async def test_delete_checkpoint(self): storage = SqliteResumableStorage(db_path) await storage.setup() cs = storage.checkpoint_storage + assert cs is not None cp = _make_checkpoint(checkpoint_id="cp-del") await cs.save(cp) @@ -149,16 +153,11 @@ async def test_list_checkpoints_filtered_by_workflow_name(self): storage = SqliteResumableStorage(db_path) await storage.setup() cs = storage.checkpoint_storage + assert cs is not None - cp1 = _make_checkpoint( - workflow_name="wf_a", checkpoint_id="cp-a1" - ) - cp2 = _make_checkpoint( - workflow_name="wf_a", checkpoint_id="cp-a2" - ) - cp3 = _make_checkpoint( - workflow_name="wf_b", checkpoint_id="cp-b1" - ) + cp1 = _make_checkpoint(workflow_name="wf_a", checkpoint_id="cp-a1") + cp2 = _make_checkpoint(workflow_name="wf_a", checkpoint_id="cp-a2") + cp3 = _make_checkpoint(workflow_name="wf_b", checkpoint_id="cp-b1") await cs.save(cp1) await cs.save(cp2) await cs.save(cp3) @@ -179,16 +178,11 @@ async def test_list_checkpoint_ids(self): storage = SqliteResumableStorage(db_path) await storage.setup() cs = storage.checkpoint_storage + assert cs is not None - cp1 = _make_checkpoint( - workflow_name="wf_x", checkpoint_id="cp-x1" - ) - cp2 = _make_checkpoint( - workflow_name="wf_x", checkpoint_id="cp-x2" - ) - cp3 = _make_checkpoint( - workflow_name="wf_y", checkpoint_id="cp-y1" - ) + cp1 = _make_checkpoint(workflow_name="wf_x", checkpoint_id="cp-x1") + cp2 = _make_checkpoint(workflow_name="wf_x", checkpoint_id="cp-x2") + cp3 = _make_checkpoint(workflow_name="wf_y", checkpoint_id="cp-y1") await cs.save(cp1) await cs.save(cp2) await cs.save(cp3) @@ -204,6 +198,7 @@ async def test_save_overwrites_existing_checkpoint(self): storage = SqliteResumableStorage(db_path) await storage.setup() cs = storage.checkpoint_storage + assert cs is not None cp1 = _make_checkpoint(checkpoint_id="cp-ow") cp1.state = {"version": 1} @@ -225,6 +220,7 @@ async def test_dispose_allows_reconnect(self): storage = SqliteResumableStorage(db_path) await storage.setup() cs = storage.checkpoint_storage + assert cs is not None cp = _make_checkpoint(checkpoint_id="cp-persist") await cs.save(cp) @@ -234,6 +230,7 @@ async def test_dispose_allows_reconnect(self): storage2 = SqliteResumableStorage(db_path) await storage2.setup() cs2 = storage2.checkpoint_storage + assert cs2 is not None loaded = await cs2.load("cp-persist") assert loaded.checkpoint_id == "cp-persist" @@ -250,16 +247,13 @@ async def test_scoped_storage_isolates_by_runtime_id(self): storage = SqliteResumableStorage(db_path) await storage.setup() cs = storage.checkpoint_storage + assert cs is not None scoped_a = ScopedCheckpointStorage(cs, "runtime-a") scoped_b = ScopedCheckpointStorage(cs, "runtime-b") - cp_a = _make_checkpoint( - workflow_name="my_wf", checkpoint_id="cp-a" - ) - cp_b = _make_checkpoint( - workflow_name="my_wf", checkpoint_id="cp-b" - ) + cp_a = _make_checkpoint(workflow_name="my_wf", checkpoint_id="cp-a") + cp_b = _make_checkpoint(workflow_name="my_wf", checkpoint_id="cp-b") await scoped_a.save(cp_a) await scoped_b.save(cp_b) @@ -282,6 +276,7 @@ async def test_scoped_get_latest_respects_scope(self): storage = SqliteResumableStorage(db_path) await storage.setup() cs = storage.checkpoint_storage + assert cs is not None scoped_a = ScopedCheckpointStorage(cs, "rt-a") scoped_b = ScopedCheckpointStorage(cs, "rt-b") @@ -316,11 +311,10 @@ async def test_scoped_load_and_delete_are_global(self): storage = SqliteResumableStorage(db_path) await storage.setup() cs = storage.checkpoint_storage + assert cs is not None scoped = ScopedCheckpointStorage(cs, "rt-x") - cp = _make_checkpoint( - workflow_name="wf", checkpoint_id="cp-global" - ) + cp = _make_checkpoint(workflow_name="wf", checkpoint_id="cp-global") await scoped.save(cp) # Load from any scope @@ -339,16 +333,13 @@ async def test_scoped_list_checkpoint_ids(self): storage = SqliteResumableStorage(db_path) await storage.setup() cs = storage.checkpoint_storage + assert cs is not None scoped_a = ScopedCheckpointStorage(cs, "rt-a") scoped_b = ScopedCheckpointStorage(cs, "rt-b") - cp_a = _make_checkpoint( - workflow_name="wf", checkpoint_id="cp-a" - ) - cp_b = _make_checkpoint( - workflow_name="wf", checkpoint_id="cp-b" - ) + cp_a = _make_checkpoint(workflow_name="wf", checkpoint_id="cp-a") + cp_b = _make_checkpoint(workflow_name="wf", checkpoint_id="cp-b") await scoped_a.save(cp_a) await scoped_b.save(cp_b) diff --git a/packages/uipath-agent-framework/tests/test_streaming.py b/packages/uipath-agent-framework/tests/test_streaming.py index 97f38873..ac7e890c 100644 --- a/packages/uipath-agent-framework/tests/test_streaming.py +++ b/packages/uipath-agent-framework/tests/test_streaming.py @@ -98,7 +98,7 @@ def calculator(expression: str) -> str: def _make_runtime(agent: BaseAgent) -> UiPathAgentFrameworkRuntime: """Create a runtime with mocked chat mapper.""" - runtime = UiPathAgentFrameworkRuntime(agent=agent) + runtime = UiPathAgentFrameworkRuntime(agent=agent) # type: ignore[arg-type] runtime.chat = MagicMock() runtime.chat.map_messages_to_input.return_value = "test" runtime.chat.map_streaming_content.return_value = [] @@ -168,7 +168,7 @@ async def test_simple_workflow(self): final = MagicMock() final.get_outputs.return_value = [] - workflow.run = MagicMock( # type: ignore[method-assign] + workflow.run = MagicMock( # type: ignore[assignment] return_value=_MockAsyncStream( [ _wf_event("executor_invoked", "triage"), @@ -179,7 +179,7 @@ async def test_simple_workflow(self): final, ) ) - agent.create_session = MagicMock(return_value=MagicMock()) # type: ignore[method-assign] + agent.create_session = MagicMock(return_value=MagicMock()) # type: ignore[assignment] runtime = _make_runtime(agent) events = await _collect_events(runtime) @@ -209,8 +209,8 @@ async def test_multi_executor_workflow(self): final = MagicMock() final.get_outputs.return_value = [] - workflow.run = MagicMock(return_value=_MockAsyncStream(wf_events, final)) # type: ignore[method-assign] - agent.create_session = MagicMock(return_value=MagicMock()) # type: ignore[method-assign] + workflow.run = MagicMock(return_value=_MockAsyncStream(wf_events, final)) # type: ignore[assignment] + agent.create_session = MagicMock(return_value=MagicMock()) # type: ignore[assignment] runtime = _make_runtime(agent) events = await _collect_events(runtime) @@ -228,7 +228,7 @@ async def test_workflow_root_wraps_executors(self): final = MagicMock() final.get_outputs.return_value = [] - workflow.run = MagicMock( # type: ignore[method-assign] + workflow.run = MagicMock( # type: ignore[assignment] return_value=_MockAsyncStream( [ _wf_event("executor_invoked", "worker"), @@ -237,7 +237,7 @@ async def test_workflow_root_wraps_executors(self): final, ) ) - agent.create_session = MagicMock(return_value=MagicMock()) # type: ignore[method-assign] + agent.create_session = MagicMock(return_value=MagicMock()) # type: ignore[assignment] runtime = _make_runtime(agent) events = await _collect_events(runtime) @@ -323,7 +323,7 @@ async def test_tool_call_emits_state_events(self): final = MagicMock() final.get_outputs.return_value = [] - workflow.run = MagicMock( # type: ignore[method-assign] + workflow.run = MagicMock( # type: ignore[assignment] return_value=_MockAsyncStream( [ _wf_event("executor_invoked", "weather_agent"), @@ -334,7 +334,7 @@ async def test_tool_call_emits_state_events(self): final, ) ) - agent.create_session = MagicMock(return_value=MagicMock()) # type: ignore[method-assign] + agent.create_session = MagicMock(return_value=MagicMock()) # type: ignore[assignment] runtime = _make_runtime(agent) events = await _collect_events(runtime) @@ -380,9 +380,7 @@ async def test_multiple_tool_calls_emit_paired_events(self): ] ), AgentResponseUpdate( - contents=[ - Content(type="function_result", call_id="c1", result="42") - ] + contents=[Content(type="function_result", call_id="c1", result="42")] ), AgentResponseUpdate( contents=[ @@ -395,9 +393,7 @@ async def test_multiple_tool_calls_emit_paired_events(self): ] ), AgentResponseUpdate( - contents=[ - Content(type="function_result", call_id="c2", result="found") - ] + contents=[Content(type="function_result", call_id="c2", result="found")] ), ] @@ -408,8 +404,8 @@ async def test_multiple_tool_calls_emit_paired_events(self): final = MagicMock() final.get_outputs.return_value = [] - workflow.run = MagicMock(return_value=_MockAsyncStream(wf_events, final)) # type: ignore[method-assign] - agent.create_session = MagicMock(return_value=MagicMock()) # type: ignore[method-assign] + workflow.run = MagicMock(return_value=_MockAsyncStream(wf_events, final)) # type: ignore[assignment] + agent.create_session = MagicMock(return_value=MagicMock()) # type: ignore[assignment] runtime = _make_runtime(agent) events = await _collect_events(runtime) @@ -433,7 +429,7 @@ async def test_no_tool_events_for_text_content(self): final = MagicMock() final.get_outputs.return_value = [] - workflow.run = MagicMock( # type: ignore[method-assign] + workflow.run = MagicMock( # type: ignore[assignment] return_value=_MockAsyncStream( [ _wf_event("executor_invoked", "text_agent"), @@ -443,7 +439,7 @@ async def test_no_tool_events_for_text_content(self): final, ) ) - agent.create_session = MagicMock(return_value=MagicMock()) # type: ignore[method-assign] + agent.create_session = MagicMock(return_value=MagicMock()) # type: ignore[assignment] runtime = _make_runtime(agent) events = await _collect_events(runtime) @@ -476,9 +472,7 @@ async def test_executor_completed_payload_excludes_streaming_updates(self): ] ) result_chunk = AgentResponseUpdate( - contents=[ - Content(type="function_result", call_id="c1", result="42") - ] + contents=[Content(type="function_result", call_id="c1", result="42")] ) # The framework packs sent_messages + yielded_outputs into completed data @@ -497,10 +491,10 @@ async def test_executor_completed_payload_excludes_streaming_updates(self): _wf_event("executor_completed", "agent_x", data=completed_data) ) - workflow.run = MagicMock( # type: ignore[method-assign] + workflow.run = MagicMock( # type: ignore[assignment] return_value=_MockAsyncStream(stream_events, final) ) - agent.create_session = MagicMock(return_value=MagicMock()) # type: ignore[method-assign] + agent.create_session = MagicMock(return_value=MagicMock()) # type: ignore[assignment] runtime = _make_runtime(agent) events = await _collect_events(runtime) @@ -543,16 +537,14 @@ async def test_tool_events_from_executor_completed_when_output_filtered(self): ] ) result_update = AgentResponseUpdate( - contents=[ - Content(type="function_result", call_id="c1", result="found") - ] + contents=[Content(type="function_result", call_id="c1", result="found")] ) summary = MagicMock() # AgentExecutorResponse completed_data = [summary, call_update, result_update] final = MagicMock() final.get_outputs.return_value = [] - workflow.run = MagicMock( # type: ignore[method-assign] + workflow.run = MagicMock( # type: ignore[assignment] return_value=_MockAsyncStream( [ # No "output" events — simulating the filter @@ -562,7 +554,7 @@ async def test_tool_events_from_executor_completed_when_output_filtered(self): final, ) ) - agent.create_session = MagicMock(return_value=MagicMock()) # type: ignore[method-assign] + agent.create_session = MagicMock(return_value=MagicMock()) # type: ignore[assignment] runtime = _make_runtime(agent) events = await _collect_events(runtime) @@ -601,16 +593,14 @@ async def test_no_duplicate_tool_events_when_output_present(self): ] ) result_update = AgentResponseUpdate( - contents=[ - Content(type="function_result", call_id="c1", result="42") - ] + contents=[Content(type="function_result", call_id="c1", result="42")] ) summary = MagicMock() completed_data = [summary, call_update, result_update] final = MagicMock() final.get_outputs.return_value = [] - workflow.run = MagicMock( # type: ignore[method-assign] + workflow.run = MagicMock( # type: ignore[assignment] return_value=_MockAsyncStream( [ _wf_event("executor_invoked", "agent_y"), @@ -622,7 +612,7 @@ async def test_no_duplicate_tool_events_when_output_present(self): final, ) ) - agent.create_session = MagicMock(return_value=MagicMock()) # type: ignore[method-assign] + agent.create_session = MagicMock(return_value=MagicMock()) # type: ignore[assignment] runtime = _make_runtime(agent) events = await _collect_events(runtime) @@ -664,7 +654,7 @@ def mock_run(**kwargs): final, ) - workflow.run = mock_run # type: ignore[method-assign] + workflow.run = mock_run # type: ignore[assignment] runtime = UiPathAgentFrameworkRuntime( agent=agent, @@ -697,7 +687,7 @@ async def mock_run(**kwargs): result.get_outputs.return_value = ["done"] return result - workflow.run = mock_run # type: ignore[method-assign] + workflow.run = mock_run # type: ignore[assignment] runtime = UiPathAgentFrameworkRuntime( agent=agent, @@ -718,7 +708,9 @@ async def test_hitl_resume_passes_checkpoint_id_and_responses(self): workflow.name = "resume_wf" mock_checkpoint_storage = MagicMock() - mock_checkpoint_storage.get_latest = AsyncMock(return_value=MagicMock(checkpoint_id="cp-123")) + mock_checkpoint_storage.get_latest = AsyncMock( + return_value=MagicMock(checkpoint_id="cp-123") + ) captured_kwargs: list[dict[str, Any]] = [] def mock_run(**kwargs): @@ -733,7 +725,7 @@ def mock_run(**kwargs): final, ) - workflow.run = mock_run # type: ignore[method-assign] + workflow.run = mock_run # type: ignore[assignment] runtime = UiPathAgentFrameworkRuntime( agent=agent, @@ -818,16 +810,20 @@ def mock_run(**kwargs): final, ) - workflow.run = mock_run # type: ignore[method-assign] + workflow.run = mock_run # type: ignore[assignment] await _collect_events(runtime) # Session should have been captured with prior turn data assert len(captured_sessions) == 1 - assert captured_sessions[0].state.get("prior_turn_data") == "previous conversation" + assert ( + captured_sessions[0].state.get("prior_turn_data") == "previous conversation" + ) # Session should have been loaded from KV storage - mock_storage.get_value.assert_called_once_with("test-session", "session", "data") + mock_storage.get_value.assert_called_once_with( + "test-session", "session", "data" + ) # Session should have been saved after execution mock_storage.set_value.assert_called_once() @@ -861,7 +857,7 @@ async def mock_run(**kwargs): result.get_outputs.return_value = ["done"] return result - workflow.run = mock_run # type: ignore[method-assign] + workflow.run = mock_run # type: ignore[assignment] await runtime.execute(input={"messages": []}) diff --git a/packages/uipath-agent-framework/uv.lock b/packages/uipath-agent-framework/uv.lock index f28f5001..b18a2cb5 100644 --- a/packages/uipath-agent-framework/uv.lock +++ b/packages/uipath-agent-framework/uv.lock @@ -40,6 +40,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d4/b6/6c32def3967496e3688a69fda5bb55066688721c2839bc220de8ca3b306b/agent_framework_core-1.0.0b260212-py3-none-any.whl", hash = "sha256:18fc00c35911bd8a8c49055eda5204dc6dde761cf2379d32b16a91aaf6635dc0", size = 301370, upload-time = "2026-02-13T00:27:32.307Z" }, ] +[[package]] +name = "agent-framework-orchestrations" +version = "1.0.0b260212" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "agent-framework-core" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/42/bf/9a1f62bf243157f5ce20e26b1549bdc563334521a1c3155b7eb4093f5b47/agent_framework_orchestrations-1.0.0b260212.tar.gz", hash = "sha256:31c215af8fe7cf954c17306113137df945aabae35e9f3c89052b6983e39516a8", size = 53522, upload-time = "2026-02-13T00:37:58.163Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/2d/87/7af9ea63943555dd133a9aaf3d056de4763cc2e91986763167a2f4dff6e0/agent_framework_orchestrations-1.0.0b260212-py3-none-any.whl", hash = "sha256:c10ed851a33ce46de8ee50f20a08b1db2604ef0a14739a23b42ba1279ddd52d1", size = 59721, upload-time = "2026-02-13T00:37:51.783Z" }, +] + [[package]] name = "aiohappyeyeballs" version = "2.6.1" @@ -2448,10 +2460,11 @@ wheels = [ [[package]] name = "uipath-agent-framework" -version = "0.0.4" +version = "0.0.5" source = { editable = "." } dependencies = [ { name = "agent-framework-core" }, + { name = "agent-framework-orchestrations" }, { name = "aiosqlite" }, { name = "openinference-instrumentation-agent-framework" }, { name = "uipath" }, @@ -2479,6 +2492,7 @@ dev = [ requires-dist = [ { name = "agent-framework-anthropic", marker = "extra == 'anthropic'", specifier = ">=1.0.0b260212" }, { name = "agent-framework-core", specifier = ">=1.0.0b260212" }, + { name = "agent-framework-orchestrations", specifier = ">=1.0.0b260212" }, { name = "aiosqlite", specifier = ">=0.20.0" }, { name = "anthropic", marker = "extra == 'anthropic'", specifier = ">=0.43.0" }, { name = "openinference-instrumentation-agent-framework", specifier = ">=0.1.0" },