Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/packages/core/AGENTS.md
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ from agent_framework import Agent, AgentMiddleware, AgentContext
class LoggingMiddleware(AgentMiddleware):
async def process(self, context: AgentContext, call_next) -> None:
print(f"Input: {context.messages}")
await call_next(context)
await call_next()
print(f"Output: {context.result}")

agent = Agent(..., middleware=[LoggingMiddleware()])
Expand Down
84 changes: 41 additions & 43 deletions python/packages/core/agent_framework/_middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ async def process(self, context: AgentContext, call_next):
context.metadata["start_time"] = time.time()

# Continue execution
await call_next(context)
await call_next()

# Access result after execution
print(f"Result: {context.result}")
Expand Down Expand Up @@ -229,7 +229,7 @@ async def process(self, context: FunctionInvocationContext, call_next):
raise MiddlewareTermination("Validation failed")

# Continue execution
await call_next(context)
await call_next()
"""

def __init__(
Expand Down Expand Up @@ -293,7 +293,7 @@ async def process(self, context: ChatContext, call_next):
context.metadata["input_tokens"] = self.count_tokens(context.messages)

# Continue execution
await call_next(context)
await call_next()

# Access result and count output tokens
if context.result:
Expand Down Expand Up @@ -365,7 +365,7 @@ def __init__(self, max_retries: int = 3):

async def process(self, context: AgentContext, call_next):
for attempt in range(self.max_retries):
await call_next(context)
await call_next()
if context.result and not context.result.is_error:
break
print(f"Retry {attempt + 1}/{self.max_retries}")
Expand All @@ -379,7 +379,7 @@ async def process(self, context: AgentContext, call_next):
async def process(
self,
context: AgentContext,
call_next: Callable[[AgentContext], Awaitable[None]],
call_next: Callable[[], Awaitable[None]],
) -> None:
"""Process an agent invocation.

Expand Down Expand Up @@ -431,7 +431,7 @@ async def process(self, context: FunctionInvocationContext, call_next):
raise MiddlewareTermination()

# Execute function
await call_next(context)
await call_next()

# Cache result
if context.result:
Expand All @@ -446,7 +446,7 @@ async def process(self, context: FunctionInvocationContext, call_next):
async def process(
self,
context: FunctionInvocationContext,
call_next: Callable[[FunctionInvocationContext], Awaitable[None]],
call_next: Callable[[], Awaitable[None]],
) -> None:
"""Process a function invocation.

Expand Down Expand Up @@ -493,7 +493,7 @@ async def process(self, context: ChatContext, call_next):
context.messages.insert(0, Message(role="system", text=self.system_prompt))

# Continue execution
await call_next(context)
await call_next()


# Use with an agent
Expand All @@ -508,7 +508,7 @@ async def process(self, context: ChatContext, call_next):
async def process(
self,
context: ChatContext,
call_next: Callable[[ChatContext], Awaitable[None]],
call_next: Callable[[], Awaitable[None]],
) -> None:
"""Process a chat client request.

Expand All @@ -531,15 +531,13 @@ async def process(


# Pure function type definitions for convenience
AgentMiddlewareCallable = Callable[[AgentContext, Callable[[AgentContext], Awaitable[None]]], Awaitable[None]]
AgentMiddlewareCallable = Callable[[AgentContext, Callable[[], Awaitable[None]]], Awaitable[None]]
AgentMiddlewareTypes: TypeAlias = AgentMiddleware | AgentMiddlewareCallable

FunctionMiddlewareCallable = Callable[
[FunctionInvocationContext, Callable[[FunctionInvocationContext], Awaitable[None]]], Awaitable[None]
]
FunctionMiddlewareCallable = Callable[[FunctionInvocationContext, Callable[[], Awaitable[None]]], Awaitable[None]]
FunctionMiddlewareTypes: TypeAlias = FunctionMiddleware | FunctionMiddlewareCallable

ChatMiddlewareCallable = Callable[[ChatContext, Callable[[ChatContext], Awaitable[None]]], Awaitable[None]]
ChatMiddlewareCallable = Callable[[ChatContext, Callable[[], Awaitable[None]]], Awaitable[None]]
ChatMiddlewareTypes: TypeAlias = ChatMiddleware | ChatMiddlewareCallable

ChatAndFunctionMiddlewareTypes: TypeAlias = (
Expand Down Expand Up @@ -578,7 +576,7 @@ def agent_middleware(func: AgentMiddlewareCallable) -> AgentMiddlewareCallable:
@agent_middleware
async def logging_middleware(context: AgentContext, call_next):
print(f"Before: {context.agent.name}")
await call_next(context)
await call_next()
print(f"After: {context.result}")


Expand Down Expand Up @@ -611,7 +609,7 @@ def function_middleware(func: FunctionMiddlewareCallable) -> FunctionMiddlewareC
@function_middleware
async def logging_middleware(context: FunctionInvocationContext, call_next):
print(f"Calling: {context.function.name}")
await call_next(context)
await call_next()
print(f"Result: {context.result}")


Expand Down Expand Up @@ -644,7 +642,7 @@ def chat_middleware(func: ChatMiddlewareCallable) -> ChatMiddlewareCallable:
@chat_middleware
async def logging_middleware(context: ChatContext, call_next):
print(f"Messages: {len(context.messages)}")
await call_next(context)
await call_next()
print(f"Response: {context.result}")


Expand All @@ -666,10 +664,10 @@ class MiddlewareWrapper(Generic[ContextT]):
ContextT: The type of context object this middleware operates on.
"""

def __init__(self, func: Callable[[ContextT, Callable[[ContextT], Awaitable[None]]], Awaitable[None]]) -> None:
def __init__(self, func: Callable[[ContextT, Callable[[], Awaitable[None]]], Awaitable[None]]) -> None:
self.func = func

async def process(self, context: ContextT, call_next: Callable[[ContextT], Awaitable[None]]) -> None:
async def process(self, context: ContextT, call_next: Callable[[], Awaitable[None]]) -> None:
await self.func(context, call_next)


Expand Down Expand Up @@ -772,25 +770,25 @@ async def execute(
context.result = await context.result
return context.result

def create_next_handler(index: int) -> Callable[[AgentContext], Awaitable[None]]:
def create_next_handler(index: int) -> Callable[[], Awaitable[None]]:
if index >= len(self._middleware):

async def final_wrapper(c: AgentContext) -> None:
c.result = final_handler(c) # type: ignore[assignment]
if inspect.isawaitable(c.result):
c.result = await c.result
async def final_wrapper() -> None:
context.result = final_handler(context) # type: ignore[assignment]
if inspect.isawaitable(context.result):
context.result = await context.result

return final_wrapper

async def current_handler(c: AgentContext) -> None:
async def current_handler() -> None:
# MiddlewareTermination bubbles up to execute() to skip post-processing
await self._middleware[index].process(c, create_next_handler(index + 1))
await self._middleware[index].process(context, create_next_handler(index + 1))

return current_handler

first_handler = create_next_handler(0)
with contextlib.suppress(MiddlewareTermination):
await first_handler(context)
await first_handler()

if context.result and isinstance(context.result, ResponseStream):
for hook in context.stream_transform_hooks:
Expand Down Expand Up @@ -847,25 +845,25 @@ async def execute(
if not self._middleware:
return await final_handler(context)

def create_next_handler(index: int) -> Callable[[FunctionInvocationContext], Awaitable[None]]:
def create_next_handler(index: int) -> Callable[[], Awaitable[None]]:
if index >= len(self._middleware):

async def final_wrapper(c: FunctionInvocationContext) -> None:
c.result = final_handler(c)
if inspect.isawaitable(c.result):
c.result = await c.result
async def final_wrapper() -> None:
context.result = final_handler(context)
if inspect.isawaitable(context.result):
context.result = await context.result

return final_wrapper

async def current_handler(c: FunctionInvocationContext) -> None:
async def current_handler() -> None:
# MiddlewareTermination bubbles up to execute() to skip post-processing
await self._middleware[index].process(c, create_next_handler(index + 1))
await self._middleware[index].process(context, create_next_handler(index + 1))

return current_handler

first_handler = create_next_handler(0)
# Don't suppress MiddlewareTermination - let it propagate to signal loop termination
await first_handler(context)
await first_handler()

return context.result

Expand Down Expand Up @@ -922,25 +920,25 @@ async def execute(
raise ValueError("Streaming agent middleware requires a ResponseStream result.")
return context.result

def create_next_handler(index: int) -> Callable[[ChatContext], Awaitable[None]]:
def create_next_handler(index: int) -> Callable[[], Awaitable[None]]:
if index >= len(self._middleware):

async def final_wrapper(c: ChatContext) -> None:
c.result = final_handler(c) # type: ignore[assignment]
if inspect.isawaitable(c.result):
c.result = await c.result
async def final_wrapper() -> None:
context.result = final_handler(context) # type: ignore[assignment]
if inspect.isawaitable(context.result):
context.result = await context.result

return final_wrapper

async def current_handler(c: ChatContext) -> None:
async def current_handler() -> None:
# MiddlewareTermination bubbles up to execute() to skip post-processing
await self._middleware[index].process(c, create_next_handler(index + 1))
await self._middleware[index].process(context, create_next_handler(index + 1))

return current_handler

first_handler = create_next_handler(0)
with contextlib.suppress(MiddlewareTermination):
await first_handler(context)
await first_handler()

if context.result and isinstance(context.result, ResponseStream):
for hook in context.stream_transform_hooks:
Expand Down
42 changes: 14 additions & 28 deletions python/packages/core/tests/core/test_as_tool_kwargs_propagation.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,10 @@ async def test_as_tool_forwards_runtime_kwargs(self, client: MockChatClient) ->
captured_kwargs: dict[str, Any] = {}

@agent_middleware
async def capture_middleware(
context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]
) -> None:
async def capture_middleware(context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
# Capture kwargs passed to the sub-agent
captured_kwargs.update(context.kwargs)
await call_next(context)
await call_next()

# Setup mock response
client.responses = [
Expand Down Expand Up @@ -62,11 +60,9 @@ async def test_as_tool_excludes_arg_name_from_forwarded_kwargs(self, client: Moc
captured_kwargs: dict[str, Any] = {}

@agent_middleware
async def capture_middleware(
context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]
) -> None:
async def capture_middleware(context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
captured_kwargs.update(context.kwargs)
await call_next(context)
await call_next()

# Setup mock response
client.responses = [
Expand Down Expand Up @@ -99,12 +95,10 @@ async def test_as_tool_nested_delegation_propagates_kwargs(self, client: MockCha
captured_kwargs_list: list[dict[str, Any]] = []

@agent_middleware
async def capture_middleware(
context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]
) -> None:
async def capture_middleware(context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
# Capture kwargs at each level
captured_kwargs_list.append(dict(context.kwargs))
await call_next(context)
await call_next()

# Setup mock responses to trigger nested tool invocation: B calls tool C, then completes.
client.responses = [
Expand Down Expand Up @@ -162,11 +156,9 @@ async def test_as_tool_streaming_mode_forwards_kwargs(self, client: MockChatClie
captured_kwargs: dict[str, Any] = {}

@agent_middleware
async def capture_middleware(
context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]
) -> None:
async def capture_middleware(context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
captured_kwargs.update(context.kwargs)
await call_next(context)
await call_next()

# Setup mock streaming responses
from agent_framework import ChatResponseUpdate
Expand Down Expand Up @@ -224,11 +216,9 @@ async def test_as_tool_kwargs_with_chat_options(self, client: MockChatClient) ->
captured_kwargs: dict[str, Any] = {}

@agent_middleware
async def capture_middleware(
context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]
) -> None:
async def capture_middleware(context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
captured_kwargs.update(context.kwargs)
await call_next(context)
await call_next()

# Setup mock response
client.responses = [
Expand Down Expand Up @@ -266,16 +256,14 @@ async def test_as_tool_kwargs_isolated_per_invocation(self, client: MockChatClie
call_count = 0

@agent_middleware
async def capture_middleware(
context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]
) -> None:
async def capture_middleware(context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
nonlocal call_count
call_count += 1
if call_count == 1:
first_call_kwargs.update(context.kwargs)
elif call_count == 2:
second_call_kwargs.update(context.kwargs)
await call_next(context)
await call_next()

# Setup mock responses for both calls
client.responses = [
Expand Down Expand Up @@ -318,11 +306,9 @@ async def test_as_tool_excludes_conversation_id_from_forwarded_kwargs(self, clie
captured_kwargs: dict[str, Any] = {}

@agent_middleware
async def capture_middleware(
context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]
) -> None:
async def capture_middleware(context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
captured_kwargs.update(context.kwargs)
await call_next(context)
await call_next()

# Setup mock response
client.responses = [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2298,9 +2298,7 @@ def sometimes_fails(arg1: str) -> str:
class TerminateLoopMiddleware(FunctionMiddleware):
"""Middleware that raises MiddlewareTermination to exit the function calling loop."""

async def process(
self, context: FunctionInvocationContext, next_handler: Callable[[FunctionInvocationContext], Awaitable[None]]
) -> None:
async def process(self, context: FunctionInvocationContext, next_handler: Callable[[], Awaitable[None]]) -> None:
# Set result to a simple value - the framework will wrap it in FunctionResultContent
context.result = "terminated by middleware"
raise MiddlewareTermination
Expand Down Expand Up @@ -2355,14 +2353,12 @@ def ai_func(arg1: str) -> str:
class SelectiveTerminateMiddleware(FunctionMiddleware):
"""Only terminates for terminating_function."""

async def process(
self, context: FunctionInvocationContext, next_handler: Callable[[FunctionInvocationContext], Awaitable[None]]
) -> None:
async def process(self, context: FunctionInvocationContext, next_handler: Callable[[], Awaitable[None]]) -> None:
if context.function.name == "terminating_function":
# Set result to a simple value - the framework will wrap it in FunctionResultContent
context.result = "terminated by middleware"
raise MiddlewareTermination
await next_handler(context)
await next_handler()


async def test_terminate_loop_multiple_function_calls_one_terminates(chat_client_base: SupportsChatGetResponse):
Expand Down
Loading
Loading