Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 58 additions & 7 deletions pkg/cli/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,39 @@ func (e RuntimeError) Unwrap() error {
return e.Err
}

// maxAutoExtensions is the maximum number of times --yolo mode will
// auto-continue when max iterations is reached, to prevent infinite loops.
const maxAutoExtensions = 5

// maxIterAction describes what the caller should do after a MaxIterationsReachedEvent.
type maxIterAction int

const (
maxIterContinue maxIterAction = iota // auto-approved, keep running
maxIterStop // safety cap reached, caller should stop
maxIterPrompt // not in yolo mode, caller should prompt the user
)

// handleMaxIterationsAutoApprove decides whether to auto-extend iterations in
// --yolo mode. Returns maxIterContinue (approved), maxIterStop (cap reached),
// or maxIterPrompt (not in auto-approve mode, caller should ask the user).
func handleMaxIterationsAutoApprove(autoApprove bool, autoExtensions *int, maxIter int) maxIterAction {
if !autoApprove {
return maxIterPrompt
}
*autoExtensions++
if *autoExtensions <= maxAutoExtensions {
slog.Info("Auto-extending iterations in yolo mode",
"extension", *autoExtensions,
"max_extensions", maxAutoExtensions,
"current_max", maxIter)
return maxIterContinue
}
slog.Warn("Max auto-extensions reached in yolo mode, stopping",
"total_extensions", *autoExtensions)
return maxIterStop
}

// Config holds configuration for running an agent in CLI mode
type Config struct {
AppName string
Expand Down Expand Up @@ -60,6 +93,8 @@ func Run(ctx context.Context, out *Printer, cfg Config, rt runtime.Runtime, sess
var lastErr error

oneLoop := func(text string, rd io.Reader) error {
autoExtensions := 0

userInput := strings.TrimSpace(text)
if userInput == "" {
return nil
Expand All @@ -74,6 +109,14 @@ func Run(ctx context.Context, out *Printer, cfg Config, rt runtime.Runtime, sess
if !cfg.AutoApprove {
rt.Resume(ctx, runtime.ResumeReject(""))
}
case *runtime.MaxIterationsReachedEvent:
switch handleMaxIterationsAutoApprove(cfg.AutoApprove, &autoExtensions, e.MaxIterations) {
case maxIterContinue:
rt.Resume(ctx, runtime.ResumeApprove())
default: // maxIterStop or maxIterPrompt (no interactive prompt in JSON mode)
rt.Resume(ctx, runtime.ResumeReject(""))
return nil
}
case *runtime.ErrorEvent:
return fmt.Errorf("%s", e.Error)
}
Expand Down Expand Up @@ -153,16 +196,24 @@ func Run(ctx context.Context, out *Printer, cfg Config, rt runtime.Runtime, sess
out.PrintError(lastErr)
}
case *runtime.MaxIterationsReachedEvent:
result := out.PromptMaxIterationsContinue(ctx, e.MaxIterations)
switch result {
case ConfirmationApprove:
switch handleMaxIterationsAutoApprove(cfg.AutoApprove, &autoExtensions, e.MaxIterations) {
case maxIterContinue:
rt.Resume(ctx, runtime.ResumeApprove())
case ConfirmationReject:
rt.Resume(ctx, runtime.ResumeReject(""))
return nil
case ConfirmationAbort:
case maxIterStop:
rt.Resume(ctx, runtime.ResumeReject(""))
return nil
case maxIterPrompt:
result := out.PromptMaxIterationsContinue(ctx, e.MaxIterations)
switch result {
case ConfirmationApprove:
rt.Resume(ctx, runtime.ResumeApprove())
case ConfirmationReject:
rt.Resume(ctx, runtime.ResumeReject(""))
return nil
case ConfirmationAbort:
rt.Resume(ctx, runtime.ResumeReject(""))
return nil
}
}
case *runtime.ElicitationRequestEvent:
serverURL, ok := e.Meta["cagent/server_url"].(string)
Expand Down
205 changes: 205 additions & 0 deletions pkg/cli/runner_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
package cli

import (
"bytes"
"context"
"sync"
"testing"

"gotest.tools/v3/assert"

"github.com/docker/cagent/pkg/runtime"
"github.com/docker/cagent/pkg/session"
"github.com/docker/cagent/pkg/sessiontitle"
"github.com/docker/cagent/pkg/tools"
mcptools "github.com/docker/cagent/pkg/tools/mcp"
)

// mockRuntime implements runtime.Runtime for testing the CLI runner.
// It emits pre-configured events from RunStream and records Resume calls.
type mockRuntime struct {
events []runtime.Event

mu sync.Mutex
resumes []runtime.ResumeRequest
}

func (m *mockRuntime) CurrentAgentName() string { return "test" }
func (m *mockRuntime) CurrentAgentInfo(context.Context) runtime.CurrentAgentInfo {
return runtime.CurrentAgentInfo{Name: "test"}
}
func (m *mockRuntime) SetCurrentAgent(string) error { return nil }
func (m *mockRuntime) CurrentAgentTools(context.Context) ([]tools.Tool, error) { return nil, nil }
func (m *mockRuntime) EmitStartupInfo(context.Context, chan runtime.Event) {}
func (m *mockRuntime) ResetStartupInfo() {}
func (m *mockRuntime) Run(context.Context, *session.Session) ([]session.Message, error) {
return nil, nil
}

func (m *mockRuntime) ResumeElicitation(context.Context, tools.ElicitationAction, map[string]any) error {
return nil
}
func (m *mockRuntime) SessionStore() session.Store { return nil }
func (m *mockRuntime) Summarize(context.Context, *session.Session, string, chan runtime.Event) {}
func (m *mockRuntime) PermissionsInfo() *runtime.PermissionsInfo { return nil }
func (m *mockRuntime) CurrentAgentSkillsEnabled() bool { return false }
func (m *mockRuntime) CurrentMCPPrompts(context.Context) map[string]mcptools.PromptInfo {
return nil
}

func (m *mockRuntime) ExecuteMCPPrompt(context.Context, string, map[string]string) (string, error) {
return "", nil
}
func (m *mockRuntime) UpdateSessionTitle(context.Context, *session.Session, string) error { return nil }
func (m *mockRuntime) TitleGenerator() *sessiontitle.Generator { return nil }
func (m *mockRuntime) Close() error { return nil }
func (m *mockRuntime) RegenerateTitle(context.Context, *session.Session, chan runtime.Event) {}

func (m *mockRuntime) Resume(_ context.Context, req runtime.ResumeRequest) {
m.mu.Lock()
defer m.mu.Unlock()
m.resumes = append(m.resumes, req)
}

func (m *mockRuntime) RunStream(_ context.Context, _ *session.Session) <-chan runtime.Event {
ch := make(chan runtime.Event, len(m.events))
for _, e := range m.events {
ch <- e
}
close(ch)
return ch
}

func (m *mockRuntime) getResumes() []runtime.ResumeRequest {
m.mu.Lock()
defer m.mu.Unlock()
result := make([]runtime.ResumeRequest, len(m.resumes))
copy(result, m.resumes)
return result
}

func maxIterEvent(maxIter int) *runtime.MaxIterationsReachedEvent {
return &runtime.MaxIterationsReachedEvent{
Type: "max_iterations_reached",
MaxIterations: maxIter,
}
}

func TestMaxIterationsAutoApproveInYoloMode(t *testing.T) {
t.Parallel()

rt := &mockRuntime{
events: []runtime.Event{maxIterEvent(60)},
}

var buf bytes.Buffer
out := NewPrinter(&buf)
sess := session.New()
cfg := Config{AutoApprove: true}

err := Run(t.Context(), out, cfg, rt, sess, []string{"hello"})
assert.NilError(t, err)

resumes := rt.getResumes()
assert.Equal(t, len(resumes), 1)
assert.Equal(t, resumes[0].Type, runtime.ResumeTypeApprove)
}

func TestMaxIterationsAutoApproveSafetyCap(t *testing.T) {
t.Parallel()

// Emit maxAutoExtensions+1 events to trigger the safety cap
events := make([]runtime.Event, maxAutoExtensions+1)
for i := range events {
events[i] = maxIterEvent(60 + i*10)
}

rt := &mockRuntime{events: events}

var buf bytes.Buffer
out := NewPrinter(&buf)
sess := session.New()
cfg := Config{AutoApprove: true}

err := Run(t.Context(), out, cfg, rt, sess, []string{"hello"})
assert.NilError(t, err)

resumes := rt.getResumes()
assert.Equal(t, len(resumes), maxAutoExtensions+1)

// First maxAutoExtensions should be approved
for i := range maxAutoExtensions {
assert.Equal(t, resumes[i].Type, runtime.ResumeTypeApprove,
"extension %d should be approved", i+1)
}
// Last one should be rejected (safety cap)
assert.Equal(t, resumes[maxAutoExtensions].Type, runtime.ResumeTypeReject,
"extension beyond cap should be rejected")
}

func TestMaxIterationsAutoApproveJSONMode(t *testing.T) {
t.Parallel()

rt := &mockRuntime{
events: []runtime.Event{maxIterEvent(60)},
}

var buf bytes.Buffer
out := NewPrinter(&buf)
sess := session.New()
cfg := Config{AutoApprove: true, OutputJSON: true}

err := Run(t.Context(), out, cfg, rt, sess, []string{"hello"})
assert.NilError(t, err)

resumes := rt.getResumes()
assert.Equal(t, len(resumes), 1)
assert.Equal(t, resumes[0].Type, runtime.ResumeTypeApprove)
}

func TestMaxIterationsRejectInJSONModeWithoutYolo(t *testing.T) {
t.Parallel()

rt := &mockRuntime{
events: []runtime.Event{maxIterEvent(60)},
}

var buf bytes.Buffer
out := NewPrinter(&buf)
sess := session.New()
cfg := Config{AutoApprove: false, OutputJSON: true}

err := Run(t.Context(), out, cfg, rt, sess, []string{"hello"})
assert.NilError(t, err)

resumes := rt.getResumes()
assert.Equal(t, len(resumes), 1)
assert.Equal(t, resumes[0].Type, runtime.ResumeTypeReject)
}

func TestMaxIterationsSafetyCapJSONMode(t *testing.T) {
t.Parallel()

events := make([]runtime.Event, maxAutoExtensions+1)
for i := range events {
events[i] = maxIterEvent(60 + i*10)
}

rt := &mockRuntime{events: events}

var buf bytes.Buffer
out := NewPrinter(&buf)
sess := session.New()
cfg := Config{AutoApprove: true, OutputJSON: true}

err := Run(t.Context(), out, cfg, rt, sess, []string{"hello"})
assert.NilError(t, err)

resumes := rt.getResumes()
assert.Equal(t, len(resumes), maxAutoExtensions+1)

for i := range maxAutoExtensions {
assert.Equal(t, resumes[i].Type, runtime.ResumeTypeApprove)
}
assert.Equal(t, resumes[maxAutoExtensions].Type, runtime.ResumeTypeReject)
}