From 1ec9e4972c5791e8671479575a0609ed1589126e Mon Sep 17 00:00:00 2001 From: Mike Fridman Date: Wed, 18 Feb 2026 07:53:39 -0500 Subject: [PATCH] feat: add flagtype package with common flag.Value implementations --- CHANGELOG.md | 5 + docs/design/001-flagtype-api.md | 95 +++++++++++++++ flagtype/doc.go | 25 ++++ flagtype/enum.go | 37 ++++++ flagtype/flagtype_test.go | 208 ++++++++++++++++++++++++++++++++ flagtype/regexp.go | 38 ++++++ flagtype/string_map.go | 57 +++++++++ flagtype/string_slice.go | 31 +++++ flagtype/url.go | 42 +++++++ 9 files changed, 538 insertions(+) create mode 100644 docs/design/001-flagtype-api.md create mode 100644 flagtype/doc.go create mode 100644 flagtype/enum.go create mode 100644 flagtype/flagtype_test.go create mode 100644 flagtype/regexp.go create mode 100644 flagtype/string_map.go create mode 100644 flagtype/string_slice.go create mode 100644 flagtype/url.go diff --git a/CHANGELOG.md b/CHANGELOG.md index 5ffc6e3..204f0c7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,11 @@ adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). ## [Unreleased] +### Added + +- New `flagtype` package with common `flag.Value` implementations: `StringSlice`, `Enum`, + `StringMap`, `URL`, and `Regexp` + ## [v0.5.0] - 2026-02-17 ### Changed diff --git a/docs/design/001-flagtype-api.md b/docs/design/001-flagtype-api.md new file mode 100644 index 0000000..972e8e5 --- /dev/null +++ b/docs/design/001-flagtype-api.md @@ -0,0 +1,95 @@ +# 001 - flagtype API + +**Date:** 2026-02-18 + +## Context + +Users of pressly/cli must manually implement `flag.Value` (and `flag.Getter`) for common types like +string slices, enums, and maps. This is repetitive boilerplate that most CLI tools need. + +## Decision + +Use stdlib-native constructors that return `flag.Value`, registered via `f.Var()`. + +```go +Flags: cli.FlagsFunc(func(f *flag.FlagSet) { + f.Bool("verbose", false, "enable verbose output") + f.Var(flagtype.StringSlice(), "tag", "add a tag (repeatable)") + f.Var(flagtype.Enum("json", "yaml", "table"), "format", "output format") + f.Var(flagtype.StringMap(), "label", "key=value pair (repeatable)") +}) +``` + +The flagtype package has no knowledge of `flag.FlagSet`. Each constructor returns a value that +implements `flag.Value` and `flag.Getter`. Storage is internal -- no destination pointers needed +since values are retrieved via `cli.GetFlag[T]`. + +## Alternatives considered + +### A: flagtype takes a FlagSet + +```go +Flags: cli.FlagsFunc(func(f *flag.FlagSet) { + f.Bool("verbose", false, "enable verbose output") + flagtype.StringSlice(f, "tag", "add a tag (repeatable)") + flagtype.Enum(f, "format", "output format", "json", "yaml", "table") +}) +``` + +One-liner registration, no `f.Var()` ceremony. Rejected because it introduces a second calling +convention in the same block -- stdlib flags use `f.Type(name, default, usage)` while flagtype would +use `flagtype.Type(f, name, usage)`. The argument ordering inconsistency makes it harder to read at +a glance. + +### B: FlagSet wrapper + +```go +Flags: cli.FlagsFunc(func(f *flag.FlagSet) { + f.Bool("verbose", false, "enable verbose output") + ft := flagtype.From(f) + ft.StringSlice("tag", "add a tag (repeatable)") + ft.Enum("format", "output format", "json", "yaml", "table") +}) +``` + +Feels like a natural extension of FlagSet. Rejected because it requires managing two objects in the +same closure -- `f` for standard types and `ft` for custom types. Also adds a layer of indirection +that doesn't pull its weight. + +### C: Declarative flag list + +```go +Flags: []cli.Flag{ + cli.String("output", "", "output file"), + cli.Bool("verbose", false, "enable verbose output"), + flagtype.StringSlice("tag", "add a tag (repeatable)"), + flagtype.Enum("format", "output format", "json", "yaml", "table"), +} +``` + +Fully declarative, no callback, no FlagSet. Rejected because it's a significant departure from the +stdlib `flag` package and would require rethinking the core `Command` type. Essentially a different +framework. + +### D: Destination pointer pattern + +```go +var tags []string +var re *regexp.Regexp +f.Var(flagtype.StringSlice(&tags), "tag", "add a tag (repeatable)") +f.Var(flagtype.Regexp(&re), "pattern", "regex pattern") +``` + +The initial implementation. Each constructor takes a pointer to the destination variable. Rejected +because pointer types like `*regexp.Regexp` and `*url.URL` require double pointers +(`**regexp.Regexp`), which is awkward. Since values are always retrieved via `cli.GetFlag[T]`, the +destination pointer serves no purpose. + +## Why this approach + +- **Zero new concepts.** Anyone who knows `flag.Var` already knows how to use flagtype. +- **No coupling.** flagtype has no dependency on the cli package or `flag.FlagSet`. +- **Consistent with stdlib.** Custom flag types in Go have always been registered via `f.Var()`. + This follows that convention exactly. +- **No double pointers.** Internal storage means the API is clean for all types, including pointer + types like `*url.URL` and `*regexp.Regexp`. diff --git a/flagtype/doc.go b/flagtype/doc.go new file mode 100644 index 0000000..88b3ee3 --- /dev/null +++ b/flagtype/doc.go @@ -0,0 +1,25 @@ +// Package flagtype provides common [flag.Value] implementations for use with [flag.FlagSet.Var]. +// +// All types implement [flag.Getter] so they work with [cli.GetFlag]. +// +// The following types are available: +// - [StringSlice] - repeatable flag that collects values into []string +// - [Enum] - restricts values to a predefined set, retrieved as string +// - [StringMap] - repeatable flag that parses key=value pairs into map[string]string +// - [URL] - parses and validates a URL (must have scheme and host), retrieved as *url.URL +// - [Regexp] - compiles a regular expression, retrieved as *regexp.Regexp +// +// Example registration: +// +// Flags: cli.FlagsFunc(func(f *flag.FlagSet) { +// f.Var(flagtype.StringSlice(), "tag", "add a tag (repeatable)") +// f.Var(flagtype.Enum("json", "yaml", "table"), "format", "output format") +// f.Var(flagtype.StringMap(), "label", "key=value pair (repeatable)") +// }) +// +// Example retrieval in Exec: +// +// tags := cli.GetFlag[[]string](s, "tag") +// format := cli.GetFlag[string](s, "format") +// labels := cli.GetFlag[map[string]string](s, "label") +package flagtype diff --git a/flagtype/enum.go b/flagtype/enum.go new file mode 100644 index 0000000..e73ed19 --- /dev/null +++ b/flagtype/enum.go @@ -0,0 +1,37 @@ +package flagtype + +import ( + "flag" + "fmt" + "slices" + "strings" +) + +type enumValue struct { + val string + allowed []string +} + +// Enum returns a [flag.Value] that restricts the flag to one of the allowed values. If a value not +// in the allowed list is provided, an error is returned listing valid options. +// +// Use [cli.GetFlag] with type string to retrieve the value. +func Enum(allowed ...string) flag.Value { + return &enumValue{allowed: allowed} +} + +func (v *enumValue) String() string { + return v.val +} + +func (v *enumValue) Set(s string) error { + if !slices.Contains(v.allowed, s) { + return fmt.Errorf("invalid value %q, must be one of: %s", s, strings.Join(v.allowed, ", ")) + } + v.val = s + return nil +} + +func (v *enumValue) Get() any { + return v.val +} diff --git a/flagtype/flagtype_test.go b/flagtype/flagtype_test.go new file mode 100644 index 0000000..70ede80 --- /dev/null +++ b/flagtype/flagtype_test.go @@ -0,0 +1,208 @@ +package flagtype + +import ( + "flag" + "net/url" + "regexp" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestStringSlice(t *testing.T) { + t.Parallel() + + t.Run("single value", func(t *testing.T) { + t.Parallel() + fs := flag.NewFlagSet("test", flag.ContinueOnError) + fs.Var(StringSlice(), "tag", "") + err := fs.Parse([]string{"--tag=foo"}) + require.NoError(t, err) + got := fs.Lookup("tag").Value.(flag.Getter).Get().([]string) + assert.Equal(t, []string{"foo"}, got) + }) + t.Run("multiple values", func(t *testing.T) { + t.Parallel() + fs := flag.NewFlagSet("test", flag.ContinueOnError) + fs.Var(StringSlice(), "tag", "") + err := fs.Parse([]string{"--tag=foo", "--tag=bar", "--tag=baz"}) + require.NoError(t, err) + got := fs.Lookup("tag").Value.(flag.Getter).Get().([]string) + assert.Equal(t, []string{"foo", "bar", "baz"}, got) + }) + t.Run("string output", func(t *testing.T) { + t.Parallel() + v := StringSlice() + require.NoError(t, v.Set("a")) + require.NoError(t, v.Set("b")) + assert.Equal(t, "a,b", v.String()) + }) + t.Run("empty", func(t *testing.T) { + t.Parallel() + v := StringSlice() + assert.Equal(t, "", v.String()) + got := v.(flag.Getter).Get().([]string) + assert.Nil(t, got) + }) +} + +func TestEnum(t *testing.T) { + t.Parallel() + + t.Run("valid value", func(t *testing.T) { + t.Parallel() + fs := flag.NewFlagSet("test", flag.ContinueOnError) + fs.Var(Enum("json", "yaml", "table"), "format", "") + err := fs.Parse([]string{"--format=yaml"}) + require.NoError(t, err) + got := fs.Lookup("format").Value.(flag.Getter).Get().(string) + assert.Equal(t, "yaml", got) + }) + t.Run("invalid value", func(t *testing.T) { + t.Parallel() + fs := flag.NewFlagSet("test", flag.ContinueOnError) + fs.SetOutput(nopWriter{}) + fs.Var(Enum("json", "yaml"), "format", "") + err := fs.Parse([]string{"--format=xml"}) + require.Error(t, err) + assert.Contains(t, err.Error(), "must be one of") + assert.Contains(t, err.Error(), "json, yaml") + }) + t.Run("empty default", func(t *testing.T) { + t.Parallel() + v := Enum("a", "b") + assert.Equal(t, "", v.String()) + assert.Equal(t, "", v.(flag.Getter).Get()) + }) +} + +func TestStringMap(t *testing.T) { + t.Parallel() + + t.Run("single pair", func(t *testing.T) { + t.Parallel() + fs := flag.NewFlagSet("test", flag.ContinueOnError) + fs.Var(StringMap(), "label", "") + err := fs.Parse([]string{"--label=env=prod"}) + require.NoError(t, err) + got := fs.Lookup("label").Value.(flag.Getter).Get().(map[string]string) + assert.Equal(t, map[string]string{"env": "prod"}, got) + }) + t.Run("multiple pairs", func(t *testing.T) { + t.Parallel() + fs := flag.NewFlagSet("test", flag.ContinueOnError) + fs.Var(StringMap(), "label", "") + err := fs.Parse([]string{"--label=env=prod", "--label=tier=web"}) + require.NoError(t, err) + got := fs.Lookup("label").Value.(flag.Getter).Get().(map[string]string) + assert.Equal(t, map[string]string{"env": "prod", "tier": "web"}, got) + }) + t.Run("value contains equals", func(t *testing.T) { + t.Parallel() + fs := flag.NewFlagSet("test", flag.ContinueOnError) + fs.Var(StringMap(), "label", "") + err := fs.Parse([]string{"--label=query=a=b"}) + require.NoError(t, err) + got := fs.Lookup("label").Value.(flag.Getter).Get().(map[string]string) + assert.Equal(t, map[string]string{"query": "a=b"}, got) + }) + t.Run("missing equals", func(t *testing.T) { + t.Parallel() + fs := flag.NewFlagSet("test", flag.ContinueOnError) + fs.SetOutput(nopWriter{}) + fs.Var(StringMap(), "label", "") + err := fs.Parse([]string{"--label=nope"}) + require.Error(t, err) + assert.Contains(t, err.Error(), "missing '='") + }) + t.Run("empty key", func(t *testing.T) { + t.Parallel() + fs := flag.NewFlagSet("test", flag.ContinueOnError) + fs.SetOutput(nopWriter{}) + fs.Var(StringMap(), "label", "") + err := fs.Parse([]string{"--label==value"}) + require.Error(t, err) + assert.Contains(t, err.Error(), "empty key") + }) + t.Run("string output sorted", func(t *testing.T) { + t.Parallel() + v := StringMap() + require.NoError(t, v.Set("b=2")) + require.NoError(t, v.Set("a=1")) + assert.Equal(t, "a=1,b=2", v.String()) + }) + t.Run("empty", func(t *testing.T) { + t.Parallel() + v := StringMap() + assert.Equal(t, "", v.String()) + assert.Nil(t, v.(flag.Getter).Get()) + }) +} + +func TestURL(t *testing.T) { + t.Parallel() + + t.Run("valid url", func(t *testing.T) { + t.Parallel() + fs := flag.NewFlagSet("test", flag.ContinueOnError) + fs.Var(URL(), "endpoint", "") + err := fs.Parse([]string{"--endpoint=https://example.com/api"}) + require.NoError(t, err) + got := fs.Lookup("endpoint").Value.(flag.Getter).Get().(*url.URL) + require.NotNil(t, got) + assert.Equal(t, "https", got.Scheme) + assert.Equal(t, "example.com", got.Host) + assert.Equal(t, "/api", got.Path) + }) + t.Run("missing scheme", func(t *testing.T) { + t.Parallel() + fs := flag.NewFlagSet("test", flag.ContinueOnError) + fs.SetOutput(nopWriter{}) + fs.Var(URL(), "endpoint", "") + err := fs.Parse([]string{"--endpoint=example.com"}) + require.Error(t, err) + assert.Contains(t, err.Error(), "must have a scheme and host") + }) + t.Run("empty", func(t *testing.T) { + t.Parallel() + v := URL() + assert.Equal(t, "", v.String()) + assert.Nil(t, v.(flag.Getter).Get()) + }) +} + +func TestRegexp(t *testing.T) { + t.Parallel() + + t.Run("valid pattern", func(t *testing.T) { + t.Parallel() + fs := flag.NewFlagSet("test", flag.ContinueOnError) + fs.Var(Regexp(), "pattern", "") + err := fs.Parse([]string{"--pattern=^foo.*bar$"}) + require.NoError(t, err) + got := fs.Lookup("pattern").Value.(flag.Getter).Get().(*regexp.Regexp) + require.NotNil(t, got) + assert.True(t, got.MatchString("fooXbar")) + assert.False(t, got.MatchString("baz")) + }) + t.Run("invalid pattern", func(t *testing.T) { + t.Parallel() + fs := flag.NewFlagSet("test", flag.ContinueOnError) + fs.SetOutput(nopWriter{}) + fs.Var(Regexp(), "pattern", "") + err := fs.Parse([]string{"--pattern=[invalid"}) + require.Error(t, err) + }) + t.Run("empty", func(t *testing.T) { + t.Parallel() + v := Regexp() + assert.Equal(t, "", v.String()) + assert.Nil(t, v.(flag.Getter).Get()) + }) +} + +// nopWriter discards all writes, used to suppress flag.FlagSet error output in tests. +type nopWriter struct{} + +func (nopWriter) Write(p []byte) (int, error) { return len(p), nil } diff --git a/flagtype/regexp.go b/flagtype/regexp.go new file mode 100644 index 0000000..6df9931 --- /dev/null +++ b/flagtype/regexp.go @@ -0,0 +1,38 @@ +package flagtype + +import ( + "flag" + "regexp" +) + +type regexpValue struct { + re *regexp.Regexp +} + +// Regexp returns a [flag.Value] that compiles the flag value as a regular expression. If the +// pattern is invalid, an error is returned. +// +// Use [cli.GetFlag] with type *regexp.Regexp to retrieve the value. +func Regexp() flag.Value { + return ®expValue{} +} + +func (v *regexpValue) String() string { + if v.re == nil { + return "" + } + return v.re.String() +} + +func (v *regexpValue) Set(s string) error { + re, err := regexp.Compile(s) + if err != nil { + return err + } + v.re = re + return nil +} + +func (v *regexpValue) Get() any { + return v.re +} diff --git a/flagtype/string_map.go b/flagtype/string_map.go new file mode 100644 index 0000000..8d28e3f --- /dev/null +++ b/flagtype/string_map.go @@ -0,0 +1,57 @@ +package flagtype + +import ( + "flag" + "fmt" + "sort" + "strings" +) + +type stringMapValue struct { + m map[string]string +} + +// StringMap returns a [flag.Value] that parses key=value pairs into a map. The flag can be repeated +// to add multiple entries, like --label=env=prod --label=tier=web. The value is split on the first +// "=" character, so values may contain additional "=" characters. +// +// Use [cli.GetFlag] with type map[string]string to retrieve the value. +func StringMap() flag.Value { + return &stringMapValue{} +} + +func (v *stringMapValue) String() string { + if v.m == nil { + return "" + } + // Sort keys for deterministic output. + keys := make([]string, 0, len(v.m)) + for k := range v.m { + keys = append(keys, k) + } + sort.Strings(keys) + pairs := make([]string, 0, len(keys)) + for _, k := range keys { + pairs = append(pairs, k+"="+v.m[k]) + } + return strings.Join(pairs, ",") +} + +func (v *stringMapValue) Set(s string) error { + key, value, ok := strings.Cut(s, "=") + if !ok { + return fmt.Errorf("invalid key=value pair: %q (missing '=')", s) + } + if key == "" { + return fmt.Errorf("invalid key=value pair: %q (empty key)", s) + } + if v.m == nil { + v.m = make(map[string]string) + } + v.m[key] = value + return nil +} + +func (v *stringMapValue) Get() any { + return v.m +} diff --git a/flagtype/string_slice.go b/flagtype/string_slice.go new file mode 100644 index 0000000..047a925 --- /dev/null +++ b/flagtype/string_slice.go @@ -0,0 +1,31 @@ +package flagtype + +import ( + "flag" + "strings" +) + +type stringSliceValue struct { + vals []string +} + +// StringSlice returns a [flag.Value] that collects values into a string slice. Each time the flag +// is set, the value is appended. This allows repeatable flags like --tag=foo --tag=bar. +// +// Use [cli.GetFlag] with type []string to retrieve the value. +func StringSlice() flag.Value { + return &stringSliceValue{} +} + +func (v *stringSliceValue) String() string { + return strings.Join(v.vals, ",") +} + +func (v *stringSliceValue) Set(s string) error { + v.vals = append(v.vals, s) + return nil +} + +func (v *stringSliceValue) Get() any { + return v.vals +} diff --git a/flagtype/url.go b/flagtype/url.go new file mode 100644 index 0000000..27c7443 --- /dev/null +++ b/flagtype/url.go @@ -0,0 +1,42 @@ +package flagtype + +import ( + "flag" + "fmt" + "net/url" +) + +type urlValue struct { + u *url.URL +} + +// URL returns a [flag.Value] that parses the flag value as a URL. The URL must have both a scheme +// and a host, otherwise an error is returned. +// +// Use [cli.GetFlag] with type *url.URL to retrieve the value. +func URL() flag.Value { + return &urlValue{} +} + +func (v *urlValue) String() string { + if v.u == nil { + return "" + } + return v.u.String() +} + +func (v *urlValue) Set(s string) error { + u, err := url.Parse(s) + if err != nil { + return fmt.Errorf("invalid URL %q: %w", s, err) + } + if u.Scheme == "" || u.Host == "" { + return fmt.Errorf("invalid URL %q: must have a scheme and host", s) + } + v.u = u + return nil +} + +func (v *urlValue) Get() any { + return v.u +}