diff --git a/pkg/cli/runner.go b/pkg/cli/runner.go index 6be776c22..cf367af01 100644 --- a/pkg/cli/runner.go +++ b/pkg/cli/runner.go @@ -31,6 +31,39 @@ func (e RuntimeError) Unwrap() error { return e.Err } +// maxAutoExtensions is the maximum number of times --yolo mode will +// auto-continue when max iterations is reached, to prevent infinite loops. +const maxAutoExtensions = 5 + +// maxIterAction describes what the caller should do after a MaxIterationsReachedEvent. +type maxIterAction int + +const ( + maxIterContinue maxIterAction = iota // auto-approved, keep running + maxIterStop // safety cap reached, caller should stop + maxIterPrompt // not in yolo mode, caller should prompt the user +) + +// handleMaxIterationsAutoApprove decides whether to auto-extend iterations in +// --yolo mode. Returns maxIterContinue (approved), maxIterStop (cap reached), +// or maxIterPrompt (not in auto-approve mode, caller should ask the user). +func handleMaxIterationsAutoApprove(autoApprove bool, autoExtensions *int, maxIter int) maxIterAction { + if !autoApprove { + return maxIterPrompt + } + *autoExtensions++ + if *autoExtensions <= maxAutoExtensions { + slog.Info("Auto-extending iterations in yolo mode", + "extension", *autoExtensions, + "max_extensions", maxAutoExtensions, + "current_max", maxIter) + return maxIterContinue + } + slog.Warn("Max auto-extensions reached in yolo mode, stopping", + "total_extensions", *autoExtensions) + return maxIterStop +} + // Config holds configuration for running an agent in CLI mode type Config struct { AppName string @@ -60,6 +93,8 @@ func Run(ctx context.Context, out *Printer, cfg Config, rt runtime.Runtime, sess var lastErr error oneLoop := func(text string, rd io.Reader) error { + autoExtensions := 0 + userInput := strings.TrimSpace(text) if userInput == "" { return nil @@ -74,6 +109,14 @@ func Run(ctx context.Context, out *Printer, cfg Config, rt runtime.Runtime, sess if !cfg.AutoApprove { rt.Resume(ctx, runtime.ResumeReject("")) } + case *runtime.MaxIterationsReachedEvent: + switch handleMaxIterationsAutoApprove(cfg.AutoApprove, &autoExtensions, e.MaxIterations) { + case maxIterContinue: + rt.Resume(ctx, runtime.ResumeApprove()) + default: // maxIterStop or maxIterPrompt (no interactive prompt in JSON mode) + rt.Resume(ctx, runtime.ResumeReject("")) + return nil + } case *runtime.ErrorEvent: return fmt.Errorf("%s", e.Error) } @@ -153,16 +196,24 @@ func Run(ctx context.Context, out *Printer, cfg Config, rt runtime.Runtime, sess out.PrintError(lastErr) } case *runtime.MaxIterationsReachedEvent: - result := out.PromptMaxIterationsContinue(ctx, e.MaxIterations) - switch result { - case ConfirmationApprove: + switch handleMaxIterationsAutoApprove(cfg.AutoApprove, &autoExtensions, e.MaxIterations) { + case maxIterContinue: rt.Resume(ctx, runtime.ResumeApprove()) - case ConfirmationReject: - rt.Resume(ctx, runtime.ResumeReject("")) - return nil - case ConfirmationAbort: + case maxIterStop: rt.Resume(ctx, runtime.ResumeReject("")) return nil + case maxIterPrompt: + result := out.PromptMaxIterationsContinue(ctx, e.MaxIterations) + switch result { + case ConfirmationApprove: + rt.Resume(ctx, runtime.ResumeApprove()) + case ConfirmationReject: + rt.Resume(ctx, runtime.ResumeReject("")) + return nil + case ConfirmationAbort: + rt.Resume(ctx, runtime.ResumeReject("")) + return nil + } } case *runtime.ElicitationRequestEvent: serverURL, ok := e.Meta["cagent/server_url"].(string) diff --git a/pkg/cli/runner_test.go b/pkg/cli/runner_test.go new file mode 100644 index 000000000..f92bc4c31 --- /dev/null +++ b/pkg/cli/runner_test.go @@ -0,0 +1,205 @@ +package cli + +import ( + "bytes" + "context" + "sync" + "testing" + + "gotest.tools/v3/assert" + + "github.com/docker/cagent/pkg/runtime" + "github.com/docker/cagent/pkg/session" + "github.com/docker/cagent/pkg/sessiontitle" + "github.com/docker/cagent/pkg/tools" + mcptools "github.com/docker/cagent/pkg/tools/mcp" +) + +// mockRuntime implements runtime.Runtime for testing the CLI runner. +// It emits pre-configured events from RunStream and records Resume calls. +type mockRuntime struct { + events []runtime.Event + + mu sync.Mutex + resumes []runtime.ResumeRequest +} + +func (m *mockRuntime) CurrentAgentName() string { return "test" } +func (m *mockRuntime) CurrentAgentInfo(context.Context) runtime.CurrentAgentInfo { + return runtime.CurrentAgentInfo{Name: "test"} +} +func (m *mockRuntime) SetCurrentAgent(string) error { return nil } +func (m *mockRuntime) CurrentAgentTools(context.Context) ([]tools.Tool, error) { return nil, nil } +func (m *mockRuntime) EmitStartupInfo(context.Context, chan runtime.Event) {} +func (m *mockRuntime) ResetStartupInfo() {} +func (m *mockRuntime) Run(context.Context, *session.Session) ([]session.Message, error) { + return nil, nil +} + +func (m *mockRuntime) ResumeElicitation(context.Context, tools.ElicitationAction, map[string]any) error { + return nil +} +func (m *mockRuntime) SessionStore() session.Store { return nil } +func (m *mockRuntime) Summarize(context.Context, *session.Session, string, chan runtime.Event) {} +func (m *mockRuntime) PermissionsInfo() *runtime.PermissionsInfo { return nil } +func (m *mockRuntime) CurrentAgentSkillsEnabled() bool { return false } +func (m *mockRuntime) CurrentMCPPrompts(context.Context) map[string]mcptools.PromptInfo { + return nil +} + +func (m *mockRuntime) ExecuteMCPPrompt(context.Context, string, map[string]string) (string, error) { + return "", nil +} +func (m *mockRuntime) UpdateSessionTitle(context.Context, *session.Session, string) error { return nil } +func (m *mockRuntime) TitleGenerator() *sessiontitle.Generator { return nil } +func (m *mockRuntime) Close() error { return nil } +func (m *mockRuntime) RegenerateTitle(context.Context, *session.Session, chan runtime.Event) {} + +func (m *mockRuntime) Resume(_ context.Context, req runtime.ResumeRequest) { + m.mu.Lock() + defer m.mu.Unlock() + m.resumes = append(m.resumes, req) +} + +func (m *mockRuntime) RunStream(_ context.Context, _ *session.Session) <-chan runtime.Event { + ch := make(chan runtime.Event, len(m.events)) + for _, e := range m.events { + ch <- e + } + close(ch) + return ch +} + +func (m *mockRuntime) getResumes() []runtime.ResumeRequest { + m.mu.Lock() + defer m.mu.Unlock() + result := make([]runtime.ResumeRequest, len(m.resumes)) + copy(result, m.resumes) + return result +} + +func maxIterEvent(maxIter int) *runtime.MaxIterationsReachedEvent { + return &runtime.MaxIterationsReachedEvent{ + Type: "max_iterations_reached", + MaxIterations: maxIter, + } +} + +func TestMaxIterationsAutoApproveInYoloMode(t *testing.T) { + t.Parallel() + + rt := &mockRuntime{ + events: []runtime.Event{maxIterEvent(60)}, + } + + var buf bytes.Buffer + out := NewPrinter(&buf) + sess := session.New() + cfg := Config{AutoApprove: true} + + err := Run(t.Context(), out, cfg, rt, sess, []string{"hello"}) + assert.NilError(t, err) + + resumes := rt.getResumes() + assert.Equal(t, len(resumes), 1) + assert.Equal(t, resumes[0].Type, runtime.ResumeTypeApprove) +} + +func TestMaxIterationsAutoApproveSafetyCap(t *testing.T) { + t.Parallel() + + // Emit maxAutoExtensions+1 events to trigger the safety cap + events := make([]runtime.Event, maxAutoExtensions+1) + for i := range events { + events[i] = maxIterEvent(60 + i*10) + } + + rt := &mockRuntime{events: events} + + var buf bytes.Buffer + out := NewPrinter(&buf) + sess := session.New() + cfg := Config{AutoApprove: true} + + err := Run(t.Context(), out, cfg, rt, sess, []string{"hello"}) + assert.NilError(t, err) + + resumes := rt.getResumes() + assert.Equal(t, len(resumes), maxAutoExtensions+1) + + // First maxAutoExtensions should be approved + for i := range maxAutoExtensions { + assert.Equal(t, resumes[i].Type, runtime.ResumeTypeApprove, + "extension %d should be approved", i+1) + } + // Last one should be rejected (safety cap) + assert.Equal(t, resumes[maxAutoExtensions].Type, runtime.ResumeTypeReject, + "extension beyond cap should be rejected") +} + +func TestMaxIterationsAutoApproveJSONMode(t *testing.T) { + t.Parallel() + + rt := &mockRuntime{ + events: []runtime.Event{maxIterEvent(60)}, + } + + var buf bytes.Buffer + out := NewPrinter(&buf) + sess := session.New() + cfg := Config{AutoApprove: true, OutputJSON: true} + + err := Run(t.Context(), out, cfg, rt, sess, []string{"hello"}) + assert.NilError(t, err) + + resumes := rt.getResumes() + assert.Equal(t, len(resumes), 1) + assert.Equal(t, resumes[0].Type, runtime.ResumeTypeApprove) +} + +func TestMaxIterationsRejectInJSONModeWithoutYolo(t *testing.T) { + t.Parallel() + + rt := &mockRuntime{ + events: []runtime.Event{maxIterEvent(60)}, + } + + var buf bytes.Buffer + out := NewPrinter(&buf) + sess := session.New() + cfg := Config{AutoApprove: false, OutputJSON: true} + + err := Run(t.Context(), out, cfg, rt, sess, []string{"hello"}) + assert.NilError(t, err) + + resumes := rt.getResumes() + assert.Equal(t, len(resumes), 1) + assert.Equal(t, resumes[0].Type, runtime.ResumeTypeReject) +} + +func TestMaxIterationsSafetyCapJSONMode(t *testing.T) { + t.Parallel() + + events := make([]runtime.Event, maxAutoExtensions+1) + for i := range events { + events[i] = maxIterEvent(60 + i*10) + } + + rt := &mockRuntime{events: events} + + var buf bytes.Buffer + out := NewPrinter(&buf) + sess := session.New() + cfg := Config{AutoApprove: true, OutputJSON: true} + + err := Run(t.Context(), out, cfg, rt, sess, []string{"hello"}) + assert.NilError(t, err) + + resumes := rt.getResumes() + assert.Equal(t, len(resumes), maxAutoExtensions+1) + + for i := range maxAutoExtensions { + assert.Equal(t, resumes[i].Type, runtime.ResumeTypeApprove) + } + assert.Equal(t, resumes[maxAutoExtensions].Type, runtime.ResumeTypeReject) +}