diff --git a/pkg/runtime/persistent_runtime.go b/pkg/runtime/persistent_runtime.go index 30df8d7b2..7ae06c4aa 100644 --- a/pkg/runtime/persistent_runtime.go +++ b/pkg/runtime/persistent_runtime.go @@ -23,6 +23,7 @@ type streamingState struct { reasoningContent strings.Builder agentName string messageID int64 // ID of the current streaming message (0 if none) + subSessionDepth int // >0 when inside a sub-session (task transfer); skip parent persistence } // New creates a new runtime for an agent and its team. @@ -72,7 +73,21 @@ func (r *PersistentRuntime) handleEvent(ctx context.Context, sess *session.Sessi } switch e := event.(type) { + case *AgentSwitchingEvent: + switch { + case e.Switching: + streaming.subSessionDepth++ + case streaming.subSessionDepth > 0: + streaming.subSessionDepth-- + default: + slog.Warn("Received AgentSwitching(false) without matching AgentSwitching(true)", + "session_id", sess.ID, "from_agent", e.FromAgent, "to_agent", e.ToAgent) + } + case *AgentChoiceEvent: + if streaming.subSessionDepth > 0 { + return + } // Accumulate streaming content streaming.content.WriteString(e.Content) streaming.agentName = e.AgentName @@ -80,6 +95,9 @@ func (r *PersistentRuntime) handleEvent(ctx context.Context, sess *session.Sessi r.persistStreamingContent(ctx, sess.ID, streaming) case *AgentChoiceReasoningEvent: + if streaming.subSessionDepth > 0 { + return + } // Accumulate streaming reasoning content streaming.reasoningContent.WriteString(e.Content) streaming.agentName = e.AgentName @@ -98,6 +116,9 @@ func (r *PersistentRuntime) handleEvent(ctx context.Context, sess *session.Sessi } case *MessageAddedEvent: + if streaming.subSessionDepth > 0 { + return + } // Finalize the streaming message with complete metadata if streaming.messageID != 0 { // Update the existing streaming message with final content diff --git a/pkg/runtime/persistent_runtime_test.go b/pkg/runtime/persistent_runtime_test.go new file mode 100644 index 000000000..5d13749ef --- /dev/null +++ b/pkg/runtime/persistent_runtime_test.go @@ -0,0 +1,142 @@ +package runtime + +import ( + "context" + "sync" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/docker/cagent/pkg/agent" + "github.com/docker/cagent/pkg/chat" + "github.com/docker/cagent/pkg/model/provider/base" + "github.com/docker/cagent/pkg/session" + "github.com/docker/cagent/pkg/team" + "github.com/docker/cagent/pkg/tools" + "github.com/docker/cagent/pkg/tools/builtin" +) + +// multiStreamProvider returns different streams on consecutive calls. +type multiStreamProvider struct { + id string + mu sync.Mutex + streams []chat.MessageStream + idx int +} + +func (m *multiStreamProvider) ID() string { return m.id } + +func (m *multiStreamProvider) CreateChatCompletionStream(_ context.Context, _ []chat.Message, _ []tools.Tool) (chat.MessageStream, error) { + m.mu.Lock() + defer m.mu.Unlock() + if m.idx >= len(m.streams) { + return m.streams[len(m.streams)-1], nil + } + s := m.streams[m.idx] + m.idx++ + return s, nil +} + +func (m *multiStreamProvider) BaseConfig() base.Config { return base.Config{} } + +func (m *multiStreamProvider) MaxTokens() int { return 0 } + +func TestPersistentRuntime_SubAgentMessagesNotPersistedToParent(t *testing.T) { + // Stream 1 (root): produces a transfer_task tool call to "worker" + rootStream := newStreamBuilder(). + AddToolCallName("call_transfer", "transfer_task"). + AddToolCallArguments("call_transfer", `{"agent":"worker","task":"do work","expected_output":"result"}`). + AddStopWithUsage(10, 5). + Build() + + // Stream 2 (worker sub-agent): produces streaming content simulating work + workerStream := newStreamBuilder(). + AddContent("I am doing "). + AddContent("the work now."). + AddStopWithUsage(5, 10). + Build() + + prov := &multiStreamProvider{ + id: "test/mock-model", + streams: []chat.MessageStream{rootStream, workerStream}, + } + + worker := agent.New("worker", "Worker agent", agent.WithModel(prov)) + root := agent.New("root", "Root coordinator", + agent.WithModel(prov), + agent.WithToolSets(builtin.NewTransferTaskTool()), + ) + agent.WithSubAgents(worker)(root) + + tm := team.New(team.WithAgents(root, worker)) + + store := session.NewInMemorySessionStore() + + rt, err := New(tm, + WithSessionCompaction(false), + WithModelStore(mockModelStore{}), + WithSessionStore(store), + ) + require.NoError(t, err) + + sess := session.New( + session.WithUserMessage("Please delegate work to the worker"), + session.WithToolsApproved(true), + ) + sess.Title = "Test Transfer Persistence" + + err = store.AddSession(t.Context(), sess) + require.NoError(t, err) + + evCh := rt.RunStream(t.Context(), sess) + for range evCh { + } + + parentSess, err := store.GetSession(t.Context(), sess.ID) + require.NoError(t, err) + + // Verify no sub-agent messages leaked into the parent session + for _, item := range parentSess.Messages { + if !item.IsMessage() { + continue + } + assert.NotEqual(t, "worker", item.Message.AgentName, + "Sub-agent 'worker' messages should not be in the parent session. "+ + "Found message with role=%s content=%q", + item.Message.Message.Role, item.Message.Message.Content) + } + + // Verify the sub-session was persisted and contains the worker's messages + var subSess *session.Session + for _, item := range parentSess.Messages { + if item.IsSubSession() { + subSess = item.SubSession + break + } + } + require.NotNil(t, subSess, + "Sub-session should be persisted in the parent session") + + var workerMsgCount int + for _, item := range subSess.Messages { + if item.IsMessage() && item.Message.AgentName == "worker" { + workerMsgCount++ + } + } + assert.Positive(t, workerMsgCount, + "Worker messages should be in the sub-session") + + // Verify the root agent's assistant message (with transfer_task tool call) + // and the tool result are both persisted in the parent + var roles []chat.MessageRole + for _, item := range parentSess.Messages { + if item.IsMessage() { + roles = append(roles, item.Message.Message.Role) + } + } + assert.Contains(t, roles, chat.MessageRoleAssistant, + "Parent session should contain root's assistant message with the transfer_task tool call") + assert.Contains(t, roles, chat.MessageRoleTool, + "Parent session should contain the tool result for transfer_task") +}