diff --git a/python/packages/core/AGENTS.md b/python/packages/core/AGENTS.md index 823b601b76..3958957596 100644 --- a/python/packages/core/AGENTS.md +++ b/python/packages/core/AGENTS.md @@ -119,7 +119,7 @@ from agent_framework import Agent, AgentMiddleware, AgentContext class LoggingMiddleware(AgentMiddleware): async def process(self, context: AgentContext, call_next) -> None: print(f"Input: {context.messages}") - await call_next(context) + await call_next() print(f"Output: {context.result}") agent = Agent(..., middleware=[LoggingMiddleware()]) diff --git a/python/packages/core/agent_framework/_middleware.py b/python/packages/core/agent_framework/_middleware.py index ac6630a03f..e595be76e3 100644 --- a/python/packages/core/agent_framework/_middleware.py +++ b/python/packages/core/agent_framework/_middleware.py @@ -145,7 +145,7 @@ async def process(self, context: AgentContext, call_next): context.metadata["start_time"] = time.time() # Continue execution - await call_next(context) + await call_next() # Access result after execution print(f"Result: {context.result}") @@ -229,7 +229,7 @@ async def process(self, context: FunctionInvocationContext, call_next): raise MiddlewareTermination("Validation failed") # Continue execution - await call_next(context) + await call_next() """ def __init__( @@ -293,7 +293,7 @@ async def process(self, context: ChatContext, call_next): context.metadata["input_tokens"] = self.count_tokens(context.messages) # Continue execution - await call_next(context) + await call_next() # Access result and count output tokens if context.result: @@ -365,7 +365,7 @@ def __init__(self, max_retries: int = 3): async def process(self, context: AgentContext, call_next): for attempt in range(self.max_retries): - await call_next(context) + await call_next() if context.result and not context.result.is_error: break print(f"Retry {attempt + 1}/{self.max_retries}") @@ -379,7 +379,7 @@ async def process(self, context: AgentContext, call_next): async def process( self, context: AgentContext, - call_next: Callable[[AgentContext], Awaitable[None]], + call_next: Callable[[], Awaitable[None]], ) -> None: """Process an agent invocation. @@ -431,7 +431,7 @@ async def process(self, context: FunctionInvocationContext, call_next): raise MiddlewareTermination() # Execute function - await call_next(context) + await call_next() # Cache result if context.result: @@ -446,7 +446,7 @@ async def process(self, context: FunctionInvocationContext, call_next): async def process( self, context: FunctionInvocationContext, - call_next: Callable[[FunctionInvocationContext], Awaitable[None]], + call_next: Callable[[], Awaitable[None]], ) -> None: """Process a function invocation. @@ -493,7 +493,7 @@ async def process(self, context: ChatContext, call_next): context.messages.insert(0, Message(role="system", text=self.system_prompt)) # Continue execution - await call_next(context) + await call_next() # Use with an agent @@ -508,7 +508,7 @@ async def process(self, context: ChatContext, call_next): async def process( self, context: ChatContext, - call_next: Callable[[ChatContext], Awaitable[None]], + call_next: Callable[[], Awaitable[None]], ) -> None: """Process a chat client request. @@ -531,15 +531,13 @@ async def process( # Pure function type definitions for convenience -AgentMiddlewareCallable = Callable[[AgentContext, Callable[[AgentContext], Awaitable[None]]], Awaitable[None]] +AgentMiddlewareCallable = Callable[[AgentContext, Callable[[], Awaitable[None]]], Awaitable[None]] AgentMiddlewareTypes: TypeAlias = AgentMiddleware | AgentMiddlewareCallable -FunctionMiddlewareCallable = Callable[ - [FunctionInvocationContext, Callable[[FunctionInvocationContext], Awaitable[None]]], Awaitable[None] -] +FunctionMiddlewareCallable = Callable[[FunctionInvocationContext, Callable[[], Awaitable[None]]], Awaitable[None]] FunctionMiddlewareTypes: TypeAlias = FunctionMiddleware | FunctionMiddlewareCallable -ChatMiddlewareCallable = Callable[[ChatContext, Callable[[ChatContext], Awaitable[None]]], Awaitable[None]] +ChatMiddlewareCallable = Callable[[ChatContext, Callable[[], Awaitable[None]]], Awaitable[None]] ChatMiddlewareTypes: TypeAlias = ChatMiddleware | ChatMiddlewareCallable ChatAndFunctionMiddlewareTypes: TypeAlias = ( @@ -578,7 +576,7 @@ def agent_middleware(func: AgentMiddlewareCallable) -> AgentMiddlewareCallable: @agent_middleware async def logging_middleware(context: AgentContext, call_next): print(f"Before: {context.agent.name}") - await call_next(context) + await call_next() print(f"After: {context.result}") @@ -611,7 +609,7 @@ def function_middleware(func: FunctionMiddlewareCallable) -> FunctionMiddlewareC @function_middleware async def logging_middleware(context: FunctionInvocationContext, call_next): print(f"Calling: {context.function.name}") - await call_next(context) + await call_next() print(f"Result: {context.result}") @@ -644,7 +642,7 @@ def chat_middleware(func: ChatMiddlewareCallable) -> ChatMiddlewareCallable: @chat_middleware async def logging_middleware(context: ChatContext, call_next): print(f"Messages: {len(context.messages)}") - await call_next(context) + await call_next() print(f"Response: {context.result}") @@ -666,10 +664,10 @@ class MiddlewareWrapper(Generic[ContextT]): ContextT: The type of context object this middleware operates on. """ - def __init__(self, func: Callable[[ContextT, Callable[[ContextT], Awaitable[None]]], Awaitable[None]]) -> None: + def __init__(self, func: Callable[[ContextT, Callable[[], Awaitable[None]]], Awaitable[None]]) -> None: self.func = func - async def process(self, context: ContextT, call_next: Callable[[ContextT], Awaitable[None]]) -> None: + async def process(self, context: ContextT, call_next: Callable[[], Awaitable[None]]) -> None: await self.func(context, call_next) @@ -772,25 +770,25 @@ async def execute( context.result = await context.result return context.result - def create_next_handler(index: int) -> Callable[[AgentContext], Awaitable[None]]: + def create_next_handler(index: int) -> Callable[[], Awaitable[None]]: if index >= len(self._middleware): - async def final_wrapper(c: AgentContext) -> None: - c.result = final_handler(c) # type: ignore[assignment] - if inspect.isawaitable(c.result): - c.result = await c.result + async def final_wrapper() -> None: + context.result = final_handler(context) # type: ignore[assignment] + if inspect.isawaitable(context.result): + context.result = await context.result return final_wrapper - async def current_handler(c: AgentContext) -> None: + async def current_handler() -> None: # MiddlewareTermination bubbles up to execute() to skip post-processing - await self._middleware[index].process(c, create_next_handler(index + 1)) + await self._middleware[index].process(context, create_next_handler(index + 1)) return current_handler first_handler = create_next_handler(0) with contextlib.suppress(MiddlewareTermination): - await first_handler(context) + await first_handler() if context.result and isinstance(context.result, ResponseStream): for hook in context.stream_transform_hooks: @@ -847,25 +845,25 @@ async def execute( if not self._middleware: return await final_handler(context) - def create_next_handler(index: int) -> Callable[[FunctionInvocationContext], Awaitable[None]]: + def create_next_handler(index: int) -> Callable[[], Awaitable[None]]: if index >= len(self._middleware): - async def final_wrapper(c: FunctionInvocationContext) -> None: - c.result = final_handler(c) - if inspect.isawaitable(c.result): - c.result = await c.result + async def final_wrapper() -> None: + context.result = final_handler(context) + if inspect.isawaitable(context.result): + context.result = await context.result return final_wrapper - async def current_handler(c: FunctionInvocationContext) -> None: + async def current_handler() -> None: # MiddlewareTermination bubbles up to execute() to skip post-processing - await self._middleware[index].process(c, create_next_handler(index + 1)) + await self._middleware[index].process(context, create_next_handler(index + 1)) return current_handler first_handler = create_next_handler(0) # Don't suppress MiddlewareTermination - let it propagate to signal loop termination - await first_handler(context) + await first_handler() return context.result @@ -922,25 +920,25 @@ async def execute( raise ValueError("Streaming agent middleware requires a ResponseStream result.") return context.result - def create_next_handler(index: int) -> Callable[[ChatContext], Awaitable[None]]: + def create_next_handler(index: int) -> Callable[[], Awaitable[None]]: if index >= len(self._middleware): - async def final_wrapper(c: ChatContext) -> None: - c.result = final_handler(c) # type: ignore[assignment] - if inspect.isawaitable(c.result): - c.result = await c.result + async def final_wrapper() -> None: + context.result = final_handler(context) # type: ignore[assignment] + if inspect.isawaitable(context.result): + context.result = await context.result return final_wrapper - async def current_handler(c: ChatContext) -> None: + async def current_handler() -> None: # MiddlewareTermination bubbles up to execute() to skip post-processing - await self._middleware[index].process(c, create_next_handler(index + 1)) + await self._middleware[index].process(context, create_next_handler(index + 1)) return current_handler first_handler = create_next_handler(0) with contextlib.suppress(MiddlewareTermination): - await first_handler(context) + await first_handler() if context.result and isinstance(context.result, ResponseStream): for hook in context.stream_transform_hooks: diff --git a/python/packages/core/tests/core/test_as_tool_kwargs_propagation.py b/python/packages/core/tests/core/test_as_tool_kwargs_propagation.py index b34164b86b..da8e907c40 100644 --- a/python/packages/core/tests/core/test_as_tool_kwargs_propagation.py +++ b/python/packages/core/tests/core/test_as_tool_kwargs_propagation.py @@ -19,12 +19,10 @@ async def test_as_tool_forwards_runtime_kwargs(self, client: MockChatClient) -> captured_kwargs: dict[str, Any] = {} @agent_middleware - async def capture_middleware( - context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]] - ) -> None: + async def capture_middleware(context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: # Capture kwargs passed to the sub-agent captured_kwargs.update(context.kwargs) - await call_next(context) + await call_next() # Setup mock response client.responses = [ @@ -62,11 +60,9 @@ async def test_as_tool_excludes_arg_name_from_forwarded_kwargs(self, client: Moc captured_kwargs: dict[str, Any] = {} @agent_middleware - async def capture_middleware( - context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]] - ) -> None: + async def capture_middleware(context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: captured_kwargs.update(context.kwargs) - await call_next(context) + await call_next() # Setup mock response client.responses = [ @@ -99,12 +95,10 @@ async def test_as_tool_nested_delegation_propagates_kwargs(self, client: MockCha captured_kwargs_list: list[dict[str, Any]] = [] @agent_middleware - async def capture_middleware( - context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]] - ) -> None: + async def capture_middleware(context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: # Capture kwargs at each level captured_kwargs_list.append(dict(context.kwargs)) - await call_next(context) + await call_next() # Setup mock responses to trigger nested tool invocation: B calls tool C, then completes. client.responses = [ @@ -162,11 +156,9 @@ async def test_as_tool_streaming_mode_forwards_kwargs(self, client: MockChatClie captured_kwargs: dict[str, Any] = {} @agent_middleware - async def capture_middleware( - context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]] - ) -> None: + async def capture_middleware(context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: captured_kwargs.update(context.kwargs) - await call_next(context) + await call_next() # Setup mock streaming responses from agent_framework import ChatResponseUpdate @@ -224,11 +216,9 @@ async def test_as_tool_kwargs_with_chat_options(self, client: MockChatClient) -> captured_kwargs: dict[str, Any] = {} @agent_middleware - async def capture_middleware( - context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]] - ) -> None: + async def capture_middleware(context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: captured_kwargs.update(context.kwargs) - await call_next(context) + await call_next() # Setup mock response client.responses = [ @@ -266,16 +256,14 @@ async def test_as_tool_kwargs_isolated_per_invocation(self, client: MockChatClie call_count = 0 @agent_middleware - async def capture_middleware( - context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]] - ) -> None: + async def capture_middleware(context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: nonlocal call_count call_count += 1 if call_count == 1: first_call_kwargs.update(context.kwargs) elif call_count == 2: second_call_kwargs.update(context.kwargs) - await call_next(context) + await call_next() # Setup mock responses for both calls client.responses = [ @@ -318,11 +306,9 @@ async def test_as_tool_excludes_conversation_id_from_forwarded_kwargs(self, clie captured_kwargs: dict[str, Any] = {} @agent_middleware - async def capture_middleware( - context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]] - ) -> None: + async def capture_middleware(context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: captured_kwargs.update(context.kwargs) - await call_next(context) + await call_next() # Setup mock response client.responses = [ diff --git a/python/packages/core/tests/core/test_function_invocation_logic.py b/python/packages/core/tests/core/test_function_invocation_logic.py index dcc28958f5..e135e2fee6 100644 --- a/python/packages/core/tests/core/test_function_invocation_logic.py +++ b/python/packages/core/tests/core/test_function_invocation_logic.py @@ -2298,9 +2298,7 @@ def sometimes_fails(arg1: str) -> str: class TerminateLoopMiddleware(FunctionMiddleware): """Middleware that raises MiddlewareTermination to exit the function calling loop.""" - async def process( - self, context: FunctionInvocationContext, next_handler: Callable[[FunctionInvocationContext], Awaitable[None]] - ) -> None: + async def process(self, context: FunctionInvocationContext, next_handler: Callable[[], Awaitable[None]]) -> None: # Set result to a simple value - the framework will wrap it in FunctionResultContent context.result = "terminated by middleware" raise MiddlewareTermination @@ -2355,14 +2353,12 @@ def ai_func(arg1: str) -> str: class SelectiveTerminateMiddleware(FunctionMiddleware): """Only terminates for terminating_function.""" - async def process( - self, context: FunctionInvocationContext, next_handler: Callable[[FunctionInvocationContext], Awaitable[None]] - ) -> None: + async def process(self, context: FunctionInvocationContext, next_handler: Callable[[], Awaitable[None]]) -> None: if context.function.name == "terminating_function": # Set result to a simple value - the framework will wrap it in FunctionResultContent context.result = "terminated by middleware" raise MiddlewareTermination - await next_handler(context) + await next_handler() async def test_terminate_loop_multiple_function_calls_one_terminates(chat_client_base: SupportsChatGetResponse): diff --git a/python/packages/core/tests/core/test_middleware.py b/python/packages/core/tests/core/test_middleware.py index 41c15b2c70..e5bd23751f 100644 --- a/python/packages/core/tests/core/test_middleware.py +++ b/python/packages/core/tests/core/test_middleware.py @@ -135,12 +135,12 @@ class TestAgentMiddlewarePipeline: """Test cases for AgentMiddlewarePipeline.""" class PreNextTerminateMiddleware(AgentMiddleware): - async def process(self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]) -> None: + async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: raise MiddlewareTermination class PostNextTerminateMiddleware(AgentMiddleware): async def process(self, context: AgentContext, call_next: Any) -> None: - await call_next(context) + await call_next() raise MiddlewareTermination def test_init_empty(self) -> None: @@ -157,8 +157,8 @@ def test_init_with_class_middleware(self) -> None: def test_init_with_function_middleware(self) -> None: """Test AgentMiddlewarePipeline initialization with function-based middleware.""" - async def test_middleware(context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]) -> None: - await call_next(context) + async def test_middleware(context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: + await call_next() pipeline = AgentMiddlewarePipeline(test_middleware) assert pipeline.has_middlewares @@ -185,11 +185,9 @@ class OrderTrackingMiddleware(AgentMiddleware): def __init__(self, name: str): self.name = name - async def process( - self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]] - ) -> None: + async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: execution_order.append(f"{self.name}_before") - await call_next(context) + await call_next() execution_order.append(f"{self.name}_after") middleware = OrderTrackingMiddleware("test") @@ -238,11 +236,9 @@ class StreamOrderTrackingMiddleware(AgentMiddleware): def __init__(self, name: str): self.name = name - async def process( - self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]] - ) -> None: + async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: execution_order.append(f"{self.name}_before") - await call_next(context) + await call_next() execution_order.append(f"{self.name}_after") middleware = StreamOrderTrackingMiddleware("test") @@ -367,12 +363,10 @@ async def test_execute_with_thread_in_context(self, mock_agent: SupportsAgentRun captured_thread = None class ThreadCapturingMiddleware(AgentMiddleware): - async def process( - self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]] - ) -> None: + async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: nonlocal captured_thread captured_thread = context.thread - await call_next(context) + await call_next() middleware = ThreadCapturingMiddleware() pipeline = AgentMiddlewarePipeline(middleware) @@ -394,12 +388,10 @@ async def test_execute_with_no_thread_in_context(self, mock_agent: SupportsAgent captured_thread = "not_none" # Use string to distinguish from None class ThreadCapturingMiddleware(AgentMiddleware): - async def process( - self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]] - ) -> None: + async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: nonlocal captured_thread captured_thread = context.thread - await call_next(context) + await call_next() middleware = ThreadCapturingMiddleware() pipeline = AgentMiddlewarePipeline(middleware) @@ -425,7 +417,7 @@ async def process(self, context: FunctionInvocationContext, call_next: Any) -> N class PostNextTerminateFunctionMiddleware(FunctionMiddleware): async def process(self, context: FunctionInvocationContext, call_next: Any) -> None: - await call_next(context) + await call_next() raise MiddlewareTermination async def test_execute_with_pre_next_termination(self, mock_function: FunctionTool[Any, Any]) -> None: @@ -482,10 +474,8 @@ def test_init_with_class_middleware(self) -> None: def test_init_with_function_middleware(self) -> None: """Test FunctionMiddlewarePipeline initialization with function-based middleware.""" - async def test_middleware( - context: FunctionInvocationContext, call_next: Callable[[FunctionInvocationContext], Awaitable[None]] - ) -> None: - await call_next(context) + async def test_middleware(context: FunctionInvocationContext, call_next: Callable[[], Awaitable[None]]) -> None: + await call_next() pipeline = FunctionMiddlewarePipeline(test_middleware) assert pipeline.has_middlewares @@ -515,10 +505,10 @@ def __init__(self, name: str): async def process( self, context: FunctionInvocationContext, - call_next: Callable[[FunctionInvocationContext], Awaitable[None]], + call_next: Callable[[], Awaitable[None]], ) -> None: execution_order.append(f"{self.name}_before") - await call_next(context) + await call_next() execution_order.append(f"{self.name}_after") middleware = OrderTrackingFunctionMiddleware("test") @@ -541,12 +531,12 @@ class TestChatMiddlewarePipeline: """Test cases for ChatMiddlewarePipeline.""" class PreNextTerminateChatMiddleware(ChatMiddleware): - async def process(self, context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]]) -> None: + async def process(self, context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None: raise MiddlewareTermination class PostNextTerminateChatMiddleware(ChatMiddleware): - async def process(self, context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]]) -> None: - await call_next(context) + async def process(self, context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None: + await call_next() raise MiddlewareTermination def test_init_empty(self) -> None: @@ -563,8 +553,8 @@ def test_init_with_class_middleware(self) -> None: def test_init_with_function_middleware(self) -> None: """Test ChatMiddlewarePipeline initialization with function-based middleware.""" - async def test_middleware(context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]]) -> None: - await call_next(context) + async def test_middleware(context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None: + await call_next() pipeline = ChatMiddlewarePipeline(test_middleware) assert pipeline.has_middlewares @@ -592,9 +582,9 @@ class OrderTrackingChatMiddleware(ChatMiddleware): def __init__(self, name: str): self.name = name - async def process(self, context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]]) -> None: + async def process(self, context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None: execution_order.append(f"{self.name}_before") - await call_next(context) + await call_next() execution_order.append(f"{self.name}_after") middleware = OrderTrackingChatMiddleware("test") @@ -644,9 +634,9 @@ class StreamOrderTrackingChatMiddleware(ChatMiddleware): def __init__(self, name: str): self.name = name - async def process(self, context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]]) -> None: + async def process(self, context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None: execution_order.append(f"{self.name}_before") - await call_next(context) + await call_next() execution_order.append(f"{self.name}_after") middleware = StreamOrderTrackingChatMiddleware("test") @@ -774,12 +764,10 @@ async def test_agent_middleware_execution(self, mock_agent: SupportsAgentRun) -> metadata_updates: list[str] = [] class MetadataAgentMiddleware(AgentMiddleware): - async def process( - self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]] - ) -> None: + async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: context.metadata["before"] = True metadata_updates.append("before") - await call_next(context) + await call_next() context.metadata["after"] = True metadata_updates.append("after") @@ -807,11 +795,11 @@ class MetadataFunctionMiddleware(FunctionMiddleware): async def process( self, context: FunctionInvocationContext, - call_next: Callable[[FunctionInvocationContext], Awaitable[None]], + call_next: Callable[[], Awaitable[None]], ) -> None: context.metadata["before"] = True metadata_updates.append("before") - await call_next(context) + await call_next() context.metadata["after"] = True metadata_updates.append("after") @@ -839,12 +827,10 @@ async def test_agent_function_middleware(self, mock_agent: SupportsAgentRun) -> """Test function-based agent middleware.""" execution_order: list[str] = [] - async def test_agent_middleware( - context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]] - ) -> None: + async def test_agent_middleware(context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: execution_order.append("function_before") context.metadata["function_middleware"] = True - await call_next(context) + await call_next() execution_order.append("function_after") pipeline = AgentMiddlewarePipeline(test_agent_middleware) @@ -866,11 +852,11 @@ async def test_function_function_middleware(self, mock_function: FunctionTool[An execution_order: list[str] = [] async def test_function_middleware( - context: FunctionInvocationContext, call_next: Callable[[FunctionInvocationContext], Awaitable[None]] + context: FunctionInvocationContext, call_next: Callable[[], Awaitable[None]] ) -> None: execution_order.append("function_before") context.metadata["function_middleware"] = True - await call_next(context) + await call_next() execution_order.append("function_after") pipeline = FunctionMiddlewarePipeline(test_function_middleware) @@ -896,18 +882,14 @@ async def test_mixed_agent_middleware(self, mock_agent: SupportsAgentRun) -> Non execution_order: list[str] = [] class ClassMiddleware(AgentMiddleware): - async def process( - self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]] - ) -> None: + async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: execution_order.append("class_before") - await call_next(context) + await call_next() execution_order.append("class_after") - async def function_middleware( - context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]] - ) -> None: + async def function_middleware(context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: execution_order.append("function_before") - await call_next(context) + await call_next() execution_order.append("function_after") pipeline = AgentMiddlewarePipeline(ClassMiddleware(), function_middleware) @@ -931,17 +913,17 @@ class ClassMiddleware(FunctionMiddleware): async def process( self, context: FunctionInvocationContext, - call_next: Callable[[FunctionInvocationContext], Awaitable[None]], + call_next: Callable[[], Awaitable[None]], ) -> None: execution_order.append("class_before") - await call_next(context) + await call_next() execution_order.append("class_after") async def function_middleware( - context: FunctionInvocationContext, call_next: Callable[[FunctionInvocationContext], Awaitable[None]] + context: FunctionInvocationContext, call_next: Callable[[], Awaitable[None]] ) -> None: execution_order.append("function_before") - await call_next(context) + await call_next() execution_order.append("function_after") pipeline = FunctionMiddlewarePipeline(ClassMiddleware(), function_middleware) @@ -962,16 +944,14 @@ async def test_mixed_chat_middleware(self, mock_chat_client: Any) -> None: execution_order: list[str] = [] class ClassChatMiddleware(ChatMiddleware): - async def process(self, context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]]) -> None: + async def process(self, context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None: execution_order.append("class_before") - await call_next(context) + await call_next() execution_order.append("class_after") - async def function_chat_middleware( - context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]] - ) -> None: + async def function_chat_middleware(context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None: execution_order.append("function_before") - await call_next(context) + await call_next() execution_order.append("function_after") pipeline = ChatMiddlewarePipeline(ClassChatMiddleware(), function_chat_middleware) @@ -997,27 +977,21 @@ async def test_agent_middleware_execution_order(self, mock_agent: SupportsAgentR execution_order: list[str] = [] class FirstMiddleware(AgentMiddleware): - async def process( - self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]] - ) -> None: + async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: execution_order.append("first_before") - await call_next(context) + await call_next() execution_order.append("first_after") class SecondMiddleware(AgentMiddleware): - async def process( - self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]] - ) -> None: + async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: execution_order.append("second_before") - await call_next(context) + await call_next() execution_order.append("second_after") class ThirdMiddleware(AgentMiddleware): - async def process( - self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]] - ) -> None: + async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: execution_order.append("third_before") - await call_next(context) + await call_next() execution_order.append("third_after") middleware = [FirstMiddleware(), SecondMiddleware(), ThirdMiddleware()] @@ -1051,20 +1025,20 @@ class FirstMiddleware(FunctionMiddleware): async def process( self, context: FunctionInvocationContext, - call_next: Callable[[FunctionInvocationContext], Awaitable[None]], + call_next: Callable[[], Awaitable[None]], ) -> None: execution_order.append("first_before") - await call_next(context) + await call_next() execution_order.append("first_after") class SecondMiddleware(FunctionMiddleware): async def process( self, context: FunctionInvocationContext, - call_next: Callable[[FunctionInvocationContext], Awaitable[None]], + call_next: Callable[[], Awaitable[None]], ) -> None: execution_order.append("second_before") - await call_next(context) + await call_next() execution_order.append("second_after") middleware = [FirstMiddleware(), SecondMiddleware()] @@ -1087,21 +1061,21 @@ async def test_chat_middleware_execution_order(self, mock_chat_client: Any) -> N execution_order: list[str] = [] class FirstChatMiddleware(ChatMiddleware): - async def process(self, context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]]) -> None: + async def process(self, context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None: execution_order.append("first_before") - await call_next(context) + await call_next() execution_order.append("first_after") class SecondChatMiddleware(ChatMiddleware): - async def process(self, context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]]) -> None: + async def process(self, context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None: execution_order.append("second_before") - await call_next(context) + await call_next() execution_order.append("second_after") class ThirdChatMiddleware(ChatMiddleware): - async def process(self, context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]]) -> None: + async def process(self, context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None: execution_order.append("third_before") - await call_next(context) + await call_next() execution_order.append("third_after") middleware = [FirstChatMiddleware(), SecondChatMiddleware(), ThirdChatMiddleware()] @@ -1136,9 +1110,7 @@ async def test_agent_context_validation(self, mock_agent: SupportsAgentRun) -> N """Test that agent context contains expected data.""" class ContextValidationMiddleware(AgentMiddleware): - async def process( - self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]] - ) -> None: + async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: # Verify context has all expected attributes assert hasattr(context, "agent") assert hasattr(context, "messages") @@ -1156,7 +1128,7 @@ async def process( # Add custom metadata context.metadata["validated"] = True - await call_next(context) + await call_next() middleware = ContextValidationMiddleware() pipeline = AgentMiddlewarePipeline(middleware) @@ -1178,7 +1150,7 @@ class ContextValidationMiddleware(FunctionMiddleware): async def process( self, context: FunctionInvocationContext, - call_next: Callable[[FunctionInvocationContext], Awaitable[None]], + call_next: Callable[[], Awaitable[None]], ) -> None: # Verify context has all expected attributes assert hasattr(context, "function") @@ -1194,7 +1166,7 @@ async def process( # Add custom metadata context.metadata["validated"] = True - await call_next(context) + await call_next() middleware = ContextValidationMiddleware() pipeline = FunctionMiddlewarePipeline(middleware) @@ -1213,7 +1185,7 @@ async def test_chat_context_validation(self, mock_chat_client: Any) -> None: """Test that chat context contains expected data.""" class ChatContextValidationMiddleware(ChatMiddleware): - async def process(self, context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]]) -> None: + async def process(self, context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None: # Verify context has all expected attributes assert hasattr(context, "client") assert hasattr(context, "messages") @@ -1235,7 +1207,7 @@ async def process(self, context: ChatContext, call_next: Callable[[ChatContext], # Add custom metadata context.metadata["validated"] = True - await call_next(context) + await call_next() middleware = ChatContextValidationMiddleware() pipeline = ChatMiddlewarePipeline(middleware) @@ -1260,11 +1232,9 @@ async def test_streaming_flag_validation(self, mock_agent: SupportsAgentRun) -> streaming_flags: list[bool] = [] class StreamingFlagMiddleware(AgentMiddleware): - async def process( - self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]] - ) -> None: + async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: streaming_flags.append(context.stream) - await call_next(context) + await call_next() middleware = StreamingFlagMiddleware() pipeline = AgentMiddlewarePipeline(middleware) @@ -1302,11 +1272,9 @@ async def test_streaming_middleware_behavior(self, mock_agent: SupportsAgentRun) chunks_processed: list[str] = [] class StreamProcessingMiddleware(AgentMiddleware): - async def process( - self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]] - ) -> None: + async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: chunks_processed.append("before_stream") - await call_next(context) + await call_next() chunks_processed.append("after_stream") middleware = StreamProcessingMiddleware() @@ -1345,9 +1313,9 @@ async def test_chat_streaming_flag_validation(self, mock_chat_client: Any) -> No streaming_flags: list[bool] = [] class ChatStreamingFlagMiddleware(ChatMiddleware): - async def process(self, context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]]) -> None: + async def process(self, context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None: streaming_flags.append(context.stream) - await call_next(context) + await call_next() middleware = ChatStreamingFlagMiddleware() pipeline = ChatMiddlewarePipeline(middleware) @@ -1386,9 +1354,9 @@ async def test_chat_streaming_middleware_behavior(self, mock_chat_client: Any) - chunks_processed: list[str] = [] class ChatStreamProcessingMiddleware(ChatMiddleware): - async def process(self, context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]]) -> None: + async def process(self, context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None: chunks_processed.append("before_stream") - await call_next(context) + await call_next() chunks_processed.append("after_stream") middleware = ChatStreamProcessingMiddleware() @@ -1436,24 +1404,22 @@ class FunctionTestArgs(BaseModel): class TestAgentMiddleware(AgentMiddleware): """Test implementation of AgentMiddleware.""" - async def process(self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]) -> None: - await call_next(context) + async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: + await call_next() class TestFunctionMiddleware(FunctionMiddleware): """Test implementation of FunctionMiddleware.""" - async def process( - self, context: FunctionInvocationContext, call_next: Callable[[FunctionInvocationContext], Awaitable[None]] - ) -> None: - await call_next(context) + async def process(self, context: FunctionInvocationContext, call_next: Callable[[], Awaitable[None]]) -> None: + await call_next() class TestChatMiddleware(ChatMiddleware): """Test implementation of ChatMiddleware.""" - async def process(self, context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]]) -> None: - await call_next(context) + async def process(self, context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None: + await call_next() class MockFunctionArgs(BaseModel): @@ -1469,9 +1435,7 @@ async def test_agent_middleware_no_next_no_execution(self, mock_agent: SupportsA """Test that when agent middleware doesn't call next(), no execution happens.""" class NoNextMiddleware(AgentMiddleware): - async def process( - self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]] - ) -> None: + async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: # Don't call next() - this should prevent any execution pass @@ -1498,9 +1462,7 @@ async def test_agent_middleware_no_next_no_streaming_execution(self, mock_agent: """Test that when agent middleware doesn't call next(), no streaming execution happens.""" class NoNextStreamingMiddleware(AgentMiddleware): - async def process( - self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]] - ) -> None: + async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: # Don't call next() - this should prevent any execution pass @@ -1537,7 +1499,7 @@ class NoNextFunctionMiddleware(FunctionMiddleware): async def process( self, context: FunctionInvocationContext, - call_next: Callable[[FunctionInvocationContext], Awaitable[None]], + call_next: Callable[[], Awaitable[None]], ) -> None: # Don't call next() - this should prevent any execution pass @@ -1566,18 +1528,14 @@ async def test_multiple_middlewares_early_stop(self, mock_agent: SupportsAgentRu execution_order: list[str] = [] class FirstMiddleware(AgentMiddleware): - async def process( - self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]] - ) -> None: + async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: execution_order.append("first") # Don't call next() - this should stop the pipeline class SecondMiddleware(AgentMiddleware): - async def process( - self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]] - ) -> None: + async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: execution_order.append("second") - await call_next(context) + await call_next() pipeline = AgentMiddlewarePipeline(FirstMiddleware(), SecondMiddleware()) messages = [Message(role="user", text="test")] @@ -1601,7 +1559,7 @@ async def test_chat_middleware_no_next_no_execution(self, mock_chat_client: Any) """Test that when chat middleware doesn't call next(), no execution happens.""" class NoNextChatMiddleware(ChatMiddleware): - async def process(self, context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]]) -> None: + async def process(self, context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None: # Don't call next() - this should prevent any execution pass @@ -1629,7 +1587,7 @@ async def test_chat_middleware_no_next_no_streaming_execution(self, mock_chat_cl """Test that when chat middleware doesn't call next(), no streaming execution happens.""" class NoNextStreamingChatMiddleware(ChatMiddleware): - async def process(self, context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]]) -> None: + async def process(self, context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None: # Don't call next() - this should prevent any execution pass @@ -1670,14 +1628,14 @@ async def test_multiple_chat_middlewares_early_stop(self, mock_chat_client: Any) execution_order: list[str] = [] class FirstChatMiddleware(ChatMiddleware): - async def process(self, context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]]) -> None: + async def process(self, context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None: execution_order.append("first") # Don't call next() - this should stop the pipeline class SecondChatMiddleware(ChatMiddleware): - async def process(self, context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]]) -> None: + async def process(self, context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None: execution_order.append("second") - await call_next(context) + await call_next() pipeline = ChatMiddlewarePipeline(FirstChatMiddleware(), SecondChatMiddleware()) messages = [Message(role="user", text="test")] diff --git a/python/packages/core/tests/core/test_middleware_context_result.py b/python/packages/core/tests/core/test_middleware_context_result.py index d17e99a85e..c5744fdca5 100644 --- a/python/packages/core/tests/core/test_middleware_context_result.py +++ b/python/packages/core/tests/core/test_middleware_context_result.py @@ -43,11 +43,9 @@ async def test_agent_middleware_response_override_non_streaming(self, mock_agent override_response = AgentResponse(messages=[Message(role="assistant", text="overridden response")]) class ResponseOverrideMiddleware(AgentMiddleware): - async def process( - self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]] - ) -> None: + async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: # Execute the pipeline first, then override the response - await call_next(context) + await call_next() context.result = override_response middleware = ResponseOverrideMiddleware() @@ -79,11 +77,9 @@ async def override_stream() -> AsyncIterable[AgentResponseUpdate]: yield AgentResponseUpdate(contents=[Content.from_text(text=" stream")]) class StreamResponseOverrideMiddleware(AgentMiddleware): - async def process( - self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]] - ) -> None: + async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: # Execute the pipeline first, then override the response stream - await call_next(context) + await call_next() context.result = ResponseStream(override_stream()) middleware = StreamResponseOverrideMiddleware() @@ -115,10 +111,10 @@ class ResultOverrideMiddleware(FunctionMiddleware): async def process( self, context: FunctionInvocationContext, - call_next: Callable[[FunctionInvocationContext], Awaitable[None]], + call_next: Callable[[], Awaitable[None]], ) -> None: # Execute the pipeline first, then override the result - await call_next(context) + await call_next() context.result = override_result middleware = ResultOverrideMiddleware() @@ -145,11 +141,9 @@ async def test_chat_agent_middleware_response_override(self) -> None: mock_chat_client = MockChatClient() class ChatAgentResponseOverrideMiddleware(AgentMiddleware): - async def process( - self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]] - ) -> None: + async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: # Always call next() first to allow execution - await call_next(context) + await call_next() # Then conditionally override based on content if any("special" in msg.text for msg in context.messages if msg.text): context.result = AgentResponse( @@ -184,15 +178,13 @@ async def custom_stream() -> AsyncIterable[AgentResponseUpdate]: yield AgentResponseUpdate(contents=[Content.from_text(text=" response!")]) class ChatAgentStreamOverrideMiddleware(AgentMiddleware): - async def process( - self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]] - ) -> None: + async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: # Check if we want to override BEFORE calling next to avoid creating unused streams if any("custom stream" in msg.text for msg in context.messages if msg.text): context.result = ResponseStream(custom_stream()) return # Don't call next() - we're overriding the entire result # Normal case - let the agent handle it - await call_next(context) + await call_next() # Create Agent with override middleware middleware = ChatAgentStreamOverrideMiddleware() @@ -223,12 +215,10 @@ async def test_agent_middleware_conditional_no_next(self, mock_agent: SupportsAg """Test that when agent middleware conditionally doesn't call next(), no execution happens.""" class ConditionalNoNextMiddleware(AgentMiddleware): - async def process( - self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]] - ) -> None: + async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: # Only call next() if message contains "execute" if any("execute" in msg.text for msg in context.messages if msg.text): - await call_next(context) + await call_next() # Otherwise, don't call next() - no execution should happen middleware = ConditionalNoNextMiddleware() @@ -269,13 +259,13 @@ class ConditionalNoNextFunctionMiddleware(FunctionMiddleware): async def process( self, context: FunctionInvocationContext, - call_next: Callable[[FunctionInvocationContext], Awaitable[None]], + call_next: Callable[[], Awaitable[None]], ) -> None: # Only call next() if argument name contains "execute" args = context.arguments assert isinstance(args, FunctionTestArgs) if "execute" in args.name: - await call_next(context) + await call_next() # Otherwise, don't call next() - no execution should happen middleware = ConditionalNoNextFunctionMiddleware() @@ -318,14 +308,12 @@ async def test_agent_middleware_response_observability(self, mock_agent: Support observed_responses: list[AgentResponse] = [] class ObservabilityMiddleware(AgentMiddleware): - async def process( - self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]] - ) -> None: + async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: # Context should be empty before next() assert context.result is None # Call next to execute - await call_next(context) + await call_next() # Context should now contain the response for observability assert context.result is not None @@ -355,13 +343,13 @@ class ObservabilityMiddleware(FunctionMiddleware): async def process( self, context: FunctionInvocationContext, - call_next: Callable[[FunctionInvocationContext], Awaitable[None]], + call_next: Callable[[], Awaitable[None]], ) -> None: # Context should be empty before next() assert context.result is None # Call next to execute - await call_next(context) + await call_next() # Context should now contain the result for observability assert context.result is not None @@ -386,11 +374,9 @@ async def test_agent_middleware_post_execution_override(self, mock_agent: Suppor """Test that middleware can override response after observing execution.""" class PostExecutionOverrideMiddleware(AgentMiddleware): - async def process( - self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]] - ) -> None: + async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: # Call next to execute first - await call_next(context) + await call_next() # Now observe and conditionally override assert context.result is not None @@ -423,10 +409,10 @@ class PostExecutionOverrideMiddleware(FunctionMiddleware): async def process( self, context: FunctionInvocationContext, - call_next: Callable[[FunctionInvocationContext], Awaitable[None]], + call_next: Callable[[], Awaitable[None]], ) -> None: # Call next to execute first - await call_next(context) + await call_next() # Now observe and conditionally override assert context.result is not None diff --git a/python/packages/core/tests/core/test_middleware_with_agent.py b/python/packages/core/tests/core/test_middleware_with_agent.py index 17f0faf4f0..597ca12dbd 100644 --- a/python/packages/core/tests/core/test_middleware_with_agent.py +++ b/python/packages/core/tests/core/test_middleware_with_agent.py @@ -44,11 +44,9 @@ class TrackingAgentMiddleware(AgentMiddleware): def __init__(self, name: str): self.name = name - async def process( - self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]] - ) -> None: + async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: execution_order.append(f"{self.name}_before") - await call_next(context) + await call_next() execution_order.append(f"{self.name}_after") # Create Agent with middleware @@ -76,9 +74,9 @@ class TrackingFunctionMiddleware(FunctionMiddleware): async def process( self, context: FunctionInvocationContext, - call_next: Callable[[FunctionInvocationContext], Awaitable[None]], + call_next: Callable[[], Awaitable[None]], ) -> None: - await call_next(context) + await call_next() middleware = TrackingFunctionMiddleware() Agent(client=client, middleware=[middleware]) @@ -96,10 +94,10 @@ def __init__(self, name: str): async def process( self, context: FunctionInvocationContext, - call_next: Callable[[FunctionInvocationContext], Awaitable[None]], + call_next: Callable[[], Awaitable[None]], ) -> None: execution_order.append(f"{self.name}_before") - await call_next(context) + await call_next() execution_order.append(f"{self.name}_after") middleware = TrackingFunctionMiddleware("function_middleware") @@ -122,13 +120,11 @@ async def test_agent_middleware_with_pre_termination(self, client: "MockChatClie execution_order: list[str] = [] class PreTerminationMiddleware(AgentMiddleware): - async def process( - self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]] - ) -> None: + async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: execution_order.append("middleware_before") raise MiddlewareTermination # Code after raise is unreachable - await call_next(context) + await call_next() execution_order.append("middleware_after") # Create Agent with terminating middleware @@ -153,11 +149,9 @@ async def test_agent_middleware_with_post_termination(self, client: "MockChatCli execution_order: list[str] = [] class PostTerminationMiddleware(AgentMiddleware): - async def process( - self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]] - ) -> None: + async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: execution_order.append("middleware_before") - await call_next(context) + await call_next() execution_order.append("middleware_after") context.terminate = True @@ -193,12 +187,12 @@ class PreTerminationFunctionMiddleware(FunctionMiddleware): async def process( self, context: FunctionInvocationContext, - call_next: Callable[[FunctionInvocationContext], Awaitable[None]], + call_next: Callable[[], Awaitable[None]], ) -> None: execution_order.append("middleware_before") context.terminate = True # We call next() but since terminate=True, subsequent middleware and handler should not execute - await call_next(context) + await call_next() execution_order.append("middleware_after") Agent(client=client, middleware=[PreTerminationFunctionMiddleware()], tools=[]) @@ -211,10 +205,10 @@ class PostTerminationFunctionMiddleware(FunctionMiddleware): async def process( self, context: FunctionInvocationContext, - call_next: Callable[[FunctionInvocationContext], Awaitable[None]], + call_next: Callable[[], Awaitable[None]], ) -> None: execution_order.append("middleware_before") - await call_next(context) + await call_next() execution_order.append("middleware_after") context.terminate = True @@ -224,11 +218,9 @@ async def test_function_based_agent_middleware_with_chat_agent(self, client: "Mo """Test function-based agent middleware with Agent.""" execution_order: list[str] = [] - async def tracking_agent_middleware( - context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]] - ) -> None: + async def tracking_agent_middleware(context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: execution_order.append("agent_function_before") - await call_next(context) + await call_next() execution_order.append("agent_function_after") # Create Agent with function middleware @@ -252,9 +244,9 @@ async def test_function_based_function_middleware_with_chat_agent(self, client: """Test function-based function middleware with Agent.""" async def tracking_function_middleware( - context: FunctionInvocationContext, call_next: Callable[[FunctionInvocationContext], Awaitable[None]] + context: FunctionInvocationContext, call_next: Callable[[], Awaitable[None]] ) -> None: - await call_next(context) + await call_next() Agent(client=client, middleware=[tracking_function_middleware]) @@ -265,10 +257,10 @@ async def test_function_based_function_middleware_with_supported_client( execution_order: list[str] = [] async def tracking_function_middleware( - context: FunctionInvocationContext, call_next: Callable[[FunctionInvocationContext], Awaitable[None]] + context: FunctionInvocationContext, call_next: Callable[[], Awaitable[None]] ) -> None: execution_order.append("function_function_before") - await call_next(context) + await call_next() execution_order.append("function_function_after") agent = Agent(client=chat_client_base, middleware=[tracking_function_middleware]) @@ -290,12 +282,10 @@ async def test_agent_middleware_with_streaming(self, client: "MockChatClient") - streaming_flags: list[bool] = [] class StreamingTrackingMiddleware(AgentMiddleware): - async def process( - self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]] - ) -> None: + async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: execution_order.append("middleware_before") streaming_flags.append(context.stream) - await call_next(context) + await call_next() execution_order.append("middleware_after") # Create Agent with middleware @@ -334,11 +324,9 @@ async def test_non_streaming_vs_streaming_flag_validation(self, client: "MockCha streaming_flags: list[bool] = [] class FlagTrackingMiddleware(AgentMiddleware): - async def process( - self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]] - ) -> None: + async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: streaming_flags.append(context.stream) - await call_next(context) + await call_next() # Create Agent with middleware middleware = FlagTrackingMiddleware() @@ -368,11 +356,9 @@ class OrderedMiddleware(AgentMiddleware): def __init__(self, name: str): self.name = name - async def process( - self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]] - ) -> None: + async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: execution_order.append(f"{self.name}_before") - await call_next(context) + await call_next() execution_order.append(f"{self.name}_after") # Create multiple middleware @@ -400,35 +386,31 @@ async def test_mixed_middleware_types_with_chat_agent(self, chat_client_base: "M execution_order: list[str] = [] class ClassAgentMiddleware(AgentMiddleware): - async def process( - self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]] - ) -> None: + async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: execution_order.append("class_agent_before") - await call_next(context) + await call_next() execution_order.append("class_agent_after") - async def function_agent_middleware( - context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]] - ) -> None: + async def function_agent_middleware(context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: execution_order.append("function_agent_before") - await call_next(context) + await call_next() execution_order.append("function_agent_after") class ClassFunctionMiddleware(FunctionMiddleware): async def process( self, context: FunctionInvocationContext, - call_next: Callable[[FunctionInvocationContext], Awaitable[None]], + call_next: Callable[[], Awaitable[None]], ) -> None: execution_order.append("class_function_before") - await call_next(context) + await call_next() execution_order.append("class_function_after") async def function_function_middleware( - context: FunctionInvocationContext, call_next: Callable[[FunctionInvocationContext], Awaitable[None]] + context: FunctionInvocationContext, call_next: Callable[[], Awaitable[None]] ) -> None: execution_order.append("function_function_before") - await call_next(context) + await call_next() execution_order.append("function_function_after") agent = Agent( @@ -447,25 +429,21 @@ async def test_mixed_middleware_types_with_supported_client(self, chat_client_ba execution_order: list[str] = [] class ClassAgentMiddleware(AgentMiddleware): - async def process( - self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]] - ) -> None: + async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: execution_order.append("class_agent_before") - await call_next(context) + await call_next() execution_order.append("class_agent_after") - async def function_agent_middleware( - context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]] - ) -> None: + async def function_agent_middleware(context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: execution_order.append("function_agent_before") - await call_next(context) + await call_next() execution_order.append("function_agent_after") async def function_function_middleware( - context: FunctionInvocationContext, call_next: Callable[[FunctionInvocationContext], Awaitable[None]] + context: FunctionInvocationContext, call_next: Callable[[], Awaitable[None]] ) -> None: execution_order.append("function_function_before") - await call_next(context) + await call_next() execution_order.append("function_function_after") agent = Agent( @@ -521,10 +499,10 @@ def __init__(self, name: str): async def process( self, context: FunctionInvocationContext, - call_next: Callable[[FunctionInvocationContext], Awaitable[None]], + call_next: Callable[[], Awaitable[None]], ) -> None: execution_order.append(f"{self.name}_before") - await call_next(context) + await call_next() execution_order.append(f"{self.name}_after") # Set up mock to return a function call first, then a regular response @@ -583,10 +561,10 @@ async def test_function_based_function_middleware_with_tool_calls( execution_order: list[str] = [] async def tracking_function_middleware( - context: FunctionInvocationContext, call_next: Callable[[FunctionInvocationContext], Awaitable[None]] + context: FunctionInvocationContext, call_next: Callable[[], Awaitable[None]] ) -> None: execution_order.append("function_middleware_before") - await call_next(context) + await call_next() execution_order.append("function_middleware_after") # Set up mock to return a function call first, then a regular response @@ -647,20 +625,20 @@ class TrackingAgentMiddleware(AgentMiddleware): async def process( self, context: AgentContext, - call_next: Callable[[AgentContext], Awaitable[None]], + call_next: Callable[[], Awaitable[None]], ) -> None: execution_order.append("agent_middleware_before") - await call_next(context) + await call_next() execution_order.append("agent_middleware_after") class TrackingFunctionMiddleware(FunctionMiddleware): async def process( self, context: FunctionInvocationContext, - call_next: Callable[[FunctionInvocationContext], Awaitable[None]], + call_next: Callable[[], Awaitable[None]], ) -> None: execution_order.append("function_middleware_before") - await call_next(context) + await call_next() execution_order.append("function_middleware_after") # Set up mock to return a function call first, then a regular response @@ -728,7 +706,7 @@ async def test_function_middleware_can_access_and_override_custom_kwargs( @function_middleware async def kwargs_middleware( - context: FunctionInvocationContext, call_next: Callable[[FunctionInvocationContext], Awaitable[None]] + context: FunctionInvocationContext, call_next: Callable[[], Awaitable[None]] ) -> None: nonlocal middleware_called middleware_called = True @@ -748,7 +726,7 @@ async def kwargs_middleware( modified_kwargs["new_param"] = context.kwargs.get("new_param") modified_kwargs["custom_param"] = context.kwargs.get("custom_param") - await call_next(context) + await call_next() chat_client_base.run_responses = [ ChatResponse( @@ -801,9 +779,9 @@ def __init__(self, name: str, execution_log: list[str]): self.name = name self.execution_log = execution_log - async def process(self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]) -> None: + async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: self.execution_log.append(f"{self.name}_start") - await call_next(context) + await call_next() self.execution_log.append(f"{self.name}_end") async def test_middleware_dynamic_rebuild_non_streaming(self, client: "MockChatClient") -> None: @@ -924,9 +902,9 @@ def __init__(self, name: str, execution_log: list[str]): self.name = name self.execution_log = execution_log - async def process(self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]) -> None: + async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: self.execution_log.append(f"{self.name}_start") - await call_next(context) + await call_next() self.execution_log.append(f"{self.name}_end") async def test_run_level_middleware_isolation(self, client: "MockChatClient") -> None: @@ -976,29 +954,25 @@ class MetadataAgentMiddleware(AgentMiddleware): def __init__(self, name: str): self.name = name - async def process( - self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]] - ) -> None: + async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: execution_log.append(f"{self.name}_start") # Set metadata to pass information to run middleware context.metadata[f"{self.name}_key"] = f"{self.name}_value" - await call_next(context) + await call_next() execution_log.append(f"{self.name}_end") class MetadataRunMiddleware(AgentMiddleware): def __init__(self, name: str): self.name = name - async def process( - self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]] - ) -> None: + async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: execution_log.append(f"{self.name}_start") # Read metadata set by agent middleware for key, value in context.metadata.items(): metadata_log.append(f"{self.name}_reads_{key}:{value}") # Set run-level metadata context.metadata[f"{self.name}_key"] = f"{self.name}_value" - await call_next(context) + await call_next() execution_log.append(f"{self.name}_end") # Create agent with agent-level middleware @@ -1049,12 +1023,10 @@ class StreamingTrackingMiddleware(AgentMiddleware): def __init__(self, name: str): self.name = name - async def process( - self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]] - ) -> None: + async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: execution_log.append(f"{self.name}_start") streaming_flags.append(context.stream) - await call_next(context) + await call_next() execution_log.append(f"{self.name}_end") # Create agent without agent-level middleware @@ -1093,48 +1065,44 @@ async def test_agent_and_run_level_both_agent_and_function_middleware( # Agent-level middleware class AgentLevelAgentMiddleware(AgentMiddleware): - async def process( - self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]] - ) -> None: + async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: execution_log.append("agent_level_agent_start") context.metadata["agent_level_agent"] = "processed" - await call_next(context) + await call_next() execution_log.append("agent_level_agent_end") class AgentLevelFunctionMiddleware(FunctionMiddleware): async def process( self, context: FunctionInvocationContext, - call_next: Callable[[FunctionInvocationContext], Awaitable[None]], + call_next: Callable[[], Awaitable[None]], ) -> None: execution_log.append("agent_level_function_start") context.metadata["agent_level_function"] = "processed" - await call_next(context) + await call_next() execution_log.append("agent_level_function_end") # Run-level middleware class RunLevelAgentMiddleware(AgentMiddleware): - async def process( - self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]] - ) -> None: + async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: execution_log.append("run_level_agent_start") # Verify agent-level middleware metadata is available assert "agent_level_agent" in context.metadata context.metadata["run_level_agent"] = "processed" - await call_next(context) + await call_next() execution_log.append("run_level_agent_end") class RunLevelFunctionMiddleware(FunctionMiddleware): async def process( self, context: FunctionInvocationContext, - call_next: Callable[[FunctionInvocationContext], Awaitable[None]], + call_next: Callable[[], Awaitable[None]], ) -> None: execution_log.append("run_level_function_start") # Verify agent-level function middleware metadata is available assert "agent_level_function" in context.metadata context.metadata["run_level_function"] = "processed" - await call_next(context) + await call_next() execution_log.append("run_level_function_end") # Create tool function for testing function middleware @@ -1217,18 +1185,16 @@ async def test_decorator_and_type_match(self, chat_client_base: "MockBaseChatCli execution_order: list[str] = [] @agent_middleware - async def matching_agent_middleware( - context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]] - ) -> None: + async def matching_agent_middleware(context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: execution_order.append("decorator_type_match_agent") - await call_next(context) + await call_next() @function_middleware async def matching_function_middleware( - context: FunctionInvocationContext, call_next: Callable[[FunctionInvocationContext], Awaitable[None]] + context: FunctionInvocationContext, call_next: Callable[[], Awaitable[None]] ) -> None: execution_order.append("decorator_type_match_function") - await call_next(context) + await call_next() # Create tool function for testing function middleware def custom_tool(message: str) -> str: @@ -1282,7 +1248,7 @@ async def mismatched_middleware( context: FunctionInvocationContext, # Wrong type for @agent_middleware call_next: Any, ) -> None: - await call_next(context) + await call_next() agent = Agent(client=client, middleware=[mismatched_middleware]) await agent.run([Message(role="user", text="test")]) @@ -1294,12 +1260,12 @@ async def test_only_decorator_specified(self, chat_client_base: "MockBaseChatCli @agent_middleware async def decorator_only_agent(context: Any, call_next: Any) -> None: # No type annotation execution_order.append("decorator_only_agent") - await call_next(context) + await call_next() @function_middleware async def decorator_only_function(context: Any, call_next: Any) -> None: # No type annotation execution_order.append("decorator_only_function") - await call_next(context) + await call_next() # Create tool function for testing function middleware def custom_tool(message: str) -> str: @@ -1346,16 +1312,16 @@ async def test_only_type_specified(self, chat_client_base: "MockBaseChatClient") execution_order: list[str] = [] # No decorator - async def type_only_agent(context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]) -> None: + async def type_only_agent(context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: execution_order.append("type_only_agent") - await call_next(context) + await call_next() # No decorator async def type_only_function( - context: FunctionInvocationContext, call_next: Callable[[FunctionInvocationContext], Awaitable[None]] + context: FunctionInvocationContext, call_next: Callable[[], Awaitable[None]] ) -> None: execution_order.append("type_only_function") - await call_next(context) + await call_next() # Create tool function for testing function middleware def custom_tool(message: str) -> str: @@ -1399,7 +1365,7 @@ async def test_neither_decorator_nor_type(self, client: Any) -> None: """Neither decorator nor parameter type specified - should throw exception.""" async def no_info_middleware(context: Any, call_next: Any) -> None: # No decorator, no type - await call_next(context) + await call_next() # Should raise MiddlewareException with pytest.raises(MiddlewareException, match="Cannot determine middleware type"): @@ -1447,9 +1413,7 @@ async def test_agent_context_thread_behavior_across_multiple_runs(self, client: thread_states: list[dict[str, Any]] = [] class ThreadTrackingMiddleware(AgentMiddleware): - async def process( - self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]] - ) -> None: + async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: # Capture state before next() call thread_messages = [] if context.thread and context.thread.message_store: @@ -1464,7 +1428,7 @@ async def process( } thread_states.append(before_state) - await call_next(context) + await call_next() # Capture state after next() call thread_messages_after = [] @@ -1560,9 +1524,9 @@ async def test_class_based_chat_middleware_with_chat_agent(self) -> None: execution_order: list[str] = [] class TrackingChatMiddleware(ChatMiddleware): - async def process(self, context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]]) -> None: + async def process(self, context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None: execution_order.append("chat_middleware_before") - await call_next(context) + await call_next() execution_order.append("chat_middleware_after") # Create Agent with chat middleware @@ -1588,11 +1552,9 @@ async def test_function_based_chat_middleware_with_chat_agent(self) -> None: """Test function-based chat middleware with Agent.""" execution_order: list[str] = [] - async def tracking_chat_middleware( - context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]] - ) -> None: + async def tracking_chat_middleware(context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None: execution_order.append("chat_middleware_before") - await call_next(context) + await call_next() execution_order.append("chat_middleware_after") # Create Agent with function-based chat middleware @@ -1617,9 +1579,7 @@ async def test_chat_middleware_can_modify_messages(self) -> None: """Test that chat middleware can modify messages before sending to model.""" @chat_middleware - async def message_modifier_middleware( - context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]] - ) -> None: + async def message_modifier_middleware(context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None: # Modify the first message by adding a prefix if context.messages: for idx, msg in enumerate(context.messages): @@ -1628,7 +1588,7 @@ async def message_modifier_middleware( original_text = msg.text or "" context.messages[idx] = Message(role=msg.role, text=f"MODIFIED: {original_text}") break - await call_next(context) + await call_next() # Create Agent with message-modifying middleware client = MockBaseChatClient() @@ -1646,9 +1606,7 @@ async def test_chat_middleware_can_override_response(self) -> None: """Test that chat middleware can override the response.""" @chat_middleware - async def response_override_middleware( - context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]] - ) -> None: + async def response_override_middleware(context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None: # Override the response without calling next() context.result = ChatResponse( messages=[Message(role="assistant", text="MiddlewareTypes overridden response")], @@ -1675,15 +1633,15 @@ async def test_multiple_chat_middleware_execution_order(self) -> None: execution_order: list[str] = [] @chat_middleware - async def first_middleware(context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]]) -> None: + async def first_middleware(context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None: execution_order.append("first_before") - await call_next(context) + await call_next() execution_order.append("first_after") @chat_middleware - async def second_middleware(context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]]) -> None: + async def second_middleware(context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None: execution_order.append("second_before") - await call_next(context) + await call_next() execution_order.append("second_after") # Create Agent with multiple chat middleware @@ -1709,10 +1667,10 @@ async def test_chat_middleware_with_streaming(self) -> None: streaming_flags: list[bool] = [] class StreamingTrackingChatMiddleware(ChatMiddleware): - async def process(self, context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]]) -> None: + async def process(self, context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None: execution_order.append("streaming_chat_before") streaming_flags.append(context.stream) - await call_next(context) + await call_next() execution_order.append("streaming_chat_after") # Create Agent with chat middleware @@ -1749,13 +1707,13 @@ async def test_chat_middleware_termination_before_execution(self) -> None: execution_order: list[str] = [] class PreTerminationChatMiddleware(ChatMiddleware): - async def process(self, context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]]) -> None: + async def process(self, context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None: execution_order.append("middleware_before") # Set a custom response since we're terminating context.result = ChatResponse(messages=[Message(role="assistant", text="Terminated by middleware")]) raise MiddlewareTermination # We call next() but since terminate=True, execution should stop - await call_next(context) + await call_next() execution_order.append("middleware_after") # Create Agent with terminating middleware @@ -1777,9 +1735,9 @@ async def test_chat_middleware_termination_after_execution(self) -> None: execution_order: list[str] = [] class PostTerminationChatMiddleware(ChatMiddleware): - async def process(self, context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]]) -> None: + async def process(self, context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None: execution_order.append("middleware_before") - await call_next(context) + await call_next() execution_order.append("middleware_after") context.terminate = True @@ -1804,21 +1762,21 @@ async def test_combined_middleware(self) -> None: """Test Agent with combined middleware types.""" execution_order: list[str] = [] - async def agent_middleware(context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]) -> None: + async def agent_middleware(context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: execution_order.append("agent_middleware_before") - await call_next(context) + await call_next() execution_order.append("agent_middleware_after") - async def chat_middleware(context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]]) -> None: + async def chat_middleware(context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None: execution_order.append("chat_middleware_before") - await call_next(context) + await call_next() execution_order.append("chat_middleware_after") async def function_middleware( - context: FunctionInvocationContext, call_next: Callable[[FunctionInvocationContext], Awaitable[None]] + context: FunctionInvocationContext, call_next: Callable[[], Awaitable[None]] ) -> None: execution_order.append("function_middleware_before") - await call_next(context) + await call_next() execution_order.append("function_middleware_after") # Create Agent with function middleware and tools @@ -1842,9 +1800,7 @@ async def test_agent_middleware_can_access_and_override_custom_kwargs(self) -> N modified_kwargs: dict[str, Any] = {} @agent_middleware - async def kwargs_middleware( - context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]] - ) -> None: + async def kwargs_middleware(context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: # Capture the original kwargs captured_kwargs.update(context.kwargs) @@ -1856,7 +1812,7 @@ async def kwargs_middleware( # Store modified kwargs for verification modified_kwargs.update(context.kwargs) - await call_next(context) + await call_next() # Create Agent with agent middleware client = MockBaseChatClient() @@ -1895,10 +1851,10 @@ async def kwargs_middleware( # class TrackingMiddleware(AgentMiddleware): # async def process( -# self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]] +# self, context: AgentContext, call_next: Callable[[], Awaitable[None]] # ) -> None: # execution_order.append("before") -# await call_next(context) +# await call_next() # execution_order.append("after") # @use_agent_middleware diff --git a/python/packages/core/tests/core/test_middleware_with_chat.py b/python/packages/core/tests/core/test_middleware_with_chat.py index 3c9d0246c7..62a168ccb0 100644 --- a/python/packages/core/tests/core/test_middleware_with_chat.py +++ b/python/packages/core/tests/core/test_middleware_with_chat.py @@ -32,10 +32,10 @@ class LoggingChatMiddleware(ChatMiddleware): async def process( self, context: ChatContext, - call_next: Callable[[ChatContext], Awaitable[None]], + call_next: Callable[[], Awaitable[None]], ) -> None: execution_order.append("chat_middleware_before") - await call_next(context) + await call_next() execution_order.append("chat_middleware_after") # Add middleware to chat client @@ -58,11 +58,9 @@ async def test_function_based_chat_middleware(self, chat_client_base: "MockBaseC execution_order: list[str] = [] @chat_middleware - async def logging_chat_middleware( - context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]] - ) -> None: + async def logging_chat_middleware(context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None: execution_order.append("function_middleware_before") - await call_next(context) + await call_next() execution_order.append("function_middleware_after") # Add middleware to chat client @@ -84,14 +82,12 @@ async def test_chat_middleware_can_modify_messages(self, chat_client_base: "Mock """Test that chat middleware can modify messages before sending to model.""" @chat_middleware - async def message_modifier_middleware( - context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]] - ) -> None: + async def message_modifier_middleware(context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None: # Modify the first message by adding a prefix if context.messages and len(context.messages) > 0: original_text = context.messages[0].text or "" context.messages[0] = Message(role=context.messages[0].role, text=f"MODIFIED: {original_text}") - await call_next(context) + await call_next() # Add middleware to chat client chat_client_base.chat_middleware = [message_modifier_middleware] @@ -110,9 +106,7 @@ async def test_chat_middleware_can_override_response(self, chat_client_base: "Mo """Test that chat middleware can override the response.""" @chat_middleware - async def response_override_middleware( - context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]] - ) -> None: + async def response_override_middleware(context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None: # Override the response without calling next() context.result = ChatResponse( messages=[Message(role="assistant", text="MiddlewareTypes overridden response")], @@ -138,15 +132,15 @@ async def test_multiple_chat_middleware_execution_order(self, chat_client_base: execution_order: list[str] = [] @chat_middleware - async def first_middleware(context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]]) -> None: + async def first_middleware(context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None: execution_order.append("first_before") - await call_next(context) + await call_next() execution_order.append("first_after") @chat_middleware - async def second_middleware(context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]]) -> None: + async def second_middleware(context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None: execution_order.append("second_before") - await call_next(context) + await call_next() execution_order.append("second_after") # Add middleware to chat client (order should be preserved) @@ -173,11 +167,9 @@ async def test_chat_agent_with_chat_middleware(self) -> None: execution_order: list[str] = [] @chat_middleware - async def agent_level_chat_middleware( - context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]] - ) -> None: + async def agent_level_chat_middleware(context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None: execution_order.append("agent_chat_middleware_before") - await call_next(context) + await call_next() execution_order.append("agent_chat_middleware_after") client = MockBaseChatClient() @@ -205,15 +197,15 @@ async def test_chat_agent_with_multiple_chat_middleware(self, chat_client_base: execution_order: list[str] = [] @chat_middleware - async def first_middleware(context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]]) -> None: + async def first_middleware(context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None: execution_order.append("first_before") - await call_next(context) + await call_next() execution_order.append("first_after") @chat_middleware - async def second_middleware(context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]]) -> None: + async def second_middleware(context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None: execution_order.append("second_before") - await call_next(context) + await call_next() execution_order.append("second_after") # Create Agent with multiple chat middleware @@ -240,9 +232,7 @@ async def test_chat_middleware_with_streaming(self, chat_client_base: "MockBaseC execution_order: list[str] = [] @chat_middleware - async def streaming_middleware( - context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]] - ) -> None: + async def streaming_middleware(context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None: execution_order.append("streaming_before") # Verify it's a streaming context assert context.stream is True @@ -254,7 +244,7 @@ def upper_case_update(update: ChatResponseUpdate) -> ChatResponseUpdate: return update context.stream_transform_hooks.append(upper_case_update) - await call_next(context) + await call_next() execution_order.append("streaming_after") # Add middleware to chat client @@ -278,11 +268,9 @@ async def test_run_level_middleware_isolation(self, chat_client_base: "MockBaseC execution_count = {"count": 0} @chat_middleware - async def counting_middleware( - context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]] - ) -> None: + async def counting_middleware(context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None: execution_count["count"] += 1 - await call_next(context) + await call_next() # First call with run-level middleware messages = [Message(role="user", text="first message")] @@ -310,7 +298,7 @@ async def test_chat_client_middleware_can_access_and_override_custom_kwargs( modified_kwargs: dict[str, Any] = {} @chat_middleware - async def kwargs_middleware(context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]]) -> None: + async def kwargs_middleware(context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None: # Capture the original kwargs captured_kwargs.update(context.kwargs) @@ -322,7 +310,7 @@ async def kwargs_middleware(context: ChatContext, call_next: Callable[[ChatConte # Store modified kwargs for verification modified_kwargs.update(context.kwargs) - await call_next(context) + await call_next() # Add middleware to chat client chat_client_base.chat_middleware = [kwargs_middleware] @@ -355,11 +343,11 @@ async def test_function_middleware_registration_on_chat_client( @function_middleware async def test_function_middleware( - context: FunctionInvocationContext, call_next: Callable[[FunctionInvocationContext], Awaitable[None]] + context: FunctionInvocationContext, call_next: Callable[[], Awaitable[None]] ) -> None: nonlocal execution_order execution_order.append(f"function_middleware_before_{context.function.name}") - await call_next(context) + await call_next() execution_order.append(f"function_middleware_after_{context.function.name}") # Define a simple tool function @@ -421,10 +409,10 @@ async def test_run_level_function_middleware(self, chat_client_base: "MockBaseCh @function_middleware async def run_level_function_middleware( - context: FunctionInvocationContext, call_next: Callable[[FunctionInvocationContext], Awaitable[None]] + context: FunctionInvocationContext, call_next: Callable[[], Awaitable[None]] ) -> None: execution_order.append("run_level_function_middleware_before") - await call_next(context) + await call_next() execution_order.append("run_level_function_middleware_after") # Define a simple tool function diff --git a/python/packages/ollama/tests/test_ollama_chat_client.py b/python/packages/ollama/tests/test_ollama_chat_client.py index 807d5b8eb8..6e4f82233f 100644 --- a/python/packages/ollama/tests/test_ollama_chat_client.py +++ b/python/packages/ollama/tests/test_ollama_chat_client.py @@ -208,7 +208,7 @@ def test_serialize(ollama_unit_test_env: dict[str, str]) -> None: def test_chat_middleware(ollama_unit_test_env: dict[str, str]) -> None: @chat_middleware async def sample_middleware(context, call_next): - await call_next(context) + await call_next() ollama_chat_client = OllamaChatClient(middleware=[sample_middleware]) assert len(ollama_chat_client.middleware) == 1 diff --git a/python/packages/orchestrations/agent_framework_orchestrations/_handoff.py b/python/packages/orchestrations/agent_framework_orchestrations/_handoff.py index 37f499d763..e574528395 100644 --- a/python/packages/orchestrations/agent_framework_orchestrations/_handoff.py +++ b/python/packages/orchestrations/agent_framework_orchestrations/_handoff.py @@ -129,11 +129,11 @@ def __init__(self, handoffs: Sequence[HandoffConfiguration]) -> None: async def process( self, context: FunctionInvocationContext, - call_next: Callable[[FunctionInvocationContext], Awaitable[None]], + call_next: Callable[[], Awaitable[None]], ) -> None: """Intercept matching handoff tool calls and inject synthetic results.""" if context.function.name not in self._handoff_functions: - await call_next(context) + await call_next() return from agent_framework._middleware import MiddlewareTermination diff --git a/python/packages/purview/agent_framework_purview/_middleware.py b/python/packages/purview/agent_framework_purview/_middleware.py index 10e0443b0b..2da8de84ee 100644 --- a/python/packages/purview/agent_framework_purview/_middleware.py +++ b/python/packages/purview/agent_framework_purview/_middleware.py @@ -65,7 +65,7 @@ def _get_agent_session_id(context: AgentContext) -> str | None: async def process( self, context: AgentContext, - call_next: Callable[[AgentContext], Awaitable[None]], + call_next: Callable[[], Awaitable[None]], ) -> None: # type: ignore[override] resolved_user_id: str | None = None try: @@ -92,7 +92,7 @@ async def process( if not self._settings.ignore_exceptions: raise - await call_next(context) + await call_next() try: # Post (response) check only if we have a normal AgentResponse @@ -162,7 +162,7 @@ def __init__( async def process( self, context: ChatContext, - call_next: Callable[[ChatContext], Awaitable[None]], + call_next: Callable[[], Awaitable[None]], ) -> None: # type: ignore[override] resolved_user_id: str | None = None try: @@ -187,7 +187,7 @@ async def process( if not self._settings.ignore_exceptions: raise - await call_next(context) + await call_next() try: # Post (response) evaluation only if non-streaming and we have messages result shape diff --git a/python/packages/purview/tests/purview/test_chat_middleware.py b/python/packages/purview/tests/purview/test_chat_middleware.py index 677e3e277b..bc9be01e1f 100644 --- a/python/packages/purview/tests/purview/test_chat_middleware.py +++ b/python/packages/purview/tests/purview/test_chat_middleware.py @@ -49,7 +49,7 @@ async def test_allows_clean_prompt( with patch.object(middleware._processor, "process_messages", return_value=(False, "user-123")) as mock_proc: next_called = False - async def mock_next(ctx: ChatContext) -> None: + async def mock_next() -> None: nonlocal next_called next_called = True @@ -57,7 +57,7 @@ class Result: def __init__(self): self.messages = [Message(role="assistant", text="Hi there")] - ctx.result = Result() + chat_context.result = Result() await middleware.process(chat_context, mock_next) assert next_called @@ -67,7 +67,7 @@ def __init__(self): async def test_blocks_prompt(self, middleware: PurviewChatPolicyMiddleware, chat_context: ChatContext) -> None: with patch.object(middleware._processor, "process_messages", return_value=(True, "user-123")): - async def mock_next(ctx: ChatContext) -> None: # should not run + async def mock_next() -> None: # should not run raise AssertionError("next should not be called when prompt blocked") with pytest.raises(MiddlewareTermination): @@ -88,12 +88,12 @@ async def side_effect(messages, activity, session_id=None, user_id=None): with patch.object(middleware._processor, "process_messages", side_effect=side_effect): - async def mock_next(ctx: ChatContext) -> None: + async def mock_next() -> None: class Result: def __init__(self): self.messages = [Message(role="assistant", text="Sensitive output")] # pragma: no cover - ctx.result = Result() + chat_context.result = Result() await middleware.process(chat_context, mock_next) assert call_state["count"] == 2 @@ -114,8 +114,8 @@ async def test_streaming_skips_post_check(self, middleware: PurviewChatPolicyMid ) with patch.object(middleware._processor, "process_messages", return_value=(False, "user-123")) as mock_proc: - async def mock_next(ctx: ChatContext) -> None: - ctx.result = MagicMock() + async def mock_next() -> None: + streaming_context.result = MagicMock() await middleware.process(streaming_context, mock_next) assert mock_proc.call_count == 1 @@ -138,10 +138,10 @@ async def mock_process_messages(*args, **kwargs): with patch.object(middleware._processor, "process_messages", side_effect=mock_process_messages): - async def mock_next(ctx: ChatContext) -> None: + async def mock_next() -> None: result = MagicMock() result.messages = [Message(role="assistant", text="Response")] - ctx.result = result + chat_context.result = result await middleware.process(chat_context, mock_next) @@ -162,10 +162,10 @@ async def mock_process_messages(messages, activity, session_id=None, user_id=Non with patch.object(middleware._processor, "process_messages", side_effect=mock_process_messages): - async def mock_next(ctx: ChatContext) -> None: + async def mock_next() -> None: result = MagicMock() result.messages = [Message(role="assistant", text="Response")] - ctx.result = result + chat_context.result = result await middleware.process(chat_context, mock_next) @@ -194,7 +194,7 @@ async def mock_process_messages(*args, **kwargs): with patch.object(middleware._processor, "process_messages", side_effect=mock_process_messages): - async def mock_next(ctx: ChatContext) -> None: + async def mock_next() -> None: raise AssertionError("next should not be called") # Should raise the exception @@ -224,10 +224,10 @@ async def side_effect(*args, **kwargs): with patch.object(middleware._processor, "process_messages", side_effect=side_effect): - async def mock_next(ctx: ChatContext) -> None: + async def mock_next() -> None: result = MagicMock() result.messages = [Message(role="assistant", text="OK")] - ctx.result = result + context.result = result with pytest.raises(PurviewPaymentRequiredError): await middleware.process(context, mock_next) @@ -249,7 +249,7 @@ async def mock_process_messages(*args, **kwargs): with patch.object(middleware._processor, "process_messages", side_effect=mock_process_messages): - async def mock_next(ctx: ChatContext) -> None: + async def mock_next() -> None: result = MagicMock() result.messages = [Message(role="assistant", text="Response")] context.result = result @@ -265,9 +265,9 @@ async def test_chat_middleware_handles_result_without_messages_attribute( """Test middleware handles result that doesn't have messages attribute.""" with patch.object(middleware._processor, "process_messages", return_value=(False, "user-123")): - async def mock_next(ctx: ChatContext) -> None: + async def mock_next() -> None: # Set result to something without messages attribute - ctx.result = "Some string result" + chat_context.result = "Some string result" await middleware.process(chat_context, mock_next) @@ -289,7 +289,7 @@ async def mock_process_messages(*args, **kwargs): with patch.object(middleware._processor, "process_messages", side_effect=mock_process_messages): - async def mock_next(ctx: ChatContext) -> None: + async def mock_next() -> None: result = MagicMock() result.messages = [Message(role="assistant", text="Response")] context.result = result @@ -313,7 +313,7 @@ async def test_chat_middleware_raises_on_pre_check_exception_when_ignore_excepti with patch.object(middleware._processor, "process_messages", side_effect=ValueError("boom")): - async def mock_next(_: ChatContext) -> None: + async def mock_next() -> None: raise AssertionError("next should not be called") with pytest.raises(ValueError, match="boom"): @@ -342,10 +342,10 @@ async def side_effect(*args, **kwargs): with patch.object(middleware._processor, "process_messages", side_effect=side_effect): - async def mock_next(ctx: ChatContext) -> None: + async def mock_next() -> None: result = MagicMock() result.messages = [Message(role="assistant", text="OK")] - ctx.result = result + context.result = result with pytest.raises(ValueError, match="post"): await middleware.process(context, mock_next) @@ -361,10 +361,10 @@ async def test_chat_middleware_uses_conversation_id_from_options( with patch.object(middleware._processor, "process_messages", return_value=(False, "user-123")) as mock_proc: - async def mock_next(ctx: ChatContext) -> None: + async def mock_next() -> None: result = MagicMock() result.messages = [Message(role="assistant", text="Hi")] - ctx.result = result + context.result = result await middleware.process(context, mock_next) @@ -382,10 +382,10 @@ async def test_chat_middleware_passes_none_session_id_when_options_missing( with patch.object(middleware._processor, "process_messages", return_value=(False, "user-123")) as mock_proc: - async def mock_next(ctx: ChatContext) -> None: + async def mock_next() -> None: result = MagicMock() result.messages = [Message(role="assistant", text="Hi")] - ctx.result = result + context.result = result await middleware.process(context, mock_next) @@ -401,10 +401,10 @@ async def test_chat_middleware_session_id_used_in_post_check(self, middleware: P with patch.object(middleware._processor, "process_messages", return_value=(False, "user-123")) as mock_proc: - async def mock_next(ctx: ChatContext) -> None: + async def mock_next() -> None: result = MagicMock() result.messages = [Message(role="assistant", text="Response")] - ctx.result = result + context.result = result await middleware.process(context, mock_next) diff --git a/python/packages/purview/tests/purview/test_middleware.py b/python/packages/purview/tests/purview/test_middleware.py index ff77331155..98dafab1e1 100644 --- a/python/packages/purview/tests/purview/test_middleware.py +++ b/python/packages/purview/tests/purview/test_middleware.py @@ -55,10 +55,10 @@ async def test_middleware_allows_clean_prompt( with patch.object(middleware._processor, "process_messages", return_value=(False, "user-123")): next_called = False - async def mock_next(ctx: AgentContext) -> None: + async def mock_next() -> None: nonlocal next_called next_called = True - ctx.result = AgentResponse(messages=[Message(role="assistant", text="I'm good, thanks!")]) + context.result = AgentResponse(messages=[Message(role="assistant", text="I'm good, thanks!")]) await middleware.process(context, mock_next) @@ -74,7 +74,7 @@ async def test_middleware_blocks_prompt_on_policy_violation( with patch.object(middleware._processor, "process_messages", return_value=(True, "user-123")): next_called = False - async def mock_next(ctx: AgentContext) -> None: + async def mock_next() -> None: nonlocal next_called next_called = True @@ -101,8 +101,8 @@ async def mock_process_messages(messages, activity, session_id=None, user_id=Non with patch.object(middleware._processor, "process_messages", side_effect=mock_process_messages): - async def mock_next(ctx: AgentContext) -> None: - ctx.result = AgentResponse( + async def mock_next() -> None: + context.result = AgentResponse( messages=[Message(role="assistant", text="Here's some sensitive information")] ) @@ -125,8 +125,8 @@ async def test_middleware_handles_result_without_messages( with patch.object(middleware._processor, "process_messages", return_value=(False, "user-123")): - async def mock_next(ctx: AgentContext) -> None: - ctx.result = "Some non-standard result" + async def mock_next() -> None: + context.result = "Some non-standard result" await middleware.process(context, mock_next) @@ -142,8 +142,8 @@ async def test_middleware_processor_receives_correct_activity( with patch.object(middleware._processor, "process_messages", return_value=(False, "user-123")) as mock_process: - async def mock_next(ctx: AgentContext) -> None: - ctx.result = AgentResponse(messages=[Message(role="assistant", text="Response")]) + async def mock_next() -> None: + context.result = AgentResponse(messages=[Message(role="assistant", text="Response")]) await middleware.process(context, mock_next) @@ -160,8 +160,8 @@ async def test_middleware_streaming_skips_post_check( with patch.object(middleware._processor, "process_messages", return_value=(False, "user-123")) as mock_proc: - async def mock_next(ctx: AgentContext) -> None: - ctx.result = AgentResponse(messages=[Message(role="assistant", text="streaming")]) + async def mock_next() -> None: + context.result = AgentResponse(messages=[Message(role="assistant", text="streaming")]) await middleware.process(context, mock_next) @@ -181,7 +181,7 @@ async def test_middleware_payment_required_in_pre_check_raises_by_default( side_effect=PurviewPaymentRequiredError("Payment required"), ): - async def mock_next(_: AgentContext) -> None: + async def mock_next() -> None: raise AssertionError("next should not be called") with pytest.raises(PurviewPaymentRequiredError): @@ -206,8 +206,8 @@ async def side_effect(*args, **kwargs): with patch.object(middleware._processor, "process_messages", side_effect=side_effect): - async def mock_next(ctx: AgentContext) -> None: - ctx.result = AgentResponse(messages=[Message(role="assistant", text="OK")]) + async def mock_next() -> None: + context.result = AgentResponse(messages=[Message(role="assistant", text="OK")]) with pytest.raises(PurviewPaymentRequiredError): await middleware.process(context, mock_next) @@ -231,8 +231,8 @@ async def side_effect(*args, **kwargs): with patch.object(middleware._processor, "process_messages", side_effect=side_effect): - async def mock_next(ctx: AgentContext) -> None: - ctx.result = AgentResponse(messages=[Message(role="assistant", text="OK")]) + async def mock_next() -> None: + context.result = AgentResponse(messages=[Message(role="assistant", text="OK")]) with pytest.raises(ValueError, match="Post-check blew up"): await middleware.process(context, mock_next) @@ -250,8 +250,8 @@ async def test_middleware_handles_pre_check_exception( middleware._processor, "process_messages", side_effect=Exception("Pre-check error") ) as mock_process: - async def mock_next(ctx: AgentContext) -> None: - ctx.result = AgentResponse(messages=[Message(role="assistant", text="Response")]) + async def mock_next() -> None: + context.result = AgentResponse(messages=[Message(role="assistant", text="Response")]) await middleware.process(context, mock_next) @@ -280,8 +280,8 @@ async def mock_process_messages(*args, **kwargs): with patch.object(middleware._processor, "process_messages", side_effect=mock_process_messages): - async def mock_next(ctx: AgentContext) -> None: - ctx.result = AgentResponse(messages=[Message(role="assistant", text="Response")]) + async def mock_next() -> None: + context.result = AgentResponse(messages=[Message(role="assistant", text="Response")]) await middleware.process(context, mock_next) @@ -306,8 +306,8 @@ async def mock_process_messages(*args, **kwargs): with patch.object(middleware._processor, "process_messages", side_effect=mock_process_messages): - async def mock_next(ctx): - ctx.result = AgentResponse(messages=[Message(role="assistant", text="Response")]) + async def mock_next(): + context.result = AgentResponse(messages=[Message(role="assistant", text="Response")]) # Should not raise, just log await middleware.process(context, mock_next) @@ -330,7 +330,7 @@ async def mock_process_messages(*args, **kwargs): with patch.object(middleware._processor, "process_messages", side_effect=mock_process_messages): - async def mock_next(ctx): + async def mock_next(): pass # Should raise the exception @@ -346,8 +346,8 @@ async def test_middleware_uses_thread_service_thread_id_as_session_id( with patch.object(middleware._processor, "process_messages", return_value=(False, "user-123")) as mock_proc: - async def mock_next(ctx: AgentContext) -> None: - ctx.result = AgentResponse(messages=[Message(role="assistant", text="Hi")]) + async def mock_next() -> None: + context.result = AgentResponse(messages=[Message(role="assistant", text="Hi")]) await middleware.process(context, mock_next) @@ -364,8 +364,8 @@ async def test_middleware_uses_message_conversation_id_as_session_id( with patch.object(middleware._processor, "process_messages", return_value=(False, "user-123")) as mock_proc: - async def mock_next(ctx: AgentContext) -> None: - ctx.result = AgentResponse(messages=[Message(role="assistant", text="Hi")]) + async def mock_next() -> None: + context.result = AgentResponse(messages=[Message(role="assistant", text="Hi")]) await middleware.process(context, mock_next) @@ -383,8 +383,8 @@ async def test_middleware_thread_id_takes_precedence_over_message_conversation_i with patch.object(middleware._processor, "process_messages", return_value=(False, "user-123")) as mock_proc: - async def mock_next(ctx: AgentContext) -> None: - ctx.result = AgentResponse(messages=[Message(role="assistant", text="Hi")]) + async def mock_next() -> None: + context.result = AgentResponse(messages=[Message(role="assistant", text="Hi")]) await middleware.process(context, mock_next) @@ -399,8 +399,8 @@ async def test_middleware_passes_none_session_id_when_not_available( with patch.object(middleware._processor, "process_messages", return_value=(False, "user-123")) as mock_proc: - async def mock_next(ctx: AgentContext) -> None: - ctx.result = AgentResponse(messages=[Message(role="assistant", text="Hi")]) + async def mock_next() -> None: + context.result = AgentResponse(messages=[Message(role="assistant", text="Hi")]) await middleware.process(context, mock_next) @@ -416,8 +416,8 @@ async def test_middleware_session_id_used_in_post_check( with patch.object(middleware._processor, "process_messages", return_value=(False, "user-123")) as mock_proc: - async def mock_next(ctx: AgentContext) -> None: - ctx.result = AgentResponse(messages=[Message(role="assistant", text="Response")]) + async def mock_next() -> None: + context.result = AgentResponse(messages=[Message(role="assistant", text="Response")]) await middleware.process(context, mock_next) diff --git a/python/samples/concepts/tools/README.md b/python/samples/concepts/tools/README.md index 91c481842d..2af617cf4c 100644 --- a/python/samples/concepts/tools/README.md +++ b/python/samples/concepts/tools/README.md @@ -267,7 +267,7 @@ class TerminatingMiddleware(FunctionMiddleware): if self.should_terminate(context): context.result = "terminated by middleware" raise MiddlewareTermination # Exit function invocation loop - await call_next(context) + await call_next() ``` ## Arguments Added/Altered at Each Layer @@ -347,7 +347,7 @@ class CachingMiddleware(FunctionMiddleware): return # Upstream post-processing still runs # Option B: Call call_next, then return normally - await call_next(context) + await call_next() self.cache[context.function.name] = context.result return # Normal completion ``` @@ -362,7 +362,7 @@ class BlockedFunctionMiddleware(FunctionMiddleware): if context.function.name in self.blocked_functions: context.result = "Function blocked by policy" raise MiddlewareTermination("Blocked") # Skips ALL post-processing - await call_next(context) + await call_next() ``` ### 3. Raise Any Other Exception @@ -374,7 +374,7 @@ class ValidationMiddleware(FunctionMiddleware): async def process(self, context: FunctionInvocationContext, call_next): if not self.is_valid(context.arguments): raise ValueError("Invalid arguments") # Bubbles up to user - await call_next(context) + await call_next() ``` ## `return` vs `raise MiddlewareTermination` @@ -385,7 +385,7 @@ The key difference is what happens to **upstream middleware's post-processing**: class MiddlewareA(AgentMiddleware): async def process(self, context, call_next): print("A: before") - await call_next(context) + await call_next() print("A: after") # Does this run? class MiddlewareB(AgentMiddleware): @@ -410,7 +410,7 @@ With middleware registered as `[MiddlewareA, MiddlewareB]`: ## Calling `call_next()` or Not -The decision to call `call_next(context)` determines whether downstream middleware and the actual operation execute: +The decision to call `call_next()` determines whether downstream middleware and the actual operation execute: ### Without calling `call_next()` - Skip downstream @@ -430,7 +430,7 @@ async def process(self, context, call_next): ```python async def process(self, context, call_next): # Pre-processing - await call_next(context) # Execute downstream + actual operation + await call_next() # Execute downstream + actual operation # Post-processing (context.result now contains real result) return ``` @@ -450,7 +450,7 @@ async def process(self, context, call_next): | `raise MiddlewareTermination` | Yes | ✅ | ✅ | ❌ No | | `raise OtherException` | Either | Depends | Depends | ❌ No (exception propagates) | -> **Note:** The first row (`return` after calling `call_next()`) is the default behavior. Python functions implicitly return `None` at the end, so simply calling `await call_next(context)` without an explicit `return` statement achieves this pattern. +> **Note:** The first row (`return` after calling `call_next()`) is the default behavior. Python functions implicitly return `None` at the end, so simply calling `await call_next()` without an explicit `return` statement achieves this pattern. ## Streaming vs Non-Streaming diff --git a/python/samples/getting_started/agents/azure_ai/azure_ai_with_agent_as_tool.py b/python/samples/getting_started/agents/azure_ai/azure_ai_with_agent_as_tool.py index f03fc4beb1..2d873f2930 100644 --- a/python/samples/getting_started/agents/azure_ai/azure_ai_with_agent_as_tool.py +++ b/python/samples/getting_started/agents/azure_ai/azure_ai_with_agent_as_tool.py @@ -20,13 +20,13 @@ async def logging_middleware( context: FunctionInvocationContext, - call_next: Callable[[FunctionInvocationContext], Awaitable[None]], + call_next: Callable[[], Awaitable[None]], ) -> None: """MiddlewareTypes that logs tool invocations to show the delegation flow.""" print(f"[Calling tool: {context.function.name}]") print(f"[Request: {context.arguments}]") - await call_next(context) + await call_next() print(f"[Response: {context.result}]") diff --git a/python/samples/getting_started/agents/openai/openai_responses_client_basic.py b/python/samples/getting_started/agents/openai/openai_responses_client_basic.py index b6ab9fb42c..7dc8fdce2f 100644 --- a/python/samples/getting_started/agents/openai/openai_responses_client_basic.py +++ b/python/samples/getting_started/agents/openai/openai_responses_client_basic.py @@ -29,7 +29,7 @@ @chat_middleware async def security_and_override_middleware( context: ChatContext, - call_next: Callable[[ChatContext], Awaitable[None]], + call_next: Callable[[], Awaitable[None]], ) -> None: """Function-based middleware that implements security filtering and response override.""" print("[SecurityMiddleware] Processing input...") @@ -60,7 +60,7 @@ async def security_and_override_middleware( raise MiddlewareTermination(result=context.result) # Continue to next middleware or AI execution - await call_next(context) + await call_next() print("[SecurityMiddleware] Response generated.") print(type(context.result)) diff --git a/python/samples/getting_started/agents/openai/openai_responses_client_with_agent_as_tool.py b/python/samples/getting_started/agents/openai/openai_responses_client_with_agent_as_tool.py index d37d5a9b4a..774231d0d6 100644 --- a/python/samples/getting_started/agents/openai/openai_responses_client_with_agent_as_tool.py +++ b/python/samples/getting_started/agents/openai/openai_responses_client_with_agent_as_tool.py @@ -19,13 +19,13 @@ async def logging_middleware( context: FunctionInvocationContext, - call_next: Callable[[FunctionInvocationContext], Awaitable[None]], + call_next: Callable[[], Awaitable[None]], ) -> None: """MiddlewareTypes that logs tool invocations to show the delegation flow.""" print(f"[Calling tool: {context.function.name}]") print(f"[Request: {context.arguments}]") - await call_next(context) + await call_next() print(f"[Response: {context.result}]") diff --git a/python/samples/getting_started/devui/weather_agent_azure/agent.py b/python/samples/getting_started/devui/weather_agent_azure/agent.py index dca5b69bbc..a754d32ead 100644 --- a/python/samples/getting_started/devui/weather_agent_azure/agent.py +++ b/python/samples/getting_started/devui/weather_agent_azure/agent.py @@ -38,7 +38,7 @@ def cleanup_resources(): @chat_middleware async def security_filter_middleware( context: ChatContext, - call_next: Callable[[ChatContext], Awaitable[None]], + call_next: Callable[[], Awaitable[None]], ) -> None: """Chat middleware that blocks requests containing sensitive information.""" blocked_terms = ["password", "secret", "api_key", "token"] @@ -80,13 +80,13 @@ async def blocked_stream(msg: str = error_message) -> AsyncIterable[ChatResponse raise MiddlewareTermination(result=context.result) - await call_next(context) + await call_next() @function_middleware async def atlantis_location_filter_middleware( context: FunctionInvocationContext, - call_next: Callable[[FunctionInvocationContext], Awaitable[None]], + call_next: Callable[[], Awaitable[None]], ) -> None: """Function middleware that blocks weather requests for Atlantis.""" # Check if location parameter is "atlantis" @@ -98,7 +98,7 @@ async def atlantis_location_filter_middleware( ) raise MiddlewareTermination(result=context.result) - await call_next(context) + await call_next() # NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; see samples/getting_started/tools/function_tool_with_approval.py and samples/getting_started/tools/function_tool_with_approval_and_threads.py. diff --git a/python/samples/getting_started/middleware/agent_and_run_level_middleware.py b/python/samples/getting_started/middleware/agent_and_run_level_middleware.py index 70408472ad..1f80c7742f 100644 --- a/python/samples/getting_started/middleware/agent_and_run_level_middleware.py +++ b/python/samples/getting_started/middleware/agent_and_run_level_middleware.py @@ -68,7 +68,7 @@ def get_weather( class SecurityAgentMiddleware(AgentMiddleware): """Agent-level security middleware that validates all requests.""" - async def process(self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]) -> None: + async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: print("[SecurityMiddleware] Checking security for all requests...") # Check for security violations in the last user message @@ -81,18 +81,18 @@ async def process(self, context: AgentContext, call_next: Callable[[AgentContext print("[SecurityMiddleware] Security check passed.") context.metadata["security_validated"] = True - await call_next(context) + await call_next() async def performance_monitor_middleware( context: AgentContext, - call_next: Callable[[AgentContext], Awaitable[None]], + call_next: Callable[[], Awaitable[None]], ) -> None: """Agent-level performance monitoring for all runs.""" print("[PerformanceMonitor] Starting performance monitoring...") start_time = time.time() - await call_next(context) + await call_next() end_time = time.time() duration = end_time - start_time @@ -104,7 +104,7 @@ async def performance_monitor_middleware( class HighPriorityMiddleware(AgentMiddleware): """Run-level middleware for high priority requests.""" - async def process(self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]) -> None: + async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: print("[HighPriority] Processing high priority request with expedited handling...") # Read metadata set by agent-level middleware @@ -115,13 +115,13 @@ async def process(self, context: AgentContext, call_next: Callable[[AgentContext context.metadata["priority"] = "high" context.metadata["expedited"] = True - await call_next(context) + await call_next() print("[HighPriority] High priority processing completed") async def debugging_middleware( context: AgentContext, - call_next: Callable[[AgentContext], Awaitable[None]], + call_next: Callable[[], Awaitable[None]], ) -> None: """Run-level debugging middleware for troubleshooting specific runs.""" print("[Debug] Debug mode enabled for this run") @@ -134,7 +134,7 @@ async def debugging_middleware( context.metadata["debug_enabled"] = True - await call_next(context) + await call_next() print("[Debug] Debug information collected") @@ -145,7 +145,7 @@ class CachingMiddleware(AgentMiddleware): def __init__(self) -> None: self.cache: dict[str, AgentResponse] = {} - async def process(self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]) -> None: + async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: # Create a simple cache key from the last message last_message = context.messages[-1] if context.messages else None cache_key: str = last_message.text if last_message and last_message.text else "no_message" @@ -158,7 +158,7 @@ async def process(self, context: AgentContext, call_next: Callable[[AgentContext print(f"[Cache] Cache MISS for: '{cache_key[:30]}...'") context.metadata["cache_key"] = cache_key - await call_next(context) + await call_next() # Cache the result if we have one if context.result: @@ -168,14 +168,14 @@ async def process(self, context: AgentContext, call_next: Callable[[AgentContext async def function_logging_middleware( context: FunctionInvocationContext, - call_next: Callable[[FunctionInvocationContext], Awaitable[None]], + call_next: Callable[[], Awaitable[None]], ) -> None: """Function middleware that logs all function calls.""" function_name = context.function.name args = context.arguments print(f"[FunctionLog] Calling function: {function_name} with args: {args}") - await call_next(context) + await call_next() print(f"[FunctionLog] Function {function_name} completed") @@ -275,7 +275,7 @@ async def main() -> None: query = "What's the secret weather password for Berlin?" print(f"User: {query}") result = await agent.run(query) - print(f"Agent: {result.text if result.text else 'Request was blocked by security middleware'}") + print(f"Agent: {result.text if result and result.text else 'Request was blocked by security middleware'}") print() # Run 7: Normal query again (no run-level middleware interference) diff --git a/python/samples/getting_started/middleware/chat_middleware.py b/python/samples/getting_started/middleware/chat_middleware.py index 424db96457..f0c9ef153e 100644 --- a/python/samples/getting_started/middleware/chat_middleware.py +++ b/python/samples/getting_started/middleware/chat_middleware.py @@ -57,7 +57,7 @@ def __init__(self, replacement: str | None = None): async def process( self, context: ChatContext, - call_next: Callable[[ChatContext], Awaitable[None]], + call_next: Callable[[], Awaitable[None]], ) -> None: """Observe and modify input messages before they are sent to AI.""" print("[InputObserverMiddleware] Observing input messages:") @@ -91,7 +91,7 @@ async def process( context.messages[:] = modified_messages # Continue to next middleware or AI execution - await call_next(context) + await call_next() # Observe that processing is complete print("[InputObserverMiddleware] Processing completed") @@ -100,7 +100,7 @@ async def process( @chat_middleware async def security_and_override_middleware( context: ChatContext, - call_next: Callable[[ChatContext], Awaitable[None]], + call_next: Callable[[], Awaitable[None]], ) -> None: """Function-based middleware that implements security filtering and response override.""" print("[SecurityMiddleware] Processing input...") @@ -131,7 +131,7 @@ async def security_and_override_middleware( raise MiddlewareTermination # Continue to next middleware or AI execution - await call_next(context) + await call_next() async def class_based_chat_middleware() -> None: diff --git a/python/samples/getting_started/middleware/class_based_middleware.py b/python/samples/getting_started/middleware/class_based_middleware.py index 208dddc96d..e3cb884c69 100644 --- a/python/samples/getting_started/middleware/class_based_middleware.py +++ b/python/samples/getting_started/middleware/class_based_middleware.py @@ -50,7 +50,7 @@ class SecurityAgentMiddleware(AgentMiddleware): async def process( self, context: AgentContext, - call_next: Callable[[AgentContext], Awaitable[None]], + call_next: Callable[[], Awaitable[None]], ) -> None: # Check for potential security violations in the query # Look at the last user message @@ -67,7 +67,7 @@ async def process( return print("[SecurityAgentMiddleware] Security check passed.") - await call_next(context) + await call_next() class LoggingFunctionMiddleware(FunctionMiddleware): @@ -76,14 +76,14 @@ class LoggingFunctionMiddleware(FunctionMiddleware): async def process( self, context: FunctionInvocationContext, - call_next: Callable[[FunctionInvocationContext], Awaitable[None]], + call_next: Callable[[], Awaitable[None]], ) -> None: function_name = context.function.name print(f"[LoggingFunctionMiddleware] About to call function: {function_name}.") start_time = time.time() - await call_next(context) + await call_next() end_time = time.time() duration = end_time - start_time diff --git a/python/samples/getting_started/middleware/decorator_middleware.py b/python/samples/getting_started/middleware/decorator_middleware.py index 3f5e57e48e..e432473a30 100644 --- a/python/samples/getting_started/middleware/decorator_middleware.py +++ b/python/samples/getting_started/middleware/decorator_middleware.py @@ -53,7 +53,7 @@ def get_current_time() -> str: async def simple_agent_middleware(context, call_next): # type: ignore - parameters intentionally untyped to demonstrate decorator functionality """Agent middleware that runs before and after agent execution.""" print("[Agent MiddlewareTypes] Before agent execution") - await call_next(context) + await call_next() print("[Agent MiddlewareTypes] After agent execution") @@ -61,7 +61,7 @@ async def simple_agent_middleware(context, call_next): # type: ignore - paramet async def simple_function_middleware(context, call_next): # type: ignore - parameters intentionally untyped to demonstrate decorator functionality """Function middleware that runs before and after function calls.""" print(f"[Function MiddlewareTypes] Before calling: {context.function.name}") # type: ignore - await call_next(context) + await call_next() print(f"[Function MiddlewareTypes] After calling: {context.function.name}") # type: ignore diff --git a/python/samples/getting_started/middleware/exception_handling_with_middleware.py b/python/samples/getting_started/middleware/exception_handling_with_middleware.py index b929af4c94..1f7ed59542 100644 --- a/python/samples/getting_started/middleware/exception_handling_with_middleware.py +++ b/python/samples/getting_started/middleware/exception_handling_with_middleware.py @@ -35,13 +35,13 @@ def unstable_data_service( async def exception_handling_middleware( - context: FunctionInvocationContext, call_next: Callable[[FunctionInvocationContext], Awaitable[None]] + context: FunctionInvocationContext, call_next: Callable[[], Awaitable[None]] ) -> None: function_name = context.function.name try: print(f"[ExceptionHandlingMiddleware] Executing function: {function_name}") - await call_next(context) + await call_next() print(f"[ExceptionHandlingMiddleware] Function {function_name} completed successfully.") except TimeoutError as e: print(f"[ExceptionHandlingMiddleware] Caught TimeoutError: {e}") diff --git a/python/samples/getting_started/middleware/function_based_middleware.py b/python/samples/getting_started/middleware/function_based_middleware.py index d9b9062003..38272a4cd1 100644 --- a/python/samples/getting_started/middleware/function_based_middleware.py +++ b/python/samples/getting_started/middleware/function_based_middleware.py @@ -43,7 +43,7 @@ def get_weather( async def security_agent_middleware( context: AgentContext, - call_next: Callable[[AgentContext], Awaitable[None]], + call_next: Callable[[], Awaitable[None]], ) -> None: """Agent middleware that checks for security violations.""" # Check for potential security violations in the query @@ -57,12 +57,12 @@ async def security_agent_middleware( return print("[SecurityAgentMiddleware] Security check passed.") - await call_next(context) + await call_next() async def logging_function_middleware( context: FunctionInvocationContext, - call_next: Callable[[FunctionInvocationContext], Awaitable[None]], + call_next: Callable[[], Awaitable[None]], ) -> None: """Function middleware that logs function calls.""" function_name = context.function.name @@ -70,7 +70,7 @@ async def logging_function_middleware( start_time = time.time() - await call_next(context) + await call_next() end_time = time.time() duration = end_time - start_time @@ -105,7 +105,7 @@ async def main() -> None: query = "What's the secret weather password?" print(f"User: {query}") result = await agent.run(query) - print(f"Agent: {result.text if result.text else 'No response'}\n") + print(f"Agent: {result.text if result and result.text else 'No response'}\n") if __name__ == "__main__": diff --git a/python/samples/getting_started/middleware/middleware_termination.py b/python/samples/getting_started/middleware/middleware_termination.py index 9f48e662c5..ce2db3e376 100644 --- a/python/samples/getting_started/middleware/middleware_termination.py +++ b/python/samples/getting_started/middleware/middleware_termination.py @@ -49,7 +49,7 @@ def __init__(self, blocked_words: list[str]): async def process( self, context: AgentContext, - call_next: Callable[[AgentContext], Awaitable[None]], + call_next: Callable[[], Awaitable[None]], ) -> None: # Check if the user message contains any blocked words last_message = context.messages[-1] if context.messages else None @@ -75,7 +75,7 @@ async def process( # Terminate to prevent further processing raise MiddlewareTermination(result=context.result) - await call_next(context) + await call_next() class PostTerminationMiddleware(AgentMiddleware): @@ -88,7 +88,7 @@ def __init__(self, max_responses: int = 1): async def process( self, context: AgentContext, - call_next: Callable[[AgentContext], Awaitable[None]], + call_next: Callable[[], Awaitable[None]], ) -> None: print(f"[PostTerminationMiddleware] Processing request (response count: {self.response_count})") @@ -101,7 +101,7 @@ async def process( raise MiddlewareTermination # Allow the agent to process normally - await call_next(context) + await call_next() # Increment response count after processing self.response_count += 1 @@ -158,14 +158,14 @@ async def post_termination_middleware() -> None: query = "What about the weather in London?" print(f"User: {query}") result = await agent.run(query) - print(f"Agent: {result.text if result.text else 'No response (terminated)'}") + print(f"Agent: {result.text if result and result.text else 'No response (terminated)'}") # Third run (should also be terminated) print("\n3. Third run (should also be terminated):") query = "And New York?" print(f"User: {query}") result = await agent.run(query) - print(f"Agent: {result.text if result.text else 'No response (terminated)'}") + print(f"Agent: {result.text if result and result.text else 'No response (terminated)'}") async def main() -> None: diff --git a/python/samples/getting_started/middleware/override_result_with_middleware.py b/python/samples/getting_started/middleware/override_result_with_middleware.py index 2239136c3c..d05ec1b4f3 100644 --- a/python/samples/getting_started/middleware/override_result_with_middleware.py +++ b/python/samples/getting_started/middleware/override_result_with_middleware.py @@ -49,11 +49,11 @@ def get_weather( return f"The weather in {location} is {conditions[randint(0, 3)]} with a high of {randint(10, 30)}°C." -async def weather_override_middleware(context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]]) -> None: +async def weather_override_middleware(context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None: """Chat middleware that overrides weather results for both streaming and non-streaming cases.""" # Let the original agent execution complete first - await call_next(context) + await call_next() # Check if there's a result to override (agent called weather function) if context.result is not None: @@ -84,9 +84,9 @@ def _update_hook(update: ChatResponseUpdate) -> ChatResponseUpdate: context.result = ChatResponse(messages=[Message(role=Role.ASSISTANT, text=custom_message)]) -async def validate_weather_middleware(context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]]) -> None: +async def validate_weather_middleware(context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None: """Chat middleware that simulates result validation for both streaming and non-streaming cases.""" - await call_next(context) + await call_next() validation_note = "Validation: weather data verified." @@ -104,9 +104,9 @@ def _append_validation_note(response: ChatResponse) -> ChatResponse: context.result.messages.append(Message(role=Role.ASSISTANT, text=validation_note)) -async def agent_cleanup_middleware(context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]) -> None: +async def agent_cleanup_middleware(context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: """Agent middleware that validates chat middleware effects and cleans the result.""" - await call_next(context) + await call_next() if context.result is None: return diff --git a/python/samples/getting_started/middleware/runtime_context_delegation.py b/python/samples/getting_started/middleware/runtime_context_delegation.py index 700b1da6f5..d839960da7 100644 --- a/python/samples/getting_started/middleware/runtime_context_delegation.py +++ b/python/samples/getting_started/middleware/runtime_context_delegation.py @@ -54,7 +54,7 @@ def __init__(self) -> None: async def inject_context_middleware( self, context: FunctionInvocationContext, - call_next: Callable[[FunctionInvocationContext], Awaitable[None]], + call_next: Callable[[], Awaitable[None]], ) -> None: """MiddlewareTypes that extracts runtime context from kwargs and stores in container. @@ -74,7 +74,7 @@ async def inject_context_middleware( print(f" - Session Metadata Keys: {list(self.session_metadata.keys())}") # Continue to tool execution - await call_next(context) + await call_next() # Create a container instance that will be shared via closure @@ -278,19 +278,19 @@ async def pattern_2_hierarchical_with_kwargs_propagation() -> None: @function_middleware async def email_kwargs_tracker( - context: FunctionInvocationContext, call_next: Callable[[FunctionInvocationContext], Awaitable[None]] + context: FunctionInvocationContext, call_next: Callable[[], Awaitable[None]] ) -> None: email_agent_kwargs.update(context.kwargs) print(f"[EmailAgent] Received runtime context: {list(context.kwargs.keys())}") - await call_next(context) + await call_next() @function_middleware async def sms_kwargs_tracker( - context: FunctionInvocationContext, call_next: Callable[[FunctionInvocationContext], Awaitable[None]] + context: FunctionInvocationContext, call_next: Callable[[], Awaitable[None]] ) -> None: sms_agent_kwargs.update(context.kwargs) print(f"[SMSAgent] Received runtime context: {list(context.kwargs.keys())}") - await call_next(context) + await call_next() client = OpenAIChatClient(model_id="gpt-4o-mini") @@ -359,7 +359,7 @@ def __init__(self) -> None: self.validated_tokens: list[str] = [] async def validate_and_track( - self, context: FunctionInvocationContext, call_next: Callable[[FunctionInvocationContext], Awaitable[None]] + self, context: FunctionInvocationContext, call_next: Callable[[], Awaitable[None]] ) -> None: """Validate API token and track usage.""" api_token = context.kwargs.get("api_token") @@ -375,7 +375,7 @@ async def validate_and_track( else: print("[AuthMiddleware] No API token provided") - await call_next(context) + await call_next() @tool(approval_mode="never_require") diff --git a/python/samples/getting_started/middleware/shared_state_middleware.py b/python/samples/getting_started/middleware/shared_state_middleware.py index a377d7dfd3..a3aae59ccd 100644 --- a/python/samples/getting_started/middleware/shared_state_middleware.py +++ b/python/samples/getting_started/middleware/shared_state_middleware.py @@ -57,7 +57,7 @@ def __init__(self) -> None: async def call_counter_middleware( self, context: FunctionInvocationContext, - call_next: Callable[[FunctionInvocationContext], Awaitable[None]], + call_next: Callable[[], Awaitable[None]], ) -> None: """First middleware: increments call count in shared state.""" # Increment the shared call count @@ -66,18 +66,18 @@ async def call_counter_middleware( print(f"[CallCounter] This is function call #{self.call_count}") # Call the next middleware/function - await call_next(context) + await call_next() async def result_enhancer_middleware( self, context: FunctionInvocationContext, - call_next: Callable[[FunctionInvocationContext], Awaitable[None]], + call_next: Callable[[], Awaitable[None]], ) -> None: """Second middleware: uses shared call count to enhance function results.""" print(f"[ResultEnhancer] Current total calls so far: {self.call_count}") # Call the next middleware/function - await call_next(context) + await call_next() # After function execution, enhance the result using shared state if context.result: diff --git a/python/samples/getting_started/middleware/thread_behavior_middleware.py b/python/samples/getting_started/middleware/thread_behavior_middleware.py index e3306eef7b..680fd01d50 100644 --- a/python/samples/getting_started/middleware/thread_behavior_middleware.py +++ b/python/samples/getting_started/middleware/thread_behavior_middleware.py @@ -46,7 +46,7 @@ def get_weather( async def thread_tracking_middleware( context: AgentContext, - call_next: Callable[[AgentContext], Awaitable[None]], + call_next: Callable[[], Awaitable[None]], ) -> None: """MiddlewareTypes that tracks and logs thread behavior across runs.""" thread_messages = [] @@ -57,7 +57,7 @@ async def thread_tracking_middleware( print(f"[MiddlewareTypes pre-execution] Thread history messages: {len(thread_messages)}") # Call call_next to execute the agent - await call_next(context) + await call_next() # Check thread state after agent execution updated_thread_messages = []