Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 28 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ That's it. ShiftAPI reflects your Go types into an OpenAPI 3.1 spec at `/openapi

### Generic type-safe handlers

Generic free functions capture your request and response types at compile time. Every method uses a single function — struct tags discriminate query params (`query:"..."`) from body fields (`json:"..."`). For routes without input, use `_ struct{}`.
Generic free functions capture your request and response types at compile time. Every method uses a single function — struct tags discriminate query params (`query:"..."`), HTTP headers (`header:"..."`), and body fields (`json:"..."`). For routes without input, use `_ struct{}`.

```go
// POST with body — input is decoded and passed as *CreateUser
Expand Down Expand Up @@ -127,6 +127,25 @@ shiftapi.Post(api, "/items", func(r *http.Request, in CreateInput) (*Result, err
})
```

### Typed HTTP headers

Define a struct with `header` tags. Headers are parsed, validated, and documented in the OpenAPI spec automatically — just like query params.

```go
type AuthInput struct {
Token string `header:"Authorization" validate:"required"`
Q string `query:"q"`
}

shiftapi.Get(api, "/search", func(r *http.Request, in AuthInput) (*Results, error) {
// in.Token parsed from the Authorization header
// in.Q parsed from ?q= query param
return doSearch(in.Token, in.Q), nil
})
```

Supports `string`, `bool`, `int*`, `uint*`, `float*` scalars and `*T` pointers for optional headers. Parse errors return `400`; validation failures return `422`. Header, query, and body fields can be freely combined in one struct.

### Validation

Built-in validation via [go-playground/validator](https://github.com/go-playground/validator). Struct tags are enforced at runtime *and* reflected into the OpenAPI schema.
Expand Down Expand Up @@ -234,6 +253,14 @@ const { data: results } = await client.GET("/search", {
params: { query: { q: "hello", page: 1, limit: 10 } },
});
// query params are fully typed too — { q: string, page?: number, limit?: number }

const { data: authResults } = await client.GET("/search", {
params: {
query: { q: "hello" },
header: { Authorization: "Bearer token" },
},
});
// header params are fully typed as well
```

In dev mode the plugin also starts the Go server, proxies API requests through Vite, watches `.go` files, and hot-reloads the frontend when types change.
Expand Down
19 changes: 17 additions & 2 deletions handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,12 @@ import (
// HandlerFunc is a typed handler for routes.
// The In struct's fields are discriminated by struct tags:
// fields with `query:"..."` tags are parsed from query parameters,
// and fields with `json:"..."` tags (or no query tag) are parsed from the request body.
// fields with `header:"..."` tags are parsed from HTTP headers,
// and remaining fields (with `json:"..."` tags or untagged) are parsed from the request body.
// For routes without input, use struct{} as the In type.
type HandlerFunc[In, Resp any] func(r *http.Request, in In) (Resp, error)

func adapt[In, Resp any](fn HandlerFunc[In, Resp], status int, validate func(any) error, hasQuery, hasBody bool) http.HandlerFunc {
func adapt[In, Resp any](fn HandlerFunc[In, Resp], status int, validate func(any) error, hasQuery, hasHeader, hasBody bool) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
var in In
rv := reflect.ValueOf(&in).Elem()
Expand All @@ -36,6 +37,12 @@ func adapt[In, Resp any](fn HandlerFunc[In, Resp], status int, validate func(any
resetQueryFields(rv)
}

// Reset any header-tagged fields that body decode may have
// inadvertently set, so they only come from HTTP headers.
if hasBody && hasHeader {
resetHeaderFields(rv)
}

// Parse query params if there are query fields
if hasQuery {
if err := parseQueryInto(rv, r.URL.Query()); err != nil {
Expand All @@ -44,6 +51,14 @@ func adapt[In, Resp any](fn HandlerFunc[In, Resp], status int, validate func(any
}
}

// Parse headers if there are header fields
if hasHeader {
if err := parseHeadersInto(rv, r.Header); err != nil {
writeError(w, Error(http.StatusBadRequest, err.Error()))
return
}
}

if err := validate(in); err != nil {
writeError(w, err)
return
Expand Down
10 changes: 7 additions & 3 deletions handlerFuncs.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,16 @@ func registerRoute[In, Resp any](
rawInType = rawInType.Elem()
}

hasQuery, hasBody := partitionFields(rawInType)
hasQuery, hasHeader, hasBody := partitionFields(rawInType)

var queryType reflect.Type
if hasQuery {
queryType = rawInType
}
var headerType reflect.Type
if hasHeader {
headerType = rawInType
}
// POST/PUT/PATCH conventionally carry a request body, so always attempt
// body decode for these methods — even when the input is struct{}.
// This means Post(api, path, func(r, _ struct{}) ...) requires at least "{}".
Expand All @@ -45,12 +49,12 @@ func registerRoute[In, Resp any](
var resp Resp
outType := reflect.TypeOf(resp)

if err := api.updateSchema(method, path, queryType, bodyType, outType, cfg.info, cfg.status); err != nil {
if err := api.updateSchema(method, path, queryType, headerType, bodyType, outType, cfg.info, cfg.status); err != nil {
panic(fmt.Sprintf("shiftapi: schema generation failed for %s %s: %v", method, path, err))
}

pattern := fmt.Sprintf("%s %s", method, path)
api.mux.HandleFunc(pattern, adapt(fn, cfg.status, api.validateBody, hasQuery, decodeBody))
api.mux.HandleFunc(pattern, adapt(fn, cfg.status, api.validateBody, hasQuery, hasHeader, decodeBody))
}

// Get registers a GET handler.
Expand Down
104 changes: 104 additions & 0 deletions header.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
package shiftapi

import (
"fmt"
"net/http"
"reflect"
"strings"
)

// hasHeaderTag returns true if the struct field has a `header` tag.
func hasHeaderTag(f reflect.StructField) bool {
return f.Tag.Get("header") != ""
}

// headerFieldName returns the header name for a struct field.
func headerFieldName(f reflect.StructField) string {
name, _, _ := strings.Cut(f.Tag.Get("header"), ",")
if name == "" {
return f.Name
}
return name
}

// resetHeaderFields zeros out any header-tagged fields on a struct value.
// This is called after body decode so that header-tagged fields are only
// populated by parseHeadersInto, not by JSON keys that happen to match.
func resetHeaderFields(rv reflect.Value) {
for rv.Kind() == reflect.Pointer {
rv = rv.Elem()
}
if rv.Kind() != reflect.Struct {
return
}
rt := rv.Type()
for i := range rt.NumField() {
f := rt.Field(i)
if f.IsExported() && hasHeaderTag(f) {
rv.Field(i).SetZero()
}
}
}

// parseHeadersInto populates header-tagged fields on an existing struct value
// from HTTP headers. Non-header fields are left untouched.
// Only scalar types and pointer-to-scalar types are supported (no slices).
func parseHeadersInto(rv reflect.Value, header http.Header) error {
for rv.Kind() == reflect.Pointer {
if rv.IsNil() {
rv.Set(reflect.New(rv.Type().Elem()))
}
rv = rv.Elem()
}

rt := rv.Type()
if rt.Kind() != reflect.Struct {
return fmt.Errorf("header type must be a struct, got %s", rt.Kind())
}

for i := range rt.NumField() {
field := rt.Field(i)
if !field.IsExported() || !hasHeaderTag(field) {
continue
}

name := headerFieldName(field)
fv := rv.Field(i)
ft := field.Type

// Handle pointer fields (optional headers)
if ft.Kind() == reflect.Pointer {
raw := header.Get(name)
if raw == "" {
continue
}
ptr := reflect.New(ft.Elem())
if err := setScalarValue(ptr.Elem(), raw); err != nil {
return &headerParseError{Field: name, Err: err}
}
fv.Set(ptr)
continue
}

// Handle scalar fields
raw := header.Get(name)
if raw == "" {
continue
}
if err := setScalarValue(fv, raw); err != nil {
return &headerParseError{Field: name, Err: err}
}
}

return nil
}

// headerParseError is returned when a header value cannot be parsed.
type headerParseError struct {
Field string
Err error
}

func (e *headerParseError) Error() string {
return fmt.Sprintf("invalid header %q: %v", e.Field, e.Err)
}
13 changes: 8 additions & 5 deletions query.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,14 @@ func hasQueryTag(f reflect.StructField) bool {
}

// partitionFields inspects a struct type and reports whether it contains
// query-tagged fields and/or body (json-tagged or untagged non-query) fields.
func partitionFields(t reflect.Type) (hasQuery, hasBody bool) {
// query-tagged fields, header-tagged fields, and/or body (json-tagged or
// untagged non-query/non-header) fields.
func partitionFields(t reflect.Type) (hasQuery, hasHeader, hasBody bool) {
for t.Kind() == reflect.Pointer {
t = t.Elem()
}
if t.Kind() != reflect.Struct {
return false, false
return false, false, false
}
for i := range t.NumField() {
f := t.Field(i)
Expand All @@ -29,8 +30,10 @@ func partitionFields(t reflect.Type) (hasQuery, hasBody bool) {
}
if hasQueryTag(f) {
hasQuery = true
} else if hasHeaderTag(f) {
hasHeader = true
} else {
// Any exported field without a query tag is a body field
// Any exported field without a query or header tag is a body field
jsonTag := f.Tag.Get("json")
if jsonTag == "-" {
continue
Expand Down Expand Up @@ -171,7 +174,7 @@ func setScalarValue(v reflect.Value, raw string) error {
}
v.SetFloat(n)
default:
return fmt.Errorf("unsupported query parameter type %s", v.Kind())
return fmt.Errorf("unsupported parameter type %s", v.Kind())
}
return nil
}
Expand Down
83 changes: 81 additions & 2 deletions schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import (

var pathParamRe = regexp.MustCompile(`\{([^}]+)\}`)

func (a *API) updateSchema(method, path string, queryType, inType, outType reflect.Type, info *RouteInfo, status int) error {
func (a *API) updateSchema(method, path string, queryType, headerType, inType, outType reflect.Type, info *RouteInfo, status int) error {
op := &openapi3.Operation{
OperationID: operationID(method, path),
Responses: openapi3.NewResponses(),
Expand Down Expand Up @@ -43,6 +43,15 @@ func (a *API) updateSchema(method, path string, queryType, inType, outType refle
op.Parameters = append(op.Parameters, queryParams...)
}

// Header parameters
if headerType != nil {
headerParams, err := a.generateHeaderParams(headerType)
if err != nil {
return err
}
op.Parameters = append(op.Parameters, headerParams...)
}

// Response schema
statusStr := fmt.Sprintf("%d", status)
outSchema, err := a.generateSchemaRef(outType)
Expand Down Expand Up @@ -121,8 +130,9 @@ func (a *API) updateSchema(method, path string, queryType, inType, outType refle
return err
}
if inSchema != nil {
// Strip query-tagged fields from the body schema
// Strip query-tagged and header-tagged fields from the body schema
stripQueryFields(inType, inSchema.Value)
stripHeaderFields(inType, inSchema.Value)

if len(inSchema.Value.Properties) > 0 {
// Named body schema with properties
Expand Down Expand Up @@ -353,6 +363,75 @@ func stripQueryFields(t reflect.Type, schema *openapi3.Schema) {
}
}

// generateHeaderParams produces OpenAPI parameter definitions for a header struct type.
// Only fields with `header` tags are included. Slices are not supported for headers.
func (a *API) generateHeaderParams(t reflect.Type) ([]*openapi3.ParameterRef, error) {
for t.Kind() == reflect.Pointer {
t = t.Elem()
}
if t.Kind() != reflect.Struct {
return nil, fmt.Errorf("header type must be a struct, got %s", t.Kind())
}

var params []*openapi3.ParameterRef
for i := range t.NumField() {
field := t.Field(i)
if !field.IsExported() {
continue
}
if !hasHeaderTag(field) {
continue
}
name := headerFieldName(field)
schema := scalarToOpenAPISchema(field.Type)

// Apply validation constraints
if err := validateSchemaCustomizer(name, field.Type, field.Tag, schema.Value); err != nil {
return nil, err
}

required := hasRule(field.Tag.Get("validate"), "required")

params = append(params, &openapi3.ParameterRef{
Value: &openapi3.Parameter{
Name: name,
In: "header",
Required: required,
Schema: schema,
},
})
}
return params, nil
}

// stripHeaderFields removes header-tagged fields from a body schema's Properties and Required.
func stripHeaderFields(t reflect.Type, schema *openapi3.Schema) {
for t.Kind() == reflect.Pointer {
t = t.Elem()
}
if t.Kind() != reflect.Struct || schema == nil {
return
}
for i := range t.NumField() {
f := t.Field(i)
if !f.IsExported() || !hasHeaderTag(f) {
continue
}
jname := jsonFieldName(f)
if jname == "" || jname == "-" {
continue
}
delete(schema.Properties, jname)
// Remove from Required slice
for j, req := range schema.Required {
if req == jname {
schema.Required = append(schema.Required[:j], schema.Required[j+1:]...)
break
}
}
}
}

func scrubRefs(s *openapi3.SchemaRef) {
if s == nil || s.Value == nil || len(s.Value.Properties) == 0 {
return
Expand Down
Loading