diff --git a/dotnet/samples/GettingStarted/Workflows/Agents/CustomAgentExecutors/CustomAgentExecutors.csproj b/dotnet/samples/GettingStarted/Workflows/Agents/CustomAgentExecutors/CustomAgentExecutors.csproj index d0c0656ade..2ab222887c 100644 --- a/dotnet/samples/GettingStarted/Workflows/Agents/CustomAgentExecutors/CustomAgentExecutors.csproj +++ b/dotnet/samples/GettingStarted/Workflows/Agents/CustomAgentExecutors/CustomAgentExecutors.csproj @@ -16,6 +16,9 @@ + diff --git a/dotnet/samples/GettingStarted/Workflows/Agents/CustomAgentExecutors/Program.cs b/dotnet/samples/GettingStarted/Workflows/Agents/CustomAgentExecutors/Program.cs index 1017111082..242c02e7cd 100644 --- a/dotnet/samples/GettingStarted/Workflows/Agents/CustomAgentExecutors/Program.cs +++ b/dotnet/samples/GettingStarted/Workflows/Agents/CustomAgentExecutors/Program.cs @@ -109,7 +109,7 @@ internal sealed class SloganGeneratedEvent(SloganResult sloganResult) : Workflow /// 1. HandleAsync(string message): Handles the initial task to create a slogan. /// 2. HandleAsync(Feedback message): Handles feedback to improve the slogan. /// -internal sealed class SloganWriterExecutor : Executor +internal sealed partial class SloganWriterExecutor : Executor { private readonly AIAgent _agent; private AgentSession? _session; @@ -133,10 +133,7 @@ public SloganWriterExecutor(string id, IChatClient chatClient) : base(id) this._agent = new ChatClientAgent(chatClient, agentOptions); } - protected override RouteBuilder ConfigureRoutes(RouteBuilder routeBuilder) => - routeBuilder.AddHandler(this.HandleAsync) - .AddHandler(this.HandleAsync); - + [MessageHandler] public async ValueTask HandleAsync(string message, IWorkflowContext context, CancellationToken cancellationToken = default) { this._session ??= await this._agent.CreateSessionAsync(cancellationToken); @@ -149,6 +146,7 @@ public async ValueTask HandleAsync(string message, IWorkflowContex return sloganResult; } + [MessageHandler] public async ValueTask HandleAsync(FeedbackResult message, IWorkflowContext context, CancellationToken cancellationToken = default) { var feedbackMessage = $""" diff --git a/dotnet/samples/GettingStarted/Workflows/Observability/WorkflowAsAnAgent/WorkflowAsAnAgentObservability.csproj b/dotnet/samples/GettingStarted/Workflows/Observability/WorkflowAsAnAgent/WorkflowAsAnAgentObservability.csproj index 400142fc4b..6a2d02be9b 100644 --- a/dotnet/samples/GettingStarted/Workflows/Observability/WorkflowAsAnAgent/WorkflowAsAnAgentObservability.csproj +++ b/dotnet/samples/GettingStarted/Workflows/Observability/WorkflowAsAnAgent/WorkflowAsAnAgentObservability.csproj @@ -23,6 +23,9 @@ + diff --git a/dotnet/samples/GettingStarted/Workflows/Observability/WorkflowAsAnAgent/WorkflowHelper.cs b/dotnet/samples/GettingStarted/Workflows/Observability/WorkflowAsAnAgent/WorkflowHelper.cs index 8069a3e88e..04eb68a325 100644 --- a/dotnet/samples/GettingStarted/Workflows/Observability/WorkflowAsAnAgent/WorkflowHelper.cs +++ b/dotnet/samples/GettingStarted/Workflows/Observability/WorkflowAsAnAgent/WorkflowHelper.cs @@ -6,7 +6,7 @@ namespace WorkflowAsAnAgentObservabilitySample; -internal static class WorkflowHelper +internal static partial class WorkflowHelper { /// /// Creates a workflow that uses two language agents to process input concurrently. @@ -50,21 +50,16 @@ private static AIAgent GetLanguageAgent(string targetLanguage, IChatClient chatC /// /// Executor that starts the concurrent processing by sending messages to the agents. /// - private sealed class ConcurrentStartExecutor() : Executor("ConcurrentStartExecutor") + private sealed partial class ConcurrentStartExecutor() : Executor("ConcurrentStartExecutor") { - protected override RouteBuilder ConfigureRoutes(RouteBuilder routeBuilder) - { - return routeBuilder - .AddHandler>(this.RouteMessages) - .AddHandler(this.RouteTurnTokenAsync); - } - - private ValueTask RouteMessages(List messages, IWorkflowContext context, CancellationToken cancellationToken) + [MessageHandler] + internal ValueTask RouteMessages(List messages, IWorkflowContext context, CancellationToken cancellationToken) { return context.SendMessageAsync(messages, cancellationToken: cancellationToken); } - private ValueTask RouteTurnTokenAsync(TurnToken token, IWorkflowContext context, CancellationToken cancellationToken) + [MessageHandler] + internal ValueTask RouteTurnTokenAsync(TurnToken token, IWorkflowContext context, CancellationToken cancellationToken) { return context.SendMessageAsync(token, cancellationToken: cancellationToken); } @@ -73,7 +68,8 @@ private ValueTask RouteTurnTokenAsync(TurnToken token, IWorkflowContext context, /// /// Executor that aggregates the results from the concurrent agents. /// - private sealed class ConcurrentAggregationExecutor() : Executor>("ConcurrentAggregationExecutor") + [YieldsOutput(typeof(List))] + private sealed partial class ConcurrentAggregationExecutor() : Executor>("ConcurrentAggregationExecutor") { private readonly List _messages = []; diff --git a/dotnet/samples/GettingStarted/Workflows/_Foundational/08_WriterCriticWorkflow/08_WriterCriticWorkflow.csproj b/dotnet/samples/GettingStarted/Workflows/_Foundational/08_WriterCriticWorkflow/08_WriterCriticWorkflow.csproj index d7804cef4e..b9139c05ba 100644 --- a/dotnet/samples/GettingStarted/Workflows/_Foundational/08_WriterCriticWorkflow/08_WriterCriticWorkflow.csproj +++ b/dotnet/samples/GettingStarted/Workflows/_Foundational/08_WriterCriticWorkflow/08_WriterCriticWorkflow.csproj @@ -1,4 +1,4 @@ - + Exe @@ -11,6 +11,10 @@ + + diff --git a/dotnet/samples/GettingStarted/Workflows/_Foundational/08_WriterCriticWorkflow/Program.cs b/dotnet/samples/GettingStarted/Workflows/_Foundational/08_WriterCriticWorkflow/Program.cs index 38bb80dddc..7df3b0c9f5 100644 --- a/dotnet/samples/GettingStarted/Workflows/_Foundational/08_WriterCriticWorkflow/Program.cs +++ b/dotnet/samples/GettingStarted/Workflows/_Foundational/08_WriterCriticWorkflow/Program.cs @@ -196,7 +196,7 @@ internal sealed class CriticDecision /// Executor that creates or revises content based on user requests or critic feedback. /// This executor demonstrates multiple message handlers for different input types. /// -internal sealed class WriterExecutor : Executor +internal sealed partial class WriterExecutor : Executor { private readonly AIAgent _agent; @@ -213,15 +213,11 @@ Maintain the same topic and length requirements. ); } - protected override RouteBuilder ConfigureRoutes(RouteBuilder routeBuilder) => - routeBuilder - .AddHandler(this.HandleInitialRequestAsync) - .AddHandler(this.HandleRevisionRequestAsync); - /// /// Handles the initial writing request from the user. /// - private async ValueTask HandleInitialRequestAsync( + [MessageHandler] + public async ValueTask HandleInitialRequestAsync( string message, IWorkflowContext context, CancellationToken cancellationToken = default) @@ -232,7 +228,8 @@ private async ValueTask HandleInitialRequestAsync( /// /// Handles revision requests from the critic with feedback. /// - private async ValueTask HandleRevisionRequestAsync( + [MessageHandler] + public async ValueTask HandleRevisionRequestAsync( CriticDecision decision, IWorkflowContext context, CancellationToken cancellationToken = default) diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows.Declarative/Interpreter/DeclarativeActionExecutor.cs b/dotnet/src/Microsoft.Agents.AI.Workflows.Declarative/Interpreter/DeclarativeActionExecutor.cs index 803d060815..2a25ee3476 100644 --- a/dotnet/src/Microsoft.Agents.AI.Workflows.Declarative/Interpreter/DeclarativeActionExecutor.cs +++ b/dotnet/src/Microsoft.Agents.AI.Workflows.Declarative/Interpreter/DeclarativeActionExecutor.cs @@ -60,6 +60,7 @@ public ValueTask ResetAsync() } /// + [SendsMessage(typeof(ActionExecutorResult))] public override async ValueTask HandleAsync(ActionExecutorResult message, IWorkflowContext context, CancellationToken cancellationToken = default) { if (this.Model.Disabled) diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows.Declarative/Interpreter/DeclarativeWorkflowExecutor.cs b/dotnet/src/Microsoft.Agents.AI.Workflows.Declarative/Interpreter/DeclarativeWorkflowExecutor.cs index 7436e64446..9c6f7f3e6f 100644 --- a/dotnet/src/Microsoft.Agents.AI.Workflows.Declarative/Interpreter/DeclarativeWorkflowExecutor.cs +++ b/dotnet/src/Microsoft.Agents.AI.Workflows.Declarative/Interpreter/DeclarativeWorkflowExecutor.cs @@ -4,6 +4,7 @@ using System.Threading; using System.Threading.Tasks; using Microsoft.Agents.AI.Workflows.Declarative.Extensions; +using Microsoft.Agents.AI.Workflows.Declarative.Kit; using Microsoft.Agents.AI.Workflows.Declarative.PowerFx; using Microsoft.Extensions.AI; @@ -25,6 +26,7 @@ public ValueTask ResetAsync() return default; } + [SendsMessage(typeof(ActionExecutorResult))] public override async ValueTask HandleAsync(TInput message, IWorkflowContext context, CancellationToken cancellationToken = default) { // No state to restore if we're starting from the beginning. diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows.Declarative/Interpreter/DelegateActionExecutor.cs b/dotnet/src/Microsoft.Agents.AI.Workflows.Declarative/Interpreter/DelegateActionExecutor.cs index 1d9a2c7552..aa0c5759e3 100644 --- a/dotnet/src/Microsoft.Agents.AI.Workflows.Declarative/Interpreter/DelegateActionExecutor.cs +++ b/dotnet/src/Microsoft.Agents.AI.Workflows.Declarative/Interpreter/DelegateActionExecutor.cs @@ -40,6 +40,7 @@ public ValueTask ResetAsync() return default; } + [SendsMessage(typeof(ActionExecutorResult))] public override async ValueTask HandleAsync(TMessage message, IWorkflowContext context, CancellationToken cancellationToken = default) { if (this._action is not null) diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows.Declarative/Kit/ActionExecutor.cs b/dotnet/src/Microsoft.Agents.AI.Workflows.Declarative/Kit/ActionExecutor.cs index db348f29dc..cf636effaf 100644 --- a/dotnet/src/Microsoft.Agents.AI.Workflows.Declarative/Kit/ActionExecutor.cs +++ b/dotnet/src/Microsoft.Agents.AI.Workflows.Declarative/Kit/ActionExecutor.cs @@ -73,6 +73,7 @@ public ValueTask ResetAsync() } /// + [SendsMessage(typeof(ActionExecutorResult))] public override async ValueTask HandleAsync(TMessage message, IWorkflowContext context, CancellationToken cancellationToken) { object? result = await this.ExecuteAsync(new DeclarativeWorkflowContext(context, this._session.State), message, cancellationToken).ConfigureAwait(false); diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows.Declarative/Kit/RootExecutor.cs b/dotnet/src/Microsoft.Agents.AI.Workflows.Declarative/Kit/RootExecutor.cs index 641ecc78a0..ff643510df 100644 --- a/dotnet/src/Microsoft.Agents.AI.Workflows.Declarative/Kit/RootExecutor.cs +++ b/dotnet/src/Microsoft.Agents.AI.Workflows.Declarative/Kit/RootExecutor.cs @@ -54,6 +54,7 @@ public ValueTask ResetAsync() } /// + [SendsMessage(typeof(ActionExecutorResult))] public override async ValueTask HandleAsync(TInput message, IWorkflowContext context, CancellationToken cancellationToken) { DeclarativeWorkflowContext declarativeContext = new(context, this._state); diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows.Generators/Analysis/SemanticAnalyzer.cs b/dotnet/src/Microsoft.Agents.AI.Workflows.Generators/Analysis/SemanticAnalyzer.cs index 62c9817252..b62377a971 100644 --- a/dotnet/src/Microsoft.Agents.AI.Workflows.Generators/Analysis/SemanticAnalyzer.cs +++ b/dotnet/src/Microsoft.Agents.AI.Workflows.Generators/Analysis/SemanticAnalyzer.cs @@ -68,7 +68,7 @@ public static MethodAnalysisResult AnalyzeHandlerMethod( string classKey = GetClassKey(classSymbol); bool isPartialClass = IsPartialClass(classSymbol, cancellationToken); bool derivesFromExecutor = DerivesFromExecutor(classSymbol); - bool hasManualConfigureRoutes = HasConfigureRoutesDefined(classSymbol); + bool configureProtocol = HasConfigureProtocolDefined(classSymbol); // Extract class metadata string? @namespace = classSymbol.ContainingNamespace?.IsGlobalNamespace == true @@ -78,7 +78,7 @@ public static MethodAnalysisResult AnalyzeHandlerMethod( string? genericParameters = GetGenericParameters(classSymbol); bool isNested = classSymbol.ContainingType != null; string containingTypeChain = GetContainingTypeChain(classSymbol); - bool baseHasConfigureRoutes = BaseHasConfigureRoutes(classSymbol); + bool baseHasConfigureProtocol = BaseHasConfigureProtocol(classSymbol); ImmutableEquatableArray classSendTypes = GetClassLevelTypes(classSymbol, SendsMessageAttributeName); ImmutableEquatableArray classYieldTypes = GetClassLevelTypes(classSymbol, YieldsOutputAttributeName); @@ -96,8 +96,8 @@ public static MethodAnalysisResult AnalyzeHandlerMethod( return new MethodAnalysisResult( classKey, @namespace, className, genericParameters, isNested, containingTypeChain, - baseHasConfigureRoutes, classSendTypes, classYieldTypes, - isPartialClass, derivesFromExecutor, hasManualConfigureRoutes, + baseHasConfigureProtocol, classSendTypes, classYieldTypes, + isPartialClass, derivesFromExecutor, configureProtocol, classLocation, handler, Diagnostics: new ImmutableEquatableArray(methodDiagnostics.ToImmutable())); @@ -152,7 +152,7 @@ public static AnalysisResult CombineHandlerMethodResults(IEnumerable(handlers), first.ClassSendTypes, first.ClassYieldTypes); @@ -211,7 +211,7 @@ public static ImmutableArray AnalyzeClassProtocolAttribute( string classKey = GetClassKey(classSymbol); bool isPartialClass = IsPartialClass(classSymbol, cancellationToken); bool derivesFromExecutor = DerivesFromExecutor(classSymbol); - bool hasManualConfigureRoutes = HasConfigureRoutesDefined(classSymbol); + bool hasManualConfigureProtocol = HasConfigureProtocolDefined(classSymbol); string? @namespace = classSymbol.ContainingNamespace?.IsGlobalNamespace == true ? null @@ -240,7 +240,7 @@ public static ImmutableArray AnalyzeClassProtocolAttribute( containingTypeChain, isPartialClass, derivesFromExecutor, - hasManualConfigureRoutes, + hasManualConfigureProtocol, classLocation, typeName, attributeKind)); @@ -251,12 +251,16 @@ public static ImmutableArray AnalyzeClassProtocolAttribute( } /// - /// Combines ClassProtocolInfo results into an AnalysisResult for classes that only have protocol attributes - /// (no [MessageHandler] methods). This generates only ConfigureSentTypes/ConfigureYieldTypes overrides. + /// Combines ClassProtocolInfo results into an AnalysisResult for classes that only have IO attributes + /// (no [MessageHandler] methods). This generates only .SendsMessage/.YieldsMessage calls in the protocol + /// configuration. /// + /// + /// This is likely to be seen combined with the basic one-method Executor%lt;TIn> or Executor<TIn, TOut> + /// /// The protocol info entries for the class. /// The combined analysis result. - public static AnalysisResult CombineProtocolOnlyResults(IEnumerable protocolInfos) + public static AnalysisResult CombineOutputOnlyResults(IEnumerable protocolInfos) { List protocols = protocolInfos.ToList(); if (protocols.Count == 0) @@ -317,7 +321,7 @@ public static AnalysisResult CombineProtocolOnlyResults(IEnumerable.Empty, ClassSendTypes: new ImmutableEquatableArray(sendTypes.ToImmutable()), ClassYieldTypes: new ImmutableEquatableArray(yieldTypes.ToImmutable())); @@ -394,12 +398,12 @@ private static bool DerivesFromExecutor(INamedTypeSymbol classSymbol) } /// - /// Checks if this class directly defines ConfigureRoutes (not inherited). + /// Checks if this class directly defines ConfigureProtocol (not inherited). /// If so, we skip generation to avoid conflicting with user's manual implementation. /// - private static bool HasConfigureRoutesDefined(INamedTypeSymbol classSymbol) + private static bool HasConfigureProtocolDefined(INamedTypeSymbol classSymbol) { - foreach (var member in classSymbol.GetMembers("ConfigureRoutes")) + foreach (var member in classSymbol.GetMembers("ConfigureProtocol")) { if (member is IMethodSymbol method && !method.IsAbstract && SymbolEqualityComparer.Default.Equals(method.ContainingType, classSymbol)) @@ -412,22 +416,22 @@ private static bool HasConfigureRoutesDefined(INamedTypeSymbol classSymbol) } /// - /// Checks if any base class (between this class and Executor) defines ConfigureRoutes. - /// If so, generated code should call base.ConfigureRoutes() to preserve inherited handlers. + /// Checks if any base class (between this class and Executor) defines ConfigureProtocol. + /// If so, generated code should call base.ConfigureProtocol() to preserve inherited handlers. /// - private static bool BaseHasConfigureRoutes(INamedTypeSymbol classSymbol) + private static bool BaseHasConfigureProtocol(INamedTypeSymbol classSymbol) { INamedTypeSymbol? baseType = classSymbol.BaseType; while (baseType != null) { string fullName = baseType.OriginalDefinition.ToDisplayString(); - // Stop at Executor - its ConfigureRoutes is abstract/empty + // Stop at Executor - its ConfigureProtocol is abstract/empty if (fullName == ExecutorTypeName) { return false; } - foreach (var member in baseType.GetMembers("ConfigureRoutes")) + foreach (var member in baseType.GetMembers("ConfigureProtocol")) { if (member is IMethodSymbol method && !method.IsAbstract) { diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows.Generators/Diagnostics/DiagnosticDescriptors.cs b/dotnet/src/Microsoft.Agents.AI.Workflows.Generators/Diagnostics/DiagnosticDescriptors.cs index 4afc7a1697..2b2bd8fd04 100644 --- a/dotnet/src/Microsoft.Agents.AI.Workflows.Generators/Diagnostics/DiagnosticDescriptors.cs +++ b/dotnet/src/Microsoft.Agents.AI.Workflows.Generators/Diagnostics/DiagnosticDescriptors.cs @@ -86,10 +86,10 @@ private static DiagnosticDescriptor Register(DiagnosticDescriptor descriptor) /// /// MAFGENWF006: ConfigureRoutes already defined. /// - public static readonly DiagnosticDescriptor ConfigureRoutesAlreadyDefined = Register(new( + public static readonly DiagnosticDescriptor ConfigureProtocolAlreadyDefined = Register(new( id: "MAFGENWF006", - title: "ConfigureRoutes already defined", - messageFormat: "Class '{0}' already defines ConfigureRoutes; [MessageHandler] methods will be ignored", + title: "ConfigureProtocol already defined", + messageFormat: "Class '{0}' already defines ConfigureProtocol; [MessageHandler] methods will be ignored", category: Category, defaultSeverity: DiagnosticSeverity.Info, isEnabledByDefault: true)); diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows.Generators/ExecutorRouteGenerator.cs b/dotnet/src/Microsoft.Agents.AI.Workflows.Generators/ExecutorRouteGenerator.cs index 181e799ae2..e323804e59 100644 --- a/dotnet/src/Microsoft.Agents.AI.Workflows.Generators/ExecutorRouteGenerator.cs +++ b/dotnet/src/Microsoft.Agents.AI.Workflows.Generators/ExecutorRouteGenerator.cs @@ -120,7 +120,7 @@ private static IEnumerable CombineAllResults( { if (!processedClasses.Contains(kvp.Key)) { - yield return SemanticAnalyzer.CombineProtocolOnlyResults(kvp.Value); + yield return SemanticAnalyzer.CombineOutputOnlyResults(kvp.Value); } } } diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows.Generators/Generation/SourceBuilder.cs b/dotnet/src/Microsoft.Agents.AI.Workflows.Generators/Generation/SourceBuilder.cs index 0779a56045..9a74c88447 100644 --- a/dotnet/src/Microsoft.Agents.AI.Workflows.Generators/Generation/SourceBuilder.cs +++ b/dotnet/src/Microsoft.Agents.AI.Workflows.Generators/Generation/SourceBuilder.cs @@ -1,6 +1,7 @@ // Copyright (c) Microsoft. All rights reserved. using System.Collections.Generic; +using System.Linq; using System.Text; using Microsoft.Agents.AI.Workflows.Generators.Models; @@ -16,6 +17,8 @@ namespace Microsoft.Agents.AI.Workflows.Generators.Generation; /// internal static class SourceBuilder { + internal const string IndentUnit = " "; + /// /// Generates the complete source file for an executor's generated partial class. /// @@ -53,7 +56,8 @@ public static string Generate(ExecutorInfo info) { sb.AppendLine($"{indent}partial class {containingType}"); sb.AppendLine($"{indent}{{"); - indent += " "; + + indent += IndentUnit; } } @@ -61,30 +65,49 @@ public static string Generate(ExecutorInfo info) sb.AppendLine($"{indent}partial class {info.ClassName}{info.GenericParameters}"); sb.AppendLine($"{indent}{{"); - string memberIndent = indent + " "; - bool hasContent = false; + string memberIndent = indent + IndentUnit; - // Only generate ConfigureRoutes if there are handlers - if (info.Handlers.Count > 0) + // ConfigureProtocol + sb.AppendLine($"{memberIndent}protected override ProtocolBuilder ConfigureProtocol(ProtocolBuilder protocolBuilder)"); + sb.AppendLine($"{memberIndent}{{"); + + string bodyIndent = memberIndent + IndentUnit; + + if (info.BaseHasConfigureProtocol) { - GenerateConfigureRoutes(sb, info, memberIndent); - hasContent = true; + sb.Append($"{bodyIndent}return base.ConfigureProtocol(protocolBuilder)"); + bodyIndent += " "; + } + else + { + sb.Append($"{bodyIndent}return protocolBuilder"); } // Only generate protocol overrides if [SendsMessage] or [YieldsOutput] attributes are present. // Without these attributes, we rely on the base class defaults. - if (info.ShouldGenerateProtocolOverrides) + if (info.ShouldGenerateSentMessageRegistrations) { - if (hasContent) - { - sb.AppendLine(); - } + GenerateConfigureSentTypes(sb, info, bodyIndent); + } - GenerateConfigureSentTypes(sb, info, memberIndent); - sb.AppendLine(); - GenerateConfigureYieldTypes(sb, info, memberIndent); + if (info.ShouldGenerateYieldedOutputRegistrations) + { + GenerateConfigureYieldTypes(sb, info, bodyIndent); + } + + // Only generate ConfigureRoutes if there are handlers + if (info.Handlers.Count > 0) + { + GenerateConfigureRoutes(sb, info, bodyIndent); + } + else + { + sb.AppendLine(";"); } + // Close ConfigureProtocol + sb.AppendLine($"{memberIndent}}}"); + // Close class sb.AppendLine($"{indent}}}"); @@ -107,24 +130,19 @@ public static string Generate(ExecutorInfo info) /// private static void GenerateConfigureRoutes(StringBuilder sb, ExecutorInfo info, string indent) { - sb.AppendLine($"{indent}protected override RouteBuilder ConfigureRoutes(RouteBuilder routeBuilder)"); - sb.AppendLine($"{indent}{{"); + sb.AppendLine(".ConfigureRoutes(ConfigureRoutes);"); - string bodyIndent = indent + " "; + sb.AppendLine($"{indent}void ConfigureRoutes(RouteBuilder routeBuilder)"); + sb.AppendLine($"{indent}{{"); - // If a base class has its own ConfigureRoutes, chain to it first to preserve inherited handlers. - if (info.BaseHasConfigureRoutes) - { - sb.AppendLine($"{bodyIndent}routeBuilder = base.ConfigureRoutes(routeBuilder);"); - sb.AppendLine(); - } + string bodyIndent = indent + IndentUnit; // Generate handler registrations using fluent AddHandler calls. // RouteBuilder.AddHandler registers a void handler; AddHandler registers one with a return value. if (info.Handlers.Count == 1) { HandlerInfo handler = info.Handlers[0]; - sb.AppendLine($"{bodyIndent}return routeBuilder"); + sb.AppendLine($"{bodyIndent}routeBuilder"); sb.Append($"{bodyIndent} .AddHandler"); AppendHandlerGenericArgs(sb, handler); sb.AppendLine($"(this.{handler.MethodName});"); @@ -132,7 +150,7 @@ private static void GenerateConfigureRoutes(StringBuilder sb, ExecutorInfo info, else { // Multiple handlers: chain fluent calls, semicolon only on the last one. - sb.AppendLine($"{bodyIndent}return routeBuilder"); + sb.AppendLine($"{bodyIndent}routeBuilder"); for (int i = 0; i < info.Handlers.Count; i++) { @@ -178,28 +196,24 @@ private static void AppendHandlerGenericArgs(StringBuilder sb, HandlerInfo handl /// private static void GenerateConfigureSentTypes(StringBuilder sb, ExecutorInfo info, string indent) { - sb.AppendLine($"{indent}protected override ISet ConfigureSentTypes()"); - sb.AppendLine($"{indent}{{"); - - string bodyIndent = indent + " "; - - sb.AppendLine($"{bodyIndent}var types = base.ConfigureSentTypes();"); + // Track types to avoid emitting duplicate Add calls (the set handles runtime dedup, + // but cleaner generated code is easier to read). + var addedTypes = new HashSet(); - foreach (var type in info.ClassSendTypes) + foreach (var type in info.ClassSendTypes.Where(type => addedTypes.Add(type))) { - sb.AppendLine($"{bodyIndent}types.Add(typeof({type}));"); + sb.AppendLine($".SendsMessage<{type}>()"); + sb.Append(indent); } foreach (var handler in info.Handlers) { - foreach (var type in handler.SendTypes) + foreach (var type in handler.SendTypes.Where(type => addedTypes.Add(type))) { - sb.AppendLine($"{bodyIndent}types.Add(typeof({type}));"); + sb.AppendLine($".SendsMessage<{type}>()"); + sb.Append(indent); } } - - sb.AppendLine($"{bodyIndent}return types;"); - sb.AppendLine($"{indent}}}"); } /// @@ -211,43 +225,23 @@ private static void GenerateConfigureSentTypes(StringBuilder sb, ExecutorInfo in /// private static void GenerateConfigureYieldTypes(StringBuilder sb, ExecutorInfo info, string indent) { - sb.AppendLine($"{indent}protected override ISet ConfigureYieldTypes()"); - sb.AppendLine($"{indent}{{"); - - string bodyIndent = indent + " "; - - sb.AppendLine($"{bodyIndent}var types = base.ConfigureYieldTypes();"); - // Track types to avoid emitting duplicate Add calls (the set handles runtime dedup, // but cleaner generated code is easier to read). var addedTypes = new HashSet(); - foreach (var type in info.ClassYieldTypes) + foreach (var type in info.ClassYieldTypes.Where(type => addedTypes.Add(type))) { - if (addedTypes.Add(type)) - { - sb.AppendLine($"{bodyIndent}types.Add(typeof({type}));"); - } + sb.AppendLine($".YieldsOutput<{type}>()"); + sb.Append(indent); } foreach (var handler in info.Handlers) { - foreach (var type in handler.YieldTypes) - { - if (addedTypes.Add(type)) - { - sb.AppendLine($"{bodyIndent}types.Add(typeof({type}));"); - } - } - - // Handler return types (ValueTask) are implicitly yielded. - if (handler.HasOutput && handler.OutputTypeName != null && addedTypes.Add(handler.OutputTypeName)) + foreach (var type in handler.YieldTypes.Where(type => addedTypes.Add(type))) { - sb.AppendLine($"{bodyIndent}types.Add(typeof({handler.OutputTypeName}));"); + sb.AppendLine($".YieldsOutput<{type}>()"); + sb.Append(indent); } } - - sb.AppendLine($"{bodyIndent}return types;"); - sb.AppendLine($"{indent}}}"); } } diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows.Generators/Models/ExecutorInfo.cs b/dotnet/src/Microsoft.Agents.AI.Workflows.Generators/Models/ExecutorInfo.cs index 507927d875..3da71d2802 100644 --- a/dotnet/src/Microsoft.Agents.AI.Workflows.Generators/Models/ExecutorInfo.cs +++ b/dotnet/src/Microsoft.Agents.AI.Workflows.Generators/Models/ExecutorInfo.cs @@ -11,7 +11,7 @@ namespace Microsoft.Agents.AI.Workflows.Generators.Models; /// The generic type parameters of the class (e.g., "<T, U>"), or null if not generic. /// Whether the class is nested inside another class. /// The chain of containing types for nested classes (e.g., "OuterClass.InnerClass"). Empty string if not nested. -/// Whether the base class has a ConfigureRoutes method that should be called. +/// Whether the base class has a ConfigureRoutes method that should be called. /// The list of handler methods to register. /// The types declared via class-level [SendsMessage] attributes. /// The types declared via class-level [YieldsOutput] attributes. @@ -21,19 +21,20 @@ internal sealed record ExecutorInfo( string? GenericParameters, bool IsNested, string ContainingTypeChain, - bool BaseHasConfigureRoutes, + bool BaseHasConfigureProtocol, ImmutableEquatableArray Handlers, ImmutableEquatableArray ClassSendTypes, ImmutableEquatableArray ClassYieldTypes) { /// - /// Gets whether any protocol type overrides should be generated. + /// Gets whether any "Sent" message type registrations should be generated. /// - public bool ShouldGenerateProtocolOverrides => - !this.ClassSendTypes.IsEmpty || - !this.ClassYieldTypes.IsEmpty || - this.HasHandlerWithSendTypes || - this.HasHandlerWithYieldTypes; + public bool ShouldGenerateSentMessageRegistrations => !this.ClassSendTypes.IsEmpty || this.HasHandlerWithSendTypes; + + /// + /// Gets whether any "Yielded" output type registrations should be generated. + /// + public bool ShouldGenerateYieldedOutputRegistrations => !this.ClassYieldTypes.IsEmpty || this.HasHandlerWithYieldTypes; /// /// Gets whether any handler has explicit Send types. diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows.Generators/Models/MethodAnalysisResult.cs b/dotnet/src/Microsoft.Agents.AI.Workflows.Generators/Models/MethodAnalysisResult.cs index f9493c5d93..fb3fafc6c2 100644 --- a/dotnet/src/Microsoft.Agents.AI.Workflows.Generators/Models/MethodAnalysisResult.cs +++ b/dotnet/src/Microsoft.Agents.AI.Workflows.Generators/Models/MethodAnalysisResult.cs @@ -22,7 +22,7 @@ internal sealed record MethodAnalysisResult( string? GenericParameters, bool IsNested, string ContainingTypeChain, - bool BaseHasConfigureRoutes, + bool BaseHasConfigureProtocol, ImmutableEquatableArray ClassSendTypes, ImmutableEquatableArray ClassYieldTypes, diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows/Attributes/SendsMessageAttribute.cs b/dotnet/src/Microsoft.Agents.AI.Workflows/Attributes/SendsMessageAttribute.cs index 3b5620fc37..93829be21e 100644 --- a/dotnet/src/Microsoft.Agents.AI.Workflows/Attributes/SendsMessageAttribute.cs +++ b/dotnet/src/Microsoft.Agents.AI.Workflows/Attributes/SendsMessageAttribute.cs @@ -29,7 +29,7 @@ namespace Microsoft.Agents.AI.Workflows; /// } /// /// -[AttributeUsage(AttributeTargets.Class, AllowMultiple = true, Inherited = true)] +[AttributeUsage(AttributeTargets.Class | AttributeTargets.Method, AllowMultiple = true, Inherited = true)] public sealed class SendsMessageAttribute : Attribute { /// diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows/Attributes/YieldsOutputAttribute.cs b/dotnet/src/Microsoft.Agents.AI.Workflows/Attributes/YieldsOutputAttribute.cs index 5aad434b1d..11093645b2 100644 --- a/dotnet/src/Microsoft.Agents.AI.Workflows/Attributes/YieldsOutputAttribute.cs +++ b/dotnet/src/Microsoft.Agents.AI.Workflows/Attributes/YieldsOutputAttribute.cs @@ -29,7 +29,7 @@ namespace Microsoft.Agents.AI.Workflows; /// } /// /// -[AttributeUsage(AttributeTargets.Class, AllowMultiple = true, Inherited = true)] +[AttributeUsage(AttributeTargets.Class | AttributeTargets.Method, AllowMultiple = true, Inherited = true)] public sealed class YieldsOutputAttribute : Attribute { /// diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows/ChatForwardingExecutor.cs b/dotnet/src/Microsoft.Agents.AI.Workflows/ChatForwardingExecutor.cs index 5bb2f5e237..93925dec32 100644 --- a/dotnet/src/Microsoft.Agents.AI.Workflows/ChatForwardingExecutor.cs +++ b/dotnet/src/Microsoft.Agents.AI.Workflows/ChatForwardingExecutor.cs @@ -34,19 +34,29 @@ public sealed class ChatForwardingExecutor(string id, ChatForwardingExecutorOpti private readonly ChatRole? _stringMessageChatRole = options?.StringMessageChatRole; /// - protected override RouteBuilder ConfigureRoutes(RouteBuilder routeBuilder) + protected override ProtocolBuilder ConfigureProtocol(ProtocolBuilder protocolBuilder) { - if (this._stringMessageChatRole.HasValue) + return protocolBuilder.ConfigureRoutes(ConfigureRoutes) + .SendsMessage() + .SendsMessage>() + .SendsMessage() + .SendsMessage(); + + void ConfigureRoutes(RouteBuilder routeBuilder) { - routeBuilder = routeBuilder.AddHandler( - (message, context) => context.SendMessageAsync(new ChatMessage(ChatRole.User, message))); - } + if (this._stringMessageChatRole.HasValue) + { + routeBuilder = routeBuilder.AddHandler( + (message, context) => context.SendMessageAsync(new ChatMessage(ChatRole.User, message))); + } - return routeBuilder.AddHandler(ForwardMessageAsync) - .AddHandler>(ForwardMessagesAsync) - .AddHandler(ForwardMessagesAsync) - .AddHandler>(ForwardMessagesAsync) - .AddHandler(ForwardTurnTokenAsync); + routeBuilder.AddHandler(ForwardMessageAsync) + .AddHandler>(ForwardMessagesAsync) + // remove this once we internalize the typecheck logic + .AddHandler(ForwardMessagesAsync) + //.AddHandler>(ForwardMessagesAsync) + .AddHandler(ForwardTurnTokenAsync); + } } private static ValueTask ForwardMessageAsync(ChatMessage message, IWorkflowContext context, CancellationToken cancellationToken) diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows/ChatProtocol.cs b/dotnet/src/Microsoft.Agents.AI.Workflows/ChatProtocol.cs index 5a328bc8c8..fc9d59ad25 100644 --- a/dotnet/src/Microsoft.Agents.AI.Workflows/ChatProtocol.cs +++ b/dotnet/src/Microsoft.Agents.AI.Workflows/ChatProtocol.cs @@ -26,7 +26,7 @@ public static class ChatProtocolExtensions /// langword="false"/>. public static bool IsChatProtocol(this ProtocolDescriptor descriptor, bool allowCatchAll = false) { - bool foundListChatMessageInput = false; + bool foundIEnumerableChatMessageInput = false; bool foundTurnTokenInput = false; if (allowCatchAll && descriptor.AcceptsAll) @@ -40,9 +40,9 @@ public static bool IsChatProtocol(this ProtocolDescriptor descriptor, bool allow // output type. foreach (Type inputType in descriptor.Accepts) { - if (inputType == typeof(List)) + if (inputType == typeof(IEnumerable)) { - foundListChatMessageInput = true; + foundIEnumerableChatMessageInput = true; } else if (inputType == typeof(TurnToken)) { @@ -50,7 +50,7 @@ public static bool IsChatProtocol(this ProtocolDescriptor descriptor, bool allow } } - return foundListChatMessageInput && foundTurnTokenInput; + return foundIEnumerableChatMessageInput && foundTurnTokenInput; } /// diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows/ChatProtocolExecutor.cs b/dotnet/src/Microsoft.Agents.AI.Workflows/ChatProtocolExecutor.cs index 8a3a8dd564..18541464c1 100644 --- a/dotnet/src/Microsoft.Agents.AI.Workflows/ChatProtocolExecutor.cs +++ b/dotnet/src/Microsoft.Agents.AI.Workflows/ChatProtocolExecutor.cs @@ -67,19 +67,26 @@ protected ChatProtocolExecutor(string id, ChatProtocolExecutorOptions? options = protected bool AutoSendTurnToken => this._options.AutoSendTurnToken; /// - protected override RouteBuilder ConfigureRoutes(RouteBuilder routeBuilder) + protected override ProtocolBuilder ConfigureProtocol(ProtocolBuilder protocolBuilder) { - if (this.SupportsStringMessage) + return protocolBuilder.ConfigureRoutes(ConfigureRoutes) + .SendsMessage>() + .SendsMessage(); + + void ConfigureRoutes(RouteBuilder routeBuilder) { - routeBuilder = routeBuilder.AddHandler( - (message, context) => this.AddMessageAsync(new(this.StringMessageChatRole.Value, message), context)); - } + if (this.SupportsStringMessage) + { + routeBuilder = routeBuilder.AddHandler( + (message, context) => this.AddMessageAsync(new(this.StringMessageChatRole.Value, message), context)); + } - return routeBuilder.AddHandler(this.AddMessageAsync) - .AddHandler>(this.AddMessagesAsync) - .AddHandler(this.AddMessagesAsync) - .AddHandler>(this.AddMessagesAsync) - .AddHandler(this.TakeTurnAsync); + routeBuilder.AddHandler(this.AddMessageAsync) + .AddHandler>(this.AddMessagesAsync) + .AddHandler(this.AddMessagesAsync) + //.AddHandler>(this.AddMessagesAsync) + .AddHandler(this.TakeTurnAsync); + } } /// diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows/Checkpointing/PortableMessageEnvelope.cs b/dotnet/src/Microsoft.Agents.AI.Workflows/Checkpointing/PortableMessageEnvelope.cs index 96fb7c88a2..dcf8680009 100644 --- a/dotnet/src/Microsoft.Agents.AI.Workflows/Checkpointing/PortableMessageEnvelope.cs +++ b/dotnet/src/Microsoft.Agents.AI.Workflows/Checkpointing/PortableMessageEnvelope.cs @@ -25,6 +25,7 @@ public PortableMessageEnvelope(MessageEnvelope envelope) { this.MessageType = envelope.MessageType; this.Message = new PortableValue(envelope.Message); + this.Source = envelope.Source; this.TargetId = envelope.TargetId; } diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows/Execution/DirectEdgeRunner.cs b/dotnet/src/Microsoft.Agents.AI.Workflows/Execution/DirectEdgeRunner.cs index db643ab441..568c8d4b23 100644 --- a/dotnet/src/Microsoft.Agents.AI.Workflows/Execution/DirectEdgeRunner.cs +++ b/dotnet/src/Microsoft.Agents.AI.Workflows/Execution/DirectEdgeRunner.cs @@ -1,6 +1,7 @@ // Copyright (c) Microsoft. All rights reserved. using System; +using System.Threading; using System.Threading.Tasks; using Microsoft.Agents.AI.Workflows.Observability; @@ -9,10 +10,7 @@ namespace Microsoft.Agents.AI.Workflows.Execution; internal sealed class DirectEdgeRunner(IRunnerContext runContext, DirectEdgeData edgeData) : EdgeRunner(runContext, edgeData) { - private async ValueTask FindRouterAsync(IStepTracer? tracer) => await this.RunContext.EnsureExecutorAsync(this.EdgeData.SinkId, tracer) - .ConfigureAwait(false); - - protected internal override async ValueTask ChaseEdgeAsync(MessageEnvelope envelope, IStepTracer? stepTracer) + protected internal override async ValueTask ChaseEdgeAsync(MessageEnvelope envelope, IStepTracer? stepTracer, CancellationToken cancellationToken) { using var activity = this.StartActivity(); activity? @@ -35,8 +33,11 @@ private async ValueTask FindRouterAsync(IStepTracer? tracer) => await return null; } - Executor target = await this.FindRouterAsync(stepTracer).ConfigureAwait(false); - if (target.CanHandle(envelope.MessageType)) + Type? messageType = await this.GetMessageRuntimeTypeAsync(envelope, stepTracer, cancellationToken) + .ConfigureAwait(false); + + Executor target = await this.RunContext.EnsureExecutorAsync(this.EdgeData.SinkId, stepTracer, cancellationToken).ConfigureAwait(false); + if (CanHandle(target, messageType)) { activity?.SetEdgeRunnerDeliveryStatus(EdgeRunnerDeliveryStatus.Delivered); return new DeliveryMapping(envelope, target); diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows/Execution/EdgeMap.cs b/dotnet/src/Microsoft.Agents.AI.Workflows/Execution/EdgeMap.cs index 3cc0e6e6a1..8c2162508d 100644 --- a/dotnet/src/Microsoft.Agents.AI.Workflows/Execution/EdgeMap.cs +++ b/dotnet/src/Microsoft.Agents.AI.Workflows/Execution/EdgeMap.cs @@ -4,6 +4,7 @@ using System.Collections.Concurrent; using System.Collections.Generic; using System.Linq; +using System.Threading; using System.Threading.Tasks; using Microsoft.Agents.AI.Workflows.Checkpointing; @@ -65,7 +66,7 @@ public EdgeMap(IRunnerContext runContext, this._stepTracer = stepTracer; } - public ValueTask PrepareDeliveryForEdgeAsync(Edge edge, MessageEnvelope message) + public ValueTask PrepareDeliveryForEdgeAsync(Edge edge, MessageEnvelope message, CancellationToken cancellationToken = default) { EdgeId id = edge.Data.Id; if (!this._edgeRunners.TryGetValue(id, out EdgeRunner? edgeRunner)) @@ -73,25 +74,25 @@ public EdgeMap(IRunnerContext runContext, throw new InvalidOperationException($"Edge {edge} not found in the edge map."); } - return edgeRunner.ChaseEdgeAsync(message, this._stepTracer); + return edgeRunner.ChaseEdgeAsync(message, this._stepTracer, cancellationToken); } public bool TryRegisterPort(IRunnerContext runContext, string executorId, RequestPort port) => this._portEdgeRunners.TryAdd(port.Id, ResponseEdgeRunner.ForPort(runContext, executorId, port)); - public ValueTask PrepareDeliveryForInputAsync(MessageEnvelope message) + public ValueTask PrepareDeliveryForInputAsync(MessageEnvelope message, CancellationToken cancellationToken = default) { - return this._inputRunner.ChaseEdgeAsync(message, this._stepTracer); + return this._inputRunner.ChaseEdgeAsync(message, this._stepTracer, cancellationToken); } - public ValueTask PrepareDeliveryForResponseAsync(ExternalResponse response) + public ValueTask PrepareDeliveryForResponseAsync(ExternalResponse response, CancellationToken cancellationToken = default) { if (!this._portEdgeRunners.TryGetValue(response.PortInfo.PortId, out ResponseEdgeRunner? portRunner)) { throw new InvalidOperationException($"Port {response.PortInfo.PortId} not found in the edge map."); } - return portRunner.ChaseEdgeAsync(new MessageEnvelope(response, ExecutorIdentity.None), this._stepTracer); + return portRunner.ChaseEdgeAsync(new MessageEnvelope(response, ExecutorIdentity.None), this._stepTracer, cancellationToken); } internal async ValueTask> ExportStateAsync() diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows/Execution/EdgeRunner.cs b/dotnet/src/Microsoft.Agents.AI.Workflows/Execution/EdgeRunner.cs index 309072c32f..481929d643 100644 --- a/dotnet/src/Microsoft.Agents.AI.Workflows/Execution/EdgeRunner.cs +++ b/dotnet/src/Microsoft.Agents.AI.Workflows/Execution/EdgeRunner.cs @@ -1,6 +1,8 @@ // Copyright (c) Microsoft. All rights reserved. +using System; using System.Diagnostics; +using System.Threading; using System.Threading.Tasks; using Microsoft.Shared.Diagnostics; @@ -14,8 +16,7 @@ internal interface IStatefulEdgeRunner internal abstract class EdgeRunner { - // TODO: Can this be sync? - protected internal abstract ValueTask ChaseEdgeAsync(MessageEnvelope envelope, IStepTracer? stepTracer); + protected internal abstract ValueTask ChaseEdgeAsync(MessageEnvelope envelope, IStepTracer? stepTracer, CancellationToken cancellationToken = default); } internal abstract class EdgeRunner( @@ -24,5 +25,46 @@ internal abstract class EdgeRunner( protected IRunnerContext RunContext { get; } = Throw.IfNull(runContext); protected TEdgeData EdgeData { get; } = Throw.IfNull(edgeData); + protected async ValueTask FindSourceProtocolAsync(string sourceId, IStepTracer? stepTracer, CancellationToken cancellationToken = default) + { + Executor sourceExecutor = await this.RunContext.EnsureExecutorAsync(Throw.IfNull(sourceId), stepTracer, cancellationToken) + .ConfigureAwait(false); + + return sourceExecutor.Protocol; + } + + protected async ValueTask GetMessageRuntimeTypeAsync(MessageEnvelope envelope, IStepTracer? stepTracer, CancellationToken cancellationToken = default) + { + // The only difficulty occurs when we have gone through a checkpoint cycle, because the messages turn into PortableValue objects. + if (envelope.Message is PortableValue portableValue) + { + if (envelope.SourceId == null) + { + return null; + } + + ExecutorProtocol protocol = await this.FindSourceProtocolAsync(envelope.SourceId, stepTracer, cancellationToken).ConfigureAwait(false); + return protocol.SendTypeTranslator.MapTypeId(portableValue.TypeId); + } + + return envelope.Message.GetType(); + } + + protected static bool CanHandle(Executor target, Type? runtimeType) + { + // If we have a runtimeType, this is either a non-serialized object, or we successfully mapped a PortableValue back to its original type. + // In either case, we can check if the target can handle that type. Alternatively, even if we do not have a type, if the target has a catch-all, + // we can still route to it, since it should be able to handle anything. + return runtimeType != null ? target.CanHandle(runtimeType) : target.Router.HasCatchAll; + } + + protected async ValueTask CanHandleAsync(string candidateTargetId, Type? runtimeType, IStepTracer? stepTracer, CancellationToken cancellationToken = default) + { + Executor candidateTarget = await this.RunContext.EnsureExecutorAsync(Throw.IfNull(candidateTargetId), stepTracer, cancellationToken) + .ConfigureAwait(false); + + return CanHandle(candidateTarget, runtimeType); + } + protected Activity? StartActivity() => this.RunContext.TelemetryContext.StartEdgeGroupProcessActivity(); } diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows/Execution/FanInEdgeRunner.cs b/dotnet/src/Microsoft.Agents.AI.Workflows/Execution/FanInEdgeRunner.cs index be80ef34de..a53bee0041 100644 --- a/dotnet/src/Microsoft.Agents.AI.Workflows/Execution/FanInEdgeRunner.cs +++ b/dotnet/src/Microsoft.Agents.AI.Workflows/Execution/FanInEdgeRunner.cs @@ -4,6 +4,7 @@ using System.Collections.Generic; using System.Diagnostics; using System.Linq; +using System.Threading; using System.Threading.Tasks; using Microsoft.Agents.AI.Workflows.Observability; @@ -15,7 +16,7 @@ internal sealed class FanInEdgeRunner(IRunnerContext runContext, FanInEdgeData e { private FanInEdgeState _state = new(edgeData); - protected internal override async ValueTask ChaseEdgeAsync(MessageEnvelope envelope, IStepTracer? stepTracer) + protected internal override async ValueTask ChaseEdgeAsync(MessageEnvelope envelope, IStepTracer? stepTracer, CancellationToken cancellationToken) { Debug.Assert(!envelope.IsExternal, "FanIn edges should never be chased from external input"); @@ -31,7 +32,7 @@ internal sealed class FanInEdgeRunner(IRunnerContext runContext, FanInEdgeData e } // source.Id is guaranteed to be non-null here because source is not None. - IEnumerable? releasedMessages = this._state.ProcessMessage(envelope.SourceId, envelope); + List>? releasedMessages = this._state.ProcessMessage(envelope.SourceId, envelope)?.ToList(); if (releasedMessages is null) { // Not ready to process yet. @@ -41,11 +42,22 @@ internal sealed class FanInEdgeRunner(IRunnerContext runContext, FanInEdgeData e try { - // TODO: Filter messages based on accepted input types? - Executor target = await this.RunContext.EnsureExecutorAsync(this.EdgeData.SinkId, stepTracer) + // Right now, for serialization purposes every message through FanInEdge goes through the PortableMessageEnvelope state, meaning + // we lose type information for all of them, potentially. + (ExecutorProtocol, IGrouping)[] + protocolGroupings = await Task.WhenAll(releasedMessages.Select(MapProtocolsAsync)) + .ConfigureAwait(false); + + IEnumerable<(Type? RuntimeType, MessageEnvelope MessageEnvelope)> + typedEnvelopes = protocolGroupings.SelectMany(MapRuntimeTypes); + + Executor target = await this.RunContext.EnsureExecutorAsync(this.EdgeData.SinkId, stepTracer, cancellationToken) .ConfigureAwait(false); + // Materialize the filtered list via ToList() to avoid multiple enumerations - var finalReleasedMessages = releasedMessages.Where(envelope => target.CanHandle(envelope.MessageType)).ToList(); + List finalReleasedMessages = typedEnvelopes.Where(te => CanHandle(target, te.RuntimeType)) + .Select(te => te.MessageEnvelope) + .ToList(); if (finalReleasedMessages.Count == 0) { activity?.SetEdgeRunnerDeliveryStatus(EdgeRunnerDeliveryStatus.DroppedTypeMismatch); @@ -53,6 +65,28 @@ internal sealed class FanInEdgeRunner(IRunnerContext runContext, FanInEdgeData e } return new DeliveryMapping(finalReleasedMessages, target); + + async Task<(ExecutorProtocol, IGrouping)> MapProtocolsAsync(IGrouping grouping) + { + ExecutorProtocol protocol = await this.FindSourceProtocolAsync(grouping.Key.Id!, stepTracer, cancellationToken).ConfigureAwait(false); + return (protocol, grouping); + } + + IEnumerable<(Type?, MessageEnvelope)> MapRuntimeTypes((ExecutorProtocol, IGrouping) input) + { + (ExecutorProtocol protocol, IGrouping grouping) = input; + return grouping.Select(envelope => (ResolveEnvelopeType(envelope), envelope)); + + Type? ResolveEnvelopeType(MessageEnvelope messageEnvelope) + { + if (envelope.Message is PortableValue portableValue) + { + return protocol.SendTypeTranslator.MapTypeId(portableValue.TypeId); + } + + return envelope.Message.GetType(); + } + } } catch (Exception) when (activity is not null) { diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows/Execution/FanInEdgeState.cs b/dotnet/src/Microsoft.Agents.AI.Workflows/Execution/FanInEdgeState.cs index 9c6a941a11..db8241c13d 100644 --- a/dotnet/src/Microsoft.Agents.AI.Workflows/Execution/FanInEdgeState.cs +++ b/dotnet/src/Microsoft.Agents.AI.Workflows/Execution/FanInEdgeState.cs @@ -32,7 +32,7 @@ public FanInEdgeState(string[] sourceIds, HashSet unseen, List? ProcessMessage(string sourceId, MessageEnvelope envelope) + public IEnumerable>? ProcessMessage(string sourceId, MessageEnvelope envelope) { this.PendingMessages.Add(new(envelope)); this.Unseen.Remove(sourceId); @@ -47,7 +47,8 @@ public FanInEdgeState(string[] sourceIds, HashSet unseen, List portable.ToMessageEnvelope()); + return takenMessages.Select(portable => portable.ToMessageEnvelope()) + .GroupBy(keySelector: messageEnvelope => messageEnvelope.Source); } return null; diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows/Execution/FanOutEdgeRunner.cs b/dotnet/src/Microsoft.Agents.AI.Workflows/Execution/FanOutEdgeRunner.cs index d1102d6554..3ff3469f1f 100644 --- a/dotnet/src/Microsoft.Agents.AI.Workflows/Execution/FanOutEdgeRunner.cs +++ b/dotnet/src/Microsoft.Agents.AI.Workflows/Execution/FanOutEdgeRunner.cs @@ -3,6 +3,7 @@ using System; using System.Collections.Generic; using System.Linq; +using System.Threading; using System.Threading.Tasks; using Microsoft.Agents.AI.Workflows.Observability; @@ -11,7 +12,7 @@ namespace Microsoft.Agents.AI.Workflows.Execution; internal sealed class FanOutEdgeRunner(IRunnerContext runContext, FanOutEdgeData edgeData) : EdgeRunner(runContext, edgeData) { - protected internal override async ValueTask ChaseEdgeAsync(MessageEnvelope envelope, IStepTracer? stepTracer) + protected internal override async ValueTask ChaseEdgeAsync(MessageEnvelope envelope, IStepTracer? stepTracer, CancellationToken cancellationToken) { using var activity = this.StartActivity(); activity? @@ -39,7 +40,10 @@ this.EdgeData.EdgeAssigner is null return null; } - IEnumerable validTargets = result.Where(t => t.CanHandle(envelope.MessageType)); + Type? runtimeType = await this.GetMessageRuntimeTypeAsync(envelope, stepTracer, cancellationToken) + .ConfigureAwait(false); + + IEnumerable validTargets = result.Where(t => CanHandle(t, runtimeType)); if (!validTargets.Any()) { diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows/Execution/MessageRouter.cs b/dotnet/src/Microsoft.Agents.AI.Workflows/Execution/MessageRouter.cs index 10ce345ad8..628c707576 100644 --- a/dotnet/src/Microsoft.Agents.AI.Workflows/Execution/MessageRouter.cs +++ b/dotnet/src/Microsoft.Agents.AI.Workflows/Execution/MessageRouter.cs @@ -27,8 +27,9 @@ namespace Microsoft.Agents.AI.Workflows.Execution; internal sealed class MessageRouter { + private readonly HashSet _interfaceHandlers = new(); private readonly Dictionary _typedHandlers; - private readonly Dictionary _runtimeTypeMap; + private readonly Dictionary _runtimeTypeMap = new(); private readonly CatchAllF? _catchAllFunc; @@ -37,7 +38,17 @@ internal MessageRouter(Dictionary handlers, HashSet Throw.IfNull(handlers); this._typedHandlers = handlers; - this._runtimeTypeMap = handlers.Keys.ToDictionary(t => new TypeId(t), t => t); + + foreach (Type type in handlers.Keys) + { + this._runtimeTypeMap[new(type)] = type; + + if (type.IsInterface) + { + this._interfaceHandlers.Add(type); + } + } + this._catchAllFunc = catchAllFunc; this.IncomingTypes = [.. handlers.Keys]; @@ -49,15 +60,42 @@ internal MessageRouter(Dictionary handlers, HashSet [MemberNotNullWhen(true, nameof(_catchAllFunc))] internal bool HasCatchAll => this._catchAllFunc is not null; - public bool CanHandle(object message) => this.CanHandle(new TypeId(Throw.IfNull(message).GetType())); - public bool CanHandle(Type candidateType) => this.CanHandle(new TypeId(Throw.IfNull(candidateType))); + public bool CanHandle(object message) => this.CanHandle(Throw.IfNull(message).GetType()); + public bool CanHandle(Type candidateType) => this.HasCatchAll || this.FindHandler(candidateType) is not null; + + public HashSet DefaultOutputTypes { get; } - public bool CanHandle(TypeId candidateType) + private MessageHandlerF? FindHandler(Type messageType) { - return this.HasCatchAll || this._runtimeTypeMap.ContainsKey(candidateType); - } + for (Type? candidateType = messageType; candidateType != null; candidateType = candidateType.BaseType) + { + if (this._typedHandlers.TryGetValue(candidateType, out MessageHandlerF? handler)) + { + if (candidateType != messageType) + { + // Cache the handler for future lookups. + this._typedHandlers[messageType] = handler; + this._runtimeTypeMap[new TypeId(messageType)] = candidateType; + } + + return handler; + } + else if (this._interfaceHandlers.Count > 0) + { + foreach (Type interfaceType in this._interfaceHandlers.Where(it => it.IsAssignableFrom(candidateType))) + { + handler = this._typedHandlers[interfaceType]; + this._typedHandlers[messageType] = handler; + + // TODO: This could cause some consternation with Checkpointing (need to ensure we surface errors well) + this._runtimeTypeMap[new TypeId(messageType)] = interfaceType; + return handler; + } + } + } - public HashSet DefaultOutputTypes { get; } + return null; + } public async ValueTask RouteMessageAsync(object message, IWorkflowContext context, bool requireRoute = false, CancellationToken cancellationToken = default) { @@ -75,7 +113,8 @@ public bool CanHandle(TypeId candidateType) try { - if (this._typedHandlers.TryGetValue(message.GetType(), out MessageHandlerF? handler)) + MessageHandlerF? handler = this.FindHandler(message.GetType()); + if (handler != null) { result = await handler(message, context, cancellationToken).ConfigureAwait(false); } diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows/Execution/ResponseEdgeRunner.cs b/dotnet/src/Microsoft.Agents.AI.Workflows/Execution/ResponseEdgeRunner.cs index 969509e40b..cdf80c0cd8 100644 --- a/dotnet/src/Microsoft.Agents.AI.Workflows/Execution/ResponseEdgeRunner.cs +++ b/dotnet/src/Microsoft.Agents.AI.Workflows/Execution/ResponseEdgeRunner.cs @@ -2,6 +2,7 @@ using System; using System.Diagnostics; +using System.Threading; using System.Threading.Tasks; using Microsoft.Agents.AI.Workflows.Observability; using Microsoft.Shared.Diagnostics; @@ -21,7 +22,7 @@ public static ResponseEdgeRunner ForPort(IRunnerContext runContext, string execu public string ExecutorId => executorId; - protected internal override async ValueTask ChaseEdgeAsync(MessageEnvelope envelope, IStepTracer? stepTracer) + protected internal override async ValueTask ChaseEdgeAsync(MessageEnvelope envelope, IStepTracer? stepTracer, CancellationToken cancellationToken) { Debug.Assert(envelope.IsExternal, "Input edges should only be chased from external input"); @@ -34,7 +35,10 @@ public static ResponseEdgeRunner ForPort(IRunnerContext runContext, string execu try { Executor target = await this.FindExecutorAsync(stepTracer).ConfigureAwait(false); - if (target.CanHandle(envelope.MessageType)) + + Type? runtimeType = await this.GetMessageRuntimeTypeAsync(envelope, stepTracer, cancellationToken).ConfigureAwait(false); + + if (CanHandle(target, runtimeType)) { activity?.SetEdgeRunnerDeliveryStatus(EdgeRunnerDeliveryStatus.Delivered); return new DeliveryMapping(envelope, target); diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows/Executor.cs b/dotnet/src/Microsoft.Agents.AI.Workflows/Executor.cs index 3dba017fa7..4c3847428d 100644 --- a/dotnet/src/Microsoft.Agents.AI.Workflows/Executor.cs +++ b/dotnet/src/Microsoft.Agents.AI.Workflows/Executor.cs @@ -5,6 +5,7 @@ using System; using System.Collections.Generic; using System.Diagnostics; +using System.Linq; using System.Reflection; using System.Threading; using System.Threading.Tasks; @@ -15,6 +16,130 @@ namespace Microsoft.Agents.AI.Workflows; +internal sealed class DelayedExternalRequestContext : IExternalRequestContext +{ + public DelayedExternalRequestContext(IExternalRequestContext? targetContext = null) + { + this._targetContext = targetContext; + } + + private sealed class DelayRegisteredSink : IExternalRequestSink + { + internal IExternalRequestSink? TargetSink { get; set; } + + public ValueTask PostAsync(ExternalRequest request) => + this.TargetSink is null + ? throw new InvalidOperationException("The external request sink has not been registered yet.") + : this.TargetSink.PostAsync(request); + } + + private readonly Dictionary _requestPorts = []; + private IExternalRequestContext? _targetContext; + + public void ApplyPortRegistrations(IExternalRequestContext targetContext) + { + this._targetContext = targetContext; + + foreach ((RequestPort requestPort, DelayRegisteredSink? sink) in this._requestPorts.Values) + { + sink?.TargetSink = targetContext.RegisterPort(requestPort); + } + } + + public IExternalRequestSink RegisterPort(RequestPort port) + { + DelayRegisteredSink delaySink = new() + { + TargetSink = this._targetContext?.RegisterPort(port), + }; + + this._requestPorts.Add(port.Id, (port, delaySink)); + + return delaySink; + } +} + +internal sealed class MessageTypeTranslator +{ + private readonly Dictionary _typeLookupMap = []; + private readonly Dictionary _delcaredTypeMap = []; + + // The types that can always be sent; this is a very inelegant solution to the following problem: + // Even with code analysis it is impossible to statically know all of the types that get sent via SendMessage, because + // IWorkflowContext can always be sent out of the current assembly (to say nothing of Reflection). This means at some + // level we have to register all the types being sent somewhere. Since we have to do dynamic serialization/deserialization + // at runtime with dependency-defined types (which we do not statically know) we need to have these types at runtime. + // At the same time, we should not force users to declare types to interact with core system concepts like RequestInfo. + // So the solution for now is to register a set of known types, at the cost of duplicating this per Executor. + // + // - TODO: Create a static translation map, and keep a set of "allowed" TypeIds per Excutor. + private static IEnumerable KnownSentTypes => + [ + typeof(ExternalRequest), + typeof(ExternalResponse), + + // TurnToken? + ]; + + public MessageTypeTranslator(ISet types) + { + foreach (Type type in KnownSentTypes.Concat(types)) + { + TypeId typeId = new(type); + if (this._typeLookupMap.ContainsKey(typeId)) + { + continue; + } + + this._typeLookupMap[typeId] = type; + this._delcaredTypeMap[type] = typeId; + } + } + + public TypeId? GetDeclaredType(Type messageType) + { + // If the user declares a base type, the user is expected to set up any serialization to be able to deal with + // the polymorphism transparently to the framework, or be expecting to deal with the appropriate truncation. + for (Type? candidateType = messageType; candidateType != null; candidateType = candidateType.BaseType) + { + if (this._delcaredTypeMap.TryGetValue(candidateType, out TypeId? declaredTypeId)) + { + if (candidateType != messageType) + { + // Add an entry for the derived type to speed up future lookups. + this._delcaredTypeMap[messageType] = declaredTypeId; + } + + return declaredTypeId; + } + } + + return null; + } + + public Type? MapTypeId(TypeId candidateTypeId) => + this._typeLookupMap.TryGetValue(candidateTypeId, out Type? mappedType) + ? mappedType + : null; +} + +internal sealed class ExecutorProtocol(MessageRouter router, ISet sendTypes, ISet yieldTypes) +{ + private readonly HashSet _yieldTypes = new(yieldTypes.Select(type => new TypeId(type))); + + public MessageTypeTranslator SendTypeTranslator => field ??= new MessageTypeTranslator(sendTypes); + + internal MessageRouter Router => router; + + //public bool CanHandle(TypeId typeId) => router.CanHandle(typeId); + public bool CanHandle(Type type) => router.CanHandle(type); + + //public bool CanOutput(TypeId typeId) => this._yieldTypes.Contains(typeId); + public bool CanOutput(Type type) => this._yieldTypes.Contains(new(type)); + + public ProtocolDescriptor Describe() => new(this.Router.IncomingTypes, yieldTypes, sendTypes, this.Router.HasCatchAll); +} + /// /// A component that processes messages in a . /// @@ -50,6 +175,10 @@ protected Executor(string id, ExecutorOptions? options = null, bool declareCross this.IsCrossRunShareable = declareCrossRunShareable; } + private DelayedExternalRequestContext DelayedPortRegistrations { get; } = new(); + + internal ExecutorProtocol Protocol => field ??= this.ConfigureProtocol(new(this.DelayedPortRegistrations)).Build(this.Options); + internal bool IsCrossRunShareable { get; } /// @@ -57,28 +186,29 @@ protected Executor(string id, ExecutorOptions? options = null, bool declareCross /// protected ExecutorOptions Options { get; } + //private bool _configuringProtocol; + /// - /// Override this method to register handlers for the executor. + /// Configures the protocol by setting up routes and declaring the message types used for sending and yielding + /// output. /// - protected abstract RouteBuilder ConfigureRoutes(RouteBuilder routeBuilder); + /// This method serves as the primary entry point for protocol configuration. It integrates route + /// setup and message type declarations. For backward compatibility, it is currently invoked from the + /// RouteBuilder. + /// An instance of that represents the fully configured protocol. + protected abstract ProtocolBuilder ConfigureProtocol(ProtocolBuilder protocolBuilder); - internal void Configure(IExternalRequestContext externalRequestContext) + internal void AttachRequestContext(IExternalRequestContext externalRequestContext) { // TODO: This is an unfortunate pattern (pending the ability to rework the Configure APIs a bit): // new() - // >>> will throw InvalidOperationException if Configure() is not invoked when using PortHandlers - // .Configure() + // >>> will throw InvalidOperationException if AttachRequestContext() is not invoked when using PortHandlers + // .AttachRequestContext() // >>> only usable now - // The fix would be to change the API surface of Executor to have Configure return the contract that the workflow - // will use to invoke the executor (currently the MessageRouter). (Ideally we would rename Executor to Node or similar, - // and the actual Executor class will represent that Contract object) - // Not a terrible issue right now because only InProcessExecution exists right now, and the InProccessRunContext centralizes - // executor instantiation in EnsureExecutorAsync. - this.Router = this.CreateRouter(externalRequestContext); - } - private MessageRouter CreateRouter(IExternalRequestContext? externalRequestContext = null) - => this.ConfigureRoutes(new RouteBuilder(externalRequestContext)).Build(); + this.DelayedPortRegistrations.ApplyPortRegistrations(externalRequestContext); + _ = this.Protocol; // Force protocol to be built if not already done. + } /// /// Perform any asynchronous initialization required by the executor. This method is called once per executor instance, @@ -90,42 +220,7 @@ private MessageRouter CreateRouter(IExternalRequestContext? externalRequestConte protected internal virtual ValueTask InitializeAsync(IWorkflowContext context, CancellationToken cancellationToken = default) => default; - /// - /// Override this method to declare the types of messages this executor can send. - /// - /// - protected virtual ISet ConfigureSentTypes() => new HashSet([typeof(object)]); - - /// - /// Override this method to declare the types of messages this executor can yield as workflow outputs. - /// - /// - protected virtual ISet ConfigureYieldTypes() - { - if (this.Options.AutoYieldOutputHandlerResultObject) - { - return this.Router.DefaultOutputTypes; - } - - return new HashSet(); - } - - internal MessageRouter Router - { - get - { - if (field is null) - { - field = this.CreateRouter(); - } - - return field; - } - private set - { - field = value; - } - } + internal MessageRouter Router => this.Protocol.Router; /// /// Process an incoming message using the registered handlers. @@ -224,41 +319,22 @@ private set /// /// A set of s, representing the messages this executor can produce as output. /// - public ISet OutputTypes { get; } = new HashSet([typeof(object)]); + public ISet OutputTypes => field ??= new HashSet(this.Protocol.Describe().Yields); /// /// Describes the protocol for communication with this . /// /// - public ProtocolDescriptor DescribeProtocol() - { - // TODO: Once burden of annotating yield/output messages becomes easier for the non-Auto case, - // we should (1) start checking for validity on output/send side, and (2) add the Yield/Send - // types to the ProtocolDescriptor. - return new(this.InputTypes, this.Router.HasCatchAll); - } + public ProtocolDescriptor DescribeProtocol() => this.Protocol.Describe(); /// /// Checks if the executor can handle a specific message type. /// /// /// - public bool CanHandle(Type messageType) => this.Router.CanHandle(messageType); + public bool CanHandle(Type messageType) => this.Protocol.CanHandle(messageType); - internal bool CanHandle(TypeId messageType) => this.Router.CanHandle(messageType); - - internal bool CanOutput(Type messageType) - { - foreach (Type type in this.OutputTypes) - { - if (type.IsAssignableFrom(messageType)) - { - return true; - } - } - - return false; - } + internal bool CanOutput(Type messageType) => this.Protocol.CanOutput(messageType); } /// @@ -272,8 +348,13 @@ public abstract class Executor(string id, ExecutorOptions? options = nul : Executor(id, options, declareCrossRunShareable), IMessageHandler { /// - protected override RouteBuilder ConfigureRoutes(RouteBuilder routeBuilder) => - routeBuilder.AddHandler(this.HandleAsync); + protected override ProtocolBuilder ConfigureProtocol(ProtocolBuilder protocolBuilder) + { + Func handlerDelegate = this.HandleAsync; + + return protocolBuilder.ConfigureRoutes(routeBuilder => routeBuilder.AddHandler(handlerDelegate)) + .AddHandlerAttributeTypes(handlerDelegate.Method); + } /// public abstract ValueTask HandleAsync(TInput message, IWorkflowContext context, CancellationToken cancellationToken = default); @@ -292,8 +373,13 @@ public abstract class Executor(string id, ExecutorOptions? opti IMessageHandler { /// - protected override RouteBuilder ConfigureRoutes(RouteBuilder routeBuilder) => - routeBuilder.AddHandler(this.HandleAsync); + protected override ProtocolBuilder ConfigureProtocol(ProtocolBuilder protocolBuilder) + { + Func> handlerDelegate = this.HandleAsync; + + return protocolBuilder.ConfigureRoutes(routeBuilder => routeBuilder.AddHandler(handlerDelegate)) + .AddHandlerAttributeTypes(handlerDelegate.Method); + } /// public abstract ValueTask HandleAsync(TInput message, IWorkflowContext context, CancellationToken cancellationToken = default); diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows/FunctionExecutor.cs b/dotnet/src/Microsoft.Agents.AI.Workflows/FunctionExecutor.cs index a3371dc302..3b74aefe4c 100644 --- a/dotnet/src/Microsoft.Agents.AI.Workflows/FunctionExecutor.cs +++ b/dotnet/src/Microsoft.Agents.AI.Workflows/FunctionExecutor.cs @@ -1,6 +1,9 @@ // Copyright (c) Microsoft. All rights reserved. using System; +using System.Collections.Generic; +using System.Linq; +using System.Reflection; using System.Threading; using System.Threading.Tasks; @@ -13,14 +16,28 @@ namespace Microsoft.Agents.AI.Workflows; /// A unique identifier for the executor. /// A delegate that defines the asynchronous function to execute for each input message. /// Configuration options for the executor. If null, default options will be used. +/// Message types sent by the handler. Defaults to empty, and will filter out non-matching messages. +/// Message types yielded as output by the handler. Defaults to empty. /// Declare that this executor may be used simultaneously by multiple runs safely. public class FunctionExecutor(string id, Func handlerAsync, ExecutorOptions? options = null, + IEnumerable? sentMessageTypes = null, + IEnumerable? outputTypes = null, bool declareCrossRunShareable = false) : Executor(id, options, declareCrossRunShareable) { - internal static Func WrapAction(Action handlerSync) + internal static Func WrapAction(Action handlerSync, out IEnumerable sentTypes, out IEnumerable yieldedTypes) { + if (handlerSync.Method != null) + { + MethodInfo method = handlerSync.Method; + (sentTypes, yieldedTypes) = method.GetAttributeTypes(); + } + else + { + sentTypes = yieldedTypes = []; + } + return RunActionAsync; ValueTask RunActionAsync(TInput input, IWorkflowContext workflowContext, CancellationToken cancellationToken) @@ -30,6 +47,16 @@ ValueTask RunActionAsync(TInput input, IWorkflowContext workflowContext, Cancell } } + /// + protected override ProtocolBuilder ConfigureProtocol(ProtocolBuilder protocolBuilder) + { + (IEnumerable attributeSentTypes, IEnumerable attributeYieldTypes) = handlerAsync.Method.GetAttributeTypes(); + + return base.ConfigureProtocol(protocolBuilder) + .SendsMessageTypes(attributeSentTypes.Concat(sentMessageTypes ?? [])) + .YieldsOutputTypes(attributeYieldTypes.Concat(outputTypes ?? [])); + } + /// public override ValueTask HandleAsync(TInput message, IWorkflowContext context, CancellationToken cancellationToken) => handlerAsync(message, context, cancellationToken); @@ -39,8 +66,15 @@ ValueTask RunActionAsync(TInput input, IWorkflowContext workflowContext, Cancell /// A unique identifier for the executor. /// A synchronous function to execute for each input message and workflow context. /// Configuration options for the executor. If null, default options will be used. + /// Message types sent by the handler. Defaults to empty, and will filter out non-matching messages. + /// Message types yielded as output by the handler. Defaults to empty. /// Declare that this executor may be used simultaneously by multiple runs safely. - public FunctionExecutor(string id, Action handlerSync, ExecutorOptions? options = null, bool declareCrossRunShareable = false) : this(id, WrapAction(handlerSync), options, declareCrossRunShareable) + public FunctionExecutor(string id, + Action handlerSync, + ExecutorOptions? options = null, + IEnumerable? sentMessageTypes = null, + IEnumerable? outputTypes = null, + bool declareCrossRunShareable = false) : this(id, WrapAction(handlerSync, out var attributeSentTypes, out var attributeYieldTypes), options, attributeSentTypes.Concat(sentMessageTypes ?? []), attributeYieldTypes.Concat(outputTypes ?? []), declareCrossRunShareable) { } } @@ -53,10 +87,14 @@ public FunctionExecutor(string id, ActionA unique identifier for the executor. /// A delegate that defines the asynchronous function to execute for each input message. /// Configuration options for the executor. If null, default options will be used. +/// Additional message types sent by the handler. Defaults to empty, and will filter out non-matching messages. +/// Additional message types yielded as output by the handler. Defaults to empty. /// Declare that this executor may be used simultaneously by multiple runs safely. public class FunctionExecutor(string id, Func> handlerAsync, ExecutorOptions? options = null, + IEnumerable? sentMessageTypes = null, + IEnumerable? outputTypes = null, bool declareCrossRunShareable = false) : Executor(id, options, declareCrossRunShareable) { internal static Func> WrapFunc(Func handlerSync) @@ -70,6 +108,12 @@ ValueTask RunFuncAsync(TInput input, IWorkflowContext workflowContext, } } + /// + protected override ProtocolBuilder ConfigureProtocol(ProtocolBuilder protocolBuilder) + => base.ConfigureProtocol(protocolBuilder) + .SendsMessageTypes(sentMessageTypes ?? []) + .YieldsOutputTypes(outputTypes ?? []); + /// public override ValueTask HandleAsync(TInput message, IWorkflowContext context, CancellationToken cancellationToken) => handlerAsync(message, context, cancellationToken); @@ -79,8 +123,15 @@ ValueTask RunFuncAsync(TInput input, IWorkflowContext workflowContext, /// A unique identifier for the executor. /// A synchronous function to execute for each input message and workflow context. /// Configuration options for the executor. If null, default options will be used. + /// Additional message types sent by the handler. Defaults to empty, and will filter out non-matching messages. + /// Additional message types yielded as output by the handler. Defaults to empty. /// Declare that this executor may be used simultaneously by multiple runs safely. - public FunctionExecutor(string id, Func handlerSync, ExecutorOptions? options = null, bool declareCrossRunShareable = false) : this(id, WrapFunc(handlerSync), options, declareCrossRunShareable) + public FunctionExecutor(string id, + Func handlerSync, + ExecutorOptions? options = null, + IEnumerable? sentMessageTypes = null, + IEnumerable? outputTypes = null, + bool declareCrossRunShareable = false) : this(id, WrapFunc(handlerSync), options, sentMessageTypes, outputTypes, declareCrossRunShareable) { } } diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows/InProc/InProcessRunner.cs b/dotnet/src/Microsoft.Agents.AI.Workflows/InProc/InProcessRunner.cs index 58e1890eed..4cf9e53954 100644 --- a/dotnet/src/Microsoft.Agents.AI.Workflows/InProc/InProcessRunner.cs +++ b/dotnet/src/Microsoft.Agents.AI.Workflows/InProc/InProcessRunner.cs @@ -201,14 +201,43 @@ private async ValueTask DeliverMessagesAsync(string receiverId, ConcurrentQueue< this.StepTracer.TraceActivated(receiverId); while (envelopes.TryDequeue(out var envelope)) { + (object message, TypeId messageType) = await TranslateMessageAsync(envelope).ConfigureAwait(false); + await executor.ExecuteAsync( - envelope.Message, - envelope.MessageType, + message, + messageType, this.RunContext.BindWorkflowContext(receiverId, envelope.TraceContext), this.TelemetryContext, cancellationToken ).ConfigureAwait(false); } + + async ValueTask<(object, TypeId)> TranslateMessageAsync(MessageEnvelope envelope) + { + object? value = envelope.Message; + TypeId messageType = envelope.MessageType; + + if (!envelope.IsExternal) + { + Executor source = await this.RunContext.EnsureExecutorAsync(envelope.SourceId, this.StepTracer, cancellationToken).ConfigureAwait(false); + Type? actualType = source.Protocol.SendTypeTranslator.MapTypeId(envelope.MessageType); + if (actualType == null) + { + // In principle, this should never happen, since we always use the SendTypeTranslator to generate the outgoing TypeId in the first place. + throw new InvalidOperationException($"Cannot translate message type ID '{envelope.MessageType}' from executor '{source.Id}'."); + } + + messageType = new(actualType); + + if (value is PortableValue portableValue && + !portableValue.IsType(actualType, out value)) + { + throw new InvalidOperationException($"Cannot interpret incoming message of type '{portableValue.TypeId}' as type '{actualType.FullName}'."); + } + } + + return (value, messageType); + } } private async ValueTask RunSuperstepAsync(StepContext currentStep, CancellationToken cancellationToken) diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows/InProc/InProcessRunnerContext.cs b/dotnet/src/Microsoft.Agents.AI.Workflows/InProc/InProcessRunnerContext.cs index 419d46cd1b..5203bb9e82 100644 --- a/dotnet/src/Microsoft.Agents.AI.Workflows/InProc/InProcessRunnerContext.cs +++ b/dotnet/src/Microsoft.Agents.AI.Workflows/InProc/InProcessRunnerContext.cs @@ -95,7 +95,7 @@ async Task CreateExecutorAsync(string id) } Executor executor = await registration.CreateInstanceAsync(this._runId).ConfigureAwait(false); - executor.Configure(this.BindExternalRequestContext(executorId)); + executor.AttachRequestContext(this.BindExternalRequestContext(executorId)); await executor.InitializeAsync(this.BindWorkflowContext(executorId), cancellationToken: cancellationToken) .ConfigureAwait(false); @@ -182,7 +182,7 @@ public async ValueTask AdvanceAsync(CancellationToken cancellationT while (this._queuedExternalDeliveries.TryDequeue(out var deliveryPrep)) { - // It's important we do not try to run these in parallel, because they make be modifying + // It's important we do not try to run these in parallel, because they may be modifying // inner edge state, etc. await deliveryPrep().ConfigureAwait(false); } @@ -212,14 +212,23 @@ public async ValueTask SendMessageAsync(string sourceId, object message, string? } this.CheckEnded(); - MessageEnvelope envelope = new(message, sourceId, targetId: targetId, traceContext: traceContext); + + Debug.Assert(this._executors.ContainsKey(sourceId)); + Executor source = await this.EnsureExecutorAsync(sourceId, tracer: null, cancellationToken).ConfigureAwait(false); + TypeId? declaredType = source.Protocol.SendTypeTranslator.GetDeclaredType(message.GetType()); + if (declaredType is null) + { + throw new InvalidOperationException($"Executor '{sourceId}' cannot send messages of type '{message.GetType().FullName}'."); + } + + MessageEnvelope envelope = new(message, sourceId, declaredType, targetId: targetId, traceContext: traceContext); if (this._workflow.Edges.TryGetValue(sourceId, out HashSet? edges)) { foreach (Edge edge in edges) { DeliveryMapping? maybeMapping = - await this._edgeMap.PrepareDeliveryForEdgeAsync(edge, envelope) + await this._edgeMap.PrepareDeliveryForEdgeAsync(edge, envelope, cancellationToken) .ConfigureAwait(false); maybeMapping?.MapInto(this._nextStep); @@ -296,12 +305,12 @@ private sealed class BoundWorkflowContext( public ValueTask SendMessageAsync(object message, string? targetId = null, CancellationToken cancellationToken = default) { - return RunnerContext.SendMessageAsync(ExecutorId, message, targetId, cancellationToken); + return RunnerContext.SendMessageAsync(ExecutorId, Throw.IfNull(message), targetId, cancellationToken); } public ValueTask YieldOutputAsync(object output, CancellationToken cancellationToken = default) { - return RunnerContext.YieldOutputAsync(ExecutorId, output, cancellationToken); + return RunnerContext.YieldOutputAsync(ExecutorId, Throw.IfNull(output), cancellationToken); } public ValueTask RequestHaltAsync() => this.AddEventAsync(new RequestHaltEvent()); diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows/ProtocolBuilder.cs b/dotnet/src/Microsoft.Agents.AI.Workflows/ProtocolBuilder.cs new file mode 100644 index 0000000000..6309d98021 --- /dev/null +++ b/dotnet/src/Microsoft.Agents.AI.Workflows/ProtocolBuilder.cs @@ -0,0 +1,146 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Reflection; +using Microsoft.Agents.AI.Workflows.Execution; +using Microsoft.Shared.Diagnostics; + +namespace Microsoft.Agents.AI.Workflows; + +internal static class MethodAttributeExtensions +{ + public static (IEnumerable Sent, IEnumerable Yielded) GetAttributeTypes(this MethodInfo method) + { + IEnumerable sendsMessageAttrs = method.GetCustomAttributes(); + IEnumerable yieldsOutputAttrs = method.GetCustomAttributes(); + // TODO: Should we include [MessageHandler]? + + return (Sent: sendsMessageAttrs.Select(attr => attr.Type), Yielded: yieldsOutputAttrs.Select(attr => attr.Type)); + } +} + +/// +/// . +/// +public sealed class ProtocolBuilder +{ + private readonly HashSet _sendTypes = []; + private readonly HashSet _yieldTypes = []; + + internal ProtocolBuilder(DelayedExternalRequestContext delayRequestContext) + { + this.RouteBuilder = new RouteBuilder(delayRequestContext); + } + + internal ProtocolBuilder AddHandlerAttributeTypes(MethodInfo method, bool registerSentTypes = true, bool registerYieldTypes = true) + { + (IEnumerable sentTypes, IEnumerable yieldTypes) = method.GetAttributeTypes(); + + if (registerSentTypes) + { + this._sendTypes.UnionWith(sentTypes); + } + + if (registerYieldTypes) + { + this._yieldTypes.UnionWith(yieldTypes); + } + + return this; + } + + /// + /// Adds the specified type to the set of declared "sent" message types for the protocol. Objects of these types will be allowed to be + /// sent through the Executor's outgoing edges, via . + /// + /// The type to be declared. + /// + public ProtocolBuilder SendsMessage() where TMessage : notnull => this.SendsMessageTypes([typeof(TMessage)]); + + /// + /// Adds the specified type to the set of declared "sent" messagetypes for the protocol. Objects of these types will be allowed to be + /// sent through the Executor's outgoing edges, via . + /// + /// The type to be declared. + /// + public ProtocolBuilder SendsMessageType(Type messageType) => this.SendsMessageTypes([messageType]); + + /// + /// Adds the specified types to the set of declared "sent" message types for the protocol. Objects of these types will be allowed to be + /// sent through the Executor's outgoing edges, via . + /// + /// A set of types to be declared. + /// + public ProtocolBuilder SendsMessageTypes(IEnumerable messageTypes) + { + Throw.IfNull(messageTypes); + this._sendTypes.UnionWith(messageTypes); + return this; + } + + /// + /// Adds the specified output type to the set of declared "yielded" output types for the protocol. Objects of this type will be + /// allowed to be output from the executor through the , via . + /// + /// The type to be declared. + /// + public ProtocolBuilder YieldsOutput() where TOutput : notnull => this.YieldsOutputTypes([typeof(TOutput)]); + + /// + /// Adds the specified output type to the set of declared "yielded" output types for the protocol. Objects of this type will be + /// allowed to be output from the executor through the , via . + /// + /// The type to be declared. + /// + public ProtocolBuilder YieldsOutputType(Type outputType) => this.YieldsOutputTypes([outputType]); + + /// + /// Adds the specified types to the set of declared "yielded" output types for the protocol. Objects of these types will be allowed to be + /// output from the executor through the , via . + /// + /// A set of types to be declared. + /// + public ProtocolBuilder YieldsOutputTypes(IEnumerable yieldedTypes) + { + Throw.IfNull(yieldedTypes); + this._yieldTypes.UnionWith(yieldedTypes); + return this; + } + + /// + /// Gets a route builder to configure message handlers. + /// + public RouteBuilder RouteBuilder { get; } + + /// + /// Fluently configures message handlers. + /// + /// The handler configuration callback. + /// + public ProtocolBuilder ConfigureRoutes(Action configureAction) + { + configureAction(this.RouteBuilder); + return this; + } + + internal ExecutorProtocol Build(ExecutorOptions options) + { + MessageRouter router = this.RouteBuilder.Build(); + + HashSet sendTypes = new(this._sendTypes); + if (options.AutoSendMessageHandlerResultObject) + { + sendTypes.UnionWith(router.DefaultOutputTypes); + } + + HashSet yieldTypes = new(this._yieldTypes); + if (options.AutoYieldOutputHandlerResultObject) + { + yieldTypes.UnionWith(router.DefaultOutputTypes); + } + + return new(router, sendTypes, yieldTypes); + } +} diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows/ProtocolDescriptor.cs b/dotnet/src/Microsoft.Agents.AI.Workflows/ProtocolDescriptor.cs index bb2663c100..655d1ad197 100644 --- a/dotnet/src/Microsoft.Agents.AI.Workflows/ProtocolDescriptor.cs +++ b/dotnet/src/Microsoft.Agents.AI.Workflows/ProtocolDescriptor.cs @@ -16,14 +16,27 @@ public class ProtocolDescriptor /// public IEnumerable Accepts { get; } + /// + /// Gets the collection of types that could be yielded as output by the or . + /// + public IEnumerable Yields { get; } + + /// + /// Gets the collection of types that could be sent from the . This is always empty for a . + /// + public IEnumerable Sends { get; } + /// /// Gets a value indicating whether the or has a "catch-all" handler. /// public bool AcceptsAll { get; set; } - internal ProtocolDescriptor(IEnumerable acceptedTypes, bool acceptsAll) + internal ProtocolDescriptor(IEnumerable acceptedTypes, IEnumerable yieldedTypes, IEnumerable sentTypes, bool acceptsAll) { this.Accepts = acceptedTypes.ToArray(); + this.Yields = yieldedTypes.ToArray(); + this.Sends = sentTypes.ToArray(); + this.AcceptsAll = acceptsAll; } } diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows/Reflection/ReflectingExecutor.cs b/dotnet/src/Microsoft.Agents.AI.Workflows/Reflection/ReflectingExecutor.cs index f4dcf1291f..3e8d4cfed9 100644 --- a/dotnet/src/Microsoft.Agents.AI.Workflows/Reflection/ReflectingExecutor.cs +++ b/dotnet/src/Microsoft.Agents.AI.Workflows/Reflection/ReflectingExecutor.cs @@ -1,7 +1,10 @@ // Copyright (c) Microsoft. All rights reserved. using System; +using System.Collections.Generic; using System.Diagnostics.CodeAnalysis; +using System.Linq; +using System.Reflection; namespace Microsoft.Agents.AI.Workflows.Reflection; @@ -29,7 +32,45 @@ protected ReflectingExecutor(string id, ExecutorOptions? options = null, bool de { } - /// - protected override RouteBuilder ConfigureRoutes(RouteBuilder routeBuilder) => - routeBuilder.ReflectHandlers(this); + /// + protected override ProtocolBuilder ConfigureProtocol(ProtocolBuilder protocolBuilder) + { + protocolBuilder.SendsMessageTypes(typeof(TExecutor).GetCustomAttributes(inherit: true) + .Select(attr => attr.Type)) + .YieldsOutputTypes(typeof(TExecutor).GetCustomAttributes(inherit: true) + .Select(attr => attr.Type)); + + List messageHandlers = typeof(TExecutor).GetHandlerInfos().ToList(); + foreach (MessageHandlerInfo handlerInfo in messageHandlers) + { + protocolBuilder.RouteBuilder.AddHandlerInternal(handlerInfo.InType, handlerInfo.Bind(this, checkType: true), handlerInfo.OutType); + + if (handlerInfo.OutType != null) + { + if (this.Options.AutoSendMessageHandlerResultObject) + { + protocolBuilder.SendsMessageType(handlerInfo.OutType); + } + + if (this.Options.AutoYieldOutputHandlerResultObject) + { + protocolBuilder.YieldsOutputType(handlerInfo.OutType); + } + } + } + + if (messageHandlers.Count > 0) + { + var handlerAnnotatedTypes = + messageHandlers.Select(mhi => (SendTypes: mhi.HandlerInfo.GetCustomAttributes().Select(attr => attr.Type), + YieldTypes: mhi.HandlerInfo.GetCustomAttributes().Select(attr => attr.Type))) + .Aggregate((accumulate, next) => (accumulate.SendTypes == null ? next.SendTypes : accumulate.SendTypes.Concat(next.SendTypes), + accumulate.YieldTypes == null ? next.YieldTypes : accumulate.YieldTypes.Concat(next.YieldTypes))); + + protocolBuilder.SendsMessageTypes(handlerAnnotatedTypes.SendTypes) + .YieldsOutputTypes(handlerAnnotatedTypes.YieldTypes); + } + + return protocolBuilder; + } } diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows/Reflection/RouteBuilderExtensions.cs b/dotnet/src/Microsoft.Agents.AI.Workflows/Reflection/RouteBuilderExtensions.cs index d554138f1e..a193edb4dc 100644 --- a/dotnet/src/Microsoft.Agents.AI.Workflows/Reflection/RouteBuilderExtensions.cs +++ b/dotnet/src/Microsoft.Agents.AI.Workflows/Reflection/RouteBuilderExtensions.cs @@ -7,7 +7,6 @@ using System.Diagnostics; using System.Diagnostics.CodeAnalysis; using System.Reflection; -using Microsoft.Shared.Diagnostics; namespace Microsoft.Agents.AI.Workflows.Reflection; @@ -45,7 +44,7 @@ internal static bool IsMessageHandlerType(this Type type) => internal static class RouteBuilderExtensions { - private static IEnumerable GetHandlerInfos( + public static IEnumerable GetHandlerInfos( [DynamicallyAccessedMembers(ReflectionDemands.RuntimeInterfaceDiscoveryAndInvocation)] this Type executorType) { @@ -77,25 +76,4 @@ private static IEnumerable GetHandlerInfos( } } } - - public static RouteBuilder ReflectHandlers< - [DynamicallyAccessedMembers( - ReflectionDemands.RuntimeInterfaceDiscoveryAndInvocation) - ] TExecutor> - (this RouteBuilder builder, ReflectingExecutor executor) - where TExecutor : ReflectingExecutor - { - Throw.IfNull(builder); - - Type executorType = typeof(TExecutor); - Debug.Assert(executorType.IsInstanceOfType(executor), - "executorType must be the same type or a base type of the executor instance."); - - foreach (MessageHandlerInfo handlerInfo in executorType.GetHandlerInfos()) - { - builder = builder.AddHandlerInternal(handlerInfo.InType, handlerInfo.Bind(executor, checkType: true), handlerInfo.OutType); - } - - return builder; - } } diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows/RouteBuilder.cs b/dotnet/src/Microsoft.Agents.AI.Workflows/RouteBuilder.cs index cf9f7b814c..2b71169268 100644 --- a/dotnet/src/Microsoft.Agents.AI.Workflows/RouteBuilder.cs +++ b/dotnet/src/Microsoft.Agents.AI.Workflows/RouteBuilder.cs @@ -35,9 +35,6 @@ namespace Microsoft.Agents.AI.Workflows; /// /// Provides a builder for configuring message type handlers for an . /// -/// -/// Override the method to customize the routing of messages to handlers. -/// public class RouteBuilder { private readonly IExternalRequestContext? _externalRequestContext; @@ -631,6 +628,8 @@ private void RegisterPortHandlerRouter() } } + internal IEnumerable OutputTypes => this._outputTypes.Values; + internal MessageRouter Build() { if (this._portHandlers.Count > 0) diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/AIAgentHostExecutor.cs b/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/AIAgentHostExecutor.cs index 6ec9c4dccb..47fd47a1eb 100644 --- a/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/AIAgentHostExecutor.cs +++ b/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/AIAgentHostExecutor.cs @@ -36,27 +36,26 @@ public AIAgentHostExecutor(AIAgent agent, AIAgentHostOptions options) : base(id: this._options = options; } - private RouteBuilder ConfigureUserInputRoutes(RouteBuilder routeBuilder) + private ProtocolBuilder ConfigureUserInputHandling(ProtocolBuilder protocolBuilder) { this._userInputHandler = new AIContentExternalHandler( - ref routeBuilder, + ref protocolBuilder, portId: $"{this.Id}_UserInput", intercepted: this._options.InterceptUserInputRequests, handler: this.HandleUserInputResponseAsync); this._functionCallHandler = new AIContentExternalHandler( - ref routeBuilder, + ref protocolBuilder, portId: $"{this.Id}_FunctionCall", intercepted: this._options.InterceptUnterminatedFunctionCalls, handler: this.HandleFunctionResultAsync); - return routeBuilder; + return protocolBuilder; } - protected override RouteBuilder ConfigureRoutes(RouteBuilder routeBuilder) + protected override ProtocolBuilder ConfigureProtocol(ProtocolBuilder protocolBuilder) { - routeBuilder = base.ConfigureRoutes(routeBuilder); - return this.ConfigureUserInputRoutes(routeBuilder); + return this.ConfigureUserInputHandling(base.ConfigureProtocol(protocolBuilder)); } private ValueTask HandleUserInputResponseAsync( diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/AIContentExternalHandler.cs b/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/AIContentExternalHandler.cs index eae1fd90f5..9173100b3e 100644 --- a/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/AIContentExternalHandler.cs +++ b/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/AIContentExternalHandler.cs @@ -18,16 +18,28 @@ internal sealed class AIContentExternalHandler _pendingRequests = new(); - public AIContentExternalHandler(ref RouteBuilder routeBuilder, string portId, bool intercepted, Func handler) + public AIContentExternalHandler(ref ProtocolBuilder protocolBuilder, string portId, bool intercepted, Func handler) { + PortBinding? portBinding = null; + protocolBuilder = protocolBuilder.ConfigureRoutes(routeBuilder => ConfigureRoutes(routeBuilder, out portBinding)); + this._portBinding = portBinding; + if (intercepted) { - this._portBinding = null; - routeBuilder = routeBuilder.AddHandler(handler); + protocolBuilder = protocolBuilder.SendsMessage(); } - else + + void ConfigureRoutes(RouteBuilder routeBuilder, out PortBinding? portBinding) { - routeBuilder = routeBuilder.AddPortHandler(portId, handler, out this._portBinding); + if (intercepted) + { + portBinding = null; + routeBuilder.AddHandler(handler); + } + else + { + routeBuilder.AddPortHandler(portId, handler, out portBinding); + } } } diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/ConcurrentEndExecutor.cs b/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/ConcurrentEndExecutor.cs index 2fc4030a5c..4374ba759b 100644 --- a/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/ConcurrentEndExecutor.cs +++ b/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/ConcurrentEndExecutor.cs @@ -36,8 +36,9 @@ private void Reset() this._remaining = this._expectedInputs; } - protected override RouteBuilder ConfigureRoutes(RouteBuilder routeBuilder) => - routeBuilder.AddHandler>(async (messages, context, cancellationToken) => + protected override ProtocolBuilder ConfigureProtocol(ProtocolBuilder protocolBuilder) + { + protocolBuilder.RouteBuilder.AddHandler>(async (messages, context, cancellationToken) => { // TODO: https://github.com/microsoft/agent-framework/issues/784 // This locking should not be necessary. @@ -58,6 +59,9 @@ protected override RouteBuilder ConfigureRoutes(RouteBuilder routeBuilder) => } }); + return protocolBuilder.YieldsOutput>(); + } + public ValueTask ResetAsync() { this.Reset(); diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/GroupChatHost.cs b/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/GroupChatHost.cs index 76e3f10bd2..b902bf8ef1 100644 --- a/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/GroupChatHost.cs +++ b/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/GroupChatHost.cs @@ -2,6 +2,7 @@ using System; using System.Collections.Generic; +using System.Threading; using System.Threading.Tasks; using Microsoft.Extensions.AI; @@ -11,52 +12,51 @@ internal sealed class GroupChatHost( string id, AIAgent[] agents, Dictionary agentMap, - Func, GroupChatManager> managerFactory) : Executor(id), IResettableExecutor + Func, GroupChatManager> managerFactory) : ChatProtocolExecutor(id, s_options), IResettableExecutor { + private static readonly ChatProtocolExecutorOptions s_options = new() + { + StringMessageChatRole = ChatRole.User, + AutoSendTurnToken = false + }; + private readonly AIAgent[] _agents = agents; private readonly Dictionary _agentMap = agentMap; private readonly Func, GroupChatManager> _managerFactory = managerFactory; - private readonly List _pendingMessages = []; private GroupChatManager? _manager; - protected override RouteBuilder ConfigureRoutes(RouteBuilder routeBuilder) => routeBuilder - .AddHandler((message, context, _) => this._pendingMessages.Add(new(ChatRole.User, message))) - .AddHandler((message, context, _) => this._pendingMessages.Add(message)) - .AddHandler>((messages, _, __) => this._pendingMessages.AddRange(messages)) - .AddHandler((messages, _, __) => this._pendingMessages.AddRange(messages)) // TODO: Remove once https://github.com/microsoft/agent-framework/issues/782 is addressed - .AddHandler>((messages, _, __) => this._pendingMessages.AddRange(messages)) // TODO: Remove once https://github.com/microsoft/agent-framework/issues/782 is addressed - .AddHandler(async (token, context, cancellationToken) => - { - List messages = [.. this._pendingMessages]; - this._pendingMessages.Clear(); + protected override ProtocolBuilder ConfigureProtocol(ProtocolBuilder protocolBuilder) + => base.ConfigureProtocol(protocolBuilder).YieldsOutput>(); + + protected override async ValueTask TakeTurnAsync(List messages, IWorkflowContext context, bool? emitEvents, CancellationToken cancellationToken = default) + { + this._manager ??= this._managerFactory(this._agents); - this._manager ??= this._managerFactory(this._agents); + if (!await this._manager.ShouldTerminateAsync(messages, cancellationToken).ConfigureAwait(false)) + { + var filtered = await this._manager.UpdateHistoryAsync(messages, cancellationToken).ConfigureAwait(false); + messages = filtered is null || ReferenceEquals(filtered, messages) ? messages : [.. filtered]; - if (!await this._manager.ShouldTerminateAsync(messages, cancellationToken).ConfigureAwait(false)) + if (await this._manager.SelectNextAgentAsync(messages, cancellationToken).ConfigureAwait(false) is AIAgent nextAgent && + this._agentMap.TryGetValue(nextAgent, out var executor)) { - var filtered = await this._manager.UpdateHistoryAsync(messages, cancellationToken).ConfigureAwait(false); - messages = filtered is null || ReferenceEquals(filtered, messages) ? messages : [.. filtered]; - - if (await this._manager.SelectNextAgentAsync(messages, cancellationToken).ConfigureAwait(false) is AIAgent nextAgent && - this._agentMap.TryGetValue(nextAgent, out var executor)) - { - this._manager.IterationCount++; - await context.SendMessageAsync(messages, executor.Id, cancellationToken).ConfigureAwait(false); - await context.SendMessageAsync(token, executor.Id, cancellationToken).ConfigureAwait(false); - return; - } + this._manager.IterationCount++; + await context.SendMessageAsync(messages, executor.Id, cancellationToken).ConfigureAwait(false); + await context.SendMessageAsync(new TurnToken(emitEvents), executor.Id, cancellationToken).ConfigureAwait(false); + return; } + } - this._manager = null; - await context.YieldOutputAsync(messages, cancellationToken).ConfigureAwait(false); - }); - - public ValueTask ResetAsync() + this._manager = null; + await context.YieldOutputAsync(messages, cancellationToken).ConfigureAwait(false); + } + protected override ValueTask ResetAsync() { - this._pendingMessages.Clear(); this._manager = null; - return default; + return base.ResetAsync(); } + + ValueTask IResettableExecutor.ResetAsync() => this.ResetAsync(); } diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/HandoffAgentExecutor.cs b/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/HandoffAgentExecutor.cs index 8c608090f3..b49d204d21 100644 --- a/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/HandoffAgentExecutor.cs +++ b/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/HandoffAgentExecutor.cs @@ -4,6 +4,7 @@ using System.Collections.Generic; using System.ComponentModel; using System.Diagnostics; +using System.Linq; using System.Text.Json; using System.Threading; using System.Threading.Tasks; @@ -14,7 +15,7 @@ namespace Microsoft.Agents.AI.Workflows.Specialized; /// Executor used to represent an agent in a handoffs workflow, responding to events. internal sealed class HandoffAgentExecutor( AIAgent agent, - string? handoffInstructions) : Executor(agent.GetDescriptiveId(), declareCrossRunShareable: true), IResettableExecutor + string? handoffInstructions) : Executor(agent.GetDescriptiveId(), declareCrossRunShareable: true), IResettableExecutor { private static readonly JsonElement s_handoffSchema = AIFunctionFactory.Create( ([Description("The reason for the handoff")] string? reasonForHandoff) => { }).JsonSchema; @@ -60,59 +61,56 @@ public void Initialize( sb.WithDefault(end); }); - protected override RouteBuilder ConfigureRoutes(RouteBuilder routeBuilder) => - routeBuilder.AddHandler(async (handoffState, context, cancellationToken) => - { - string? requestedHandoff = null; - List updates = []; - List allMessages = handoffState.Messages; + public override async ValueTask HandleAsync(HandoffState message, IWorkflowContext context, CancellationToken cancellationToken = default) + { + string? requestedHandoff = null; + List updates = []; + List allMessages = message.Messages; - List? roleChanges = allMessages.ChangeAssistantToUserForOtherParticipants(this._agent.Name ?? this._agent.Id); + List? roleChanges = allMessages.ChangeAssistantToUserForOtherParticipants(this._agent.Name ?? this._agent.Id); - await foreach (var update in this._agent.RunStreamingAsync(allMessages, - options: this._agentOptions, - cancellationToken: cancellationToken) - .ConfigureAwait(false)) - { - await AddUpdateAsync(update, cancellationToken).ConfigureAwait(false); + await foreach (var update in this._agent.RunStreamingAsync(allMessages, + options: this._agentOptions, + cancellationToken: cancellationToken) + .ConfigureAwait(false)) + { + await AddUpdateAsync(update, cancellationToken).ConfigureAwait(false); - foreach (var c in update.Contents) - { - if (c is FunctionCallContent fcc && this._handoffFunctionNames.Contains(fcc.Name)) - { - requestedHandoff = fcc.Name; - await AddUpdateAsync( - new AgentResponseUpdate - { - AgentId = this._agent.Id, - AuthorName = this._agent.Name ?? this._agent.Id, - Contents = [new FunctionResultContent(fcc.CallId, "Transferred.")], - CreatedAt = DateTimeOffset.UtcNow, - MessageId = Guid.NewGuid().ToString("N"), - Role = ChatRole.Tool, - }, - cancellationToken - ) - .ConfigureAwait(false); - } - } + foreach (var fcc in update.Contents.OfType() + .Where(fcc => this._handoffFunctionNames.Contains(fcc.Name))) + { + requestedHandoff = fcc.Name; + await AddUpdateAsync( + new AgentResponseUpdate + { + AgentId = this._agent.Id, + AuthorName = this._agent.Name ?? this._agent.Id, + Contents = [new FunctionResultContent(fcc.CallId, "Transferred.")], + CreatedAt = DateTimeOffset.UtcNow, + MessageId = Guid.NewGuid().ToString("N"), + Role = ChatRole.Tool, + }, + cancellationToken + ) + .ConfigureAwait(false); } + } - allMessages.AddRange(updates.ToAgentResponse().Messages); + allMessages.AddRange(updates.ToAgentResponse().Messages); - roleChanges.ResetUserToAssistantForChangedRoles(); + roleChanges.ResetUserToAssistantForChangedRoles(); - await context.SendMessageAsync(new HandoffState(handoffState.TurnToken, requestedHandoff, allMessages), cancellationToken: cancellationToken).ConfigureAwait(false); + return new(message.TurnToken, requestedHandoff, allMessages); - async Task AddUpdateAsync(AgentResponseUpdate update, CancellationToken cancellationToken) + async Task AddUpdateAsync(AgentResponseUpdate update, CancellationToken cancellationToken) + { + updates.Add(update); + if (message.TurnToken.EmitEvents is true) { - updates.Add(update); - if (handoffState.TurnToken.EmitEvents is true) - { - await context.AddEventAsync(new AgentResponseUpdateEvent(this.Id, update), cancellationToken).ConfigureAwait(false); - } + await context.AddEventAsync(new AgentResponseUpdateEvent(this.Id, update), cancellationToken).ConfigureAwait(false); } - }); + } + } public ValueTask ResetAsync() => default; } diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/HandoffsEndExecutor.cs b/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/HandoffsEndExecutor.cs index eeabeb5d5a..69f81376be 100644 --- a/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/HandoffsEndExecutor.cs +++ b/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/HandoffsEndExecutor.cs @@ -1,6 +1,8 @@ // Copyright (c) Microsoft. All rights reserved. +using System.Collections.Generic; using System.Threading.Tasks; +using Microsoft.Extensions.AI; namespace Microsoft.Agents.AI.Workflows.Specialized; @@ -9,9 +11,10 @@ internal sealed class HandoffsEndExecutor() : Executor(ExecutorId, declareCrossR { public const string ExecutorId = "HandoffEnd"; - protected override RouteBuilder ConfigureRoutes(RouteBuilder routeBuilder) => - routeBuilder.AddHandler((handoff, context, cancellationToken) => - context.YieldOutputAsync(handoff.Messages, cancellationToken)); + protected override ProtocolBuilder ConfigureProtocol(ProtocolBuilder protocolBuilder) => + protocolBuilder.ConfigureRoutes(routeBuilder => routeBuilder.AddHandler((handoff, context, cancellationToken) => + context.YieldOutputAsync(handoff.Messages, cancellationToken))) + .YieldsOutput>(); public ValueTask ResetAsync() => default; } diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/HandoffsStartExecutor.cs b/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/HandoffsStartExecutor.cs index 982b8aabf2..9039e86f5b 100644 --- a/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/HandoffsStartExecutor.cs +++ b/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/HandoffsStartExecutor.cs @@ -14,9 +14,13 @@ internal sealed class HandoffsStartExecutor() : ChatProtocolExecutor(ExecutorId, private static ChatProtocolExecutorOptions DefaultOptions => new() { - StringMessageChatRole = ChatRole.User + StringMessageChatRole = ChatRole.User, + AutoSendTurnToken = false }; + protected override ProtocolBuilder ConfigureProtocol(ProtocolBuilder protocolBuilder) => + base.ConfigureProtocol(protocolBuilder).SendsMessage(); + protected override ValueTask TakeTurnAsync(List messages, IWorkflowContext context, bool? emitEvents, CancellationToken cancellationToken = default) => context.SendMessageAsync(new HandoffState(new(emitEvents), null, messages), cancellationToken: cancellationToken); diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/OutputMessagesExecutor.cs b/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/OutputMessagesExecutor.cs index b3c714406d..17d0ffebc9 100644 --- a/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/OutputMessagesExecutor.cs +++ b/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/OutputMessagesExecutor.cs @@ -15,6 +15,10 @@ internal sealed class OutputMessagesExecutor(ChatProtocolExecutorOptions? option { public const string ExecutorId = "OutputMessages"; + protected override ProtocolBuilder ConfigureProtocol(ProtocolBuilder protocolBuilder) => + base.ConfigureProtocol(protocolBuilder) + .YieldsOutput>(); + protected override ValueTask TakeTurnAsync(List messages, IWorkflowContext context, bool? emitEvents, CancellationToken cancellationToken = default) => context.YieldOutputAsync(messages, cancellationToken); diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/RequestInfoExecutor.cs b/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/RequestInfoExecutor.cs index 3dda4a85c6..b35d682f2c 100644 --- a/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/RequestInfoExecutor.cs +++ b/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/RequestInfoExecutor.cs @@ -34,22 +34,29 @@ public RequestInfoExecutor(RequestPort port, bool allowWrapped = true) : base(po this._allowWrapped = allowWrapped; } - protected override RouteBuilder ConfigureRoutes(RouteBuilder routeBuilder) + protected override ProtocolBuilder ConfigureProtocol(ProtocolBuilder protocolBuilder) { - routeBuilder = routeBuilder - // Handle incoming requests (as raw request payloads) - .AddHandlerUntyped(this.Port.Request, this.HandleAsync) - .AddCatchAll(this.HandleCatchAllAsync); + return protocolBuilder.ConfigureRoutes(ConfigureRoutes) + .SendsMessage() + .SendsMessageType(this.Port.Response); - if (this._allowWrapped) + void ConfigureRoutes(RouteBuilder routeBuilder) { routeBuilder = routeBuilder - .AddHandler(this.HandleAsync); + // Handle incoming requests (as raw request payloads) + .AddHandlerUntyped(this.Port.Request, this.HandleAsync) + .AddCatchAll(this.HandleCatchAllAsync); + + if (this._allowWrapped) + { + routeBuilder = routeBuilder + .AddHandler(this.HandleAsync); + } + + routeBuilder + // Handle incoming responses (as wrapped Response object) + .AddHandler(this.HandleAsync); } - - return routeBuilder - // Handle incoming responses (as wrapped Response object) - .AddHandler(this.HandleAsync); } internal void AttachRequestSink(IExternalRequestSink requestSink) => this.RequestSink = Throw.IfNull(requestSink); diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/WorkflowHostExecutor.cs b/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/WorkflowHostExecutor.cs index ab8a499a75..df4af62bb8 100644 --- a/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/WorkflowHostExecutor.cs +++ b/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/WorkflowHostExecutor.cs @@ -17,6 +17,7 @@ internal class WorkflowHostExecutor : Executor, IAsyncDisposable { private readonly string _runId; private readonly Workflow _workflow; + private readonly ProtocolDescriptor _workflowProtocol; private readonly object _ownershipToken; private InProcessRunner? _activeRunner; @@ -30,19 +31,26 @@ internal class WorkflowHostExecutor : Executor, IAsyncDisposable [MemberNotNullWhen(true, nameof(_checkpointManager))] private bool WithCheckpointing => this._checkpointManager != null; - public WorkflowHostExecutor(string id, Workflow workflow, string runId, object ownershipToken, ExecutorOptions? options = null) : base(id, options) + public WorkflowHostExecutor(string id, Workflow workflow, ProtocolDescriptor workflowProtocol, string runId, object ownershipToken, ExecutorOptions? options = null) : base(id, options) { this._options = options ?? new(); - Throw.IfNull(workflow); + //Throw.IfNull(workflow); this._runId = Throw.IfNull(runId); this._ownershipToken = Throw.IfNull(ownershipToken); this._workflow = Throw.IfNull(workflow); + this._workflowProtocol = Throw.IfNull(workflowProtocol); } - protected override RouteBuilder ConfigureRoutes(RouteBuilder routeBuilder) + protected override ProtocolBuilder ConfigureProtocol(ProtocolBuilder protocolBuilder) { - return routeBuilder.AddCatchAll(this.QueueExternalMessageAsync); + if (this._options.AutoYieldOutputHandlerResultObject) + { + protocolBuilder = protocolBuilder.YieldsOutputTypes(this._workflowProtocol.Yields); + } + + return protocolBuilder.ConfigureRoutes(routeBuilder => routeBuilder.AddCatchAll(this.QueueExternalMessageAsync)) + .SendsMessageTypes(this._workflowProtocol.Yields); } private async ValueTask QueueExternalMessageAsync(PortableValue portableValue, IWorkflowContext context, CancellationToken cancellationToken) diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows/StatefulExecutor.cs b/dotnet/src/Microsoft.Agents.AI.Workflows/StatefulExecutor.cs index 234958a98a..3ed23cc019 100644 --- a/dotnet/src/Microsoft.Agents.AI.Workflows/StatefulExecutor.cs +++ b/dotnet/src/Microsoft.Agents.AI.Workflows/StatefulExecutor.cs @@ -3,6 +3,7 @@ #pragma warning disable CS0618 // Type or member is obsolete - Internal use of obsolete types for backward compatibility using System; +using System.Collections.Generic; using System.Threading; using System.Threading.Tasks; using Microsoft.Agents.AI.Workflows.Reflection; @@ -136,7 +137,7 @@ await context.InvokeWithStateAsync(invocation, } /// - protected ValueTask ResetAsync() + protected virtual ValueTask ResetAsync() { this._stateCache = this._initialStateFactory(); @@ -153,13 +154,25 @@ protected ValueTask ResetAsync() /// A unique identifier for the executor. /// A factory to initialize the state value to be used by the executor. /// Configuration options for the executor. If null, default options will be used. +/// Message types sent by the handler. Defaults to empty, and will filter out non-matching messages. +/// Message types yielded as output by the handler. Defaults to empty. /// Declare that this executor may be used simultaneously by multiple runs safely. -public abstract class StatefulExecutor(string id, Func initialStateFactory, StatefulExecutorOptions? options = null, bool declareCrossRunShareable = false) +public abstract class StatefulExecutor(string id, + Func initialStateFactory, + StatefulExecutorOptions? options = null, + IEnumerable? sentMessageTypes = null, + IEnumerable? outputTypes = null, + bool declareCrossRunShareable = false) : StatefulExecutor(id, initialStateFactory, options, declareCrossRunShareable), IMessageHandler { /// - protected override RouteBuilder ConfigureRoutes(RouteBuilder routeBuilder) => - routeBuilder.AddHandler(this.HandleAsync); + protected override ProtocolBuilder ConfigureProtocol(ProtocolBuilder protocolBuilder) + { + protocolBuilder.RouteBuilder.AddHandler(this.HandleAsync); + + return protocolBuilder.SendsMessageTypes(sentMessageTypes ?? []) + .YieldsOutputTypes(outputTypes ?? []); + } /// public abstract ValueTask HandleAsync(TInput message, IWorkflowContext context, CancellationToken cancellationToken = default); @@ -175,13 +188,35 @@ protected override RouteBuilder ConfigureRoutes(RouteBuilder routeBuilder) => /// A unique identifier for the executor. /// A factory to initialize the state value to be used by the executor. /// Configuration options for the executor. If null, default options will be used. +/// Message types sent by the handler. Defaults to empty, and will filter out non-matching messages. +/// Message types yielded as output by the handler. Defaults to empty. /// Declare that this executor may be used simultaneously by multiple runs safely. -public abstract class StatefulExecutor(string id, Func initialStateFactory, StatefulExecutorOptions? options = null, bool declareCrossRunShareable = false) +public abstract class StatefulExecutor(string id, + Func initialStateFactory, + StatefulExecutorOptions? options = null, + IEnumerable? sentMessageTypes = null, + IEnumerable? outputTypes = null, + bool declareCrossRunShareable = false) : StatefulExecutor(id, initialStateFactory, options, declareCrossRunShareable), IMessageHandler + where TOutput : notnull { /// - protected override RouteBuilder ConfigureRoutes(RouteBuilder routeBuilder) => - routeBuilder.AddHandler(this.HandleAsync); + protected override ProtocolBuilder ConfigureProtocol(ProtocolBuilder protocolBuilder) + { + protocolBuilder.RouteBuilder.AddHandler(this.HandleAsync); + + if (this.Options.AutoSendMessageHandlerResultObject) + { + protocolBuilder.SendsMessage(); + } + + if (this.Options.AutoYieldOutputHandlerResultObject) + { + protocolBuilder.YieldsOutput(); + } + + return protocolBuilder.SendsMessageTypes(sentMessageTypes ?? []).YieldsOutputTypes(outputTypes ?? []); + } /// public abstract ValueTask HandleAsync(TInput message, IWorkflowContext context, CancellationToken cancellationToken = default); diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows/SubworkflowBinding.cs b/dotnet/src/Microsoft.Agents.AI.Workflows/SubworkflowBinding.cs index 389aa19afc..8d67e62e4a 100644 --- a/dotnet/src/Microsoft.Agents.AI.Workflows/SubworkflowBinding.cs +++ b/dotnet/src/Microsoft.Agents.AI.Workflows/SubworkflowBinding.cs @@ -27,9 +27,11 @@ private static Func> CreateWorkflowExecutorFactory(W return InitHostExecutorAsync; - ValueTask InitHostExecutorAsync(string runId) + async ValueTask InitHostExecutorAsync(string runId) { - return new(new WorkflowHostExecutor(id, workflow, runId, ownershipToken, options)); + ProtocolDescriptor workflowProtocol = await workflow.DescribeProtocolAsync().ConfigureAwait(false); + + return new WorkflowHostExecutor(id, workflow, workflowProtocol, runId, ownershipToken, options); } } diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows/Workflow.cs b/dotnet/src/Microsoft.Agents.AI.Workflows/Workflow.cs index 6b26a403cf..eff1cfb9a3 100644 --- a/dotnet/src/Microsoft.Agents.AI.Workflows/Workflow.cs +++ b/dotnet/src/Microsoft.Agents.AI.Workflows/Workflow.cs @@ -218,8 +218,14 @@ public async ValueTask DescribeProtocolAsync(CancellationTok ExecutorBinding startExecutorRegistration = this.ExecutorBindings[this.StartExecutorId]; Executor startExecutor = await startExecutorRegistration.CreateInstanceAsync(string.Empty) .ConfigureAwait(false); - startExecutor.Configure(new NoOpExternalRequestContext()); + startExecutor.AttachRequestContext(new NoOpExternalRequestContext()); - return startExecutor.DescribeProtocol(); + ProtocolDescriptor inputProtocol = startExecutor.DescribeProtocol(); + IEnumerable> outputExecutorTasks = this.OutputExecutors.Select(executorId => this.ExecutorBindings[executorId].CreateInstanceAsync(string.Empty).AsTask()); + + Executor[] outputExecutors = await Task.WhenAll(outputExecutorTasks).ConfigureAwait(false); + IEnumerable yieldedTypes = outputExecutors.SelectMany(executor => executor.DescribeProtocol().Yields); + + return new(inputProtocol.Accepts, yieldedTypes, [], inputProtocol.AcceptsAll); } } diff --git a/dotnet/tests/Microsoft.Agents.AI.DevUI.UnitTests/DevUIIntegrationTests.cs b/dotnet/tests/Microsoft.Agents.AI.DevUI.UnitTests/DevUIIntegrationTests.cs index b8512a856e..d39839297e 100644 --- a/dotnet/tests/Microsoft.Agents.AI.DevUI.UnitTests/DevUIIntegrationTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.DevUI.UnitTests/DevUIIntegrationTests.cs @@ -17,9 +17,9 @@ public class DevUIIntegrationTests { private sealed class NoOpExecutor(string id) : Executor(id) { - protected override RouteBuilder ConfigureRoutes(RouteBuilder routeBuilder) => - routeBuilder.AddHandler( - (msg, ctx) => ctx.SendMessageAsync(msg)); + protected override ProtocolBuilder ConfigureProtocol(ProtocolBuilder protocolBuilder) + => protocolBuilder.ConfigureRoutes(routeBuilder => + routeBuilder.AddHandler((msg, ctx) => ctx.SendMessageAsync(msg))); } [Fact] diff --git a/dotnet/tests/Microsoft.Agents.AI.Workflows.Declarative.UnitTests/ObjectModel/WorkflowActionExecutorTest.cs b/dotnet/tests/Microsoft.Agents.AI.Workflows.Declarative.UnitTests/ObjectModel/WorkflowActionExecutorTest.cs index cff93904b8..7c6351fdb6 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Workflows.Declarative.UnitTests/ObjectModel/WorkflowActionExecutorTest.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Workflows.Declarative.UnitTests/ObjectModel/WorkflowActionExecutorTest.cs @@ -113,6 +113,7 @@ protected static TAction AssignParent(DialogAction.Builder actionBuilde internal sealed class TestWorkflowExecutor() : Executor("test_workflow") { + [SendsMessage(typeof(ActionExecutorResult))] public override async ValueTask HandleAsync(WorkflowFormulaState message, IWorkflowContext context, CancellationToken cancellationToken) => await context.SendResultMessageAsync(this.Id, cancellationToken).ConfigureAwait(false); } diff --git a/dotnet/tests/Microsoft.Agents.AI.Workflows.Generators.UnitTests/ExecutorRouteGeneratorTests.cs b/dotnet/tests/Microsoft.Agents.AI.Workflows.Generators.UnitTests/ExecutorRouteGeneratorTests.cs index c48ba9ffdf..d2160486cc 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Workflows.Generators.UnitTests/ExecutorRouteGeneratorTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Workflows.Generators.UnitTests/ExecutorRouteGeneratorTests.cs @@ -38,9 +38,9 @@ private void HandleMessage(string message, IWorkflowContext context) result.RunResult.GeneratedTrees.Should().HaveCount(1); - var generated = result.RunResult.GeneratedTrees[0].ToString(); - generated.Should().Contain("protected override RouteBuilder ConfigureRoutes(RouteBuilder routeBuilder)"); - generated.Should().Contain(".AddHandler(this.HandleMessage)"); + var generated = result.RunResult.GeneratedTrees[0]; + + generated.Should().AddHandler("this.HandleMessage", "string"); } [Fact] @@ -205,9 +205,9 @@ private void HandleMessage(string message, IWorkflowContext context) { } result.RunResult.GeneratedTrees.Should().HaveCount(1); - var generated = result.RunResult.GeneratedTrees[0].ToString(); - generated.Should().Contain("protected override ISet ConfigureYieldTypes()"); - generated.Should().Contain("types.Add(typeof(global::TestNamespace.OutputMessage))"); + var generated = result.RunResult.GeneratedTrees[0]; + + generated.Should().RegisterYieldedOutputType("global::TestNamespace.OutputMessage"); } [Fact] @@ -236,9 +236,8 @@ private void HandleMessage(string message, IWorkflowContext context) { } result.RunResult.GeneratedTrees.Should().HaveCount(1); - var generated = result.RunResult.GeneratedTrees[0].ToString(); - generated.Should().Contain("protected override ISet ConfigureSentTypes()"); - generated.Should().Contain("types.Add(typeof(global::TestNamespace.SendMessage))"); + var generated = result.RunResult.GeneratedTrees[0]; + generated.Should().RegisterSentMessageType("global::TestNamespace.SendMessage"); } [Fact] @@ -268,9 +267,8 @@ private void HandleMessage(string message, IWorkflowContext context) { } result.RunResult.GeneratedTrees.Should().HaveCount(1); - var generated = result.RunResult.GeneratedTrees[0].ToString(); - generated.Should().Contain("protected override ISet ConfigureSentTypes()"); - generated.Should().Contain("types.Add(typeof(global::TestNamespace.BroadcastMessage))"); + var generated = result.RunResult.GeneratedTrees[0]; + generated.Should().RegisterSentMessageType("global::TestNamespace.BroadcastMessage"); } [Fact] @@ -300,9 +298,8 @@ private void HandleMessage(string message, IWorkflowContext context) { } result.RunResult.GeneratedTrees.Should().HaveCount(1); - var generated = result.RunResult.GeneratedTrees[0].ToString(); - generated.Should().Contain("protected override ISet ConfigureYieldTypes()"); - generated.Should().Contain("types.Add(typeof(global::TestNamespace.YieldedMessage))"); + var generated = result.RunResult.GeneratedTrees[0]; + generated.Should().RegisterYieldedOutputType("global::TestNamespace.YieldedMessage"); } #endregion @@ -336,20 +333,10 @@ private void HandleMessage(string message, IWorkflowContext context) { } result.RunResult.GeneratedTrees.Should().HaveCount(1); result.RunResult.Diagnostics.Should().BeEmpty(); - var generated = result.RunResult.GeneratedTrees[0].ToString(); - - // Verify partial declarations are present - generated.Should().Contain("partial class OuterClass"); - generated.Should().Contain("partial class TestExecutor"); - - // Verify proper nesting structure with braces - // The outer class should open before the inner class - var outerIndex = generated.IndexOf("partial class OuterClass", StringComparison.Ordinal); - var innerIndex = generated.IndexOf("partial class TestExecutor", StringComparison.Ordinal); - outerIndex.Should().BeLessThan(innerIndex, "outer class should appear before inner class"); + var generated = result.RunResult.GeneratedTrees[0]; - // Verify handler registration is present - generated.Should().Contain(".AddHandler(this.HandleMessage)"); + generated.Should().HaveHierarchy("OuterClass", "TestExecutor") + .And.AddHandler("this.HandleMessage", "string"); } [Fact] @@ -382,22 +369,10 @@ private void HandleMessage(string message, IWorkflowContext context) { } result.RunResult.GeneratedTrees.Should().HaveCount(1); result.RunResult.Diagnostics.Should().BeEmpty(); - var generated = result.RunResult.GeneratedTrees[0].ToString(); - - // Verify all three partial declarations are present in correct order - generated.Should().Contain("partial class Outer"); - generated.Should().Contain("partial class Inner"); - generated.Should().Contain("partial class TestExecutor"); - - var outerIndex = generated.IndexOf("partial class Outer", StringComparison.Ordinal); - var innerIndex = generated.IndexOf("partial class Inner", StringComparison.Ordinal); - var executorIndex = generated.IndexOf("partial class TestExecutor", StringComparison.Ordinal); + var generated = result.RunResult.GeneratedTrees[0]; - outerIndex.Should().BeLessThan(innerIndex, "Outer should appear before Inner"); - innerIndex.Should().BeLessThan(executorIndex, "Inner should appear before TestExecutor"); - - // Verify handler registration - generated.Should().Contain(".AddHandler(this.HandleMessage)"); + generated.Should().HaveHierarchy("Outer", "Inner", "TestExecutor") + .And.AddHandler("this.HandleMessage", "string"); } [Fact] @@ -433,26 +408,10 @@ private void HandleMessage(int message, IWorkflowContext context) { } result.RunResult.GeneratedTrees.Should().HaveCount(1); result.RunResult.Diagnostics.Should().BeEmpty(); - var generated = result.RunResult.GeneratedTrees[0].ToString(); - - // All four partial class declarations should be present - generated.Should().Contain("partial class Level1"); - generated.Should().Contain("partial class Level2"); - generated.Should().Contain("partial class Level3"); - generated.Should().Contain("partial class TestExecutor"); - - // Verify correct ordering - var level1Index = generated.IndexOf("partial class Level1", StringComparison.Ordinal); - var level2Index = generated.IndexOf("partial class Level2", StringComparison.Ordinal); - var level3Index = generated.IndexOf("partial class Level3", StringComparison.Ordinal); - var executorIndex = generated.IndexOf("partial class TestExecutor", StringComparison.Ordinal); + var generated = result.RunResult.GeneratedTrees[0]; - level1Index.Should().BeLessThan(level2Index); - level2Index.Should().BeLessThan(level3Index); - level3Index.Should().BeLessThan(executorIndex); - - // Verify handler registration - generated.Should().Contain(".AddHandler(this.HandleMessage)"); + generated.Should().HaveHierarchy("Level1", "Level2", "Level3", "TestExecutor") + .And.AddHandler("this.HandleMessage", "int"); } [Fact] @@ -480,15 +439,11 @@ private void HandleMessage(string message, IWorkflowContext context) { } result.RunResult.GeneratedTrees.Should().HaveCount(1); result.RunResult.Diagnostics.Should().BeEmpty(); - var generated = result.RunResult.GeneratedTrees[0].ToString(); - - // Should not contain namespace declaration - generated.Should().NotContain("namespace "); + var generated = result.RunResult.GeneratedTrees[0]; - // Should still have proper partial hierarchy - generated.Should().Contain("partial class OuterClass"); - generated.Should().Contain("partial class TestExecutor"); - generated.Should().Contain(".AddHandler(this.HandleMessage)"); + generated.Should().NotHaveNamespace() + .And.HaveHierarchy("OuterClass", "TestExecutor") + .And.AddHandler("this.HandleMessage", "string"); } [Fact] @@ -576,7 +531,7 @@ private void HandleMessage(string message, IWorkflowContext context) { } // - 1 for Outer class // - 1 for Inner class // - 1 for TestExecutor class - // - 1 for ConfigureRoutes method + // - 1 for ConfigureProtocol method // = 4 pairs minimum openBraces.Should().BeGreaterThanOrEqualTo(4, "should have braces for all nested classes and method"); } @@ -633,11 +588,11 @@ private ValueTask HandleIntAsync(int message, IWorkflowContext context) result.RunResult.GeneratedTrees.Should().HaveCount(1); result.RunResult.Diagnostics.Should().BeEmpty(); - var generated = result.RunResult.GeneratedTrees[0].ToString(); + var generated = result.RunResult.GeneratedTrees[0]; // Should have both handlers registered - generated.Should().Contain(".AddHandler(this.HandleString)"); - generated.Should().Contain(".AddHandler(this.HandleIntAsync)"); + generated.Should().AddHandler("this.HandleString", "string") + .And.AddHandler("this.HandleIntAsync", "int"); // Verify the generated code compiles with all three partials combined var compilationErrors = result.OutputCompilation.GetDiagnostics() @@ -688,11 +643,11 @@ private void HandleFromFile2(int message, IWorkflowContext context) { } result.RunResult.GeneratedTrees.Should().HaveCount(1); result.RunResult.Diagnostics.Should().BeEmpty(); - var generated = result.RunResult.GeneratedTrees[0].ToString(); + var generated = result.RunResult.GeneratedTrees[0]; // Both handlers from different files should be registered - generated.Should().Contain(".AddHandler(this.HandleFromFile1)"); - generated.Should().Contain(".AddHandler(this.HandleFromFile2)"); + generated.Should().AddHandler("this.HandleFromFile1", "string") + .And.AddHandler("this.HandleFromFile2", "int"); } [Fact] @@ -739,29 +694,13 @@ private void HandleFromFile2(int message, IWorkflowContext context) { } result.RunResult.GeneratedTrees.Should().HaveCount(1); result.RunResult.Diagnostics.Should().BeEmpty(); - var generated = result.RunResult.GeneratedTrees[0].ToString(); - - // Verify ConfigureSentTypes override - var sendsStart = generated.IndexOf("protected override ISet ConfigureSentTypes()", StringComparison.Ordinal); - sendsStart.Should().NotBe(-1, "should generate ConfigureSentTypes override"); - - var sendsEnd = generated.IndexOf("}", sendsStart, StringComparison.Ordinal); - sendsEnd.Should().NotBe(-1, "should close ConfigureSentTypes override"); + var generated = result.RunResult.GeneratedTrees[0]; - generated.Substring(sendsStart, sendsEnd - sendsStart).Should().ContainAll( - "types.Add(typeof(string));", - "types.Add(typeof(int));"); - - // Verify ConfigureYieldTypes override - var yieldsStart = generated.IndexOf("protected override ISet ConfigureYieldTypes()", StringComparison.Ordinal); - yieldsStart.Should().NotBe(-1, "should generate ConfigureYieldTypes override"); - - var yieldsEnd = generated.IndexOf("}", yieldsStart, StringComparison.Ordinal); - yieldsEnd.Should().NotBe(-1, "should close ConfigureYieldTypes override"); - - generated.Substring(yieldsStart, yieldsEnd - yieldsStart).Should().ContainAll( - "types.Add(typeof(string));", - "types.Add(typeof(int));"); + // Verify SendsMessage and YieldsOutput from both partials are combined correctly + generated.Should().RegisterSentMessageType("string") + .And.RegisterSentMessageType("int") + .And.RegisterYieldedOutputType("string") + .And.RegisterYieldedOutputType("string"); } #endregion @@ -896,7 +835,7 @@ private void HandleMessage(string message, string notContext) { } #region No Generation Tests [Fact] - public void ClassWithManualConfigureRoutes_DoesNotGenerate() + public void ClassWithManualConfigureProtocol_DoesNotGenerate() { var source = """ using System.Threading; @@ -909,9 +848,9 @@ public partial class TestExecutor : Executor { public TestExecutor() : base("test") { } - protected override RouteBuilder ConfigureRoutes(RouteBuilder routeBuilder) + protected override ProtocolBuilder ConfigureProtocol(ProtocolBuilder protocolBuilder) { - return routeBuilder; + return protocolBuilder; } [MessageHandler] @@ -953,130 +892,6 @@ private void SomeOtherMethod(string message, IWorkflowContext context) { } #region Protocol-Only Generation Tests - [Fact] - public void ProtocolOnly_SendsMessage_WithManualRoutes_GeneratesConfigureSentTypes() - { - var source = """ - using System; - using System.Threading; - using System.Threading.Tasks; - using Microsoft.Agents.AI.Workflows; - - namespace TestNamespace; - - public class BroadcastMessage { } - - [SendsMessage(typeof(BroadcastMessage))] - public partial class TestExecutor : Executor - { - public TestExecutor() : base("test") { } - - protected override RouteBuilder ConfigureRoutes(RouteBuilder routeBuilder) - { - return routeBuilder; - } - } - """; - - var result = GeneratorTestHelper.RunGenerator(source); - - result.RunResult.GeneratedTrees.Should().HaveCount(1); - result.RunResult.Diagnostics.Should().BeEmpty(); - - var generated = result.RunResult.GeneratedTrees[0].ToString(); - - // Should NOT generate ConfigureRoutes (user has manual implementation) - generated.Should().NotContain("protected override RouteBuilder ConfigureRoutes"); - - // Should generate ConfigureSentTypes - generated.Should().Contain("protected override ISet ConfigureSentTypes()"); - generated.Should().Contain("types.Add(typeof(global::TestNamespace.BroadcastMessage))"); - } - - [Fact] - public void ProtocolOnly_YieldsOutput_WithManualRoutes_GeneratesConfigureYieldTypes() - { - var source = """ - using System; - using System.Threading; - using System.Threading.Tasks; - using Microsoft.Agents.AI.Workflows; - - namespace TestNamespace; - - public class OutputMessage { } - - [YieldsOutput(typeof(OutputMessage))] - public partial class TestExecutor : Executor - { - public TestExecutor() : base("test") { } - - protected override RouteBuilder ConfigureRoutes(RouteBuilder routeBuilder) - { - return routeBuilder; - } - } - """; - - var result = GeneratorTestHelper.RunGenerator(source); - - result.RunResult.GeneratedTrees.Should().HaveCount(1); - result.RunResult.Diagnostics.Should().BeEmpty(); - - var generated = result.RunResult.GeneratedTrees[0].ToString(); - - // Should NOT generate ConfigureRoutes (user has manual implementation) - generated.Should().NotContain("protected override RouteBuilder ConfigureRoutes"); - - // Should generate ConfigureYieldTypes - generated.Should().Contain("protected override ISet ConfigureYieldTypes()"); - generated.Should().Contain("types.Add(typeof(global::TestNamespace.OutputMessage))"); - } - - [Fact] - public void ProtocolOnly_BothAttributes_WithManualRoutes_GeneratesBothOverrides() - { - var source = """ - using System; - using System.Threading; - using System.Threading.Tasks; - using Microsoft.Agents.AI.Workflows; - - namespace TestNamespace; - - public class SendMessage { } - public class YieldMessage { } - - [SendsMessage(typeof(SendMessage))] - [YieldsOutput(typeof(YieldMessage))] - public partial class TestExecutor : Executor - { - public TestExecutor() : base("test") { } - - protected override RouteBuilder ConfigureRoutes(RouteBuilder routeBuilder) - { - return routeBuilder; - } - } - """; - - var result = GeneratorTestHelper.RunGenerator(source); - - result.RunResult.GeneratedTrees.Should().HaveCount(1); - result.RunResult.Diagnostics.Should().BeEmpty(); - - var generated = result.RunResult.GeneratedTrees[0].ToString(); - - // Should NOT generate ConfigureRoutes - generated.Should().NotContain("protected override RouteBuilder ConfigureRoutes"); - - // Should generate both protocol overrides - generated.Should().Contain("protected override ISet ConfigureSentTypes()"); - generated.Should().Contain("types.Add(typeof(global::TestNamespace.SendMessage))"); - generated.Should().Contain("protected override ISet ConfigureYieldTypes()"); - generated.Should().Contain("types.Add(typeof(global::TestNamespace.YieldMessage))"); - } - [Fact] public void ProtocolOnly_MultipleSendsMessageAttributes_GeneratesAllTypes() { @@ -1098,11 +913,6 @@ public class MessageC { } public partial class TestExecutor : Executor { public TestExecutor() : base("test") { } - - protected override RouteBuilder ConfigureRoutes(RouteBuilder routeBuilder) - { - return routeBuilder; - } } """; @@ -1110,10 +920,11 @@ protected override RouteBuilder ConfigureRoutes(RouteBuilder routeBuilder) result.RunResult.GeneratedTrees.Should().HaveCount(1); - var generated = result.RunResult.GeneratedTrees[0].ToString(); - generated.Should().Contain("types.Add(typeof(global::TestNamespace.MessageA))"); - generated.Should().Contain("types.Add(typeof(global::TestNamespace.MessageB))"); - generated.Should().Contain("types.Add(typeof(global::TestNamespace.MessageC))"); + var generated = result.RunResult.GeneratedTrees[0]; + + generated.Should().RegisterSentMessageType("global::TestNamespace.MessageA") + .And.RegisterSentMessageType("global::TestNamespace.MessageB") + .And.RegisterSentMessageType("global::TestNamespace.MessageC"); } [Fact] @@ -1133,11 +944,6 @@ public class BroadcastMessage { } public class TestExecutor : Executor { public TestExecutor() : base("test") { } - - protected override RouteBuilder ConfigureRoutes(RouteBuilder routeBuilder) - { - return routeBuilder; - } } """; @@ -1193,11 +999,6 @@ public partial class OuterClass public partial class TestExecutor : Executor { public TestExecutor() : base("test") { } - - protected override RouteBuilder ConfigureRoutes(RouteBuilder routeBuilder) - { - return routeBuilder; - } } } """; @@ -1207,14 +1008,12 @@ protected override RouteBuilder ConfigureRoutes(RouteBuilder routeBuilder) result.RunResult.GeneratedTrees.Should().HaveCount(1); result.RunResult.Diagnostics.Should().BeEmpty(); - var generated = result.RunResult.GeneratedTrees[0].ToString(); + var generated = result.RunResult.GeneratedTrees[0]; // Verify partial declarations are present - generated.Should().Contain("partial class OuterClass"); - generated.Should().Contain("partial class TestExecutor"); - + generated.Should().HaveHierarchy("OuterClass", "TestExecutor") // Verify protocol types are generated - generated.Should().Contain("types.Add(typeof(global::TestNamespace.BroadcastMessage))"); + .And.RegisterSentMessageType("global::TestNamespace.BroadcastMessage"); } [Fact] @@ -1234,11 +1033,6 @@ public class BroadcastMessage { } public partial class GenericExecutor : Executor where T : class { public GenericExecutor() : base("generic") { } - - protected override RouteBuilder ConfigureRoutes(RouteBuilder routeBuilder) - { - return routeBuilder; - } } """; @@ -1246,9 +1040,10 @@ protected override RouteBuilder ConfigureRoutes(RouteBuilder routeBuilder) result.RunResult.GeneratedTrees.Should().HaveCount(1); - var generated = result.RunResult.GeneratedTrees[0].ToString(); - generated.Should().Contain("partial class GenericExecutor"); - generated.Should().Contain("types.Add(typeof(global::TestNamespace.BroadcastMessage))"); + var generated = result.RunResult.GeneratedTrees[0]; + + generated.Should().HaveHierarchy("GenericExecutor") + .And.RegisterSentMessageType("global::TestNamespace.BroadcastMessage"); } #endregion @@ -1278,9 +1073,10 @@ private void HandleMessage(T message, IWorkflowContext context) { } result.RunResult.GeneratedTrees.Should().HaveCount(1); - var generated = result.RunResult.GeneratedTrees[0].ToString(); - generated.Should().Contain("partial class GenericExecutor"); - generated.Should().Contain(".AddHandler(this.HandleMessage)"); + var generated = result.RunResult.GeneratedTrees[0]; + + generated.Should().HaveHierarchy("GenericExecutor") + .And.AddHandler("this.HandleMessage", "T"); } #endregion diff --git a/dotnet/tests/Microsoft.Agents.AI.Workflows.Generators.UnitTests/SyntaxTreeFluentExtensions.cs b/dotnet/tests/Microsoft.Agents.AI.Workflows.Generators.UnitTests/SyntaxTreeFluentExtensions.cs new file mode 100644 index 0000000000..3da1e7d891 --- /dev/null +++ b/dotnet/tests/Microsoft.Agents.AI.Workflows.Generators.UnitTests/SyntaxTreeFluentExtensions.cs @@ -0,0 +1,220 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using FluentAssertions; +using FluentAssertions.Execution; +using FluentAssertions.Primitives; +using Microsoft.CodeAnalysis; + +namespace Microsoft.Agents.AI.Workflows.Generators.UnitTests; + +internal sealed class SyntaxTreeAssertions : ObjectAssertions +{ + private readonly string _syntaxString; + + public SyntaxTreeAssertions(SyntaxTree instance, AssertionChain assertionChain) : base(instance, assertionChain) + { + this._syntaxString = instance.ToString(); + } + + public AndConstraint AddHandler(string handlerName) + { + string expectedRegistration = $".AddHandler({handlerName})"; + + this.CurrentAssertionChain + .ForCondition(this._syntaxString.Contains(expectedRegistration)) + .BecauseOf($"expected handler {handlerName} to be registered") + .FailWith("Expected {context} to contain handler registration {0}{reason}, but it was not found. Actual syntax: {1}", + expectedRegistration, this._syntaxString); + + return new(this); + } + + public AndConstraint AddHandler(string handlerName, string inTypeParam) + { + string expectedRegistration = $".AddHandler<{inTypeParam}>({handlerName})"; + + this.CurrentAssertionChain + .ForCondition(this._syntaxString.Contains(expectedRegistration)) + .BecauseOf($"expected handler {handlerName} to be registered") + .FailWith("Expected {context} to contain handler registration {0}{reason}, but it was not found. Actual syntax: {1}", + expectedRegistration, this._syntaxString); + + return new(this); + } + + public AndConstraint AddHandler(string handlerName, string inTypeParam, string outTypeParam) + { + string expectedRegistration = $".AddHandler<{inTypeParam},{outTypeParam}>({handlerName})"; + + this.CurrentAssertionChain + .ForCondition(this._syntaxString.Contains(expectedRegistration)) + .BecauseOf($"expected handler {handlerName} to be registered") + .FailWith("Expected {context} to contain handler registration {0}{reason}, but it was not found. Actual syntax: {1}", + expectedRegistration, this._syntaxString); + + return new(this); + } + + public AndConstraint AddHandler(string handlerName, bool globalQualified = false) + { + Type inType = typeof(TIn); + string inTypeParam = globalQualified ? $"global::{inType.FullName}" : inType.Name; + return this.AddHandler(handlerName, inTypeParam); + } + + public AndConstraint AddHandler(string handlerName, bool globalQualified = false) + { + Type inType = typeof(TIn), outType = typeof(TOut); + string inTypeParam = globalQualified ? $"global::{inType.FullName}" : inType.Name; + string outTypeParam = globalQualified ? $"global::{outType.FullName}" : outType.Name; + return this.AddHandler(handlerName, inTypeParam, outTypeParam); + } + + public AndConstraint HaveNoHandlers() + { + this.CurrentAssertionChain + .ForCondition(!this._syntaxString.Contains(".AddHandler(")) + .BecauseOf("expected no handlers to be registered") + .FailWith("Expected {context} to have no handler registrations{reason}, but found at least one. Actual syntax: {1}", + this._syntaxString); + + return new(this); + } + + public AndConstraint RegisterSentMessageType(string messageTypeParam) + { + string expectedRegistration = $".SendsMessage<{messageTypeParam}>()"; + + this.CurrentAssertionChain + .ForCondition(this._syntaxString.Contains(expectedRegistration)) + .BecauseOf($"expected message type {messageTypeParam} to be registered") + .FailWith("Expected {context} to contain message type registration {0}{reason}, but it was not found. Actual syntax: {1}", + expectedRegistration, this._syntaxString); + + return new(this); + } + + public AndConstraint RegisterSentMessageType(bool globalQualified = true) + { + Type messageType = typeof(TMessage); + string messageTypeParam = globalQualified ? $"global::{messageType.FullName}" : messageType.Name; + return this.RegisterSentMessageType(messageTypeParam); + } + + public AndConstraint NotRegisterSentMessageTypes() + { + this.CurrentAssertionChain + .ForCondition(!this._syntaxString.Contains(".SendsMessage<")) + .BecauseOf("expected no message types to be registered") + .FailWith("Expected {context} to have no message type registrations{reason}, but found at least one. Actual syntax: {1}", + this._syntaxString); + + return new(this); + } + + public AndConstraint RegisterYieldedOutputType(string outputTypeParam) + { + string expectedRegistration = $".YieldsOutput<{outputTypeParam}>()"; + + this.CurrentAssertionChain + .ForCondition(this._syntaxString.Contains(expectedRegistration)) + .BecauseOf($"expected output type {outputTypeParam} to be registered") + .FailWith("Expected {context} to contain output type registration {0}{reason}, but it was not found. Actual syntax: {1}", + expectedRegistration, this._syntaxString); + + return new(this); + } + + public AndConstraint RegisterYieldedOutputType(bool globalQualified = true) + { + Type outputType = typeof(TOutput); + string outputTypeParam = globalQualified ? $"global::{outputType.FullName}" : outputType.Name; + return this.RegisterYieldedOutputType(outputTypeParam); + } + + public AndConstraint NotRegisterYieldedOutputTypes() + { + this.CurrentAssertionChain + .ForCondition(!this._syntaxString.Contains(".YieldsOutput<")) + .BecauseOf("expected no output types to be registered") + .FailWith("Expected {context} to have no output type registrations{reason}, but found at least one. Actual syntax: {1}", + this._syntaxString); + + return new(this); + } + + private AndConstraint ContainPartialDeclaration(int level, int index, string className) + { + this.CurrentAssertionChain + .ForCondition(index > 0) + .BecauseOf($"expected \"partial class {className}\" at nesting level {level}") + .FailWith("Expected {context} to contain \"partial class {0}\" at nesting level {1}{reason}, but it was not found. Actual syntax: {2}", + className, level, this._syntaxString); + + return new(this); + } + + private AndConstraint DeclarePartialsInCorrectOrder(int prevIndex, int currIndex, string prevClass, string currClass) + { + this.CurrentAssertionChain + .ForCondition(prevIndex < currIndex) + .BecauseOf($"expected \"partial class {prevClass}\" before \"partial class {currClass}\"") + .FailWith("Expected {context} to have \"partial class {0}\" before \"partial class {1}\"{reason}, but the order was incorrect. Actual syntax: {2}", + prevClass, currClass, this._syntaxString); + + return new(this); + } + + public AndConstraint HaveHierarchy(params string[] expectedNesting) + { + if (expectedNesting.Length == 0) + { + return new AndConstraint(this); + } + + int[] indicies = new int[expectedNesting.Length]; + + for (int i = 0; i < expectedNesting.Length; i++) + { + indicies[i] = this._syntaxString.IndexOf($"partial class {expectedNesting[i]}", StringComparison.Ordinal); + } + + // Verify partial declarations are present + AndConstraint runningResult = this.ContainPartialDeclaration(0, indicies[0], expectedNesting[0]); + for (int i = 1; i < expectedNesting.Length; i++) + { + runningResult = runningResult.And.ContainPartialDeclaration(i, indicies[i], expectedNesting[i]) + .And.DeclarePartialsInCorrectOrder(indicies[i - 1], indicies[i], expectedNesting[i - 1], expectedNesting[i]); + } + + return runningResult; + } + + public AndConstraint HaveNamespace() + { + this.CurrentAssertionChain + .ForCondition(this._syntaxString.Contains("namespace ")) + .BecauseOf("expected namespace declaration") + .FailWith("Expected {context} to contain a namespace declaration{reason}, but it was found. Actual syntax: {0}", + this._syntaxString); + + return new(this); + } + + public AndConstraint NotHaveNamespace() + { + this.CurrentAssertionChain + .ForCondition(!this._syntaxString.Contains("namespace ")) + .BecauseOf("expected no namespace declaration") + .FailWith("Expected {context} to not contain a namespace declaration{reason}, but it was found. Actual syntax: {0}", + this._syntaxString); + + return new(this); + } +} + +internal static class SyntaxTreeFluentExtensions +{ + public static SyntaxTreeAssertions Should(this SyntaxTree syntaxTree) => new(syntaxTree, AssertionChain.GetOrCreate()); +} diff --git a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/AgentWorkflowBuilderTests.cs b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/AgentWorkflowBuilderTests.cs index eb017f07a3..6a5043f055 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/AgentWorkflowBuilderTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/AgentWorkflowBuilderTests.cs @@ -405,6 +405,10 @@ public async Task BuildGroupChat_AgentsRunInOrderAsync(int maxIterations) output = e; break; } + else if (evt is WorkflowErrorEvent errorEvent) + { + Assert.Fail($"Workflow execution failed with error: {errorEvent.Exception}"); + } } return (sb.ToString(), output?.As>()); diff --git a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/DynamicPortsExecutor.cs b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/DynamicPortsExecutor.cs index cbeb13c86b..2ddc0d1eea 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/DynamicPortsExecutor.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/DynamicPortsExecutor.cs @@ -12,22 +12,25 @@ internal sealed class DynamicPortsExecutor(string id, param public ConcurrentDictionary> ReceivedResponses { get; } = new(); - protected override RouteBuilder ConfigureRoutes(RouteBuilder routeBuilder) + protected override ProtocolBuilder ConfigureProtocol(ProtocolBuilder protocolBuilder) { - foreach (string portId in ports) + return protocolBuilder.ConfigureRoutes(ConfigureRoutes); + + void ConfigureRoutes(RouteBuilder routeBuilder) { - routeBuilder = routeBuilder - .AddPortHandler(portId, - (response, context, cancellationToken) => - { - this.ReceivedResponses.GetOrAdd(portId, _ => new()).Enqueue(response); - return default; - }, out PortBinding? binding); - - this.PortBindings[portId] = binding; + foreach (string portId in ports) + { + routeBuilder = routeBuilder + .AddPortHandler(portId, + (response, context, cancellationToken) => + { + this.ReceivedResponses.GetOrAdd(portId, _ => new()).Enqueue(response); + return default; + }, out PortBinding? binding); + + this.PortBindings[portId] = binding; + } } - - return routeBuilder; } public ValueTask PostRequestAsync(string portId, TRequest request, TestRunContext testContext, string? requestId = null) diff --git a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/DynamicRequestPortTests.cs b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/DynamicRequestPortTests.cs index 568bab8120..faccbef803 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/DynamicRequestPortTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/DynamicRequestPortTests.cs @@ -21,7 +21,7 @@ private sealed class RequestPortTestContext public RequestPortTestContext() { this.Executor = new(ExecutorId, PortId); - this.Executor.Configure(this.ExternalRequestContext); + this.Executor.AttachRequestContext(this.ExternalRequestContext); } public TestRunContext RunContext { get; } = new(); diff --git a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/EdgeRunnerTests.cs b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/EdgeRunnerTests.cs index ef065db4d5..70210fac41 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/EdgeRunnerTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/EdgeRunnerTests.cs @@ -2,6 +2,7 @@ using System; using System.Collections.Generic; +using System.Threading; using System.Threading.Tasks; using FluentAssertions; using Microsoft.Agents.AI.Workflows.Execution; @@ -39,7 +40,7 @@ private static async Task CreateAndRunDirectedEdgeTestAsync(bool? conditionMatch MessageEnvelope envelope = new(MessageVariant1, "executor1", targetId: targetId); - DeliveryMapping? mapping = await runner.ChaseEdgeAsync(envelope, stepTracer: null); + DeliveryMapping? mapping = await runner.ChaseEdgeAsync(envelope, stepTracer: null, CancellationToken.None); bool expectMessage = (!conditionMatch.HasValue || conditionMatch.Value) && (!targetMatch.HasValue || targetMatch.Value); @@ -101,7 +102,7 @@ private static async Task CreateAndRunFanOutEdgeTestAsync(bool? assignerSelectsE MessageEnvelope envelope = new("test", "executor1", targetId: targetId); - DeliveryMapping? mapping = await runner.ChaseEdgeAsync(envelope, stepTracer: null); + DeliveryMapping? mapping = await runner.ChaseEdgeAsync(envelope, stepTracer: null, CancellationToken.None); bool expectForwardFrom2 = (!assignerSelectsEmpty.HasValue || !assignerSelectsEmpty.Value) && (!targetMatch.HasValue || targetMatch.Value); @@ -178,22 +179,22 @@ async ValueTask RunIterationAsync() { //await runner.ChaseAsync("executor1", new("part1"), state, tracer: null); //MessageDeliveryValidation.CheckForwarded(runContext.QueuedMessages); - DeliveryMapping? mapping = await runner.ChaseEdgeAsync(new("part1", "executor1"), stepTracer: null); + DeliveryMapping? mapping = await runner.ChaseEdgeAsync(new("part1", "executor1"), stepTracer: null, CancellationToken.None); mapping.Should().BeNull(); //await runner.ChaseAsync("executor2", new("part-for-1", targetId: "executor1"), state, tracer: null); //MessageDeliveryValidation.CheckForwarded(runContext.QueuedMessages); - mapping = await runner.ChaseEdgeAsync(new("part-for-1", "executor2", targetId: "executor1"), stepTracer: null); + mapping = await runner.ChaseEdgeAsync(new("part-for-1", "executor2", targetId: "executor1"), stepTracer: null, CancellationToken.None); mapping.Should().BeNull(); //await runner.ChaseAsync("executor1", new("part2", targetId: "executor3"), state, tracer: null); //MessageDeliveryValidation.CheckForwarded(runContext.QueuedMessages); - mapping = await runner.ChaseEdgeAsync(new("part2", "executor1", targetId: "executor3"), stepTracer: null); + mapping = await runner.ChaseEdgeAsync(new("part2", "executor1", targetId: "executor3"), stepTracer: null, CancellationToken.None); mapping.Should().BeNull(); //await runner.ChaseAsync("executor2", new("final part"), state, tracer: null); //MessageDeliveryValidation.CheckForwarded(runContext.QueuedMessages, ("executor3", ["part1", "part2", "final part"])); - mapping = await runner.ChaseEdgeAsync(new("final part", "executor2"), stepTracer: null); + mapping = await runner.ChaseEdgeAsync(new("final part", "executor2"), stepTracer: null, CancellationToken.None); mapping.Should().NotBeNull(); mapping.CheckDeliveries(["executor3"], ["part1", "part2", "final part"]); } diff --git a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/ForwardMessageExecutor.cs b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/ForwardMessageExecutor.cs index 85f7a4491e..d78f12a67a 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/ForwardMessageExecutor.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/ForwardMessageExecutor.cs @@ -4,6 +4,10 @@ namespace Microsoft.Agents.AI.Workflows.UnitTests; internal sealed class ForwardMessageExecutor(string id) : Executor(id) where TMessage : notnull { - protected override RouteBuilder ConfigureRoutes(RouteBuilder routeBuilder) => - routeBuilder.AddHandler((message, ctx) => ctx.SendMessageAsync(message)); + protected override ProtocolBuilder ConfigureProtocol(ProtocolBuilder protocolBuilder) + { + protocolBuilder.RouteBuilder.AddHandler((message, ctx) => ctx.SendMessageAsync(message)); + + return protocolBuilder.SendsMessage(); + } } diff --git a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/InProcessStateTests.cs b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/InProcessStateTests.cs index 0ecd6bfac1..8037cbabbb 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/InProcessStateTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/InProcessStateTests.cs @@ -7,7 +7,7 @@ namespace Microsoft.Agents.AI.Workflows.UnitTests; -public class InProcessStateTests +public partial class InProcessStateTests { private sealed class TurnToken { diff --git a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/Microsoft.Agents.AI.Workflows.UnitTests.csproj b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/Microsoft.Agents.AI.Workflows.UnitTests.csproj index 60dac38ecd..58979a4f1b 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/Microsoft.Agents.AI.Workflows.UnitTests.csproj +++ b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/Microsoft.Agents.AI.Workflows.UnitTests.csproj @@ -7,6 +7,10 @@ + diff --git a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/RepresentationTests.cs b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/RepresentationTests.cs index 391b6a3371..26983c930e 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/RepresentationTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/RepresentationTests.cs @@ -19,7 +19,7 @@ public class RepresentationTests { private sealed class TestExecutor() : Executor("TestExecutor") { - protected override RouteBuilder ConfigureRoutes(RouteBuilder routeBuilder) => routeBuilder; + protected override ProtocolBuilder ConfigureProtocol(ProtocolBuilder protocolBuilder) => protocolBuilder; } private sealed class TestAgent : AIAgent diff --git a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/Sample/01_Simple_Workflow_Sequential.cs b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/Sample/01_Simple_Workflow_Sequential.cs index 5af52874f6..9935b224bf 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/Sample/01_Simple_Workflow_Sequential.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/Sample/01_Simple_Workflow_Sequential.cs @@ -7,7 +7,6 @@ using System.Linq; using System.Threading; using System.Threading.Tasks; -using Microsoft.Agents.AI.Workflows.Reflection; namespace Microsoft.Agents.AI.Workflows.Sample; @@ -29,6 +28,7 @@ public static Workflow WorkflowInstance public static async ValueTask RunAsync(TextWriter writer, IWorkflowExecutionEnvironment environment) { + // TODO: Potentially normalize terminology viz Agent.RunStreamingAsync StreamingRun run = await environment.StreamAsync(WorkflowInstance, input: "Hello, World!").ConfigureAwait(false); await foreach (WorkflowEvent evt in run.WatchStreamAsync().ConfigureAwait(false)) @@ -41,14 +41,26 @@ public static async ValueTask RunAsync(TextWriter writer, IWorkflowExecutionEnvi } } -internal sealed class UppercaseExecutor() : ReflectingExecutor("UppercaseExecutor", declareCrossRunShareable: true), IMessageHandler +internal sealed class UppercaseExecutor() : Executor(nameof(UppercaseExecutor), declareCrossRunShareable: true) { - public async ValueTask HandleAsync(string message, IWorkflowContext context, CancellationToken cancellationToken = default) => + public override async ValueTask HandleAsync(string message, IWorkflowContext context, CancellationToken cancellationToken = default) => message.ToUpperInvariant(); } -internal sealed class ReverseTextExecutor() : ReflectingExecutor("ReverseTextExecutor", declareCrossRunShareable: true), IMessageHandler +//internal partial sealed class UppercaseExecutorEx() : Executor(nameof(UppercaseExecutorEx), declareCrossRunShareable: true) +//{ +// [MessageHandler(Send = [typeof(string)])] +// public async ValueTask MyHandlerMethod(string message, IWorkflowContext context, CancellationToken cancellationToken = default) => +// message.ToUpperInvariant(); +//} + +internal sealed class ReverseTextExecutor() : Executor("ReverseTextExecutor", declareCrossRunShareable: true) { + protected override ProtocolBuilder ConfigureProtocol(ProtocolBuilder protocolBuilder) + { + return protocolBuilder.ConfigureRoutes(routeBuilder => routeBuilder.AddHandler(this.HandleAsync)); + } + public async ValueTask HandleAsync(string message, IWorkflowContext context, CancellationToken cancellationToken = default) { string result = string.Concat(message.Reverse()); diff --git a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/Sample/01a_Simple_Workflow_Sequential.cs b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/Sample/01a_Simple_Workflow_Sequential.cs index ffa798126e..1e3677e9f2 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/Sample/01a_Simple_Workflow_Sequential.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/Sample/01a_Simple_Workflow_Sequential.cs @@ -8,6 +8,7 @@ namespace Microsoft.Agents.AI.Workflows.Sample; internal static class Step1aEntryPoint { + // TODO: Maybe env.CreateRunAsync? public static async ValueTask RunAsync(TextWriter writer, IWorkflowExecutionEnvironment environment) { Run run = await environment.RunAsync(WorkflowInstance, "Hello, World!").ConfigureAwait(false); @@ -23,3 +24,75 @@ public static async ValueTask RunAsync(TextWriter writer, IWorkflowExecutionEnvi } } } + +/* +internal class Example +{ + const string GreeterAgent = nameof(GreeterAgent); + const string TaskAgent1 = nameof(TaskAgent1); + const string ErrorAgent = nameof(ErrorAgent); + public WorkflowBuilder CreateTemplate() + { + return new WorkflowBuilder(GreeterAgent) + .AddEdge(GreeterAgent, TaskAgent1) + .AddEdge(TaskAgent1, ErrorAgent, condition: (object? message) => message is Exception) + .WithOutputFrom(TaskAgent1); + } + public void BuildWorkflow() + { + WorkflowBuilder template = CreateTemplate(); + + AIAgent myGreeter = new AIAgentBuilder().Build(); // with id = "GreeterAgent" + AIAgent myTaskAgent = new AIAgentBuilder().Build(); // with id = "TaskAgent1" + AIAgent myErrorAgent = new AIAgentBuilder().Build(); // with id = "ErrorAgent" + + template.BindExecutor(myGreeter).BindExecutor(myTaskAgent).BindExecutor(myErrorAgent); + + WorkflowBuilder directBuilder = new WorkflowBuilder(myGreeter) + .AddEdge(myGreeter, myTaskAgent) + .AddEdge(myTaskAgent, myErrorAgent, condition: (object? message) => message is Exception) + .WithOutputFrom(myTaskAgent); + + // TODO: Add Id remapping to BindExecutor() + //ExecutorBinding myRenamedGreeter = myGreeter.BindAsExecutor(IDataDiscoverer: "NewGreeterAgent") + + string executorPlaceholder = "PLACEHODLER"; + WorkflowBuilder builder = new(executorPlaceholder); + + builder.AddEdge(executorPlaceholder, executorPlaceholder); // Direct Edge, Unconditional + builder.AddEdge(executorPlaceholder, executorPlaceholder, + condition: (string? myString) => myString?.Contains("Hello") is true); + + builder.AddFanOutEdge(executorPlaceholder, [executorPlaceholder]); // FanOut Edge, Simple/Unconditional + // ~equivalent to foreach (executor in [executor])... builder.AddEdge(...) + + static IEnumerable targetSelector(object? message, int potentialTargetCount) => + Enumerable.Range(0, potentialTargetCount); + + builder.AddFanOutEdge(executorPlaceholder, [executorPlaceholder], (Func>)targetSelector); + + builder.AddSwitch(executorPlaceholder, + (SwitchBuilder sb) => + sb.AddCase(predicate: (string? s) => s?.Contains("Hello") is true, executorPlaceholder) + .WithDefault(executorPlaceholder)); + // FanIn + builder.AddFanInEdge([executorPlaceholder], executorPlaceholder); // builder.WaitAll([executor], then: executor) + // TODO: FanIn + /* + builder.AddFanInEdge([executor], executor, FanInStrategy); + FanInStrategy: => views incoming messages and decides whether or not to send anything + => potentially aggregates messages / filters them + //* / + + AIAgent myAgent = new AIAgentBuilder().Build(); + var myAgentExecutor = myAgent.BindAsExecutor(); + + Func> myFactory; + ExecutorBinding myExecutor = myFactory.BindAsExecutor("PLACEHOLDER"); // TODO: AsExecutorBinding() + + //builder.AddEdge(myAgent, myAgent); + builder.BindExecutor(myExecutor); // TODO: Lots of "Bind" - better name? + // Needed at all? + } +} +// */ diff --git a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/Sample/02_Simple_Workflow_Condition.cs b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/Sample/02_Simple_Workflow_Condition.cs index d44b0babcd..79135e4906 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/Sample/02_Simple_Workflow_Condition.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/Sample/02_Simple_Workflow_Condition.cs @@ -46,6 +46,9 @@ public static async ValueTask RunAsync(TextWriter writer, IWorkflowExecu case ExecutorCompletedEvent executorCompletedEvt: writer.WriteLine($"'{executorCompletedEvt.ExecutorId}: {executorCompletedEvt.Data}"); break; + case WorkflowErrorEvent errorEvent: + Assert.Fail($"Workflow failed with error: {errorEvent.Exception}"); + break; } } @@ -60,10 +63,11 @@ public async ValueTask HandleAsync(string message, IWorkflowContext contex spamKeywords.Any(keyword => message.IndexOf(keyword, StringComparison.OrdinalIgnoreCase) >= 0); } -internal sealed class RespondToMessageExecutor(string id) : ReflectingExecutor(id, declareCrossRunShareable: true), IMessageHandler +internal sealed partial class RespondToMessageExecutor(string id) : Executor(id, declareCrossRunShareable: true), IMessageHandler { public const string ActionResult = "Message processed successfully."; + [MessageHandler(Yield = [typeof(string)])] public async ValueTask HandleAsync(bool message, IWorkflowContext context, CancellationToken cancellationToken = default) { if (message) @@ -79,10 +83,11 @@ await context.YieldOutputAsync(ActionResult, cancellationToken) } } -internal sealed class RemoveSpamExecutor(string id) : ReflectingExecutor(id, declareCrossRunShareable: true), IMessageHandler +internal sealed partial class RemoveSpamExecutor(string id) : Executor(id, declareCrossRunShareable: true), IMessageHandler { public const string ActionResult = "Spam message removed."; + [MessageHandler(Yield = [typeof(string)])] public async ValueTask HandleAsync(bool message, IWorkflowContext context, CancellationToken cancellationToken = default) { if (!message) diff --git a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/Sample/03_Simple_Workflow_Loop.cs b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/Sample/03_Simple_Workflow_Loop.cs index 61e063df32..7dadad1847 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/Sample/03_Simple_Workflow_Loop.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/Sample/03_Simple_Workflow_Loop.cs @@ -6,7 +6,6 @@ using System.IO; using System.Threading; using System.Threading.Tasks; -using Microsoft.Agents.AI.Workflows.Reflection; namespace Microsoft.Agents.AI.Workflows.Sample; @@ -68,7 +67,8 @@ internal enum NumberSignal Matched } -internal sealed class GuessNumberExecutor : ReflectingExecutor, IMessageHandler +[YieldsOutput(typeof(string))] +internal sealed partial class GuessNumberExecutor : Executor { private readonly int _initialLowerBound; private readonly int _initialUpperBound; @@ -84,6 +84,7 @@ internal sealed class GuessNumberExecutor : ReflectingExecutor HandleAsync(NumberSignal message, IWorkflowContext context, CancellationToken cancellationToken = default) { NumberBounds bounds = await context.ReadStateAsync(nameof(NumberBounds), cancellationToken: cancellationToken) @@ -111,7 +112,8 @@ await context.YieldOutputAsync($"Guessed the number: {bounds.CurrGuess}", cancel } } -internal sealed class JudgeExecutor : ReflectingExecutor, IMessageHandler +[YieldsOutput(typeof(TryCount))] +internal sealed partial class JudgeExecutor : Executor { private readonly int _targetNumber; @@ -120,6 +122,7 @@ public JudgeExecutor(string id, int targetNumber) : base(id, declareCrossRunShar this._targetNumber = targetNumber; } + [MessageHandler] public async ValueTask HandleAsync(int message, IWorkflowContext context, CancellationToken cancellationToken = default) { // This works properly because the default when unset is 0, and we increment before use. diff --git a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/Sample/08_Subworkflow_Simple.cs b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/Sample/08_Subworkflow_Simple.cs index 98f46cf551..1e5b4f7a5d 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/Sample/08_Subworkflow_Simple.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/Sample/08_Subworkflow_Simple.cs @@ -4,6 +4,7 @@ using System.Collections.Generic; using System.IO; using System.Linq; +using System.Text; using System.Threading; using System.Threading.Tasks; using FluentAssertions; @@ -15,7 +16,7 @@ internal sealed record class TextProcessingResult(string TaskId, string Text, in //internal sealed class AllTasksCompletedEvent(IEnumerable results) : WorkflowEvent(results); -internal static class Step8EntryPoint +internal static partial class Step8EntryPoint { public static List TextsToProcess => [ "Hello world! This is a simple test.", @@ -29,6 +30,7 @@ internal static class Step8EntryPoint public static async ValueTask> RunAsync(TextWriter writer, IWorkflowExecutionEnvironment environment, List textsToProcess) { Func processTextAsyncFunc = ProcessTextAsync; + ExecutorBinding processText = processTextAsyncFunc.BindAsExecutor("TextProcessor", threadsafe: true); Workflow subWorkflow = new WorkflowBuilder(processText).WithOutputFrom(processText).Build(); @@ -46,6 +48,22 @@ public static async ValueTask> RunAsync(TextWriter wr Run workflowRun = await environment.RunAsync(workflow, textsToProcess); RunStatus status = await workflowRun.GetStatusAsync(); + List errors = workflowRun.OutgoingEvents.OfType() + .Select(errorEvent => errorEvent.Exception) + .Where(e => e is not null).ToList(); + if (errors.Count > 0) + { + StringBuilder errorBuilder = new(); + errorBuilder.AppendLine($"Workflow execution failed. ({errors.Count} errors.):"); + + foreach (Exception? error in errors) + { + errorBuilder.Append('\t').AppendLine(error!.ToString()); + } + + Assert.Fail(errorBuilder.ToString()); + } + status.Should().Be(RunStatus.Idle); WorkflowOutputEvent? maybeOutput = workflowRun.OutgoingEvents.OfType() @@ -62,6 +80,7 @@ public static async ValueTask> RunAsync(TextWriter wr return results; } + [YieldsOutput(typeof(TextProcessingResult))] private static ValueTask ProcessTextAsync(TextProcessingRequest request, IWorkflowContext context, CancellationToken cancellationToken = default) { int wordCount = 0; @@ -76,7 +95,7 @@ private static ValueTask ProcessTextAsync(TextProcessingRequest request, IWorkfl return context.YieldOutputAsync(new TextProcessingResult(request.TaskId, request.Text, wordCount, charCount), cancellationToken); } - private sealed class TextProcessingOrchestrator(string id) + private sealed partial class TextProcessingOrchestrator(string id) : StatefulExecutor(id, () => new(), declareCrossRunShareable: false) { internal sealed class State @@ -90,13 +109,8 @@ internal sealed class State public bool CompletePending(string taskId) => this.PendingTaskIds.Remove(taskId); } - protected override RouteBuilder ConfigureRoutes(RouteBuilder routeBuilder) - { - return routeBuilder.AddHandler>(this.StartProcessingAsync) - .AddHandler(this.CollectResultAsync); - } - - private async ValueTask StartProcessingAsync(List texts, IWorkflowContext context, CancellationToken cancellationToken) + [MessageHandler(Send = [typeof(TextProcessingRequest)])] + public async ValueTask StartProcessingAsync(List texts, IWorkflowContext context, CancellationToken cancellationToken) { await this.InvokeWithStateAsync(QueueProcessingTasksAsync, context, cancellationToken: cancellationToken); @@ -112,7 +126,8 @@ private async ValueTask StartProcessingAsync(List texts, IWorkflowContex } } - private async ValueTask CollectResultAsync(TextProcessingResult result, IWorkflowContext context, CancellationToken cancellationToken = default) + [MessageHandler(Yield = [typeof(List)])] + public async ValueTask CollectResultAsync(TextProcessingResult result, IWorkflowContext context, CancellationToken cancellationToken = default) { await this.InvokeWithStateAsync(CollectResultAndCheckCompletionAsync, context, cancellationToken: cancellationToken); diff --git a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/Sample/09_Subworkflow_ExternalRequest.cs b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/Sample/09_Subworkflow_ExternalRequest.cs index 56c7f0a157..9e16b8a45d 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/Sample/09_Subworkflow_ExternalRequest.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/Sample/09_Subworkflow_ExternalRequest.cs @@ -4,6 +4,7 @@ using System.Collections.Generic; using System.IO; using System.Linq; +using System.Text; using System.Threading; using System.Threading.Tasks; using FluentAssertions; @@ -256,6 +257,22 @@ public static async ValueTask> RunAsync(TextWriter writer, await workflowRun.ResumeAsync(responses: responses).ConfigureAwait(false); runStatus = await workflowRun.GetStatusAsync(); + List errors = workflowRun.OutgoingEvents.OfType() + .Select(errorEvent => errorEvent.Exception) + .Where(e => e is not null).ToList(); + if (errors.Count > 0) + { + StringBuilder errorBuilder = new(); + errorBuilder.AppendLine($"Workflow execution failed. ({errors.Count} errors.):"); + + foreach (Exception? error in errors) + { + errorBuilder.Append('\t').AppendLine(error!.ToString()); + } + + Assert.Fail(errorBuilder.ToString()); + } + runStatus.Should().Be(RunStatus.Idle); results = finishedRequests; @@ -277,18 +294,26 @@ public static async ValueTask> RunAsync(TextWriter writer, internal sealed class ResourceRequestor() : Executor(nameof(ResourceRequestor), declareCrossRunShareable: true) { - protected override RouteBuilder ConfigureRoutes(RouteBuilder routeBuilder) + protected override ProtocolBuilder ConfigureProtocol(ProtocolBuilder protocolBuilder) { - return routeBuilder.AddHandler>(this.RequestResourcesAsync) - .AddHandler(InvokeResourceRequestAsync) - .AddHandler(this.HandleResponseAsync) - .AddHandler(this.HandleResponseAsync); - - // For some reason, using a lambda here causes the analyzer to generate a spurious - // VSTHRD110: "Observe the awaitable result of this method call by awaiting it, assigning - // to a variable, or passing it to another method" - ValueTask InvokeResourceRequestAsync(UserRequest request, IWorkflowContext context) - => this.RequestResourcesAsync([request], context); + return protocolBuilder.ConfigureRoutes(ConfigureRoutes) + .SendsMessage() + .SendsMessage() + .YieldsOutput(); + + void ConfigureRoutes(RouteBuilder routeBuilder) + { + routeBuilder.AddHandler>(this.RequestResourcesAsync) + .AddHandler(InvokeResourceRequestAsync) + .AddHandler(this.HandleResponseAsync) + .AddHandler(this.HandleResponseAsync); + + // For some reason, using a lambda here causes the analyzer to generate a spurious + // VSTHRD110: "Observe the awaitable result of this method call by awaiting it, assigning + // to a variable, or passing it to another method" + ValueTask InvokeResourceRequestAsync(UserRequest request, IWorkflowContext context) + => this.RequestResourcesAsync([request], context); + } } private async ValueTask RequestResourcesAsync(List requests, IWorkflowContext context) @@ -332,12 +357,17 @@ private static Dictionary InitializeResourceCache() ["disk"] = 100, }; - protected override RouteBuilder ConfigureRoutes(RouteBuilder routeBuilder) + protected override ProtocolBuilder ConfigureProtocol(ProtocolBuilder protocolBuilder) { - // Note the disbalance here - we could also handle ExternalResponse here instead, but we would have - // to do the exact same type check on it, so we might as well handle - return routeBuilder.AddHandler(this.UnwrapAndHandleRequestAsync) - .AddHandler(this.CollectResultAsync); + return protocolBuilder.ConfigureRoutes(ConfigureRoutes); + + void ConfigureRoutes(RouteBuilder routeBuilder) + { + // Note the disbalance here - we could also handle ExternalResponse here instead, but we would have + // to do the exact same type check on it, so we might as well handle + routeBuilder.AddHandler(this.UnwrapAndHandleRequestAsync) + .AddHandler(this.CollectResultAsync); + } } private async ValueTask UnwrapAndHandleRequestAsync(ExternalRequest request, IWorkflowContext context, CancellationToken cancellationToken = default) @@ -414,10 +444,17 @@ private static Dictionary InitializePolicyQuotas() ["disk"] = 1000, }; - protected override RouteBuilder ConfigureRoutes(RouteBuilder routeBuilder) + protected override ProtocolBuilder ConfigureProtocol(ProtocolBuilder protocolBuilder) { - return routeBuilder.AddHandler(this.UnwrapAndHandleRequestAsync) - .AddHandler(this.CollectAndForwardAsync); + return protocolBuilder.ConfigureRoutes(ConfigureRoutes); + + void ConfigureRoutes(RouteBuilder routeBuilder) + { + // Note the disbalance here - we could also handle ExternalResponse here instead, but we would have + // to do the exact same type check on it, so we might as well handle + routeBuilder.AddHandler(this.UnwrapAndHandleRequestAsync) + .AddHandler(this.CollectAndForwardAsync); + } } private async ValueTask UnwrapAndHandleRequestAsync(ExternalRequest request, IWorkflowContext context) @@ -483,17 +520,24 @@ internal sealed class Coordinator() : Executor(nameof(Coordinator), declareCross { private const string StateKey = nameof(StateKey); - protected override RouteBuilder ConfigureRoutes(RouteBuilder routeBuilder) + protected override ProtocolBuilder ConfigureProtocol(ProtocolBuilder protocolBuilder) { - return routeBuilder.AddHandler>(this.StartAsync) - .AddHandler(InvokeStartAsync) - .AddHandler(this.HandleFinishedRequestAsync); - - // For some reason, using a lambda here causes the analyzer to generate a spurious - // VSTHRD110: "Observe the awaitable result of this method call by awaiting it, assigning - // to a variable, or passing it to another method" - ValueTask InvokeStartAsync(UserRequest request, IWorkflowContext context, CancellationToken cancellationToken) - => this.StartAsync([request], context, cancellationToken); + return protocolBuilder.ConfigureRoutes(ConfigureRoutes) + .SendsMessage() + .YieldsOutput(); + + void ConfigureRoutes(RouteBuilder routeBuilder) + { + routeBuilder.AddHandler>(this.StartAsync) + .AddHandler(InvokeStartAsync) + .AddHandler(this.HandleFinishedRequestAsync); + + // For some reason, using a lambda here causes the analyzer to generate a spurious + // VSTHRD110: "Observe the awaitable result of this method call by awaiting it, assigning + // to a variable, or passing it to another method" + ValueTask InvokeStartAsync(UserRequest request, IWorkflowContext context, CancellationToken cancellationToken) + => this.StartAsync([request], context, cancellationToken); + } } private ValueTask HandleFinishedRequestAsync(RequestFinished finished, IWorkflowContext context, CancellationToken cancellationToken) diff --git a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/Sample/14_Subworkflow_SharedState.cs b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/Sample/14_Subworkflow_SharedState.cs index c4219d58a3..77ad3c63c7 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/Sample/14_Subworkflow_SharedState.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/Sample/14_Subworkflow_SharedState.cs @@ -11,7 +11,7 @@ namespace Microsoft.Agents.AI.Workflows.Sample; /// Tests for shared state preservation across subworkflow boundaries. /// Validates fix for issue #2419: ".NET: Shared State is not preserved in Subworkflows" /// -internal static class Step14EntryPoint +internal static partial class Step14EntryPoint { public const string WordStateScope = "WordStateScope"; @@ -106,12 +106,10 @@ public static async ValueTask RunSubworkflowInternalStateAsync(string text, /// /// Executor that reads text and stores it in shared state with a generated key. /// - internal sealed class TextReadExecutor() : Executor("TextReadExecutor") + internal sealed partial class TextReadExecutor() : Executor("TextReadExecutor") { - protected override RouteBuilder ConfigureRoutes(RouteBuilder routeBuilder) - => routeBuilder.AddHandler(this.HandleAsync); - - private async ValueTask HandleAsync(string text, IWorkflowContext context, CancellationToken cancellationToken = default) + [MessageHandler] + public async ValueTask HandleAsync(string text, IWorkflowContext context, CancellationToken cancellationToken = default) { string key = Guid.NewGuid().ToString(); await context.QueueStateUpdateAsync(key, text, scopeName: WordStateScope, cancellationToken); @@ -122,12 +120,10 @@ private async ValueTask HandleAsync(string text, IWorkflowContext contex /// /// Executor that reads text from shared state, trims it, and updates the state. /// - internal sealed class TextTrimExecutor() : Executor("TextTrimExecutor") + internal sealed partial class TextTrimExecutor() : Executor("TextTrimExecutor") { - protected override RouteBuilder ConfigureRoutes(RouteBuilder routeBuilder) - => routeBuilder.AddHandler(this.HandleAsync); - - private async ValueTask HandleAsync(string key, IWorkflowContext context, CancellationToken cancellationToken = default) + [MessageHandler] + public async ValueTask HandleAsync(string key, IWorkflowContext context, CancellationToken cancellationToken = default) { string? content = await context.ReadStateAsync(key, scopeName: WordStateScope, cancellationToken); if (content is null) @@ -144,12 +140,10 @@ private async ValueTask HandleAsync(string key, IWorkflowContext context /// /// Executor that reads text from shared state and returns its character count. /// - internal sealed class CharCountingExecutor() : Executor("CharCountingExecutor") + internal sealed partial class CharCountingExecutor() : Executor("CharCountingExecutor") { - protected override RouteBuilder ConfigureRoutes(RouteBuilder routeBuilder) - => routeBuilder.AddHandler(this.HandleAsync); - - private async ValueTask HandleAsync(string key, IWorkflowContext context, CancellationToken cancellationToken = default) + [MessageHandler] + public async ValueTask HandleAsync(string key, IWorkflowContext context, CancellationToken cancellationToken = default) { string? content = await context.ReadStateAsync(key, scopeName: WordStateScope, cancellationToken); return content?.Length ?? 0; diff --git a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/TestRunContext.cs b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/TestRunContext.cs index 9782e68f4f..7eded8ea70 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/TestRunContext.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/TestRunContext.cs @@ -27,7 +27,7 @@ public IExternalRequestSink RegisterPort(RequestPort port) internal TestRunContext ConfigureExecutor(Executor executor, EdgeMap? map = null) { - executor.Configure(new TestExternalRequestContext(this, executor.Id, map)); + executor.AttachRequestContext(new TestExternalRequestContext(this, executor.Id, map)); this.Executors.Add(executor.Id, executor); return this; } diff --git a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/TestingExecutor.cs b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/TestingExecutor.cs index 210f3aa89b..37807e9257 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/TestingExecutor.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/TestingExecutor.cs @@ -8,7 +8,7 @@ namespace Microsoft.Agents.AI.Workflows.UnitTests; -internal abstract class TestingExecutor : Executor, IDisposable +internal abstract partial class TestingExecutor : Executor, IDisposable { private readonly bool _loop; private readonly Func>[] _actions; @@ -39,11 +39,10 @@ public void LinkCancellation(CancellationToken cancellationToken) public void SetCancel() => Volatile.Read(ref this._internalCts).Cancel(); - protected sealed override RouteBuilder ConfigureRoutes(RouteBuilder routeBuilder) => - routeBuilder.AddHandler(this.RouteToActionsAsync); - private int _nextActionIndex; - private ValueTask RouteToActionsAsync(TIn message, IWorkflowContext context) + + [MessageHandler] + public ValueTask RouteToActionsAsync(TIn message, IWorkflowContext context) { if (this.AtEnd) { diff --git a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/WorkflowBuilderSmokeTests.cs b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/WorkflowBuilderSmokeTests.cs index 8bc5455b70..2b370de99e 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/WorkflowBuilderSmokeTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/WorkflowBuilderSmokeTests.cs @@ -9,16 +9,16 @@ public partial class WorkflowBuilderSmokeTests { private sealed class NoOpExecutor(string id) : Executor(id) { - protected override RouteBuilder ConfigureRoutes(RouteBuilder routeBuilder) => - routeBuilder.AddHandler( - (msg, ctx) => ctx.SendMessageAsync(msg)); + protected override ProtocolBuilder ConfigureProtocol(ProtocolBuilder protocolBuilder) + => protocolBuilder.ConfigureRoutes(routeBuilder => + routeBuilder.AddHandler((msg, ctx) => ctx.SendMessageAsync(msg))); } private sealed class SomeOtherNoOpExecutor(string id) : Executor(id) { - protected override RouteBuilder ConfigureRoutes(RouteBuilder routeBuilder) => - routeBuilder.AddHandler( - (msg, ctx) => ctx.SendMessageAsync(msg)); + protected override ProtocolBuilder ConfigureProtocol(ProtocolBuilder protocolBuilder) + => protocolBuilder.ConfigureRoutes(routeBuilder => + routeBuilder.AddHandler((msg, ctx) => ctx.SendMessageAsync(msg))); } [Fact] diff --git a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/WorkflowVisualizerTests.cs b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/WorkflowVisualizerTests.cs index 447c52a66e..15f0216230 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/WorkflowVisualizerTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/WorkflowVisualizerTests.cs @@ -9,14 +9,16 @@ public class WorkflowVisualizerTests { private sealed class MockExecutor(string id) : Executor(id) { - protected override RouteBuilder ConfigureRoutes(RouteBuilder routeBuilder) => - routeBuilder.AddHandler((msg, ctx) => ctx.SendMessageAsync(msg)); + protected override ProtocolBuilder ConfigureProtocol(ProtocolBuilder protocolBuilder) + => protocolBuilder.ConfigureRoutes(routeBuilder => + routeBuilder.AddHandler((msg, ctx) => ctx.SendMessageAsync(msg))); } private sealed class ListStrTargetExecutor(string id) : Executor(id) { - protected override RouteBuilder ConfigureRoutes(RouteBuilder routeBuilder) => - routeBuilder.AddHandler((msgs, ctx) => ctx.SendMessageAsync(string.Join(",", msgs))); + protected override ProtocolBuilder ConfigureProtocol(ProtocolBuilder protocolBuilder) + => protocolBuilder.ConfigureRoutes(routeBuilder => + routeBuilder.AddHandler((msgs, ctx) => ctx.SendMessageAsync(string.Join(",", msgs)))); } [Fact]