Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions command.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,10 @@ type FlagMetadata struct {

// Required indicates whether the flag is required.
Required bool

// Local indicates that the flag should not be inherited by child commands. When true, the
// flag is only available on the command that defines it.
Local bool
}

// FlagsFunc is a helper function that creates a new [flag.FlagSet] and applies the given function
Expand Down
36 changes: 34 additions & 2 deletions parse.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,12 +110,21 @@ func resolveCommandPath(root *Command, argsToParse []string) (*Command, error) {
name := strings.TrimLeft(arg, "-")
skipValue := false
for _, cmd := range root.state.path {
localFlags := localFlagSet(cmd.FlagsMetadata)
// Skip local flags on ancestor commands (any command already in the
// path is an ancestor of the not-yet-resolved terminal command).
if localFlags[name] {
continue
}
// First try direct lookup.
f := cmd.Flags.Lookup(name)
// If not found, check if it's a short alias.
if f == nil {
for _, fm := range cmd.FlagsMetadata {
if fm.Short == name {
if localFlags[fm.Name] {
break
}
f = cmd.Flags.Lookup(fm.Name)
break
}
Expand Down Expand Up @@ -161,13 +170,20 @@ func resolveCommandPath(root *Command, argsToParse []string) (*Command, error) {
func combineFlags(path []*Command) *flag.FlagSet {
combined := flag.NewFlagSet(path[0].Name, flag.ContinueOnError)
combined.SetOutput(io.Discard)
for i := len(path) - 1; i >= 0; i-- {
terminalIdx := len(path) - 1
for i := terminalIdx; i >= 0; i-- {
cmd := path[i]
if cmd.Flags == nil {
continue
}
localFlags := localFlagSet(cmd.FlagsMetadata)
shortMap := shortFlagMap(cmd.FlagsMetadata)
isAncestor := i < terminalIdx
cmd.Flags.VisitAll(func(f *flag.Flag) {
// Skip local flags from ancestor commands — they are not inherited.
if isAncestor && localFlags[f.Name] {
return
}
if combined.Lookup(f.Name) == nil {
combined.Var(f.Value, f.Name, f.Usage)
}
Expand All @@ -182,6 +198,17 @@ func combineFlags(path []*Command) *flag.FlagSet {
return combined
}

// localFlagSet builds a set of flag names that are marked as local in FlagsMetadata.
func localFlagSet(metadata []FlagMetadata) map[string]bool {
m := make(map[string]bool, len(metadata))
for _, fm := range metadata {
if fm.Local {
m[fm.Name] = true
}
}
return m
}

// shortFlagMap builds a map from long flag name to short alias from FlagsMetadata.
func shortFlagMap(metadata []FlagMetadata) map[string]string {
m := make(map[string]string, len(metadata))
Expand All @@ -203,12 +230,17 @@ func checkRequiredFlags(path []*Command, combined *flag.FlagSet) error {
setFlags[f.Name] = struct{}{}
})

terminalIdx := len(path) - 1
var missingFlags []string
for _, cmd := range path {
for i, cmd := range path {
for _, flagMetadata := range cmd.FlagsMetadata {
if !flagMetadata.Required {
continue
}
// Skip required-flag checks for local flags on ancestor commands.
if flagMetadata.Local && i < terminalIdx {
continue
}
if combined.Lookup(flagMetadata.Name) == nil {
return fmt.Errorf("command %q: internal error: required flag %s not found in flag set", getCommandPath(path), formatFlagName(flagMetadata.Name))
}
Expand Down
158 changes: 158 additions & 0 deletions parse_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -836,6 +836,164 @@ func TestShortFlags(t *testing.T) {
})
}

func TestLocalFlags(t *testing.T) {
t.Parallel()

t.Run("local flag on parent not available to child", func(t *testing.T) {
t.Parallel()
child := &Command{
Name: "child",
Exec: func(ctx context.Context, s *State) error { return nil },
}
root := &Command{
Name: "root",
Flags: FlagsFunc(func(f *flag.FlagSet) {
f.Bool("version", false, "show version")
f.Bool("verbose", false, "enable verbose output")
}),
FlagsMetadata: []FlagMetadata{
{Name: "version", Local: true},
},
SubCommands: []*Command{child},
Exec: func(ctx context.Context, s *State) error { return nil },
}
// --version on child should fail because it's local to root
err := Parse(root, []string{"child", "--version"})
require.Error(t, err)
require.ErrorContains(t, err, "flag provided but not defined")

// --verbose on child should still work (not local)
root2 := &Command{
Name: "root",
Flags: FlagsFunc(func(f *flag.FlagSet) {
f.Bool("version", false, "show version")
f.Bool("verbose", false, "enable verbose output")
}),
FlagsMetadata: []FlagMetadata{
{Name: "version", Local: true},
},
SubCommands: []*Command{{
Name: "child",
Exec: func(ctx context.Context, s *State) error { return nil },
}},
Exec: func(ctx context.Context, s *State) error { return nil },
}
err = Parse(root2, []string{"child", "--verbose"})
require.NoError(t, err)
assert.True(t, GetFlag[bool](root2.state, "verbose"))
})

t.Run("local flag works on defining command", func(t *testing.T) {
t.Parallel()
root := &Command{
Name: "root",
Flags: FlagsFunc(func(f *flag.FlagSet) {
f.Bool("version", false, "show version")
}),
FlagsMetadata: []FlagMetadata{
{Name: "version", Local: true},
},
Exec: func(ctx context.Context, s *State) error { return nil },
}
err := Parse(root, []string{"--version"})
require.NoError(t, err)
assert.True(t, GetFlag[bool](root.state, "version"))
})

t.Run("local required flag only enforced on defining command", func(t *testing.T) {
t.Parallel()
child := &Command{
Name: "child",
Exec: func(ctx context.Context, s *State) error { return nil },
}
root := &Command{
Name: "root",
Flags: FlagsFunc(func(f *flag.FlagSet) {
f.String("token", "", "auth token")
}),
FlagsMetadata: []FlagMetadata{
{Name: "token", Required: true, Local: true},
},
SubCommands: []*Command{child},
Exec: func(ctx context.Context, s *State) error { return nil },
}
// Child command should not require parent's local required flag
err := Parse(root, []string{"child"})
require.NoError(t, err)

// But root command itself should still require it
root2 := &Command{
Name: "root",
Flags: FlagsFunc(func(f *flag.FlagSet) {
f.String("token", "", "auth token")
}),
FlagsMetadata: []FlagMetadata{
{Name: "token", Required: true, Local: true},
},
Exec: func(ctx context.Context, s *State) error { return nil },
}
err = Parse(root2, []string{})
require.Error(t, err)
require.ErrorContains(t, err, "required flag")
})

t.Run("usage excludes local parent flags from inherited flags", func(t *testing.T) {
t.Parallel()
child := &Command{
Name: "child",
Flags: FlagsFunc(func(f *flag.FlagSet) {
f.Bool("dry-run", false, "dry run mode")
}),
Exec: func(ctx context.Context, s *State) error { return nil },
}
root := &Command{
Name: "root",
Flags: FlagsFunc(func(f *flag.FlagSet) {
f.Bool("version", false, "show version")
f.Bool("verbose", false, "enable verbose output")
}),
FlagsMetadata: []FlagMetadata{
{Name: "version", Local: true},
},
SubCommands: []*Command{child},
Exec: func(ctx context.Context, s *State) error { return nil },
}
err := Parse(root, []string{"child", "--help"})
require.ErrorIs(t, err, flag.ErrHelp)

usage := DefaultUsage(root)
// --verbose should appear in inherited flags (not local)
assert.Contains(t, usage, "--verbose")
// --version should NOT appear (local to root, not inherited)
assert.NotContains(t, usage, "--version")
// --dry-run should appear in local flags
assert.Contains(t, usage, "--dry-run")
})

t.Run("local flag with short alias not inherited", func(t *testing.T) {
t.Parallel()
child := &Command{
Name: "child",
Exec: func(ctx context.Context, s *State) error { return nil },
}
root := &Command{
Name: "root",
Flags: FlagsFunc(func(f *flag.FlagSet) {
f.Bool("version", false, "show version")
}),
FlagsMetadata: []FlagMetadata{
{Name: "version", Short: "V", Local: true},
},
SubCommands: []*Command{child},
Exec: func(ctx context.Context, s *State) error { return nil },
}
// Short alias -V should also not work on child
err := Parse(root, []string{"child", "-V"})
require.Error(t, err)
require.ErrorContains(t, err, "flag provided but not defined")
})
}

func getCommand(t *testing.T, c *Command) *Command {
require.NotNil(t, c)
require.NotNil(t, c.state)
Expand Down
47 changes: 27 additions & 20 deletions usage.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,19 +90,26 @@ func DefaultUsage(root *Command) string {

var flags []flagInfo
if root.state != nil && len(root.state.path) > 0 {
terminalIdx := len(root.state.path) - 1
for i, cmd := range root.state.path {
if cmd.Flags == nil {
continue
}
isGlobal := i < len(root.state.path)-1
isInherited := i < terminalIdx
metaMap := flagMetadataMap(cmd.FlagsMetadata)
cmd.Flags.VisitAll(func(f *flag.Flag) {
// Skip local flags from ancestor commands — they don't appear in child help.
if isInherited {
if m, ok := metaMap[f.Name]; ok && m.Local {
return
}
}
fi := flagInfo{
name: "--" + f.Name,
usage: f.Usage,
defval: f.DefValue,
typeName: flagTypeName(f),
global: isGlobal,
name: "--" + f.Name,
usage: f.Usage,
defval: f.DefValue,
typeName: flagTypeName(f),
inherited: isInherited,
}
if m, ok := metaMap[f.Name]; ok {
fi.required = m.Required
Expand Down Expand Up @@ -150,10 +157,10 @@ func DefaultUsage(root *Command) string {
}

hasLocal := false
hasGlobal := false
hasInherited := false
for _, f := range flags {
if f.global {
hasGlobal = true
if f.inherited {
hasInherited = true
} else {
hasLocal = true
}
Expand All @@ -165,8 +172,8 @@ func DefaultUsage(root *Command) string {
b.WriteString("\n")
}

if hasGlobal {
b.WriteString("Global Flags:\n")
if hasInherited {
b.WriteString("Inherited Flags:\n")
writeFlagSection(&b, flags, maxFlagLen, true, hasAnyShort)
b.WriteString("\n")
}
Expand All @@ -184,12 +191,12 @@ func DefaultUsage(root *Command) string {
}

// writeFlagSection handles the formatting of flag descriptions
func writeFlagSection(b *strings.Builder, flags []flagInfo, maxLen int, global, hasAnyShort bool) {
func writeFlagSection(b *strings.Builder, flags []flagInfo, maxLen int, inherited, hasAnyShort bool) {
nameWidth := maxLen + 4
wrapWidth := defaultTerminalWidth - nameWidth

for _, f := range flags {
if f.global != global {
if f.inherited != inherited {
continue
}

Expand Down Expand Up @@ -222,13 +229,13 @@ func flagMetadataMap(metadata []FlagMetadata) map[string]FlagMetadata {
}

type flagInfo struct {
name string
short string
usage string
defval string
typeName string
global bool
required bool
name string
short string
usage string
defval string
typeName string
inherited bool
required bool
}

// displayName returns the flag name with optional short alias and type hint. When hasAnyShort is
Expand Down
4 changes: 2 additions & 2 deletions usage_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ func TestUsageGeneration(t *testing.T) {
require.Contains(t, output, "custom [options] <file>")
})

t.Run("usage with global and local flags", func(t *testing.T) {
t.Run("usage with inherited and local flags", func(t *testing.T) {
t.Parallel()

child := &Command{
Expand Down Expand Up @@ -487,6 +487,6 @@ func TestWriteFlagSection(t *testing.T) {

output := DefaultUsage(cmd)
require.NotContains(t, output, "Flags:")
require.NotContains(t, output, "Global Flags:")
require.NotContains(t, output, "Inherited Flags:")
})
}