From 7818451b7d601c085a1cd7159a2923560d41ff8b Mon Sep 17 00:00:00 2001 From: "Frank Chiarulli Jr." Date: Mon, 23 Feb 2026 10:14:36 -0500 Subject: [PATCH] feat: add typed header support --- README.md | 29 ++- handler.go | 19 +- handlerFuncs.go | 10 +- header.go | 104 ++++++++++ query.go | 13 +- schema.go | 83 +++++++- shiftapi_test.go | 494 +++++++++++++++++++++++++++++++++++++++++++++++ 7 files changed, 739 insertions(+), 13 deletions(-) create mode 100644 header.go diff --git a/README.md b/README.md index 0743342..367a08c 100644 --- a/README.md +++ b/README.md @@ -82,7 +82,7 @@ That's it. ShiftAPI reflects your Go types into an OpenAPI 3.1 spec at `/openapi ### Generic type-safe handlers -Generic free functions capture your request and response types at compile time. Every method uses a single function — struct tags discriminate query params (`query:"..."`) from body fields (`json:"..."`). For routes without input, use `_ struct{}`. +Generic free functions capture your request and response types at compile time. Every method uses a single function — struct tags discriminate query params (`query:"..."`), HTTP headers (`header:"..."`), and body fields (`json:"..."`). For routes without input, use `_ struct{}`. ```go // POST with body — input is decoded and passed as *CreateUser @@ -127,6 +127,25 @@ shiftapi.Post(api, "/items", func(r *http.Request, in CreateInput) (*Result, err }) ``` +### Typed HTTP headers + +Define a struct with `header` tags. Headers are parsed, validated, and documented in the OpenAPI spec automatically — just like query params. + +```go +type AuthInput struct { + Token string `header:"Authorization" validate:"required"` + Q string `query:"q"` +} + +shiftapi.Get(api, "/search", func(r *http.Request, in AuthInput) (*Results, error) { + // in.Token parsed from the Authorization header + // in.Q parsed from ?q= query param + return doSearch(in.Token, in.Q), nil +}) +``` + +Supports `string`, `bool`, `int*`, `uint*`, `float*` scalars and `*T` pointers for optional headers. Parse errors return `400`; validation failures return `422`. Header, query, and body fields can be freely combined in one struct. + ### Validation Built-in validation via [go-playground/validator](https://github.com/go-playground/validator). Struct tags are enforced at runtime *and* reflected into the OpenAPI schema. @@ -234,6 +253,14 @@ const { data: results } = await client.GET("/search", { params: { query: { q: "hello", page: 1, limit: 10 } }, }); // query params are fully typed too — { q: string, page?: number, limit?: number } + +const { data: authResults } = await client.GET("/search", { + params: { + query: { q: "hello" }, + header: { Authorization: "Bearer token" }, + }, +}); +// header params are fully typed as well ``` In dev mode the plugin also starts the Go server, proxies API requests through Vite, watches `.go` files, and hot-reloads the frontend when types change. diff --git a/handler.go b/handler.go index a827de3..6d7eb46 100644 --- a/handler.go +++ b/handler.go @@ -11,11 +11,12 @@ import ( // HandlerFunc is a typed handler for routes. // The In struct's fields are discriminated by struct tags: // fields with `query:"..."` tags are parsed from query parameters, -// and fields with `json:"..."` tags (or no query tag) are parsed from the request body. +// fields with `header:"..."` tags are parsed from HTTP headers, +// and remaining fields (with `json:"..."` tags or untagged) are parsed from the request body. // For routes without input, use struct{} as the In type. type HandlerFunc[In, Resp any] func(r *http.Request, in In) (Resp, error) -func adapt[In, Resp any](fn HandlerFunc[In, Resp], status int, validate func(any) error, hasQuery, hasBody bool) http.HandlerFunc { +func adapt[In, Resp any](fn HandlerFunc[In, Resp], status int, validate func(any) error, hasQuery, hasHeader, hasBody bool) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { var in In rv := reflect.ValueOf(&in).Elem() @@ -36,6 +37,12 @@ func adapt[In, Resp any](fn HandlerFunc[In, Resp], status int, validate func(any resetQueryFields(rv) } + // Reset any header-tagged fields that body decode may have + // inadvertently set, so they only come from HTTP headers. + if hasBody && hasHeader { + resetHeaderFields(rv) + } + // Parse query params if there are query fields if hasQuery { if err := parseQueryInto(rv, r.URL.Query()); err != nil { @@ -44,6 +51,14 @@ func adapt[In, Resp any](fn HandlerFunc[In, Resp], status int, validate func(any } } + // Parse headers if there are header fields + if hasHeader { + if err := parseHeadersInto(rv, r.Header); err != nil { + writeError(w, Error(http.StatusBadRequest, err.Error())) + return + } + } + if err := validate(in); err != nil { writeError(w, err) return diff --git a/handlerFuncs.go b/handlerFuncs.go index 5b7810e..c63a232 100644 --- a/handlerFuncs.go +++ b/handlerFuncs.go @@ -23,12 +23,16 @@ func registerRoute[In, Resp any]( rawInType = rawInType.Elem() } - hasQuery, hasBody := partitionFields(rawInType) + hasQuery, hasHeader, hasBody := partitionFields(rawInType) var queryType reflect.Type if hasQuery { queryType = rawInType } + var headerType reflect.Type + if hasHeader { + headerType = rawInType + } // POST/PUT/PATCH conventionally carry a request body, so always attempt // body decode for these methods — even when the input is struct{}. // This means Post(api, path, func(r, _ struct{}) ...) requires at least "{}". @@ -45,12 +49,12 @@ func registerRoute[In, Resp any]( var resp Resp outType := reflect.TypeOf(resp) - if err := api.updateSchema(method, path, queryType, bodyType, outType, cfg.info, cfg.status); err != nil { + if err := api.updateSchema(method, path, queryType, headerType, bodyType, outType, cfg.info, cfg.status); err != nil { panic(fmt.Sprintf("shiftapi: schema generation failed for %s %s: %v", method, path, err)) } pattern := fmt.Sprintf("%s %s", method, path) - api.mux.HandleFunc(pattern, adapt(fn, cfg.status, api.validateBody, hasQuery, decodeBody)) + api.mux.HandleFunc(pattern, adapt(fn, cfg.status, api.validateBody, hasQuery, hasHeader, decodeBody)) } // Get registers a GET handler. diff --git a/header.go b/header.go new file mode 100644 index 0000000..f220885 --- /dev/null +++ b/header.go @@ -0,0 +1,104 @@ +package shiftapi + +import ( + "fmt" + "net/http" + "reflect" + "strings" +) + +// hasHeaderTag returns true if the struct field has a `header` tag. +func hasHeaderTag(f reflect.StructField) bool { + return f.Tag.Get("header") != "" +} + +// headerFieldName returns the header name for a struct field. +func headerFieldName(f reflect.StructField) string { + name, _, _ := strings.Cut(f.Tag.Get("header"), ",") + if name == "" { + return f.Name + } + return name +} + +// resetHeaderFields zeros out any header-tagged fields on a struct value. +// This is called after body decode so that header-tagged fields are only +// populated by parseHeadersInto, not by JSON keys that happen to match. +func resetHeaderFields(rv reflect.Value) { + for rv.Kind() == reflect.Pointer { + rv = rv.Elem() + } + if rv.Kind() != reflect.Struct { + return + } + rt := rv.Type() + for i := range rt.NumField() { + f := rt.Field(i) + if f.IsExported() && hasHeaderTag(f) { + rv.Field(i).SetZero() + } + } +} + +// parseHeadersInto populates header-tagged fields on an existing struct value +// from HTTP headers. Non-header fields are left untouched. +// Only scalar types and pointer-to-scalar types are supported (no slices). +func parseHeadersInto(rv reflect.Value, header http.Header) error { + for rv.Kind() == reflect.Pointer { + if rv.IsNil() { + rv.Set(reflect.New(rv.Type().Elem())) + } + rv = rv.Elem() + } + + rt := rv.Type() + if rt.Kind() != reflect.Struct { + return fmt.Errorf("header type must be a struct, got %s", rt.Kind()) + } + + for i := range rt.NumField() { + field := rt.Field(i) + if !field.IsExported() || !hasHeaderTag(field) { + continue + } + + name := headerFieldName(field) + fv := rv.Field(i) + ft := field.Type + + // Handle pointer fields (optional headers) + if ft.Kind() == reflect.Pointer { + raw := header.Get(name) + if raw == "" { + continue + } + ptr := reflect.New(ft.Elem()) + if err := setScalarValue(ptr.Elem(), raw); err != nil { + return &headerParseError{Field: name, Err: err} + } + fv.Set(ptr) + continue + } + + // Handle scalar fields + raw := header.Get(name) + if raw == "" { + continue + } + if err := setScalarValue(fv, raw); err != nil { + return &headerParseError{Field: name, Err: err} + } + } + + return nil +} + +// headerParseError is returned when a header value cannot be parsed. +type headerParseError struct { + Field string + Err error +} + +func (e *headerParseError) Error() string { + return fmt.Sprintf("invalid header %q: %v", e.Field, e.Err) +} diff --git a/query.go b/query.go index d113f83..130703b 100644 --- a/query.go +++ b/query.go @@ -14,13 +14,14 @@ func hasQueryTag(f reflect.StructField) bool { } // partitionFields inspects a struct type and reports whether it contains -// query-tagged fields and/or body (json-tagged or untagged non-query) fields. -func partitionFields(t reflect.Type) (hasQuery, hasBody bool) { +// query-tagged fields, header-tagged fields, and/or body (json-tagged or +// untagged non-query/non-header) fields. +func partitionFields(t reflect.Type) (hasQuery, hasHeader, hasBody bool) { for t.Kind() == reflect.Pointer { t = t.Elem() } if t.Kind() != reflect.Struct { - return false, false + return false, false, false } for i := range t.NumField() { f := t.Field(i) @@ -29,8 +30,10 @@ func partitionFields(t reflect.Type) (hasQuery, hasBody bool) { } if hasQueryTag(f) { hasQuery = true + } else if hasHeaderTag(f) { + hasHeader = true } else { - // Any exported field without a query tag is a body field + // Any exported field without a query or header tag is a body field jsonTag := f.Tag.Get("json") if jsonTag == "-" { continue @@ -171,7 +174,7 @@ func setScalarValue(v reflect.Value, raw string) error { } v.SetFloat(n) default: - return fmt.Errorf("unsupported query parameter type %s", v.Kind()) + return fmt.Errorf("unsupported parameter type %s", v.Kind()) } return nil } diff --git a/schema.go b/schema.go index 4d6f852..8a45455 100644 --- a/schema.go +++ b/schema.go @@ -12,7 +12,7 @@ import ( var pathParamRe = regexp.MustCompile(`\{([^}]+)\}`) -func (a *API) updateSchema(method, path string, queryType, inType, outType reflect.Type, info *RouteInfo, status int) error { +func (a *API) updateSchema(method, path string, queryType, headerType, inType, outType reflect.Type, info *RouteInfo, status int) error { op := &openapi3.Operation{ OperationID: operationID(method, path), Responses: openapi3.NewResponses(), @@ -43,6 +43,15 @@ func (a *API) updateSchema(method, path string, queryType, inType, outType refle op.Parameters = append(op.Parameters, queryParams...) } + // Header parameters + if headerType != nil { + headerParams, err := a.generateHeaderParams(headerType) + if err != nil { + return err + } + op.Parameters = append(op.Parameters, headerParams...) + } + // Response schema statusStr := fmt.Sprintf("%d", status) outSchema, err := a.generateSchemaRef(outType) @@ -121,8 +130,9 @@ func (a *API) updateSchema(method, path string, queryType, inType, outType refle return err } if inSchema != nil { - // Strip query-tagged fields from the body schema + // Strip query-tagged and header-tagged fields from the body schema stripQueryFields(inType, inSchema.Value) + stripHeaderFields(inType, inSchema.Value) if len(inSchema.Value.Properties) > 0 { // Named body schema with properties @@ -353,6 +363,75 @@ func stripQueryFields(t reflect.Type, schema *openapi3.Schema) { } } +// generateHeaderParams produces OpenAPI parameter definitions for a header struct type. +// Only fields with `header` tags are included. Slices are not supported for headers. +func (a *API) generateHeaderParams(t reflect.Type) ([]*openapi3.ParameterRef, error) { + for t.Kind() == reflect.Pointer { + t = t.Elem() + } + if t.Kind() != reflect.Struct { + return nil, fmt.Errorf("header type must be a struct, got %s", t.Kind()) + } + + var params []*openapi3.ParameterRef + for i := range t.NumField() { + field := t.Field(i) + if !field.IsExported() { + continue + } + if !hasHeaderTag(field) { + continue + } + name := headerFieldName(field) + schema := scalarToOpenAPISchema(field.Type) + + // Apply validation constraints + if err := validateSchemaCustomizer(name, field.Type, field.Tag, schema.Value); err != nil { + return nil, err + } + + required := hasRule(field.Tag.Get("validate"), "required") + + params = append(params, &openapi3.ParameterRef{ + Value: &openapi3.Parameter{ + Name: name, + In: "header", + Required: required, + Schema: schema, + }, + }) + } + return params, nil +} + +// stripHeaderFields removes header-tagged fields from a body schema's Properties and Required. +func stripHeaderFields(t reflect.Type, schema *openapi3.Schema) { + for t.Kind() == reflect.Pointer { + t = t.Elem() + } + if t.Kind() != reflect.Struct || schema == nil { + return + } + for i := range t.NumField() { + f := t.Field(i) + if !f.IsExported() || !hasHeaderTag(f) { + continue + } + jname := jsonFieldName(f) + if jname == "" || jname == "-" { + continue + } + delete(schema.Properties, jname) + // Remove from Required slice + for j, req := range schema.Required { + if req == jname { + schema.Required = append(schema.Required[:j], schema.Required[j+1:]...) + break + } + } + } +} + func scrubRefs(s *openapi3.SchemaRef) { if s == nil || s.Value == nil || len(s.Value.Properties) == 0 { return diff --git a/shiftapi_test.go b/shiftapi_test.go index f73ffa7..5fe6508 100644 --- a/shiftapi_test.go +++ b/shiftapi_test.go @@ -2493,3 +2493,497 @@ func TestSpecQueryOnlyInputHasNoRequestBody(t *testing.T) { t.Error("GET with query-only input should not have a request body in the spec") } } + +// --- Header parameter test types --- + +type AuthHeader struct { + Token string `header:"Authorization" validate:"required"` +} + +type AuthResult struct { + Token string `json:"token"` +} + +type OptionalHeader struct { + Debug *bool `header:"X-Debug"` + Limit *int `header:"X-Limit"` +} + +type OptionalHeaderResult struct { + HasDebug bool `json:"has_debug"` + Debug bool `json:"debug"` + HasLimit bool `json:"has_limit"` + Limit int `json:"limit"` +} + +type HeaderAndBody struct { + Token string `header:"Authorization" validate:"required"` + Name string `json:"name" validate:"required"` +} + +type HeaderAndBodyResult struct { + Token string `json:"token"` + Name string `json:"name"` +} + +type HeaderAndQuery struct { + Token string `header:"Authorization" validate:"required"` + Q string `query:"q"` +} + +type HeaderAndQueryResult struct { + Token string `json:"token"` + Q string `json:"q"` +} + +type ScalarHeaders struct { + Flag bool `header:"X-Flag"` + Count uint `header:"X-Count"` + Score float64 `header:"X-Score"` +} + +type ScalarHeaderResult struct { + Flag bool `json:"flag"` + Count uint `json:"count"` + Score float64 `json:"score"` +} + +// --- Header parameter test helpers --- + +func doRequestWithHeaders(t *testing.T, api http.Handler, method, path, body string, headers map[string]string) *http.Response { + t.Helper() + var bodyReader io.Reader + if body != "" { + bodyReader = strings.NewReader(body) + } + req := httptest.NewRequest(method, path, bodyReader) + if body != "" { + req.Header.Set("Content-Type", "application/json") + } + for k, v := range headers { + req.Header.Set(k, v) + } + rec := httptest.NewRecorder() + api.ServeHTTP(rec, req) + return rec.Result() +} + +// --- Header parameter runtime tests --- + +func TestGetWithHeaderBasic(t *testing.T) { + api := newTestAPI(t) + shiftapi.Get(api, "/auth", func(r *http.Request, in AuthHeader) (*AuthResult, error) { + return &AuthResult{Token: in.Token}, nil + }) + + resp := doRequestWithHeaders(t, api, http.MethodGet, "/auth", "", map[string]string{ + "Authorization": "Bearer abc123", + }) + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected 200, got %d", resp.StatusCode) + } + result := decodeJSON[AuthResult](t, resp) + if result.Token != "Bearer abc123" { + t.Errorf("expected Token=%q, got %q", "Bearer abc123", result.Token) + } +} + +func TestGetWithHeaderMissingRequired(t *testing.T) { + api := newTestAPI(t) + shiftapi.Get(api, "/auth", func(r *http.Request, in AuthHeader) (*AuthResult, error) { + return &AuthResult{Token: in.Token}, nil + }) + + // Missing required "Authorization" header + resp := doRequest(t, api, http.MethodGet, "/auth", "") + if resp.StatusCode != http.StatusUnprocessableEntity { + t.Fatalf("expected 422, got %d", resp.StatusCode) + } +} + +func TestGetWithHeaderInvalidType(t *testing.T) { + api := newTestAPI(t) + type IntHeader struct { + Count int `header:"X-Count" validate:"required"` + } + shiftapi.Get(api, "/count", func(r *http.Request, in IntHeader) (*Status, error) { + return &Status{OK: true}, nil + }) + + resp := doRequestWithHeaders(t, api, http.MethodGet, "/count", "", map[string]string{ + "X-Count": "notanumber", + }) + if resp.StatusCode != http.StatusBadRequest { + t.Fatalf("expected 400, got %d", resp.StatusCode) + } +} + +func TestGetWithHeaderOptionalPointers(t *testing.T) { + api := newTestAPI(t) + shiftapi.Get(api, "/optional", func(r *http.Request, in OptionalHeader) (*OptionalHeaderResult, error) { + result := &OptionalHeaderResult{} + if in.Debug != nil { + result.HasDebug = true + result.Debug = *in.Debug + } + if in.Limit != nil { + result.HasLimit = true + result.Limit = *in.Limit + } + return result, nil + }) + + // With both headers + resp := doRequestWithHeaders(t, api, http.MethodGet, "/optional", "", map[string]string{ + "X-Debug": "true", + "X-Limit": "50", + }) + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected 200, got %d", resp.StatusCode) + } + result := decodeJSON[OptionalHeaderResult](t, resp) + if !result.HasDebug || !result.Debug { + t.Error("expected Debug to be true") + } + if !result.HasLimit || result.Limit != 50 { + t.Errorf("expected Limit=50, got %d", result.Limit) + } + + // Without optional headers + resp2 := doRequest(t, api, http.MethodGet, "/optional", "") + if resp2.StatusCode != http.StatusOK { + t.Fatalf("expected 200, got %d", resp2.StatusCode) + } + result2 := decodeJSON[OptionalHeaderResult](t, resp2) + if result2.HasDebug { + t.Error("expected HasDebug=false when header absent") + } + if result2.HasLimit { + t.Error("expected HasLimit=false when header absent") + } +} + +func TestPostWithHeaderAndBody(t *testing.T) { + api := newTestAPI(t) + shiftapi.Post(api, "/items", func(r *http.Request, in HeaderAndBody) (*HeaderAndBodyResult, error) { + return &HeaderAndBodyResult{Token: in.Token, Name: in.Name}, nil + }) + + resp := doRequestWithHeaders(t, api, http.MethodPost, "/items", `{"name":"widget"}`, map[string]string{ + "Authorization": "Bearer xyz", + }) + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected 200, got %d", resp.StatusCode) + } + result := decodeJSON[HeaderAndBodyResult](t, resp) + if result.Token != "Bearer xyz" { + t.Errorf("expected Token=%q, got %q", "Bearer xyz", result.Token) + } + if result.Name != "widget" { + t.Errorf("expected Name=%q, got %q", "widget", result.Name) + } +} + +func TestHeaderFieldNotSetByBodyDecode(t *testing.T) { + api := newTestAPI(t) + shiftapi.Post(api, "/items", func(r *http.Request, in HeaderAndBody) (*HeaderAndBodyResult, error) { + return &HeaderAndBodyResult{Token: in.Token, Name: in.Name}, nil + }) + + // Body includes "Token" key that matches the header field name — it should be ignored + resp := doRequestWithHeaders(t, api, http.MethodPost, "/items", `{"name":"widget","Token":"body-token"}`, map[string]string{ + "Authorization": "Bearer real", + }) + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected 200, got %d", resp.StatusCode) + } + result := decodeJSON[HeaderAndBodyResult](t, resp) + if result.Token != "Bearer real" { + t.Errorf("expected Token=%q from header, got %q", "Bearer real", result.Token) + } +} + +func TestGetWithHeaderAndQuery(t *testing.T) { + api := newTestAPI(t) + shiftapi.Get(api, "/search", func(r *http.Request, in HeaderAndQuery) (*HeaderAndQueryResult, error) { + return &HeaderAndQueryResult{Token: in.Token, Q: in.Q}, nil + }) + + resp := doRequestWithHeaders(t, api, http.MethodGet, "/search?q=hello", "", map[string]string{ + "Authorization": "Bearer abc", + }) + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected 200, got %d", resp.StatusCode) + } + result := decodeJSON[HeaderAndQueryResult](t, resp) + if result.Token != "Bearer abc" { + t.Errorf("expected Token=%q, got %q", "Bearer abc", result.Token) + } + if result.Q != "hello" { + t.Errorf("expected Q=%q, got %q", "hello", result.Q) + } +} + +func TestGetWithHeaderScalars(t *testing.T) { + api := newTestAPI(t) + shiftapi.Get(api, "/scalars", func(r *http.Request, in ScalarHeaders) (*ScalarHeaderResult, error) { + return &ScalarHeaderResult{Flag: in.Flag, Count: in.Count, Score: in.Score}, nil + }) + + resp := doRequestWithHeaders(t, api, http.MethodGet, "/scalars", "", map[string]string{ + "X-Flag": "true", + "X-Count": "42", + "X-Score": "3.14", + }) + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected 200, got %d", resp.StatusCode) + } + result := decodeJSON[ScalarHeaderResult](t, resp) + if !result.Flag { + t.Error("expected Flag=true") + } + if result.Count != 42 { + t.Errorf("expected Count=42, got %d", result.Count) + } + if result.Score != 3.14 { + t.Errorf("expected Score=3.14, got %f", result.Score) + } +} + +func TestGetWithHeaderInvalidBool(t *testing.T) { + api := newTestAPI(t) + type BoolHeader struct { + Flag bool `header:"X-Flag"` + } + shiftapi.Get(api, "/test", func(r *http.Request, in BoolHeader) (*Status, error) { + return &Status{OK: true}, nil + }) + + resp := doRequestWithHeaders(t, api, http.MethodGet, "/test", "", map[string]string{ + "X-Flag": "notabool", + }) + if resp.StatusCode != http.StatusBadRequest { + t.Fatalf("expected 400, got %d", resp.StatusCode) + } +} + +func TestGetWithHeaderInvalidUint(t *testing.T) { + api := newTestAPI(t) + type UintHeader struct { + Count uint `header:"X-Count"` + } + shiftapi.Get(api, "/test", func(r *http.Request, in UintHeader) (*Status, error) { + return &Status{OK: true}, nil + }) + + resp := doRequestWithHeaders(t, api, http.MethodGet, "/test", "", map[string]string{ + "X-Count": "-1", + }) + if resp.StatusCode != http.StatusBadRequest { + t.Fatalf("expected 400, got %d", resp.StatusCode) + } +} + +func TestGetWithHeaderInvalidFloat(t *testing.T) { + api := newTestAPI(t) + type FloatHeader struct { + Score float64 `header:"X-Score"` + } + shiftapi.Get(api, "/test", func(r *http.Request, in FloatHeader) (*Status, error) { + return &Status{OK: true}, nil + }) + + resp := doRequestWithHeaders(t, api, http.MethodGet, "/test", "", map[string]string{ + "X-Score": "abc", + }) + if resp.StatusCode != http.StatusBadRequest { + t.Fatalf("expected 400, got %d", resp.StatusCode) + } +} + +// --- Header parameter OpenAPI spec tests --- + +func TestSpecHeaderParamsDocumented(t *testing.T) { + api := newTestAPI(t) + shiftapi.Get(api, "/auth", func(r *http.Request, in AuthHeader) (*AuthResult, error) { + return &AuthResult{}, nil + }) + + spec := api.Spec() + op := spec.Paths.Find("/auth").Get + + var found bool + for _, p := range op.Parameters { + if p.Value.Name == "Authorization" && p.Value.In == "header" { + found = true + break + } + } + if !found { + t.Error("expected Authorization header parameter documented in spec") + } +} + +func TestSpecHeaderParamTypes(t *testing.T) { + api := newTestAPI(t) + shiftapi.Get(api, "/scalars", func(r *http.Request, in ScalarHeaders) (*ScalarHeaderResult, error) { + return &ScalarHeaderResult{}, nil + }) + + spec := api.Spec() + op := spec.Paths.Find("/scalars").Get + + expected := map[string]string{ + "X-Flag": "boolean", + "X-Count": "integer", + "X-Score": "number", + } + for _, p := range op.Parameters { + if p.Value.In != "header" { + continue + } + want, ok := expected[p.Value.Name] + if !ok { + t.Errorf("unexpected header param %q", p.Value.Name) + continue + } + got := p.Value.Schema.Value.Type.Slice()[0] + if got != want { + t.Errorf("header %q: expected type %q, got %q", p.Value.Name, want, got) + } + } +} + +func TestSpecHeaderParamRequired(t *testing.T) { + api := newTestAPI(t) + shiftapi.Get(api, "/auth", func(r *http.Request, in AuthHeader) (*AuthResult, error) { + return &AuthResult{}, nil + }) + + spec := api.Spec() + op := spec.Paths.Find("/auth").Get + + for _, p := range op.Parameters { + if p.Value.Name == "Authorization" && p.Value.In == "header" { + if !p.Value.Required { + t.Error("expected Authorization header to be required") + } + return + } + } + t.Error("Authorization header param not found") +} + +func TestSpecHeaderParamOptionalPointerNotRequired(t *testing.T) { + api := newTestAPI(t) + shiftapi.Get(api, "/optional", func(r *http.Request, in OptionalHeader) (*OptionalHeaderResult, error) { + return &OptionalHeaderResult{}, nil + }) + + spec := api.Spec() + op := spec.Paths.Find("/optional").Get + + for _, p := range op.Parameters { + if p.Value.In == "header" && p.Value.Required { + t.Errorf("optional header %q should not be required", p.Value.Name) + } + } +} + +func TestSpecHeaderParamValidationConstraints(t *testing.T) { + api := newTestAPI(t) + type ConstrainedHeader struct { + Count int `header:"X-Count" validate:"min=1,max=100"` + } + shiftapi.Get(api, "/constrained", func(r *http.Request, in ConstrainedHeader) (*Status, error) { + return &Status{OK: true}, nil + }) + + spec := api.Spec() + op := spec.Paths.Find("/constrained").Get + + for _, p := range op.Parameters { + if p.Value.Name == "X-Count" && p.Value.In == "header" { + s := p.Value.Schema.Value + if s.Min == nil || *s.Min != 1 { + t.Error("expected Min=1 on X-Count header param") + } + if s.Max == nil || *s.Max != 100 { + t.Error("expected Max=100 on X-Count header param") + } + return + } + } + t.Error("X-Count header param not found") +} + +func TestSpecBodySchemaExcludesHeaderFields(t *testing.T) { + api := newTestAPI(t) + shiftapi.Post(api, "/items", func(r *http.Request, in HeaderAndBody) (*HeaderAndBodyResult, error) { + return &HeaderAndBodyResult{}, nil + }) + + spec := api.Spec() + // Find the body schema in components + for name, schemaRef := range spec.Components.Schemas { + if name == "HeaderAndBody" { + if _, has := schemaRef.Value.Properties["Token"]; has { + t.Error("body schema should not contain header field 'Token'") + } + if _, has := schemaRef.Value.Properties["name"]; !has { + t.Error("body schema should contain body field 'name'") + } + return + } + } + t.Error("HeaderAndBody schema not found in components") +} + +func TestSpecHeaderParamsCombinedWithQueryParams(t *testing.T) { + api := newTestAPI(t) + shiftapi.Get(api, "/search", func(r *http.Request, in HeaderAndQuery) (*HeaderAndQueryResult, error) { + return &HeaderAndQueryResult{}, nil + }) + + spec := api.Spec() + op := spec.Paths.Find("/search").Get + + var headerParams, queryParams int + for _, p := range op.Parameters { + switch p.Value.In { + case "header": + headerParams++ + case "query": + queryParams++ + } + } + if headerParams != 1 { + t.Errorf("expected 1 header param, got %d", headerParams) + } + if queryParams != 1 { + t.Errorf("expected 1 query param, got %d", queryParams) + } +} + +func TestSpecHeaderParamEnum(t *testing.T) { + api := newTestAPI(t) + type EnumHeader struct { + Format string `header:"Accept" validate:"oneof=json xml csv"` + } + shiftapi.Get(api, "/data", func(r *http.Request, in EnumHeader) (*Status, error) { + return &Status{OK: true}, nil + }) + + spec := api.Spec() + op := spec.Paths.Find("/data").Get + + for _, p := range op.Parameters { + if p.Value.Name == "Accept" && p.Value.In == "header" { + if len(p.Value.Schema.Value.Enum) != 3 { + t.Errorf("expected 3 enum values, got %d", len(p.Value.Schema.Value.Enum)) + } + return + } + } + t.Error("Accept header param not found") +}