feat(translator): add token usage tracking and improve usage handling
- Introduced `claudeUsageTokens` struct for detailed token usage tracking. - Replaced `calculateClaudeUsageTokens` with `Merge` and `OpenAIUsage` methods for better modularity. - Enhanced integration of usage tokens into response processing, enabling more accurate reporting of token details. Fixed: #2419
This commit is contained in:
@@ -25,10 +25,19 @@ type ConvertAnthropicResponseToOpenAIParams struct {
|
|||||||
CreatedAt int64
|
CreatedAt int64
|
||||||
ResponseID string
|
ResponseID string
|
||||||
FinishReason string
|
FinishReason string
|
||||||
|
Usage claudeUsageTokens
|
||||||
// Tool calls accumulator for streaming
|
// Tool calls accumulator for streaming
|
||||||
ToolCallsAccumulator map[int]*ToolCallAccumulator
|
ToolCallsAccumulator map[int]*ToolCallAccumulator
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type claudeUsageTokens struct {
|
||||||
|
InputTokens int64
|
||||||
|
OutputTokens int64
|
||||||
|
CacheCreationInputTokens int64
|
||||||
|
CacheReadInputTokens int64
|
||||||
|
HasUsage bool
|
||||||
|
}
|
||||||
|
|
||||||
// ToolCallAccumulator holds the state for accumulating tool call data
|
// ToolCallAccumulator holds the state for accumulating tool call data
|
||||||
type ToolCallAccumulator struct {
|
type ToolCallAccumulator struct {
|
||||||
ID string
|
ID string
|
||||||
@@ -36,15 +45,30 @@ type ToolCallAccumulator struct {
|
|||||||
Arguments strings.Builder
|
Arguments strings.Builder
|
||||||
}
|
}
|
||||||
|
|
||||||
func calculateClaudeUsageTokens(usage gjson.Result) (promptTokens, completionTokens, totalTokens, cachedTokens int64) {
|
func (u *claudeUsageTokens) Merge(usage gjson.Result) {
|
||||||
inputTokens := usage.Get("input_tokens").Int()
|
if !usage.Exists() {
|
||||||
completionTokens = usage.Get("output_tokens").Int()
|
return
|
||||||
cachedTokens = usage.Get("cache_read_input_tokens").Int()
|
}
|
||||||
cacheCreationInputTokens := usage.Get("cache_creation_input_tokens").Int()
|
u.HasUsage = true
|
||||||
|
if inputTokens := usage.Get("input_tokens"); inputTokens.Exists() {
|
||||||
|
u.InputTokens = inputTokens.Int()
|
||||||
|
}
|
||||||
|
if outputTokens := usage.Get("output_tokens"); outputTokens.Exists() {
|
||||||
|
u.OutputTokens = outputTokens.Int()
|
||||||
|
}
|
||||||
|
if cacheCreationInputTokens := usage.Get("cache_creation_input_tokens"); cacheCreationInputTokens.Exists() {
|
||||||
|
u.CacheCreationInputTokens = cacheCreationInputTokens.Int()
|
||||||
|
}
|
||||||
|
if cacheReadInputTokens := usage.Get("cache_read_input_tokens"); cacheReadInputTokens.Exists() {
|
||||||
|
u.CacheReadInputTokens = cacheReadInputTokens.Int()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
promptTokens = inputTokens + cacheCreationInputTokens + cachedTokens
|
func (u claudeUsageTokens) OpenAIUsage() (promptTokens, completionTokens, totalTokens, cachedTokens int64) {
|
||||||
|
cachedTokens = u.CacheReadInputTokens
|
||||||
|
promptTokens = u.InputTokens + u.CacheCreationInputTokens + cachedTokens
|
||||||
|
completionTokens = u.OutputTokens
|
||||||
totalTokens = promptTokens + completionTokens
|
totalTokens = promptTokens + completionTokens
|
||||||
|
|
||||||
return promptTokens, completionTokens, totalTokens, cachedTokens
|
return promptTokens, completionTokens, totalTokens, cachedTokens
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -112,6 +136,7 @@ func ConvertClaudeResponseToOpenAI(_ context.Context, modelName string, original
|
|||||||
if (*param).(*ConvertAnthropicResponseToOpenAIParams).ToolCallsAccumulator == nil {
|
if (*param).(*ConvertAnthropicResponseToOpenAIParams).ToolCallsAccumulator == nil {
|
||||||
(*param).(*ConvertAnthropicResponseToOpenAIParams).ToolCallsAccumulator = make(map[int]*ToolCallAccumulator)
|
(*param).(*ConvertAnthropicResponseToOpenAIParams).ToolCallsAccumulator = make(map[int]*ToolCallAccumulator)
|
||||||
}
|
}
|
||||||
|
(*param).(*ConvertAnthropicResponseToOpenAIParams).Usage.Merge(message.Get("usage"))
|
||||||
}
|
}
|
||||||
return [][]byte{template}
|
return [][]byte{template}
|
||||||
|
|
||||||
@@ -215,7 +240,8 @@ func ConvertClaudeResponseToOpenAI(_ context.Context, modelName string, original
|
|||||||
|
|
||||||
// Handle usage information for token counts
|
// Handle usage information for token counts
|
||||||
if usage := root.Get("usage"); usage.Exists() {
|
if usage := root.Get("usage"); usage.Exists() {
|
||||||
promptTokens, completionTokens, totalTokens, cachedTokens := calculateClaudeUsageTokens(usage)
|
(*param).(*ConvertAnthropicResponseToOpenAIParams).Usage.Merge(usage)
|
||||||
|
promptTokens, completionTokens, totalTokens, cachedTokens := (*param).(*ConvertAnthropicResponseToOpenAIParams).Usage.OpenAIUsage()
|
||||||
template, _ = sjson.SetBytes(template, "usage.prompt_tokens", promptTokens)
|
template, _ = sjson.SetBytes(template, "usage.prompt_tokens", promptTokens)
|
||||||
template, _ = sjson.SetBytes(template, "usage.completion_tokens", completionTokens)
|
template, _ = sjson.SetBytes(template, "usage.completion_tokens", completionTokens)
|
||||||
template, _ = sjson.SetBytes(template, "usage.total_tokens", totalTokens)
|
template, _ = sjson.SetBytes(template, "usage.total_tokens", totalTokens)
|
||||||
@@ -296,6 +322,7 @@ func ConvertClaudeResponseToOpenAINonStream(_ context.Context, _ string, origina
|
|||||||
var stopReason string
|
var stopReason string
|
||||||
var contentParts []string
|
var contentParts []string
|
||||||
var reasoningParts []string
|
var reasoningParts []string
|
||||||
|
usageTokens := claudeUsageTokens{}
|
||||||
toolCallsAccumulator := make(map[int]*ToolCallAccumulator)
|
toolCallsAccumulator := make(map[int]*ToolCallAccumulator)
|
||||||
|
|
||||||
for _, chunk := range chunks {
|
for _, chunk := range chunks {
|
||||||
@@ -309,6 +336,7 @@ func ConvertClaudeResponseToOpenAINonStream(_ context.Context, _ string, origina
|
|||||||
messageID = message.Get("id").String()
|
messageID = message.Get("id").String()
|
||||||
model = message.Get("model").String()
|
model = message.Get("model").String()
|
||||||
createdAt = time.Now().Unix()
|
createdAt = time.Now().Unix()
|
||||||
|
usageTokens.Merge(message.Get("usage"))
|
||||||
}
|
}
|
||||||
|
|
||||||
case "content_block_start":
|
case "content_block_start":
|
||||||
@@ -371,15 +399,19 @@ func ConvertClaudeResponseToOpenAINonStream(_ context.Context, _ string, origina
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if usage := root.Get("usage"); usage.Exists() {
|
if usage := root.Get("usage"); usage.Exists() {
|
||||||
promptTokens, completionTokens, totalTokens, cachedTokens := calculateClaudeUsageTokens(usage)
|
usageTokens.Merge(usage)
|
||||||
out, _ = sjson.SetBytes(out, "usage.prompt_tokens", promptTokens)
|
|
||||||
out, _ = sjson.SetBytes(out, "usage.completion_tokens", completionTokens)
|
|
||||||
out, _ = sjson.SetBytes(out, "usage.total_tokens", totalTokens)
|
|
||||||
out, _ = sjson.SetBytes(out, "usage.prompt_tokens_details.cached_tokens", cachedTokens)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if usageTokens.HasUsage {
|
||||||
|
promptTokens, completionTokens, totalTokens, cachedTokens := usageTokens.OpenAIUsage()
|
||||||
|
out, _ = sjson.SetBytes(out, "usage.prompt_tokens", promptTokens)
|
||||||
|
out, _ = sjson.SetBytes(out, "usage.completion_tokens", completionTokens)
|
||||||
|
out, _ = sjson.SetBytes(out, "usage.total_tokens", totalTokens)
|
||||||
|
out, _ = sjson.SetBytes(out, "usage.prompt_tokens_details.cached_tokens", cachedTokens)
|
||||||
|
}
|
||||||
|
|
||||||
// Set basic response fields including message ID, creation time, and model
|
// Set basic response fields including message ID, creation time, and model
|
||||||
out, _ = sjson.SetBytes(out, "id", messageID)
|
out, _ = sjson.SetBytes(out, "id", messageID)
|
||||||
out, _ = sjson.SetBytes(out, "created", createdAt)
|
out, _ = sjson.SetBytes(out, "created", createdAt)
|
||||||
|
|||||||
@@ -37,6 +37,44 @@ func TestConvertClaudeResponseToOpenAI_StreamUsageIncludesCachedTokens(t *testin
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestConvertClaudeResponseToOpenAI_StreamUsageMergesMessageStartUsage(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
var param any
|
||||||
|
|
||||||
|
ConvertClaudeResponseToOpenAI(
|
||||||
|
ctx,
|
||||||
|
"claude-opus-4-6",
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
[]byte(`data: {"type":"message_start","message":{"id":"msg_123","model":"claude-opus-4-6","usage":{"input_tokens":13,"output_tokens":1,"cache_read_input_tokens":22000,"cache_creation_input_tokens":31}}}`),
|
||||||
|
¶m,
|
||||||
|
)
|
||||||
|
out := ConvertClaudeResponseToOpenAI(
|
||||||
|
ctx,
|
||||||
|
"claude-opus-4-6",
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
[]byte(`data: {"type":"message_delta","delta":{"stop_reason":"end_turn"},"usage":{"output_tokens":4}}`),
|
||||||
|
¶m,
|
||||||
|
)
|
||||||
|
if len(out) != 1 {
|
||||||
|
t.Fatalf("expected 1 chunk, got %d", len(out))
|
||||||
|
}
|
||||||
|
|
||||||
|
if gotPromptTokens := gjson.GetBytes(out[0], "usage.prompt_tokens").Int(); gotPromptTokens != 22044 {
|
||||||
|
t.Fatalf("expected prompt_tokens %d, got %d", 22044, gotPromptTokens)
|
||||||
|
}
|
||||||
|
if gotCompletionTokens := gjson.GetBytes(out[0], "usage.completion_tokens").Int(); gotCompletionTokens != 4 {
|
||||||
|
t.Fatalf("expected completion_tokens %d, got %d", 4, gotCompletionTokens)
|
||||||
|
}
|
||||||
|
if gotTotalTokens := gjson.GetBytes(out[0], "usage.total_tokens").Int(); gotTotalTokens != 22048 {
|
||||||
|
t.Fatalf("expected total_tokens %d, got %d", 22048, gotTotalTokens)
|
||||||
|
}
|
||||||
|
if gotCachedTokens := gjson.GetBytes(out[0], "usage.prompt_tokens_details.cached_tokens").Int(); gotCachedTokens != 22000 {
|
||||||
|
t.Fatalf("expected cached_tokens %d, got %d", 22000, gotCachedTokens)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestConvertClaudeResponseToOpenAINonStream_UsageIncludesCachedTokens(t *testing.T) {
|
func TestConvertClaudeResponseToOpenAINonStream_UsageIncludesCachedTokens(t *testing.T) {
|
||||||
rawJSON := []byte("data: {\"type\":\"message_start\",\"message\":{\"id\":\"msg_123\",\"model\":\"claude-opus-4-6\"}}\n" +
|
rawJSON := []byte("data: {\"type\":\"message_start\",\"message\":{\"id\":\"msg_123\",\"model\":\"claude-opus-4-6\"}}\n" +
|
||||||
"data: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"end_turn\"},\"usage\":{\"input_tokens\":13,\"output_tokens\":4,\"cache_read_input_tokens\":22000,\"cache_creation_input_tokens\":31}}\n")
|
"data: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"end_turn\"},\"usage\":{\"input_tokens\":13,\"output_tokens\":4,\"cache_read_input_tokens\":22000,\"cache_creation_input_tokens\":31}}\n")
|
||||||
@@ -56,3 +94,23 @@ func TestConvertClaudeResponseToOpenAINonStream_UsageIncludesCachedTokens(t *tes
|
|||||||
t.Fatalf("expected cached_tokens %d, got %d", 22000, gotCachedTokens)
|
t.Fatalf("expected cached_tokens %d, got %d", 22000, gotCachedTokens)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestConvertClaudeResponseToOpenAINonStream_UsageMergesMessageStartUsage(t *testing.T) {
|
||||||
|
rawJSON := []byte("data: {\"type\":\"message_start\",\"message\":{\"id\":\"msg_123\",\"model\":\"claude-opus-4-6\",\"usage\":{\"input_tokens\":13,\"output_tokens\":1,\"cache_read_input_tokens\":22000,\"cache_creation_input_tokens\":31}}}\n" +
|
||||||
|
"data: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"end_turn\"},\"usage\":{\"output_tokens\":4}}\n")
|
||||||
|
|
||||||
|
out := ConvertClaudeResponseToOpenAINonStream(context.Background(), "", nil, nil, rawJSON, nil)
|
||||||
|
|
||||||
|
if gotPromptTokens := gjson.GetBytes(out, "usage.prompt_tokens").Int(); gotPromptTokens != 22044 {
|
||||||
|
t.Fatalf("expected prompt_tokens %d, got %d", 22044, gotPromptTokens)
|
||||||
|
}
|
||||||
|
if gotCompletionTokens := gjson.GetBytes(out, "usage.completion_tokens").Int(); gotCompletionTokens != 4 {
|
||||||
|
t.Fatalf("expected completion_tokens %d, got %d", 4, gotCompletionTokens)
|
||||||
|
}
|
||||||
|
if gotTotalTokens := gjson.GetBytes(out, "usage.total_tokens").Int(); gotTotalTokens != 22048 {
|
||||||
|
t.Fatalf("expected total_tokens %d, got %d", 22048, gotTotalTokens)
|
||||||
|
}
|
||||||
|
if gotCachedTokens := gjson.GetBytes(out, "usage.prompt_tokens_details.cached_tokens").Int(); gotCachedTokens != 22000 {
|
||||||
|
t.Fatalf("expected cached_tokens %d, got %d", 22000, gotCachedTokens)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user