diff --git a/run.go b/run.go index c35cf22..4bd15eb 100644 --- a/run.go +++ b/run.go @@ -69,7 +69,12 @@ func ParseAndRun(ctx context.Context, root *Command, args []string, options *Run if err := Parse(root, args); err != nil { if errors.Is(err, ErrHelp) { options = checkAndSetRunOptions(options) - _, _ = fmt.Fprintln(options.Stdout, DefaultUsage(root)) + cmd := root.terminal() + if cmd.UsageFunc != nil { + _, _ = fmt.Fprintln(options.Stdout, cmd.UsageFunc(cmd)) + } else { + _, _ = fmt.Fprintln(options.Stdout, DefaultUsage(root)) + } return nil } return err diff --git a/run_test.go b/run_test.go index 4b1639e..e7f0d41 100644 --- a/run_test.go +++ b/run_test.go @@ -228,4 +228,36 @@ func TestRun(t *testing.T) { require.Equal(t, val, GetFlag[string](root.state, "text")) } }) + t.Run("ParseAndRun uses UsageFunc on help", func(t *testing.T) { + t.Parallel() + + root := &Command{ + Name: "myapp", + ShortHelp: "my application", + UsageFunc: func(c *Command) string { + return "custom usage output" + }, + Exec: func(ctx context.Context, s *State) error { return nil }, + } + + stdout := bytes.NewBuffer(nil) + err := ParseAndRun(context.Background(), root, []string{"-help"}, &RunOptions{Stdout: stdout}) + require.NoError(t, err) + require.Contains(t, stdout.String(), "custom usage output") + }) + t.Run("ParseAndRun falls back to DefaultUsage without UsageFunc", func(t *testing.T) { + t.Parallel() + + root := &Command{ + Name: "myapp", + ShortHelp: "my application", + Exec: func(ctx context.Context, s *State) error { return nil }, + } + + stdout := bytes.NewBuffer(nil) + err := ParseAndRun(context.Background(), root, []string{"-help"}, &RunOptions{Stdout: stdout}) + require.NoError(t, err) + require.Contains(t, stdout.String(), "my application") + require.Contains(t, stdout.String(), "Usage:") + }) } diff --git a/usage.go b/usage.go index 58e9b4e..8ba87d7 100644 --- a/usage.go +++ b/usage.go @@ -26,10 +26,6 @@ func DefaultUsage(root *Command) string { var b strings.Builder - if terminalCmd.UsageFunc != nil { - return terminalCmd.UsageFunc(terminalCmd) - } - if terminalCmd.ShortHelp != "" { b.WriteString(terminalCmd.ShortHelp) b.WriteString("\n\n") diff --git a/usage_test.go b/usage_test.go index 8bacfb0..917caa9 100644 --- a/usage_test.go +++ b/usage_test.go @@ -490,3 +490,27 @@ func TestWriteFlagSection(t *testing.T) { require.NotContains(t, output, "Inherited Flags:") }) } + +func TestDefaultUsageComposableFromUsageFunc(t *testing.T) { + t.Parallel() + + cmd := &Command{ + Name: "myapp", + ShortHelp: "my application", + Exec: func(ctx context.Context, s *State) error { return nil }, + } + cmd.UsageFunc = func(c *Command) string { + // Calling DefaultUsage from within UsageFunc should not recurse infinitely. + s := DefaultUsage(c) + return s + "\n\nExamples:\n myapp --verbose" + } + + err := Parse(cmd, []string{}) + require.NoError(t, err) + + output := cmd.UsageFunc(cmd) + require.Contains(t, output, "my application") + require.Contains(t, output, "Usage:") + require.Contains(t, output, "Examples:") + require.Contains(t, output, "myapp --verbose") +}