diff --git a/internal/translator/claude/openai/chat-completions/claude_openai_response.go b/internal/translator/claude/openai/chat-completions/claude_openai_response.go index 18d79a8f..1fd3f2ae 100644 --- a/internal/translator/claude/openai/chat-completions/claude_openai_response.go +++ b/internal/translator/claude/openai/chat-completions/claude_openai_response.go @@ -36,6 +36,18 @@ type ToolCallAccumulator struct { Arguments strings.Builder } +func calculateClaudeUsageTokens(usage gjson.Result) (promptTokens, completionTokens, totalTokens, cachedTokens int64) { + inputTokens := usage.Get("input_tokens").Int() + completionTokens = usage.Get("output_tokens").Int() + cachedTokens = usage.Get("cache_read_input_tokens").Int() + cacheCreationInputTokens := usage.Get("cache_creation_input_tokens").Int() + + promptTokens = inputTokens + cacheCreationInputTokens + cachedTokens + totalTokens = promptTokens + completionTokens + + return promptTokens, completionTokens, totalTokens, cachedTokens +} + // ConvertClaudeResponseToOpenAI converts Claude Code streaming response format to OpenAI Chat Completions format. // This function processes various Claude Code event types and transforms them into OpenAI-compatible JSON responses. // It handles text content, tool calls, reasoning content, and usage metadata, outputting responses that match @@ -203,14 +215,11 @@ func ConvertClaudeResponseToOpenAI(_ context.Context, modelName string, original // Handle usage information for token counts if usage := root.Get("usage"); usage.Exists() { - inputTokens := usage.Get("input_tokens").Int() - outputTokens := usage.Get("output_tokens").Int() - cacheReadInputTokens := usage.Get("cache_read_input_tokens").Int() - cacheCreationInputTokens := usage.Get("cache_creation_input_tokens").Int() - template, _ = sjson.SetBytes(template, "usage.prompt_tokens", inputTokens+cacheCreationInputTokens) - template, _ = sjson.SetBytes(template, "usage.completion_tokens", outputTokens) - template, _ = sjson.SetBytes(template, "usage.total_tokens", inputTokens+outputTokens) - template, _ = sjson.SetBytes(template, "usage.prompt_tokens_details.cached_tokens", cacheReadInputTokens) + promptTokens, completionTokens, totalTokens, cachedTokens := calculateClaudeUsageTokens(usage) + template, _ = sjson.SetBytes(template, "usage.prompt_tokens", promptTokens) + template, _ = sjson.SetBytes(template, "usage.completion_tokens", completionTokens) + template, _ = sjson.SetBytes(template, "usage.total_tokens", totalTokens) + template, _ = sjson.SetBytes(template, "usage.prompt_tokens_details.cached_tokens", cachedTokens) } return [][]byte{template} @@ -362,14 +371,11 @@ func ConvertClaudeResponseToOpenAINonStream(_ context.Context, _ string, origina } } if usage := root.Get("usage"); usage.Exists() { - inputTokens := usage.Get("input_tokens").Int() - outputTokens := usage.Get("output_tokens").Int() - cacheReadInputTokens := usage.Get("cache_read_input_tokens").Int() - cacheCreationInputTokens := usage.Get("cache_creation_input_tokens").Int() - out, _ = sjson.SetBytes(out, "usage.prompt_tokens", inputTokens+cacheCreationInputTokens) - out, _ = sjson.SetBytes(out, "usage.completion_tokens", outputTokens) - out, _ = sjson.SetBytes(out, "usage.total_tokens", inputTokens+outputTokens) - out, _ = sjson.SetBytes(out, "usage.prompt_tokens_details.cached_tokens", cacheReadInputTokens) + promptTokens, completionTokens, totalTokens, cachedTokens := calculateClaudeUsageTokens(usage) + out, _ = sjson.SetBytes(out, "usage.prompt_tokens", promptTokens) + out, _ = sjson.SetBytes(out, "usage.completion_tokens", completionTokens) + out, _ = sjson.SetBytes(out, "usage.total_tokens", totalTokens) + out, _ = sjson.SetBytes(out, "usage.prompt_tokens_details.cached_tokens", cachedTokens) } } } diff --git a/internal/translator/claude/openai/chat-completions/claude_openai_response_test.go b/internal/translator/claude/openai/chat-completions/claude_openai_response_test.go new file mode 100644 index 00000000..7bd6eb1f --- /dev/null +++ b/internal/translator/claude/openai/chat-completions/claude_openai_response_test.go @@ -0,0 +1,58 @@ +package chat_completions + +import ( + "context" + "testing" + + "github.com/tidwall/gjson" +) + +func TestConvertClaudeResponseToOpenAI_StreamUsageIncludesCachedTokens(t *testing.T) { + ctx := context.Background() + var param any + + out := ConvertClaudeResponseToOpenAI( + ctx, + "claude-opus-4-6", + nil, + nil, + []byte(`data: {"type":"message_delta","delta":{"stop_reason":"end_turn"},"usage":{"input_tokens":13,"output_tokens":4,"cache_read_input_tokens":22000,"cache_creation_input_tokens":31}}`), + ¶m, + ) + if len(out) != 1 { + t.Fatalf("expected 1 chunk, got %d", len(out)) + } + + if gotPromptTokens := gjson.GetBytes(out[0], "usage.prompt_tokens").Int(); gotPromptTokens != 22044 { + t.Fatalf("expected prompt_tokens %d, got %d", 22044, gotPromptTokens) + } + if gotCompletionTokens := gjson.GetBytes(out[0], "usage.completion_tokens").Int(); gotCompletionTokens != 4 { + t.Fatalf("expected completion_tokens %d, got %d", 4, gotCompletionTokens) + } + if gotTotalTokens := gjson.GetBytes(out[0], "usage.total_tokens").Int(); gotTotalTokens != 22048 { + t.Fatalf("expected total_tokens %d, got %d", 22048, gotTotalTokens) + } + if gotCachedTokens := gjson.GetBytes(out[0], "usage.prompt_tokens_details.cached_tokens").Int(); gotCachedTokens != 22000 { + t.Fatalf("expected cached_tokens %d, got %d", 22000, gotCachedTokens) + } +} + +func TestConvertClaudeResponseToOpenAINonStream_UsageIncludesCachedTokens(t *testing.T) { + rawJSON := []byte("data: {\"type\":\"message_start\",\"message\":{\"id\":\"msg_123\",\"model\":\"claude-opus-4-6\"}}\n" + + "data: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"end_turn\"},\"usage\":{\"input_tokens\":13,\"output_tokens\":4,\"cache_read_input_tokens\":22000,\"cache_creation_input_tokens\":31}}\n") + + out := ConvertClaudeResponseToOpenAINonStream(context.Background(), "", nil, nil, rawJSON, nil) + + if gotPromptTokens := gjson.GetBytes(out, "usage.prompt_tokens").Int(); gotPromptTokens != 22044 { + t.Fatalf("expected prompt_tokens %d, got %d", 22044, gotPromptTokens) + } + if gotCompletionTokens := gjson.GetBytes(out, "usage.completion_tokens").Int(); gotCompletionTokens != 4 { + t.Fatalf("expected completion_tokens %d, got %d", 4, gotCompletionTokens) + } + if gotTotalTokens := gjson.GetBytes(out, "usage.total_tokens").Int(); gotTotalTokens != 22048 { + t.Fatalf("expected total_tokens %d, got %d", 22048, gotTotalTokens) + } + if gotCachedTokens := gjson.GetBytes(out, "usage.prompt_tokens_details.cached_tokens").Int(); gotCachedTokens != 22000 { + t.Fatalf("expected cached_tokens %d, got %d", 22000, gotCachedTokens) + } +}