diff --git a/pkg/config/latest/types.go b/pkg/config/latest/types.go index 0e046a158..49eb35222 100644 --- a/pkg/config/latest/types.go +++ b/pkg/config/latest/types.go @@ -259,6 +259,9 @@ func (a *AgentConfig) GetFallbackCooldown() time.Duration { // ModelConfig represents the configuration for a model type ModelConfig struct { + // Name is the manifest model name (map key), populated at runtime. + // Not serialized — set by teamloader/model_switcher when resolving models. + Name string `json:"-"` Provider string `json:"provider,omitempty"` Model string `json:"model,omitempty"` Temperature *float64 `json:"temperature,omitempty"` diff --git a/pkg/httpclient/client.go b/pkg/httpclient/client.go index 345ab5e01..4e3720df1 100644 --- a/pkg/httpclient/client.go +++ b/pkg/httpclient/client.go @@ -76,6 +76,14 @@ func WithModel(model string) Opt { } } +func WithModelName(name string) Opt { + return func(o *HTTPOptions) { + if name != "" { + o.Header.Set("X-Cagent-Model-Name", name) + } + } +} + func WithQuery(query url.Values) Opt { return func(o *HTTPOptions) { o.Query = query diff --git a/pkg/httpclient/client_test.go b/pkg/httpclient/client_test.go new file mode 100644 index 000000000..949d52c51 --- /dev/null +++ b/pkg/httpclient/client_test.go @@ -0,0 +1,97 @@ +package httpclient + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestWithModelName(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + modelName string + wantSet bool + }{ + { + name: "sets header when name is provided", + modelName: "my-fast-model", + wantSet: true, + }, + { + name: "skips header when name is empty", + modelName: "", + wantSet: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + var capturedHeaders http.Header + srv := httptest.NewServer(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { + capturedHeaders = r.Header + })) + defer srv.Close() + + client := NewHTTPClient(WithModelName(tt.modelName)) + req, err := http.NewRequest(http.MethodGet, srv.URL, http.NoBody) + require.NoError(t, err) + + resp, err := client.Do(req) + require.NoError(t, err) + defer func() { _ = resp.Body.Close() }() + + if tt.wantSet { + assert.Equal(t, tt.modelName, capturedHeaders.Get("X-Cagent-Model-Name")) + } else { + assert.Empty(t, capturedHeaders.Get("X-Cagent-Model-Name")) + } + }) + } +} + +func TestWithModel(t *testing.T) { + t.Parallel() + + var capturedHeaders http.Header + srv := httptest.NewServer(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { + capturedHeaders = r.Header + })) + defer srv.Close() + + client := NewHTTPClient(WithModel("gpt-4o")) + req, err := http.NewRequest(http.MethodGet, srv.URL, http.NoBody) + require.NoError(t, err) + + resp, err := client.Do(req) + require.NoError(t, err) + defer func() { _ = resp.Body.Close() }() + + assert.Equal(t, "gpt-4o", capturedHeaders.Get("X-Cagent-Model")) +} + +func TestWithProvider(t *testing.T) { + t.Parallel() + + var capturedHeaders http.Header + srv := httptest.NewServer(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { + capturedHeaders = r.Header + })) + defer srv.Close() + + client := NewHTTPClient(WithProvider("openai")) + req, err := http.NewRequest(http.MethodGet, srv.URL, http.NoBody) + require.NoError(t, err) + + resp, err := client.Do(req) + require.NoError(t, err) + defer func() { _ = resp.Body.Close() }() + + assert.Equal(t, "openai", capturedHeaders.Get("X-Cagent-Provider")) +} diff --git a/pkg/model/provider/anthropic/client.go b/pkg/model/provider/anthropic/client.go index fcd10c205..ca0c4d737 100644 --- a/pkg/model/provider/anthropic/client.go +++ b/pkg/model/provider/anthropic/client.go @@ -184,6 +184,7 @@ func NewClient(ctx context.Context, cfg *latest.ModelConfig, env environment.Pro httpclient.WithProxiedBaseURL(cmp.Or(cfg.BaseURL, "https://api.anthropic.com/")), httpclient.WithProvider(cfg.Provider), httpclient.WithModel(cfg.Model), + httpclient.WithModelName(cfg.Name), httpclient.WithQuery(url.Query()), } if globalOptions.GeneratingTitle() { diff --git a/pkg/model/provider/gemini/client.go b/pkg/model/provider/gemini/client.go index 02e6b0d55..7b4ac3501 100644 --- a/pkg/model/provider/gemini/client.go +++ b/pkg/model/provider/gemini/client.go @@ -130,6 +130,7 @@ func NewClient(ctx context.Context, cfg *latest.ModelConfig, env environment.Pro httpclient.WithProxiedBaseURL(cmp.Or(cfg.BaseURL, "https://generativelanguage.googleapis.com/")), httpclient.WithProvider(cfg.Provider), httpclient.WithModel(cfg.Model), + httpclient.WithModelName(cfg.Name), httpclient.WithQuery(url.Query()), } if globalOptions.GeneratingTitle() { diff --git a/pkg/model/provider/openai/client.go b/pkg/model/provider/openai/client.go index ad9ba77d9..f13e7a2a9 100644 --- a/pkg/model/provider/openai/client.go +++ b/pkg/model/provider/openai/client.go @@ -118,6 +118,7 @@ func NewClient(ctx context.Context, cfg *latest.ModelConfig, env environment.Pro httpclient.WithProxiedBaseURL(cmp.Or(cfg.BaseURL, "https://api.openai.com/v1")), httpclient.WithProvider(cfg.Provider), httpclient.WithModel(cfg.Model), + httpclient.WithModelName(cfg.Name), httpclient.WithQuery(url.Query()), } if globalOptions.GeneratingTitle() { diff --git a/pkg/runtime/model_switcher.go b/pkg/runtime/model_switcher.go index 9b49c5a0d..6d95573e9 100644 --- a/pkg/runtime/model_switcher.go +++ b/pkg/runtime/model_switcher.go @@ -84,6 +84,7 @@ func (r *LocalRuntime) SetAgentModel(ctx context.Context, agentName, modelRef st // Check if modelRef is a named model from config if modelConfig, exists := r.modelSwitcherCfg.Models[modelRef]; exists { + modelConfig.Name = modelRef // Check if this is an alloy model (no provider, comma-separated models) if isAlloyModelConfig(modelConfig) { providers, err := r.createProvidersFromAlloyConfig(ctx, modelConfig) @@ -175,6 +176,7 @@ func (r *LocalRuntime) createProvidersFromInlineAlloy(ctx context.Context, model // Check if this part exists as a named model in config if modelCfg, exists := r.modelSwitcherCfg.Models[part]; exists { + modelCfg.Name = part prov, err := r.createProviderFromConfig(ctx, &modelCfg) if err != nil { return nil, fmt.Errorf("failed to create provider for %q: %w", part, err) @@ -219,6 +221,7 @@ func (r *LocalRuntime) createProvidersFromAlloyConfig(ctx context.Context, alloy // Check if this model reference exists in the config if modelCfg, exists := r.modelSwitcherCfg.Models[modelRef]; exists { + modelCfg.Name = modelRef prov, err := r.createProviderFromConfig(ctx, &modelCfg) if err != nil { return nil, fmt.Errorf("failed to create provider for %q: %w", modelRef, err) diff --git a/pkg/teamloader/teamloader.go b/pkg/teamloader/teamloader.go index 8efb2846c..8cd0e3c9b 100644 --- a/pkg/teamloader/teamloader.go +++ b/pkg/teamloader/teamloader.go @@ -301,6 +301,7 @@ func getModelsForAgent(ctx context.Context, cfg *latest.Config, a *latest.AgentC return nil, false, fmt.Errorf("model '%s' not found in configuration", name) } } + modelCfg.Name = name // Check if thinking_budget was explicitly configured BEFORE provider defaults are applied. // This is used to initialize session thinking state - thinking is only enabled by default @@ -371,6 +372,7 @@ func getFallbackModelsForAgent(ctx context.Context, cfg *latest.Config, a *lates Model: modelName, } } + modelCfg.Name = name // Use max_tokens from config if specified, otherwise look up from models.dev maxTokens := &defaultMaxTokens