Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions pkg/config/latest/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,9 @@ func (a *AgentConfig) GetFallbackCooldown() time.Duration {

// ModelConfig represents the configuration for a model
type ModelConfig struct {
// Name is the manifest model name (map key), populated at runtime.
// Not serialized — set by teamloader/model_switcher when resolving models.
Name string `json:"-"`
Provider string `json:"provider,omitempty"`
Model string `json:"model,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
Expand Down
8 changes: 8 additions & 0 deletions pkg/httpclient/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,14 @@ func WithModel(model string) Opt {
}
}

func WithModelName(name string) Opt {
return func(o *HTTPOptions) {
if name != "" {
o.Header.Set("X-Cagent-Model-Name", name)
}
}
}

func WithQuery(query url.Values) Opt {
return func(o *HTTPOptions) {
o.Query = query
Expand Down
97 changes: 97 additions & 0 deletions pkg/httpclient/client_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
package httpclient

import (
"net/http"
"net/http/httptest"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestWithModelName(t *testing.T) {
t.Parallel()

tests := []struct {
name string
modelName string
wantSet bool
}{
{
name: "sets header when name is provided",
modelName: "my-fast-model",
wantSet: true,
},
{
name: "skips header when name is empty",
modelName: "",
wantSet: false,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()

var capturedHeaders http.Header
srv := httptest.NewServer(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
capturedHeaders = r.Header
}))
defer srv.Close()

client := NewHTTPClient(WithModelName(tt.modelName))
req, err := http.NewRequest(http.MethodGet, srv.URL, http.NoBody)
require.NoError(t, err)

resp, err := client.Do(req)
require.NoError(t, err)
defer func() { _ = resp.Body.Close() }()

if tt.wantSet {
assert.Equal(t, tt.modelName, capturedHeaders.Get("X-Cagent-Model-Name"))
} else {
assert.Empty(t, capturedHeaders.Get("X-Cagent-Model-Name"))
}
})
}
}

func TestWithModel(t *testing.T) {
t.Parallel()

var capturedHeaders http.Header
srv := httptest.NewServer(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
capturedHeaders = r.Header
}))
defer srv.Close()

client := NewHTTPClient(WithModel("gpt-4o"))
req, err := http.NewRequest(http.MethodGet, srv.URL, http.NoBody)
require.NoError(t, err)

resp, err := client.Do(req)
require.NoError(t, err)
defer func() { _ = resp.Body.Close() }()

assert.Equal(t, "gpt-4o", capturedHeaders.Get("X-Cagent-Model"))
}

func TestWithProvider(t *testing.T) {
t.Parallel()

var capturedHeaders http.Header
srv := httptest.NewServer(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
capturedHeaders = r.Header
}))
defer srv.Close()

client := NewHTTPClient(WithProvider("openai"))
req, err := http.NewRequest(http.MethodGet, srv.URL, http.NoBody)
require.NoError(t, err)

resp, err := client.Do(req)
require.NoError(t, err)
defer func() { _ = resp.Body.Close() }()

assert.Equal(t, "openai", capturedHeaders.Get("X-Cagent-Provider"))
}
1 change: 1 addition & 0 deletions pkg/model/provider/anthropic/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ func NewClient(ctx context.Context, cfg *latest.ModelConfig, env environment.Pro
httpclient.WithProxiedBaseURL(cmp.Or(cfg.BaseURL, "https://api.anthropic.com/")),
httpclient.WithProvider(cfg.Provider),
httpclient.WithModel(cfg.Model),
httpclient.WithModelName(cfg.Name),
httpclient.WithQuery(url.Query()),
}
if globalOptions.GeneratingTitle() {
Expand Down
1 change: 1 addition & 0 deletions pkg/model/provider/gemini/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ func NewClient(ctx context.Context, cfg *latest.ModelConfig, env environment.Pro
httpclient.WithProxiedBaseURL(cmp.Or(cfg.BaseURL, "https://generativelanguage.googleapis.com/")),
httpclient.WithProvider(cfg.Provider),
httpclient.WithModel(cfg.Model),
httpclient.WithModelName(cfg.Name),
httpclient.WithQuery(url.Query()),
}
if globalOptions.GeneratingTitle() {
Expand Down
1 change: 1 addition & 0 deletions pkg/model/provider/openai/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ func NewClient(ctx context.Context, cfg *latest.ModelConfig, env environment.Pro
httpclient.WithProxiedBaseURL(cmp.Or(cfg.BaseURL, "https://api.openai.com/v1")),
httpclient.WithProvider(cfg.Provider),
httpclient.WithModel(cfg.Model),
httpclient.WithModelName(cfg.Name),
httpclient.WithQuery(url.Query()),
}
if globalOptions.GeneratingTitle() {
Expand Down
3 changes: 3 additions & 0 deletions pkg/runtime/model_switcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ func (r *LocalRuntime) SetAgentModel(ctx context.Context, agentName, modelRef st

// Check if modelRef is a named model from config
if modelConfig, exists := r.modelSwitcherCfg.Models[modelRef]; exists {
modelConfig.Name = modelRef
// Check if this is an alloy model (no provider, comma-separated models)
if isAlloyModelConfig(modelConfig) {
providers, err := r.createProvidersFromAlloyConfig(ctx, modelConfig)
Expand Down Expand Up @@ -175,6 +176,7 @@ func (r *LocalRuntime) createProvidersFromInlineAlloy(ctx context.Context, model

// Check if this part exists as a named model in config
if modelCfg, exists := r.modelSwitcherCfg.Models[part]; exists {
modelCfg.Name = part
prov, err := r.createProviderFromConfig(ctx, &modelCfg)
if err != nil {
return nil, fmt.Errorf("failed to create provider for %q: %w", part, err)
Expand Down Expand Up @@ -219,6 +221,7 @@ func (r *LocalRuntime) createProvidersFromAlloyConfig(ctx context.Context, alloy

// Check if this model reference exists in the config
if modelCfg, exists := r.modelSwitcherCfg.Models[modelRef]; exists {
modelCfg.Name = modelRef
prov, err := r.createProviderFromConfig(ctx, &modelCfg)
if err != nil {
return nil, fmt.Errorf("failed to create provider for %q: %w", modelRef, err)
Expand Down
2 changes: 2 additions & 0 deletions pkg/teamloader/teamloader.go
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,7 @@ func getModelsForAgent(ctx context.Context, cfg *latest.Config, a *latest.AgentC
return nil, false, fmt.Errorf("model '%s' not found in configuration", name)
}
}
modelCfg.Name = name

// Check if thinking_budget was explicitly configured BEFORE provider defaults are applied.
// This is used to initialize session thinking state - thinking is only enabled by default
Expand Down Expand Up @@ -371,6 +372,7 @@ func getFallbackModelsForAgent(ctx context.Context, cfg *latest.Config, a *lates
Model: modelName,
}
}
modelCfg.Name = name

// Use max_tokens from config if specified, otherwise look up from models.dev
maxTokens := &defaultMaxTokens
Expand Down