diff --git a/pkg/runtime/runtime.go b/pkg/runtime/runtime.go index ce6457bfe..a9bf2fe2a 100644 --- a/pkg/runtime/runtime.go +++ b/pkg/runtime/runtime.go @@ -1276,9 +1276,15 @@ func (r *LocalRuntime) handleStream(ctx context.Context, stream chat.MessageStre return streamResult{Stopped: true}, fmt.Errorf("error receiving from stream: %w", err) } - if response.Usage != nil { + // Some providers emit token usage multiple times during a stream, + // others only once, and some emit partial / zeroed usage snapshots. + // To be provider-agnostic and avoid usage being overwritten to zero, + // we capture the FIRST non-nil usage and treat it as immutable. + if response.Usage != nil && messageUsage == nil { + // Capture usage once per stream messageUsage = response.Usage + // Accumulate cost for the session using model pricing if m != nil && m.Cost != nil { cost := float64(response.Usage.InputTokens)*m.Cost.Input + float64(response.Usage.OutputTokens)*m.Cost.Output + @@ -1287,14 +1293,25 @@ func (r *LocalRuntime) handleStream(ctx context.Context, stream chat.MessageStre sess.Cost += cost / 1e6 } - sess.InputTokens = response.Usage.InputTokens + response.Usage.CachedInputTokens + response.Usage.CacheWriteTokens + // Persist token usage at the session level + // These values are used by the TUI to compute token usage % + sess.InputTokens = response.Usage.InputTokens + + response.Usage.CachedInputTokens + + response.Usage.CacheWriteTokens sess.OutputTokens = response.Usage.OutputTokens + // Emit telemetry once per stream to avoid duplicate usage records modelName := "unknown" if m != nil { modelName = m.Name } - telemetry.RecordTokenUsage(ctx, modelName, sess.InputTokens, sess.OutputTokens, sess.Cost) + telemetry.RecordTokenUsage( + ctx, + modelName, + sess.InputTokens, + sess.OutputTokens, + sess.Cost, + ) } if response.RateLimit != nil { diff --git a/pkg/runtime/runtime_test.go b/pkg/runtime/runtime_test.go index 765c2476e..af0487fea 100644 --- a/pkg/runtime/runtime_test.go +++ b/pkg/runtime/runtime_test.go @@ -1364,3 +1364,42 @@ func TestToolRejectionWithoutReason(t *testing.T) { require.Equal(t, "The user rejected the tool call.", toolResponse.Response) require.NotContains(t, toolResponse.Response, "Reason:") } + +func TestStream_CapturesUsageOnlyOnce(t *testing.T) { + stream := newStreamBuilder(). + AddContent("Hello"). + AddStopWithUsage(10, 5). // first usage + AddStopWithUsage(0, 0). // provider emits usage again + Build() + + sess := session.New(session.WithUserMessage("Hi")) + + events := runSession(t, sess, stream) + + var usageEvents []Event + for _, ev := range events { + if reflect.DeepEqual( + ev, + TokenUsageWithMessage( + sess.ID, + "root", + 10, + 5, + 15, + 0, + 0, + &MessageUsage{ + Usage: chat.Usage{ + InputTokens: 10, + OutputTokens: 5, + }, + Model: "test/mock-model", + }, + ), + ) { + usageEvents = append(usageEvents, ev) + } + } + + require.Len(t, usageEvents, 1, "expected token usage to be emitted only once") +}