fix(claude): enhance ensureModelMaxTokens to use registered max_completion_tokens and fallback to default
This commit is contained in:
@@ -45,33 +45,14 @@ type ClaudeExecutor struct {
|
|||||||
// Previously "proxy_" was used but this is a detectable fingerprint difference.
|
// Previously "proxy_" was used but this is a detectable fingerprint difference.
|
||||||
const claudeToolPrefix = ""
|
const claudeToolPrefix = ""
|
||||||
|
|
||||||
// Anthropic-compatible upstreams may reject or even crash when dynamically
|
// Anthropic-compatible upstreams may reject or even crash when Claude models
|
||||||
// registered Claude models omit max_tokens. Use a conservative default.
|
// omit max_tokens. Prefer registered model metadata before using a fallback.
|
||||||
const defaultModelMaxTokens = 1024
|
const defaultModelMaxTokens = 1024
|
||||||
|
|
||||||
func NewClaudeExecutor(cfg *config.Config) *ClaudeExecutor { return &ClaudeExecutor{cfg: cfg} }
|
func NewClaudeExecutor(cfg *config.Config) *ClaudeExecutor { return &ClaudeExecutor{cfg: cfg} }
|
||||||
|
|
||||||
func (e *ClaudeExecutor) Identifier() string { return "claude" }
|
func (e *ClaudeExecutor) Identifier() string { return "claude" }
|
||||||
|
|
||||||
func ensureModelMaxTokens(body []byte, modelID string) []byte {
|
|
||||||
if len(body) == 0 || !gjson.ValidBytes(body) {
|
|
||||||
return body
|
|
||||||
}
|
|
||||||
|
|
||||||
if maxTokens := gjson.GetBytes(body, "max_tokens"); maxTokens.Exists() {
|
|
||||||
return body
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, provider := range registry.GetGlobalRegistry().GetModelProviders(strings.TrimSpace(modelID)) {
|
|
||||||
if strings.EqualFold(provider, "claude") {
|
|
||||||
body, _ = sjson.SetBytes(body, "max_tokens", defaultModelMaxTokens)
|
|
||||||
return body
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return body
|
|
||||||
}
|
|
||||||
|
|
||||||
// PrepareRequest injects Claude credentials into the outgoing HTTP request.
|
// PrepareRequest injects Claude credentials into the outgoing HTTP request.
|
||||||
func (e *ClaudeExecutor) PrepareRequest(req *http.Request, auth *cliproxyauth.Auth) error {
|
func (e *ClaudeExecutor) PrepareRequest(req *http.Request, auth *cliproxyauth.Auth) error {
|
||||||
if req == nil {
|
if req == nil {
|
||||||
@@ -1906,3 +1887,26 @@ func injectSystemCacheControl(payload []byte) []byte {
|
|||||||
|
|
||||||
return payload
|
return payload
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func ensureModelMaxTokens(body []byte, modelID string) []byte {
|
||||||
|
if len(body) == 0 || !gjson.ValidBytes(body) {
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
|
||||||
|
if maxTokens := gjson.GetBytes(body, "max_tokens"); maxTokens.Exists() {
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, provider := range registry.GetGlobalRegistry().GetModelProviders(strings.TrimSpace(modelID)) {
|
||||||
|
if strings.EqualFold(provider, "claude") {
|
||||||
|
maxTokens := defaultModelMaxTokens
|
||||||
|
if info := registry.GetGlobalRegistry().GetModelInfo(strings.TrimSpace(modelID), "claude"); info != nil && info.MaxCompletionTokens > 0 {
|
||||||
|
maxTokens = info.MaxCompletionTokens
|
||||||
|
}
|
||||||
|
body, _ = sjson.SetBytes(body, "max_tokens", maxTokens)
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ import (
|
|||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/klauspost/compress/zstd"
|
"github.com/klauspost/compress/zstd"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||||
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||||
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||||
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
|
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
|
||||||
@@ -1183,6 +1184,83 @@ func testClaudeExecutorInvalidCompressedErrorBody(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestEnsureModelMaxTokens_UsesRegisteredMaxCompletionTokens(t *testing.T) {
|
||||||
|
reg := registry.GetGlobalRegistry()
|
||||||
|
clientID := "test-claude-max-completion-tokens-client"
|
||||||
|
modelID := "test-claude-max-completion-tokens-model"
|
||||||
|
reg.RegisterClient(clientID, "claude", []*registry.ModelInfo{{
|
||||||
|
ID: modelID,
|
||||||
|
Type: "claude",
|
||||||
|
OwnedBy: "anthropic",
|
||||||
|
Object: "model",
|
||||||
|
Created: time.Now().Unix(),
|
||||||
|
MaxCompletionTokens: 4096,
|
||||||
|
UserDefined: true,
|
||||||
|
}})
|
||||||
|
defer reg.UnregisterClient(clientID)
|
||||||
|
|
||||||
|
input := []byte(`{"model":"test-claude-max-completion-tokens-model","messages":[{"role":"user","content":"hi"}]}`)
|
||||||
|
out := ensureModelMaxTokens(input, modelID)
|
||||||
|
|
||||||
|
if got := gjson.GetBytes(out, "max_tokens").Int(); got != 4096 {
|
||||||
|
t.Fatalf("max_tokens = %d, want %d", got, 4096)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEnsureModelMaxTokens_DefaultsMissingValue(t *testing.T) {
|
||||||
|
reg := registry.GetGlobalRegistry()
|
||||||
|
clientID := "test-claude-default-max-tokens-client"
|
||||||
|
modelID := "test-claude-default-max-tokens-model"
|
||||||
|
reg.RegisterClient(clientID, "claude", []*registry.ModelInfo{{
|
||||||
|
ID: modelID,
|
||||||
|
Type: "claude",
|
||||||
|
OwnedBy: "anthropic",
|
||||||
|
Object: "model",
|
||||||
|
Created: time.Now().Unix(),
|
||||||
|
UserDefined: true,
|
||||||
|
}})
|
||||||
|
defer reg.UnregisterClient(clientID)
|
||||||
|
|
||||||
|
input := []byte(`{"model":"test-claude-default-max-tokens-model","messages":[{"role":"user","content":"hi"}]}`)
|
||||||
|
out := ensureModelMaxTokens(input, modelID)
|
||||||
|
|
||||||
|
if got := gjson.GetBytes(out, "max_tokens").Int(); got != defaultModelMaxTokens {
|
||||||
|
t.Fatalf("max_tokens = %d, want %d", got, defaultModelMaxTokens)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEnsureModelMaxTokens_PreservesExplicitValue(t *testing.T) {
|
||||||
|
reg := registry.GetGlobalRegistry()
|
||||||
|
clientID := "test-claude-preserve-max-tokens-client"
|
||||||
|
modelID := "test-claude-preserve-max-tokens-model"
|
||||||
|
reg.RegisterClient(clientID, "claude", []*registry.ModelInfo{{
|
||||||
|
ID: modelID,
|
||||||
|
Type: "claude",
|
||||||
|
OwnedBy: "anthropic",
|
||||||
|
Object: "model",
|
||||||
|
Created: time.Now().Unix(),
|
||||||
|
MaxCompletionTokens: 4096,
|
||||||
|
UserDefined: true,
|
||||||
|
}})
|
||||||
|
defer reg.UnregisterClient(clientID)
|
||||||
|
|
||||||
|
input := []byte(`{"model":"test-claude-preserve-max-tokens-model","max_tokens":2048,"messages":[{"role":"user","content":"hi"}]}`)
|
||||||
|
out := ensureModelMaxTokens(input, modelID)
|
||||||
|
|
||||||
|
if got := gjson.GetBytes(out, "max_tokens").Int(); got != 2048 {
|
||||||
|
t.Fatalf("max_tokens = %d, want %d", got, 2048)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEnsureModelMaxTokens_SkipsUnregisteredModel(t *testing.T) {
|
||||||
|
input := []byte(`{"model":"test-claude-unregistered-model","messages":[{"role":"user","content":"hi"}]}`)
|
||||||
|
out := ensureModelMaxTokens(input, "test-claude-unregistered-model")
|
||||||
|
|
||||||
|
if gjson.GetBytes(out, "max_tokens").Exists() {
|
||||||
|
t.Fatalf("max_tokens should remain unset, got %s", gjson.GetBytes(out, "max_tokens").Raw)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// TestClaudeExecutor_ExecuteStream_SetsIdentityAcceptEncoding verifies that streaming
|
// TestClaudeExecutor_ExecuteStream_SetsIdentityAcceptEncoding verifies that streaming
|
||||||
// requests use Accept-Encoding: identity so the upstream cannot respond with a
|
// requests use Accept-Encoding: identity so the upstream cannot respond with a
|
||||||
// compressed SSE body that would silently break the line scanner.
|
// compressed SSE body that would silently break the line scanner.
|
||||||
|
|||||||
Reference in New Issue
Block a user