Merge remote-tracking branch 'upstream/dev' into dev
# Conflicts: # internal/runtime/executor/antigravity_executor.go
This commit is contained in:
@@ -408,6 +408,9 @@ func (h *Handler) buildAuthFileEntry(auth *coreauth.Auth) gin.H {
|
||||
if !auth.LastRefreshedAt.IsZero() {
|
||||
entry["last_refresh"] = auth.LastRefreshedAt
|
||||
}
|
||||
if !auth.NextRetryAfter.IsZero() {
|
||||
entry["next_retry_after"] = auth.NextRetryAfter
|
||||
}
|
||||
if path != "" {
|
||||
entry["path"] = path
|
||||
entry["source"] = "file"
|
||||
@@ -947,11 +950,17 @@ func (h *Handler) saveTokenRecord(ctx context.Context, record *coreauth.Auth) (s
|
||||
if store == nil {
|
||||
return "", fmt.Errorf("token store unavailable")
|
||||
}
|
||||
if h.postAuthHook != nil {
|
||||
if err := h.postAuthHook(ctx, record); err != nil {
|
||||
return "", fmt.Errorf("post-auth hook failed: %w", err)
|
||||
}
|
||||
}
|
||||
return store.Save(ctx, record)
|
||||
}
|
||||
|
||||
func (h *Handler) RequestAnthropicToken(c *gin.Context) {
|
||||
ctx := context.Background()
|
||||
ctx = PopulateAuthContext(ctx, c)
|
||||
|
||||
fmt.Println("Initializing Claude authentication...")
|
||||
|
||||
@@ -1096,6 +1105,7 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) {
|
||||
|
||||
func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
||||
ctx := context.Background()
|
||||
ctx = PopulateAuthContext(ctx, c)
|
||||
proxyHTTPClient := util.SetProxy(&h.cfg.SDKConfig, &http.Client{})
|
||||
ctx = context.WithValue(ctx, oauth2.HTTPClient, proxyHTTPClient)
|
||||
|
||||
@@ -1354,6 +1364,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
||||
|
||||
func (h *Handler) RequestCodexToken(c *gin.Context) {
|
||||
ctx := context.Background()
|
||||
ctx = PopulateAuthContext(ctx, c)
|
||||
|
||||
fmt.Println("Initializing Codex authentication...")
|
||||
|
||||
@@ -1499,6 +1510,7 @@ func (h *Handler) RequestCodexToken(c *gin.Context) {
|
||||
|
||||
func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
||||
ctx := context.Background()
|
||||
ctx = PopulateAuthContext(ctx, c)
|
||||
|
||||
fmt.Println("Initializing Antigravity authentication...")
|
||||
|
||||
@@ -1663,6 +1675,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
||||
|
||||
func (h *Handler) RequestQwenToken(c *gin.Context) {
|
||||
ctx := context.Background()
|
||||
ctx = PopulateAuthContext(ctx, c)
|
||||
|
||||
fmt.Println("Initializing Qwen authentication...")
|
||||
|
||||
@@ -1718,6 +1731,7 @@ func (h *Handler) RequestQwenToken(c *gin.Context) {
|
||||
|
||||
func (h *Handler) RequestKimiToken(c *gin.Context) {
|
||||
ctx := context.Background()
|
||||
ctx = PopulateAuthContext(ctx, c)
|
||||
|
||||
fmt.Println("Initializing Kimi authentication...")
|
||||
|
||||
@@ -1794,6 +1808,7 @@ func (h *Handler) RequestKimiToken(c *gin.Context) {
|
||||
|
||||
func (h *Handler) RequestIFlowToken(c *gin.Context) {
|
||||
ctx := context.Background()
|
||||
ctx = PopulateAuthContext(ctx, c)
|
||||
|
||||
fmt.Println("Initializing iFlow authentication...")
|
||||
|
||||
@@ -2412,3 +2427,12 @@ func (h *Handler) GetAuthStatus(c *gin.Context) {
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"status": "wait"})
|
||||
}
|
||||
|
||||
// PopulateAuthContext extracts request info and adds it to the context
|
||||
func PopulateAuthContext(ctx context.Context, c *gin.Context) context.Context {
|
||||
info := &coreauth.RequestInfo{
|
||||
Query: c.Request.URL.Query(),
|
||||
Headers: c.Request.Header,
|
||||
}
|
||||
return coreauth.WithRequestInfo(ctx, info)
|
||||
}
|
||||
|
||||
@@ -47,6 +47,7 @@ type Handler struct {
|
||||
allowRemoteOverride bool
|
||||
envSecret string
|
||||
logDir string
|
||||
postAuthHook coreauth.PostAuthHook
|
||||
}
|
||||
|
||||
// NewHandler creates a new management handler instance.
|
||||
@@ -128,6 +129,11 @@ func (h *Handler) SetLogDirectory(dir string) {
|
||||
h.logDir = dir
|
||||
}
|
||||
|
||||
// SetPostAuthHook registers a hook to be called after auth record creation but before persistence.
|
||||
func (h *Handler) SetPostAuthHook(hook coreauth.PostAuthHook) {
|
||||
h.postAuthHook = hook
|
||||
}
|
||||
|
||||
// Middleware enforces access control for management endpoints.
|
||||
// All requests (local and remote) require a valid management key.
|
||||
// Additionally, remote access requires allow-remote-management=true.
|
||||
|
||||
@@ -51,6 +51,7 @@ type serverOptionConfig struct {
|
||||
keepAliveEnabled bool
|
||||
keepAliveTimeout time.Duration
|
||||
keepAliveOnTimeout func()
|
||||
postAuthHook auth.PostAuthHook
|
||||
}
|
||||
|
||||
// ServerOption customises HTTP server construction.
|
||||
@@ -111,6 +112,13 @@ func WithRequestLoggerFactory(factory func(*config.Config, string) logging.Reque
|
||||
}
|
||||
}
|
||||
|
||||
// WithPostAuthHook registers a hook to be called after auth record creation.
|
||||
func WithPostAuthHook(hook auth.PostAuthHook) ServerOption {
|
||||
return func(cfg *serverOptionConfig) {
|
||||
cfg.postAuthHook = hook
|
||||
}
|
||||
}
|
||||
|
||||
// Server represents the main API server.
|
||||
// It encapsulates the Gin engine, HTTP server, handlers, and configuration.
|
||||
type Server struct {
|
||||
@@ -262,6 +270,9 @@ func NewServer(cfg *config.Config, authManager *auth.Manager, accessManager *sdk
|
||||
}
|
||||
logDir := logging.ResolveLogDirectory(cfg)
|
||||
s.mgmt.SetLogDirectory(logDir)
|
||||
if optionState.postAuthHook != nil {
|
||||
s.mgmt.SetPostAuthHook(optionState.postAuthHook)
|
||||
}
|
||||
s.localPassword = optionState.localPassword
|
||||
|
||||
// Setup routes
|
||||
|
||||
@@ -36,11 +36,21 @@ type ClaudeTokenStorage struct {
|
||||
|
||||
// Expire is the timestamp when the current access token expires.
|
||||
Expire string `json:"expired"`
|
||||
|
||||
// Metadata holds arbitrary key-value pairs injected via hooks.
|
||||
// It is not exported to JSON directly to allow flattening during serialization.
|
||||
Metadata map[string]any `json:"-"`
|
||||
}
|
||||
|
||||
// SetMetadata allows external callers to inject metadata into the storage before saving.
|
||||
func (ts *ClaudeTokenStorage) SetMetadata(meta map[string]any) {
|
||||
ts.Metadata = meta
|
||||
}
|
||||
|
||||
// SaveTokenToFile serializes the Claude token storage to a JSON file.
|
||||
// This method creates the necessary directory structure and writes the token
|
||||
// data in JSON format to the specified file path for persistent storage.
|
||||
// It merges any injected metadata into the top-level JSON object.
|
||||
//
|
||||
// Parameters:
|
||||
// - authFilePath: The full path where the token file should be saved
|
||||
@@ -65,8 +75,14 @@ func (ts *ClaudeTokenStorage) SaveTokenToFile(authFilePath string) error {
|
||||
_ = f.Close()
|
||||
}()
|
||||
|
||||
// Merge metadata using helper
|
||||
data, errMerge := misc.MergeMetadata(ts, ts.Metadata)
|
||||
if errMerge != nil {
|
||||
return fmt.Errorf("failed to merge metadata: %w", errMerge)
|
||||
}
|
||||
|
||||
// Encode and write the token data as JSON
|
||||
if err = json.NewEncoder(f).Encode(ts); err != nil {
|
||||
if err = json.NewEncoder(f).Encode(data); err != nil {
|
||||
return fmt.Errorf("failed to write token to file: %w", err)
|
||||
}
|
||||
return nil
|
||||
|
||||
@@ -71,16 +71,26 @@ func (o *CodexAuth) GenerateAuthURL(state string, pkceCodes *PKCECodes) (string,
|
||||
// It performs an HTTP POST request to the OpenAI token endpoint with the provided
|
||||
// authorization code and PKCE verifier.
|
||||
func (o *CodexAuth) ExchangeCodeForTokens(ctx context.Context, code string, pkceCodes *PKCECodes) (*CodexAuthBundle, error) {
|
||||
return o.ExchangeCodeForTokensWithRedirect(ctx, code, RedirectURI, pkceCodes)
|
||||
}
|
||||
|
||||
// ExchangeCodeForTokensWithRedirect exchanges an authorization code for tokens using
|
||||
// a caller-provided redirect URI. This supports alternate auth flows such as device
|
||||
// login while preserving the existing token parsing and storage behavior.
|
||||
func (o *CodexAuth) ExchangeCodeForTokensWithRedirect(ctx context.Context, code, redirectURI string, pkceCodes *PKCECodes) (*CodexAuthBundle, error) {
|
||||
if pkceCodes == nil {
|
||||
return nil, fmt.Errorf("PKCE codes are required for token exchange")
|
||||
}
|
||||
if strings.TrimSpace(redirectURI) == "" {
|
||||
return nil, fmt.Errorf("redirect URI is required for token exchange")
|
||||
}
|
||||
|
||||
// Prepare token exchange request
|
||||
data := url.Values{
|
||||
"grant_type": {"authorization_code"},
|
||||
"client_id": {ClientID},
|
||||
"code": {code},
|
||||
"redirect_uri": {RedirectURI},
|
||||
"redirect_uri": {strings.TrimSpace(redirectURI)},
|
||||
"code_verifier": {pkceCodes.CodeVerifier},
|
||||
}
|
||||
|
||||
@@ -266,6 +276,10 @@ func (o *CodexAuth) RefreshTokensWithRetry(ctx context.Context, refreshToken str
|
||||
if err == nil {
|
||||
return tokenData, nil
|
||||
}
|
||||
if isNonRetryableRefreshErr(err) {
|
||||
log.Warnf("Token refresh attempt %d failed with non-retryable error: %v", attempt+1, err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
lastErr = err
|
||||
log.Warnf("Token refresh attempt %d failed: %v", attempt+1, err)
|
||||
@@ -274,6 +288,14 @@ func (o *CodexAuth) RefreshTokensWithRetry(ctx context.Context, refreshToken str
|
||||
return nil, fmt.Errorf("token refresh failed after %d attempts: %w", maxRetries, lastErr)
|
||||
}
|
||||
|
||||
func isNonRetryableRefreshErr(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
raw := strings.ToLower(err.Error())
|
||||
return strings.Contains(raw, "refresh_token_reused")
|
||||
}
|
||||
|
||||
// UpdateTokenStorage updates an existing CodexTokenStorage with new token data.
|
||||
// This is typically called after a successful token refresh to persist the new credentials.
|
||||
func (o *CodexAuth) UpdateTokenStorage(storage *CodexTokenStorage, tokenData *CodexTokenData) {
|
||||
|
||||
@@ -0,0 +1,44 @@
|
||||
package codex
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type roundTripFunc func(*http.Request) (*http.Response, error)
|
||||
|
||||
func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
return f(req)
|
||||
}
|
||||
|
||||
func TestRefreshTokensWithRetry_NonRetryableOnlyAttemptsOnce(t *testing.T) {
|
||||
var calls int32
|
||||
auth := &CodexAuth{
|
||||
httpClient: &http.Client{
|
||||
Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) {
|
||||
atomic.AddInt32(&calls, 1)
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusBadRequest,
|
||||
Body: io.NopCloser(strings.NewReader(`{"error":"invalid_grant","code":"refresh_token_reused"}`)),
|
||||
Header: make(http.Header),
|
||||
Request: req,
|
||||
}, nil
|
||||
}),
|
||||
},
|
||||
}
|
||||
|
||||
_, err := auth.RefreshTokensWithRetry(context.Background(), "dummy_refresh_token", 3)
|
||||
if err == nil {
|
||||
t.Fatalf("expected error for non-retryable refresh failure")
|
||||
}
|
||||
if !strings.Contains(strings.ToLower(err.Error()), "refresh_token_reused") {
|
||||
t.Fatalf("expected refresh_token_reused in error, got: %v", err)
|
||||
}
|
||||
if got := atomic.LoadInt32(&calls); got != 1 {
|
||||
t.Fatalf("expected 1 refresh attempt, got %d", got)
|
||||
}
|
||||
}
|
||||
@@ -32,11 +32,21 @@ type CodexTokenStorage struct {
|
||||
Type string `json:"type"`
|
||||
// Expire is the timestamp when the current access token expires.
|
||||
Expire string `json:"expired"`
|
||||
|
||||
// Metadata holds arbitrary key-value pairs injected via hooks.
|
||||
// It is not exported to JSON directly to allow flattening during serialization.
|
||||
Metadata map[string]any `json:"-"`
|
||||
}
|
||||
|
||||
// SetMetadata allows external callers to inject metadata into the storage before saving.
|
||||
func (ts *CodexTokenStorage) SetMetadata(meta map[string]any) {
|
||||
ts.Metadata = meta
|
||||
}
|
||||
|
||||
// SaveTokenToFile serializes the Codex token storage to a JSON file.
|
||||
// This method creates the necessary directory structure and writes the token
|
||||
// data in JSON format to the specified file path for persistent storage.
|
||||
// It merges any injected metadata into the top-level JSON object.
|
||||
//
|
||||
// Parameters:
|
||||
// - authFilePath: The full path where the token file should be saved
|
||||
@@ -58,7 +68,13 @@ func (ts *CodexTokenStorage) SaveTokenToFile(authFilePath string) error {
|
||||
_ = f.Close()
|
||||
}()
|
||||
|
||||
if err = json.NewEncoder(f).Encode(ts); err != nil {
|
||||
// Merge metadata using helper
|
||||
data, errMerge := misc.MergeMetadata(ts, ts.Metadata)
|
||||
if errMerge != nil {
|
||||
return fmt.Errorf("failed to merge metadata: %w", errMerge)
|
||||
}
|
||||
|
||||
if err = json.NewEncoder(f).Encode(data); err != nil {
|
||||
return fmt.Errorf("failed to write token to file: %w", err)
|
||||
}
|
||||
return nil
|
||||
|
||||
@@ -35,11 +35,21 @@ type GeminiTokenStorage struct {
|
||||
|
||||
// Type indicates the authentication provider type, always "gemini" for this storage.
|
||||
Type string `json:"type"`
|
||||
|
||||
// Metadata holds arbitrary key-value pairs injected via hooks.
|
||||
// It is not exported to JSON directly to allow flattening during serialization.
|
||||
Metadata map[string]any `json:"-"`
|
||||
}
|
||||
|
||||
// SetMetadata allows external callers to inject metadata into the storage before saving.
|
||||
func (ts *GeminiTokenStorage) SetMetadata(meta map[string]any) {
|
||||
ts.Metadata = meta
|
||||
}
|
||||
|
||||
// SaveTokenToFile serializes the Gemini token storage to a JSON file.
|
||||
// This method creates the necessary directory structure and writes the token
|
||||
// data in JSON format to the specified file path for persistent storage.
|
||||
// It merges any injected metadata into the top-level JSON object.
|
||||
//
|
||||
// Parameters:
|
||||
// - authFilePath: The full path where the token file should be saved
|
||||
@@ -49,6 +59,11 @@ type GeminiTokenStorage struct {
|
||||
func (ts *GeminiTokenStorage) SaveTokenToFile(authFilePath string) error {
|
||||
misc.LogSavingCredentials(authFilePath)
|
||||
ts.Type = "gemini"
|
||||
// Merge metadata using helper
|
||||
data, errMerge := misc.MergeMetadata(ts, ts.Metadata)
|
||||
if errMerge != nil {
|
||||
return fmt.Errorf("failed to merge metadata: %w", errMerge)
|
||||
}
|
||||
if err := os.MkdirAll(filepath.Dir(authFilePath), 0700); err != nil {
|
||||
return fmt.Errorf("failed to create directory: %v", err)
|
||||
}
|
||||
@@ -63,7 +78,9 @@ func (ts *GeminiTokenStorage) SaveTokenToFile(authFilePath string) error {
|
||||
}
|
||||
}()
|
||||
|
||||
if err = json.NewEncoder(f).Encode(ts); err != nil {
|
||||
enc := json.NewEncoder(f)
|
||||
enc.SetIndent("", " ")
|
||||
if err := enc.Encode(data); err != nil {
|
||||
return fmt.Errorf("failed to write token to file: %w", err)
|
||||
}
|
||||
return nil
|
||||
|
||||
@@ -21,6 +21,15 @@ type IFlowTokenStorage struct {
|
||||
Scope string `json:"scope"`
|
||||
Cookie string `json:"cookie"`
|
||||
Type string `json:"type"`
|
||||
|
||||
// Metadata holds arbitrary key-value pairs injected via hooks.
|
||||
// It is not exported to JSON directly to allow flattening during serialization.
|
||||
Metadata map[string]any `json:"-"`
|
||||
}
|
||||
|
||||
// SetMetadata allows external callers to inject metadata into the storage before saving.
|
||||
func (ts *IFlowTokenStorage) SetMetadata(meta map[string]any) {
|
||||
ts.Metadata = meta
|
||||
}
|
||||
|
||||
// SaveTokenToFile serialises the token storage to disk.
|
||||
@@ -37,7 +46,13 @@ func (ts *IFlowTokenStorage) SaveTokenToFile(authFilePath string) error {
|
||||
}
|
||||
defer func() { _ = f.Close() }()
|
||||
|
||||
if err = json.NewEncoder(f).Encode(ts); err != nil {
|
||||
// Merge metadata using helper
|
||||
data, errMerge := misc.MergeMetadata(ts, ts.Metadata)
|
||||
if errMerge != nil {
|
||||
return fmt.Errorf("failed to merge metadata: %w", errMerge)
|
||||
}
|
||||
|
||||
if err = json.NewEncoder(f).Encode(data); err != nil {
|
||||
return fmt.Errorf("iflow token: encode token failed: %w", err)
|
||||
}
|
||||
return nil
|
||||
|
||||
@@ -29,6 +29,15 @@ type KimiTokenStorage struct {
|
||||
Expired string `json:"expired,omitempty"`
|
||||
// Type indicates the authentication provider type, always "kimi" for this storage.
|
||||
Type string `json:"type"`
|
||||
|
||||
// Metadata holds arbitrary key-value pairs injected via hooks.
|
||||
// It is not exported to JSON directly to allow flattening during serialization.
|
||||
Metadata map[string]any `json:"-"`
|
||||
}
|
||||
|
||||
// SetMetadata allows external callers to inject metadata into the storage before saving.
|
||||
func (ts *KimiTokenStorage) SetMetadata(meta map[string]any) {
|
||||
ts.Metadata = meta
|
||||
}
|
||||
|
||||
// KimiTokenData holds the raw OAuth token response from Kimi.
|
||||
@@ -86,9 +95,15 @@ func (ts *KimiTokenStorage) SaveTokenToFile(authFilePath string) error {
|
||||
_ = f.Close()
|
||||
}()
|
||||
|
||||
// Merge metadata using helper
|
||||
data, errMerge := misc.MergeMetadata(ts, ts.Metadata)
|
||||
if errMerge != nil {
|
||||
return fmt.Errorf("failed to merge metadata: %w", errMerge)
|
||||
}
|
||||
|
||||
encoder := json.NewEncoder(f)
|
||||
encoder.SetIndent("", " ")
|
||||
if err = encoder.Encode(ts); err != nil {
|
||||
if err = encoder.Encode(data); err != nil {
|
||||
return fmt.Errorf("failed to write token to file: %w", err)
|
||||
}
|
||||
return nil
|
||||
|
||||
@@ -30,11 +30,21 @@ type QwenTokenStorage struct {
|
||||
Type string `json:"type"`
|
||||
// Expire is the timestamp when the current access token expires.
|
||||
Expire string `json:"expired"`
|
||||
|
||||
// Metadata holds arbitrary key-value pairs injected via hooks.
|
||||
// It is not exported to JSON directly to allow flattening during serialization.
|
||||
Metadata map[string]any `json:"-"`
|
||||
}
|
||||
|
||||
// SetMetadata allows external callers to inject metadata into the storage before saving.
|
||||
func (ts *QwenTokenStorage) SetMetadata(meta map[string]any) {
|
||||
ts.Metadata = meta
|
||||
}
|
||||
|
||||
// SaveTokenToFile serializes the Qwen token storage to a JSON file.
|
||||
// This method creates the necessary directory structure and writes the token
|
||||
// data in JSON format to the specified file path for persistent storage.
|
||||
// It merges any injected metadata into the top-level JSON object.
|
||||
//
|
||||
// Parameters:
|
||||
// - authFilePath: The full path where the token file should be saved
|
||||
@@ -56,7 +66,13 @@ func (ts *QwenTokenStorage) SaveTokenToFile(authFilePath string) error {
|
||||
_ = f.Close()
|
||||
}()
|
||||
|
||||
if err = json.NewEncoder(f).Encode(ts); err != nil {
|
||||
// Merge metadata using helper
|
||||
data, errMerge := misc.MergeMetadata(ts, ts.Metadata)
|
||||
if errMerge != nil {
|
||||
return fmt.Errorf("failed to merge metadata: %w", errMerge)
|
||||
}
|
||||
|
||||
if err = json.NewEncoder(f).Encode(data); err != nil {
|
||||
return fmt.Errorf("failed to write token to file: %w", err)
|
||||
}
|
||||
return nil
|
||||
|
||||
@@ -0,0 +1,60 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codex"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
codexLoginModeMetadataKey = "codex_login_mode"
|
||||
codexLoginModeDevice = "device"
|
||||
)
|
||||
|
||||
// DoCodexDeviceLogin triggers the Codex device-code flow while keeping the
|
||||
// existing codex-login OAuth callback flow intact.
|
||||
func DoCodexDeviceLogin(cfg *config.Config, options *LoginOptions) {
|
||||
if options == nil {
|
||||
options = &LoginOptions{}
|
||||
}
|
||||
|
||||
promptFn := options.Prompt
|
||||
if promptFn == nil {
|
||||
promptFn = defaultProjectPrompt()
|
||||
}
|
||||
|
||||
manager := newAuthManager()
|
||||
|
||||
authOpts := &sdkAuth.LoginOptions{
|
||||
NoBrowser: options.NoBrowser,
|
||||
CallbackPort: options.CallbackPort,
|
||||
Metadata: map[string]string{
|
||||
codexLoginModeMetadataKey: codexLoginModeDevice,
|
||||
},
|
||||
Prompt: promptFn,
|
||||
}
|
||||
|
||||
_, savedPath, err := manager.Login(context.Background(), "codex", cfg, authOpts)
|
||||
if err != nil {
|
||||
if authErr, ok := errors.AsType[*codex.AuthenticationError](err); ok {
|
||||
log.Error(codex.GetUserFriendlyMessage(authErr))
|
||||
if authErr.Type == codex.ErrPortInUse.Type {
|
||||
os.Exit(codex.ErrPortInUse.Code)
|
||||
}
|
||||
return
|
||||
}
|
||||
fmt.Printf("Codex device authentication failed: %v\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
if savedPath != "" {
|
||||
fmt.Printf("Authentication saved to %s\n", savedPath)
|
||||
}
|
||||
fmt.Println("Codex device authentication successful!")
|
||||
}
|
||||
@@ -1,6 +1,7 @@
|
||||
package misc
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
@@ -24,3 +25,37 @@ func LogSavingCredentials(path string) {
|
||||
func LogCredentialSeparator() {
|
||||
log.Debug(credentialSeparator)
|
||||
}
|
||||
|
||||
// MergeMetadata serializes the source struct into a map and merges the provided metadata into it.
|
||||
func MergeMetadata(source any, metadata map[string]any) (map[string]any, error) {
|
||||
var data map[string]any
|
||||
|
||||
// Fast path: if source is already a map, just copy it to avoid mutation of original
|
||||
if srcMap, ok := source.(map[string]any); ok {
|
||||
data = make(map[string]any, len(srcMap)+len(metadata))
|
||||
for k, v := range srcMap {
|
||||
data[k] = v
|
||||
}
|
||||
} else {
|
||||
// Slow path: marshal to JSON and back to map to respect JSON tags
|
||||
temp, err := json.Marshal(source)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal source: %w", err)
|
||||
}
|
||||
if err := json.Unmarshal(temp, &data); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal to map: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Merge extra metadata
|
||||
if metadata != nil {
|
||||
if data == nil {
|
||||
data = make(map[string]any)
|
||||
}
|
||||
for k, v := range metadata {
|
||||
data[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
return data, nil
|
||||
}
|
||||
|
||||
@@ -904,19 +904,12 @@ func GetIFlowModels() []*ModelInfo {
|
||||
Created int64
|
||||
Thinking *ThinkingSupport
|
||||
}{
|
||||
{ID: "tstars2.0", DisplayName: "TStars-2.0", Description: "iFlow TStars-2.0 multimodal assistant", Created: 1746489600},
|
||||
{ID: "qwen3-coder-plus", DisplayName: "Qwen3-Coder-Plus", Description: "Qwen3 Coder Plus code generation", Created: 1753228800},
|
||||
{ID: "qwen3-max", DisplayName: "Qwen3-Max", Description: "Qwen3 flagship model", Created: 1758672000},
|
||||
{ID: "qwen3-vl-plus", DisplayName: "Qwen3-VL-Plus", Description: "Qwen3 multimodal vision-language", Created: 1758672000},
|
||||
{ID: "qwen3-max-preview", DisplayName: "Qwen3-Max-Preview", Description: "Qwen3 Max preview build", Created: 1757030400, Thinking: iFlowThinkingSupport},
|
||||
{ID: "kimi-k2-0905", DisplayName: "Kimi-K2-Instruct-0905", Description: "Moonshot Kimi K2 instruct 0905", Created: 1757030400},
|
||||
{ID: "glm-4.6", DisplayName: "GLM-4.6", Description: "Zhipu GLM 4.6 general model", Created: 1759190400, Thinking: iFlowThinkingSupport},
|
||||
{ID: "glm-4.7", DisplayName: "GLM-4.7", Description: "Zhipu GLM 4.7 general model", Created: 1766448000, Thinking: iFlowThinkingSupport},
|
||||
{ID: "glm-5", DisplayName: "GLM-5", Description: "Zhipu GLM 5 general model", Created: 1770768000, Thinking: iFlowThinkingSupport},
|
||||
{ID: "kimi-k2", DisplayName: "Kimi-K2", Description: "Moonshot Kimi K2 general model", Created: 1752192000},
|
||||
{ID: "kimi-k2-thinking", DisplayName: "Kimi-K2-Thinking", Description: "Moonshot Kimi K2 thinking model", Created: 1762387200},
|
||||
{ID: "deepseek-v3.2-chat", DisplayName: "DeepSeek-V3.2", Description: "DeepSeek V3.2 Chat", Created: 1764576000},
|
||||
{ID: "deepseek-v3.2-reasoner", DisplayName: "DeepSeek-V3.2", Description: "DeepSeek V3.2 Reasoner", Created: 1764576000},
|
||||
{ID: "deepseek-v3.2", DisplayName: "DeepSeek-V3.2-Exp", Description: "DeepSeek V3.2 experimental", Created: 1759104000, Thinking: iFlowThinkingSupport},
|
||||
{ID: "deepseek-v3.1", DisplayName: "DeepSeek-V3.1-Terminus", Description: "DeepSeek V3.1 Terminus", Created: 1756339200, Thinking: iFlowThinkingSupport},
|
||||
{ID: "deepseek-r1", DisplayName: "DeepSeek-R1", Description: "DeepSeek reasoning model R1", Created: 1737331200},
|
||||
@@ -925,11 +918,7 @@ func GetIFlowModels() []*ModelInfo {
|
||||
{ID: "qwen3-235b-a22b-thinking-2507", DisplayName: "Qwen3-235B-A22B-Thinking", Description: "Qwen3 235B A22B Thinking (2507)", Created: 1753401600},
|
||||
{ID: "qwen3-235b-a22b-instruct", DisplayName: "Qwen3-235B-A22B-Instruct", Description: "Qwen3 235B A22B Instruct", Created: 1753401600},
|
||||
{ID: "qwen3-235b", DisplayName: "Qwen3-235B-A22B", Description: "Qwen3 235B A22B", Created: 1753401600},
|
||||
{ID: "minimax-m2", DisplayName: "MiniMax-M2", Description: "MiniMax M2", Created: 1758672000, Thinking: iFlowThinkingSupport},
|
||||
{ID: "minimax-m2.1", DisplayName: "MiniMax-M2.1", Description: "MiniMax M2.1", Created: 1766448000, Thinking: iFlowThinkingSupport},
|
||||
{ID: "minimax-m2.5", DisplayName: "MiniMax-M2.5", Description: "MiniMax M2.5", Created: 1770825600, Thinking: iFlowThinkingSupport},
|
||||
{ID: "iflow-rome-30ba3b", DisplayName: "iFlow-ROME", Description: "iFlow Rome 30BA3B model", Created: 1736899200},
|
||||
{ID: "kimi-k2.5", DisplayName: "Kimi-K2.5", Description: "Moonshot Kimi K2.5", Created: 1769443200, Thinking: iFlowThinkingSupport},
|
||||
}
|
||||
models := make([]*ModelInfo, 0, len(entries))
|
||||
for _, entry := range entries {
|
||||
@@ -963,6 +952,7 @@ func GetAntigravityModelConfig() map[string]*AntigravityModelConfig {
|
||||
"gemini-2.5-flash-lite": {Thinking: &ThinkingSupport{Min: 0, Max: 24576, ZeroAllowed: true, DynamicAllowed: true}},
|
||||
"gemini-3-pro-high": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}}},
|
||||
"gemini-3-pro-image": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}}},
|
||||
"gemini-3.1-pro-high": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}}},
|
||||
"gemini-3-flash": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"minimal", "low", "medium", "high"}}},
|
||||
"claude-sonnet-4-5-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: true}, MaxCompletionTokens: 64000},
|
||||
"claude-opus-4-5-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: true}, MaxCompletionTokens: 64000},
|
||||
|
||||
@@ -55,8 +55,78 @@ const (
|
||||
var (
|
||||
randSource = rand.New(rand.NewSource(time.Now().UnixNano()))
|
||||
randSourceMutex sync.Mutex
|
||||
// antigravityPrimaryModelsCache keeps the latest non-empty model list fetched
|
||||
// from any antigravity auth. Empty fetches never overwrite this cache.
|
||||
antigravityPrimaryModelsCache struct {
|
||||
mu sync.RWMutex
|
||||
models []*registry.ModelInfo
|
||||
}
|
||||
)
|
||||
|
||||
func cloneAntigravityModels(models []*registry.ModelInfo) []*registry.ModelInfo {
|
||||
if len(models) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]*registry.ModelInfo, 0, len(models))
|
||||
for _, model := range models {
|
||||
if model == nil || strings.TrimSpace(model.ID) == "" {
|
||||
continue
|
||||
}
|
||||
out = append(out, cloneAntigravityModelInfo(model))
|
||||
}
|
||||
if len(out) == 0 {
|
||||
return nil
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func cloneAntigravityModelInfo(model *registry.ModelInfo) *registry.ModelInfo {
|
||||
if model == nil {
|
||||
return nil
|
||||
}
|
||||
clone := *model
|
||||
if len(model.SupportedGenerationMethods) > 0 {
|
||||
clone.SupportedGenerationMethods = append([]string(nil), model.SupportedGenerationMethods...)
|
||||
}
|
||||
if len(model.SupportedParameters) > 0 {
|
||||
clone.SupportedParameters = append([]string(nil), model.SupportedParameters...)
|
||||
}
|
||||
if model.Thinking != nil {
|
||||
thinkingClone := *model.Thinking
|
||||
if len(model.Thinking.Levels) > 0 {
|
||||
thinkingClone.Levels = append([]string(nil), model.Thinking.Levels...)
|
||||
}
|
||||
clone.Thinking = &thinkingClone
|
||||
}
|
||||
return &clone
|
||||
}
|
||||
|
||||
func storeAntigravityPrimaryModels(models []*registry.ModelInfo) bool {
|
||||
cloned := cloneAntigravityModels(models)
|
||||
if len(cloned) == 0 {
|
||||
return false
|
||||
}
|
||||
antigravityPrimaryModelsCache.mu.Lock()
|
||||
antigravityPrimaryModelsCache.models = cloned
|
||||
antigravityPrimaryModelsCache.mu.Unlock()
|
||||
return true
|
||||
}
|
||||
|
||||
func loadAntigravityPrimaryModels() []*registry.ModelInfo {
|
||||
antigravityPrimaryModelsCache.mu.RLock()
|
||||
cloned := cloneAntigravityModels(antigravityPrimaryModelsCache.models)
|
||||
antigravityPrimaryModelsCache.mu.RUnlock()
|
||||
return cloned
|
||||
}
|
||||
|
||||
func fallbackAntigravityPrimaryModels() []*registry.ModelInfo {
|
||||
models := loadAntigravityPrimaryModels()
|
||||
if len(models) > 0 {
|
||||
log.Debugf("antigravity executor: using cached primary model list (%d models)", len(models))
|
||||
}
|
||||
return models
|
||||
}
|
||||
|
||||
// AntigravityExecutor proxies requests to the antigravity upstream.
|
||||
type AntigravityExecutor struct {
|
||||
cfg *config.Config
|
||||
@@ -1072,7 +1142,7 @@ func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *c
|
||||
exec := &AntigravityExecutor{cfg: cfg}
|
||||
token, updatedAuth, errToken := exec.ensureAccessToken(ctx, auth)
|
||||
if errToken != nil || token == "" {
|
||||
return nil
|
||||
return fallbackAntigravityPrimaryModels()
|
||||
}
|
||||
if updatedAuth != nil {
|
||||
auth = updatedAuth
|
||||
@@ -1096,7 +1166,7 @@ func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *c
|
||||
|
||||
httpReq, errReq := http.NewRequestWithContext(ctx, http.MethodPost, modelsURL, bytes.NewReader(payload))
|
||||
if errReq != nil {
|
||||
return nil
|
||||
return fallbackAntigravityPrimaryModels()
|
||||
}
|
||||
httpReq.Close = true
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
@@ -1109,14 +1179,13 @@ func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *c
|
||||
httpResp, errDo := httpClient.Do(httpReq)
|
||||
if errDo != nil {
|
||||
if errors.Is(errDo, context.Canceled) || errors.Is(errDo, context.DeadlineExceeded) {
|
||||
return nil
|
||||
return fallbackAntigravityPrimaryModels()
|
||||
}
|
||||
if idx+1 < len(baseURLs) {
|
||||
log.Debugf("antigravity executor: models request error on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
||||
continue
|
||||
}
|
||||
log.Errorf("antigravity executor: models request failed: %v", errDo)
|
||||
return nil
|
||||
return fallbackAntigravityPrimaryModels()
|
||||
}
|
||||
|
||||
bodyBytes, errRead := io.ReadAll(httpResp.Body)
|
||||
@@ -1128,21 +1197,27 @@ func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *c
|
||||
log.Debugf("antigravity executor: models read error on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
||||
continue
|
||||
}
|
||||
log.Errorf("antigravity executor: models read body failed: %v", errRead)
|
||||
return nil
|
||||
return fallbackAntigravityPrimaryModels()
|
||||
}
|
||||
if httpResp.StatusCode < http.StatusOK || httpResp.StatusCode >= http.StatusMultipleChoices {
|
||||
if httpResp.StatusCode == http.StatusTooManyRequests && idx+1 < len(baseURLs) {
|
||||
log.Debugf("antigravity executor: models request rate limited on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
||||
continue
|
||||
}
|
||||
log.Errorf("antigravity executor: models request error status %d: %s", httpResp.StatusCode, string(bodyBytes))
|
||||
return nil
|
||||
if idx+1 < len(baseURLs) {
|
||||
log.Debugf("antigravity executor: models request failed with status %d on base url %s, retrying with fallback base url: %s", httpResp.StatusCode, baseURL, baseURLs[idx+1])
|
||||
continue
|
||||
}
|
||||
return fallbackAntigravityPrimaryModels()
|
||||
}
|
||||
|
||||
result := gjson.GetBytes(bodyBytes, "models")
|
||||
if !result.Exists() {
|
||||
return nil
|
||||
if idx+1 < len(baseURLs) {
|
||||
log.Debugf("antigravity executor: models field missing on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
||||
continue
|
||||
}
|
||||
return fallbackAntigravityPrimaryModels()
|
||||
}
|
||||
|
||||
now := time.Now().Unix()
|
||||
@@ -1210,9 +1285,18 @@ func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *c
|
||||
}
|
||||
models = append(models, modelInfo)
|
||||
}
|
||||
if len(models) == 0 {
|
||||
if idx+1 < len(baseURLs) {
|
||||
log.Debugf("antigravity executor: empty models list on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
||||
continue
|
||||
}
|
||||
log.Debug("antigravity executor: fetched empty model list; retaining cached primary model list")
|
||||
return fallbackAntigravityPrimaryModels()
|
||||
}
|
||||
storeAntigravityPrimaryModels(models)
|
||||
return models
|
||||
}
|
||||
return nil
|
||||
return fallbackAntigravityPrimaryModels()
|
||||
}
|
||||
|
||||
func (e *AntigravityExecutor) ensureAccessToken(ctx context.Context, auth *cliproxyauth.Auth) (string, *cliproxyauth.Auth, error) {
|
||||
|
||||
@@ -0,0 +1,90 @@
|
||||
package executor
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||
)
|
||||
|
||||
func resetAntigravityPrimaryModelsCacheForTest() {
|
||||
antigravityPrimaryModelsCache.mu.Lock()
|
||||
antigravityPrimaryModelsCache.models = nil
|
||||
antigravityPrimaryModelsCache.mu.Unlock()
|
||||
}
|
||||
|
||||
func TestStoreAntigravityPrimaryModels_EmptyDoesNotOverwrite(t *testing.T) {
|
||||
resetAntigravityPrimaryModelsCacheForTest()
|
||||
t.Cleanup(resetAntigravityPrimaryModelsCacheForTest)
|
||||
|
||||
seed := []*registry.ModelInfo{
|
||||
{ID: "claude-sonnet-4-5"},
|
||||
{ID: "gemini-2.5-pro"},
|
||||
}
|
||||
if updated := storeAntigravityPrimaryModels(seed); !updated {
|
||||
t.Fatal("expected non-empty model list to update primary cache")
|
||||
}
|
||||
|
||||
if updated := storeAntigravityPrimaryModels(nil); updated {
|
||||
t.Fatal("expected nil model list not to overwrite primary cache")
|
||||
}
|
||||
if updated := storeAntigravityPrimaryModels([]*registry.ModelInfo{}); updated {
|
||||
t.Fatal("expected empty model list not to overwrite primary cache")
|
||||
}
|
||||
|
||||
got := loadAntigravityPrimaryModels()
|
||||
if len(got) != 2 {
|
||||
t.Fatalf("expected cached model count 2, got %d", len(got))
|
||||
}
|
||||
if got[0].ID != "claude-sonnet-4-5" || got[1].ID != "gemini-2.5-pro" {
|
||||
t.Fatalf("unexpected cached model ids: %q, %q", got[0].ID, got[1].ID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadAntigravityPrimaryModels_ReturnsClone(t *testing.T) {
|
||||
resetAntigravityPrimaryModelsCacheForTest()
|
||||
t.Cleanup(resetAntigravityPrimaryModelsCacheForTest)
|
||||
|
||||
if updated := storeAntigravityPrimaryModels([]*registry.ModelInfo{{
|
||||
ID: "gpt-5",
|
||||
DisplayName: "GPT-5",
|
||||
SupportedGenerationMethods: []string{"generateContent"},
|
||||
SupportedParameters: []string{"temperature"},
|
||||
Thinking: ®istry.ThinkingSupport{
|
||||
Levels: []string{"high"},
|
||||
},
|
||||
}}); !updated {
|
||||
t.Fatal("expected model cache update")
|
||||
}
|
||||
|
||||
got := loadAntigravityPrimaryModels()
|
||||
if len(got) != 1 {
|
||||
t.Fatalf("expected one cached model, got %d", len(got))
|
||||
}
|
||||
got[0].ID = "mutated-id"
|
||||
if len(got[0].SupportedGenerationMethods) > 0 {
|
||||
got[0].SupportedGenerationMethods[0] = "mutated-method"
|
||||
}
|
||||
if len(got[0].SupportedParameters) > 0 {
|
||||
got[0].SupportedParameters[0] = "mutated-parameter"
|
||||
}
|
||||
if got[0].Thinking != nil && len(got[0].Thinking.Levels) > 0 {
|
||||
got[0].Thinking.Levels[0] = "mutated-level"
|
||||
}
|
||||
|
||||
again := loadAntigravityPrimaryModels()
|
||||
if len(again) != 1 {
|
||||
t.Fatalf("expected one cached model after mutation, got %d", len(again))
|
||||
}
|
||||
if again[0].ID != "gpt-5" {
|
||||
t.Fatalf("expected cached model id to remain %q, got %q", "gpt-5", again[0].ID)
|
||||
}
|
||||
if len(again[0].SupportedGenerationMethods) == 0 || again[0].SupportedGenerationMethods[0] != "generateContent" {
|
||||
t.Fatalf("expected cached generation methods to be unmutated, got %v", again[0].SupportedGenerationMethods)
|
||||
}
|
||||
if len(again[0].SupportedParameters) == 0 || again[0].SupportedParameters[0] != "temperature" {
|
||||
t.Fatalf("expected cached supported parameters to be unmutated, got %v", again[0].SupportedParameters)
|
||||
}
|
||||
if again[0].Thinking == nil || len(again[0].Thinking.Levels) == 0 || again[0].Thinking.Levels[0] != "high" {
|
||||
t.Fatalf("expected cached model thinking levels to be unmutated, got %v", again[0].Thinking)
|
||||
}
|
||||
}
|
||||
@@ -156,7 +156,7 @@ func (e *CodexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
|
||||
b, _ := io.ReadAll(httpResp.Body)
|
||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
|
||||
err = newCodexStatusErr(httpResp.StatusCode, b)
|
||||
return resp, err
|
||||
}
|
||||
data, err := io.ReadAll(httpResp.Body)
|
||||
@@ -260,7 +260,7 @@ func (e *CodexExecutor) executeCompact(ctx context.Context, auth *cliproxyauth.A
|
||||
b, _ := io.ReadAll(httpResp.Body)
|
||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
|
||||
err = newCodexStatusErr(httpResp.StatusCode, b)
|
||||
return resp, err
|
||||
}
|
||||
data, err := io.ReadAll(httpResp.Body)
|
||||
@@ -358,7 +358,7 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
|
||||
}
|
||||
appendAPIResponseChunk(ctx, e.cfg, data)
|
||||
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), data))
|
||||
err = statusErr{code: httpResp.StatusCode, msg: string(data)}
|
||||
err = newCodexStatusErr(httpResp.StatusCode, data)
|
||||
return nil, err
|
||||
}
|
||||
out := make(chan cliproxyexecutor.StreamChunk)
|
||||
@@ -673,6 +673,35 @@ func applyCodexHeaders(r *http.Request, auth *cliproxyauth.Auth, token string, s
|
||||
util.ApplyCustomHeadersFromAttrs(r, attrs)
|
||||
}
|
||||
|
||||
func newCodexStatusErr(statusCode int, body []byte) statusErr {
|
||||
err := statusErr{code: statusCode, msg: string(body)}
|
||||
if retryAfter := parseCodexRetryAfter(statusCode, body, time.Now()); retryAfter != nil {
|
||||
err.retryAfter = retryAfter
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func parseCodexRetryAfter(statusCode int, errorBody []byte, now time.Time) *time.Duration {
|
||||
if statusCode != http.StatusTooManyRequests || len(errorBody) == 0 {
|
||||
return nil
|
||||
}
|
||||
if strings.TrimSpace(gjson.GetBytes(errorBody, "error.type").String()) != "usage_limit_reached" {
|
||||
return nil
|
||||
}
|
||||
if resetsAt := gjson.GetBytes(errorBody, "error.resets_at").Int(); resetsAt > 0 {
|
||||
resetAtTime := time.Unix(resetsAt, 0)
|
||||
if resetAtTime.After(now) {
|
||||
retryAfter := resetAtTime.Sub(now)
|
||||
return &retryAfter
|
||||
}
|
||||
}
|
||||
if resetsInSeconds := gjson.GetBytes(errorBody, "error.resets_in_seconds").Int(); resetsInSeconds > 0 {
|
||||
retryAfter := time.Duration(resetsInSeconds) * time.Second
|
||||
return &retryAfter
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func codexCreds(a *cliproxyauth.Auth) (apiKey, baseURL string) {
|
||||
if a == nil {
|
||||
return "", ""
|
||||
|
||||
@@ -0,0 +1,65 @@
|
||||
package executor
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestParseCodexRetryAfter(t *testing.T) {
|
||||
now := time.Unix(1_700_000_000, 0)
|
||||
|
||||
t.Run("resets_in_seconds", func(t *testing.T) {
|
||||
body := []byte(`{"error":{"type":"usage_limit_reached","resets_in_seconds":123}}`)
|
||||
retryAfter := parseCodexRetryAfter(http.StatusTooManyRequests, body, now)
|
||||
if retryAfter == nil {
|
||||
t.Fatalf("expected retryAfter, got nil")
|
||||
}
|
||||
if *retryAfter != 123*time.Second {
|
||||
t.Fatalf("retryAfter = %v, want %v", *retryAfter, 123*time.Second)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("prefers resets_at", func(t *testing.T) {
|
||||
resetAt := now.Add(5 * time.Minute).Unix()
|
||||
body := []byte(`{"error":{"type":"usage_limit_reached","resets_at":` + itoa(resetAt) + `,"resets_in_seconds":1}}`)
|
||||
retryAfter := parseCodexRetryAfter(http.StatusTooManyRequests, body, now)
|
||||
if retryAfter == nil {
|
||||
t.Fatalf("expected retryAfter, got nil")
|
||||
}
|
||||
if *retryAfter != 5*time.Minute {
|
||||
t.Fatalf("retryAfter = %v, want %v", *retryAfter, 5*time.Minute)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("fallback when resets_at is past", func(t *testing.T) {
|
||||
resetAt := now.Add(-1 * time.Minute).Unix()
|
||||
body := []byte(`{"error":{"type":"usage_limit_reached","resets_at":` + itoa(resetAt) + `,"resets_in_seconds":77}}`)
|
||||
retryAfter := parseCodexRetryAfter(http.StatusTooManyRequests, body, now)
|
||||
if retryAfter == nil {
|
||||
t.Fatalf("expected retryAfter, got nil")
|
||||
}
|
||||
if *retryAfter != 77*time.Second {
|
||||
t.Fatalf("retryAfter = %v, want %v", *retryAfter, 77*time.Second)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("non-429 status code", func(t *testing.T) {
|
||||
body := []byte(`{"error":{"type":"usage_limit_reached","resets_in_seconds":30}}`)
|
||||
if got := parseCodexRetryAfter(http.StatusBadRequest, body, now); got != nil {
|
||||
t.Fatalf("expected nil for non-429, got %v", *got)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("non usage_limit_reached error type", func(t *testing.T) {
|
||||
body := []byte(`{"error":{"type":"server_error","resets_in_seconds":30}}`)
|
||||
if got := parseCodexRetryAfter(http.StatusTooManyRequests, body, now); got != nil {
|
||||
t.Fatalf("expected nil for non-usage_limit_reached, got %v", *got)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func itoa(v int64) string {
|
||||
return strconv.FormatInt(v, 10)
|
||||
}
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
qwenauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/qwen"
|
||||
@@ -22,9 +23,151 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
qwenUserAgent = "QwenCode/0.10.3 (darwin; arm64)"
|
||||
qwenUserAgent = "QwenCode/0.10.3 (darwin; arm64)"
|
||||
qwenRateLimitPerMin = 60 // 60 requests per minute per credential
|
||||
qwenRateLimitWindow = time.Minute // sliding window duration
|
||||
)
|
||||
|
||||
// qwenBeijingLoc caches the Beijing timezone to avoid repeated LoadLocation syscalls.
|
||||
var qwenBeijingLoc = func() *time.Location {
|
||||
loc, err := time.LoadLocation("Asia/Shanghai")
|
||||
if err != nil || loc == nil {
|
||||
log.Warnf("qwen: failed to load Asia/Shanghai timezone: %v, using fixed UTC+8", err)
|
||||
return time.FixedZone("CST", 8*3600)
|
||||
}
|
||||
return loc
|
||||
}()
|
||||
|
||||
// qwenQuotaCodes is a package-level set of error codes that indicate quota exhaustion.
|
||||
var qwenQuotaCodes = map[string]struct{}{
|
||||
"insufficient_quota": {},
|
||||
"quota_exceeded": {},
|
||||
}
|
||||
|
||||
// qwenRateLimiter tracks request timestamps per credential for rate limiting.
|
||||
// Qwen has a limit of 60 requests per minute per account.
|
||||
var qwenRateLimiter = struct {
|
||||
sync.Mutex
|
||||
requests map[string][]time.Time // authID -> request timestamps
|
||||
}{
|
||||
requests: make(map[string][]time.Time),
|
||||
}
|
||||
|
||||
// redactAuthID returns a redacted version of the auth ID for safe logging.
|
||||
// Keeps a small prefix/suffix to allow correlation across events.
|
||||
func redactAuthID(id string) string {
|
||||
if id == "" {
|
||||
return ""
|
||||
}
|
||||
if len(id) <= 8 {
|
||||
return id
|
||||
}
|
||||
return id[:4] + "..." + id[len(id)-4:]
|
||||
}
|
||||
|
||||
// checkQwenRateLimit checks if the credential has exceeded the rate limit.
|
||||
// Returns nil if allowed, or a statusErr with retryAfter if rate limited.
|
||||
func checkQwenRateLimit(authID string) error {
|
||||
if authID == "" {
|
||||
// Empty authID should not bypass rate limiting in production
|
||||
// Use debug level to avoid log spam for certain auth flows
|
||||
log.Debug("qwen rate limit check: empty authID, skipping rate limit")
|
||||
return nil
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
windowStart := now.Add(-qwenRateLimitWindow)
|
||||
|
||||
qwenRateLimiter.Lock()
|
||||
defer qwenRateLimiter.Unlock()
|
||||
|
||||
// Get and filter timestamps within the window
|
||||
timestamps := qwenRateLimiter.requests[authID]
|
||||
var validTimestamps []time.Time
|
||||
for _, ts := range timestamps {
|
||||
if ts.After(windowStart) {
|
||||
validTimestamps = append(validTimestamps, ts)
|
||||
}
|
||||
}
|
||||
|
||||
// Always prune expired entries to prevent memory leak
|
||||
// Delete empty entries, otherwise update with pruned slice
|
||||
if len(validTimestamps) == 0 {
|
||||
delete(qwenRateLimiter.requests, authID)
|
||||
}
|
||||
|
||||
// Check if rate limit exceeded
|
||||
if len(validTimestamps) >= qwenRateLimitPerMin {
|
||||
// Calculate when the oldest request will expire
|
||||
oldestInWindow := validTimestamps[0]
|
||||
retryAfter := oldestInWindow.Add(qwenRateLimitWindow).Sub(now)
|
||||
if retryAfter < time.Second {
|
||||
retryAfter = time.Second
|
||||
}
|
||||
retryAfterSec := int(retryAfter.Seconds())
|
||||
return statusErr{
|
||||
code: http.StatusTooManyRequests,
|
||||
msg: fmt.Sprintf(`{"error":{"code":"rate_limit_exceeded","message":"Qwen rate limit: %d requests/minute exceeded, retry after %ds","type":"rate_limit_exceeded"}}`, qwenRateLimitPerMin, retryAfterSec),
|
||||
retryAfter: &retryAfter,
|
||||
}
|
||||
}
|
||||
|
||||
// Record this request and update the map with pruned timestamps
|
||||
validTimestamps = append(validTimestamps, now)
|
||||
qwenRateLimiter.requests[authID] = validTimestamps
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// isQwenQuotaError checks if the error response indicates a quota exceeded error.
|
||||
// Qwen returns HTTP 403 with error.code="insufficient_quota" when daily quota is exhausted.
|
||||
func isQwenQuotaError(body []byte) bool {
|
||||
code := strings.ToLower(gjson.GetBytes(body, "error.code").String())
|
||||
errType := strings.ToLower(gjson.GetBytes(body, "error.type").String())
|
||||
|
||||
// Primary check: exact match on error.code or error.type (most reliable)
|
||||
if _, ok := qwenQuotaCodes[code]; ok {
|
||||
return true
|
||||
}
|
||||
if _, ok := qwenQuotaCodes[errType]; ok {
|
||||
return true
|
||||
}
|
||||
|
||||
// Fallback: check message only if code/type don't match (less reliable)
|
||||
msg := strings.ToLower(gjson.GetBytes(body, "error.message").String())
|
||||
if strings.Contains(msg, "insufficient_quota") || strings.Contains(msg, "quota exceeded") ||
|
||||
strings.Contains(msg, "free allocated quota exceeded") {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// wrapQwenError wraps an HTTP error response, detecting quota errors and mapping them to 429.
|
||||
// Returns the appropriate status code and retryAfter duration for statusErr.
|
||||
// Only checks for quota errors when httpCode is 403 or 429 to avoid false positives.
|
||||
func wrapQwenError(ctx context.Context, httpCode int, body []byte) (errCode int, retryAfter *time.Duration) {
|
||||
errCode = httpCode
|
||||
// Only check quota errors for expected status codes to avoid false positives
|
||||
// Qwen returns 403 for quota errors, 429 for rate limits
|
||||
if (httpCode == http.StatusForbidden || httpCode == http.StatusTooManyRequests) && isQwenQuotaError(body) {
|
||||
errCode = http.StatusTooManyRequests // Map to 429 to trigger quota logic
|
||||
cooldown := timeUntilNextDay()
|
||||
retryAfter = &cooldown
|
||||
logWithRequestID(ctx).Warnf("qwen quota exceeded (http %d -> %d), cooling down until tomorrow (%v)", httpCode, errCode, cooldown)
|
||||
}
|
||||
return errCode, retryAfter
|
||||
}
|
||||
|
||||
// timeUntilNextDay returns duration until midnight Beijing time (UTC+8).
|
||||
// Qwen's daily quota resets at 00:00 Beijing time.
|
||||
func timeUntilNextDay() time.Duration {
|
||||
now := time.Now()
|
||||
nowLocal := now.In(qwenBeijingLoc)
|
||||
tomorrow := time.Date(nowLocal.Year(), nowLocal.Month(), nowLocal.Day()+1, 0, 0, 0, 0, qwenBeijingLoc)
|
||||
return tomorrow.Sub(now)
|
||||
}
|
||||
|
||||
// QwenExecutor is a stateless executor for Qwen Code using OpenAI-compatible chat completions.
|
||||
// If access token is unavailable, it falls back to legacy via ClientAdapter.
|
||||
type QwenExecutor struct {
|
||||
@@ -67,6 +210,17 @@ func (e *QwenExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
|
||||
if opts.Alt == "responses/compact" {
|
||||
return resp, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
|
||||
}
|
||||
|
||||
// Check rate limit before proceeding
|
||||
var authID string
|
||||
if auth != nil {
|
||||
authID = auth.ID
|
||||
}
|
||||
if err := checkQwenRateLimit(authID); err != nil {
|
||||
logWithRequestID(ctx).Warnf("qwen rate limit exceeded for credential %s", redactAuthID(authID))
|
||||
return resp, err
|
||||
}
|
||||
|
||||
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
||||
|
||||
token, baseURL := qwenCreds(auth)
|
||||
@@ -102,9 +256,8 @@ func (e *QwenExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
|
||||
return resp, err
|
||||
}
|
||||
applyQwenHeaders(httpReq, token, false)
|
||||
var authID, authLabel, authType, authValue string
|
||||
var authLabel, authType, authValue string
|
||||
if auth != nil {
|
||||
authID = auth.ID
|
||||
authLabel = auth.Label
|
||||
authType, authValue = auth.AccountInfo()
|
||||
}
|
||||
@@ -135,8 +288,10 @@ func (e *QwenExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
|
||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||
b, _ := io.ReadAll(httpResp.Body)
|
||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
|
||||
|
||||
errCode, retryAfter := wrapQwenError(ctx, httpResp.StatusCode, b)
|
||||
logWithRequestID(ctx).Debugf("request error, error status: %d (mapped: %d), error message: %s", httpResp.StatusCode, errCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||
err = statusErr{code: errCode, msg: string(b), retryAfter: retryAfter}
|
||||
return resp, err
|
||||
}
|
||||
data, err := io.ReadAll(httpResp.Body)
|
||||
@@ -158,6 +313,17 @@ func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
|
||||
if opts.Alt == "responses/compact" {
|
||||
return nil, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
|
||||
}
|
||||
|
||||
// Check rate limit before proceeding
|
||||
var authID string
|
||||
if auth != nil {
|
||||
authID = auth.ID
|
||||
}
|
||||
if err := checkQwenRateLimit(authID); err != nil {
|
||||
logWithRequestID(ctx).Warnf("qwen rate limit exceeded for credential %s", redactAuthID(authID))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
||||
|
||||
token, baseURL := qwenCreds(auth)
|
||||
@@ -200,9 +366,8 @@ func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
|
||||
return nil, err
|
||||
}
|
||||
applyQwenHeaders(httpReq, token, true)
|
||||
var authID, authLabel, authType, authValue string
|
||||
var authLabel, authType, authValue string
|
||||
if auth != nil {
|
||||
authID = auth.ID
|
||||
authLabel = auth.Label
|
||||
authType, authValue = auth.AccountInfo()
|
||||
}
|
||||
@@ -228,11 +393,13 @@ func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
|
||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||
b, _ := io.ReadAll(httpResp.Body)
|
||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||
|
||||
errCode, retryAfter := wrapQwenError(ctx, httpResp.StatusCode, b)
|
||||
logWithRequestID(ctx).Debugf("request error, error status: %d (mapped: %d), error message: %s", httpResp.StatusCode, errCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||
log.Errorf("qwen executor: close response body error: %v", errClose)
|
||||
}
|
||||
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
|
||||
err = statusErr{code: errCode, msg: string(b), retryAfter: retryAfter}
|
||||
return nil, err
|
||||
}
|
||||
out := make(chan cliproxyexecutor.StreamChunk)
|
||||
|
||||
@@ -10,53 +10,10 @@ import (
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
// validReasoningEffortLevels contains the standard values accepted by the
|
||||
// OpenAI reasoning_effort field. Provider-specific extensions (xhigh, minimal,
|
||||
// auto) are NOT in this set and must be clamped before use.
|
||||
var validReasoningEffortLevels = map[string]struct{}{
|
||||
"none": {},
|
||||
"low": {},
|
||||
"medium": {},
|
||||
"high": {},
|
||||
}
|
||||
|
||||
// clampReasoningEffort maps any thinking level string to a value that is safe
|
||||
// to send as OpenAI reasoning_effort. Non-standard CPA-internal values are
|
||||
// mapped to the nearest standard equivalent.
|
||||
//
|
||||
// Mapping rules:
|
||||
// - none / low / medium / high → returned as-is (already valid)
|
||||
// - xhigh → "high" (nearest lower standard level)
|
||||
// - minimal → "low" (nearest higher standard level)
|
||||
// - auto → "medium" (reasonable default)
|
||||
// - anything else → "medium" (safe default)
|
||||
func clampReasoningEffort(level string) string {
|
||||
if _, ok := validReasoningEffortLevels[level]; ok {
|
||||
return level
|
||||
}
|
||||
var clamped string
|
||||
switch level {
|
||||
case string(thinking.LevelXHigh):
|
||||
clamped = string(thinking.LevelHigh)
|
||||
case string(thinking.LevelMinimal):
|
||||
clamped = string(thinking.LevelLow)
|
||||
case string(thinking.LevelAuto):
|
||||
clamped = string(thinking.LevelMedium)
|
||||
default:
|
||||
clamped = string(thinking.LevelMedium)
|
||||
}
|
||||
log.WithFields(log.Fields{
|
||||
"original": level,
|
||||
"clamped": clamped,
|
||||
}).Debug("openai: reasoning_effort clamped to nearest valid standard value")
|
||||
return clamped
|
||||
}
|
||||
|
||||
// Applier implements thinking.ProviderApplier for OpenAI models.
|
||||
//
|
||||
// OpenAI-specific behavior:
|
||||
@@ -101,7 +58,7 @@ func (a *Applier) Apply(body []byte, config thinking.ThinkingConfig, modelInfo *
|
||||
}
|
||||
|
||||
if config.Mode == thinking.ModeLevel {
|
||||
result, _ := sjson.SetBytes(body, "reasoning_effort", clampReasoningEffort(string(config.Level)))
|
||||
result, _ := sjson.SetBytes(body, "reasoning_effort", string(config.Level))
|
||||
return result, nil
|
||||
}
|
||||
|
||||
@@ -122,7 +79,7 @@ func (a *Applier) Apply(body []byte, config thinking.ThinkingConfig, modelInfo *
|
||||
return body, nil
|
||||
}
|
||||
|
||||
result, _ := sjson.SetBytes(body, "reasoning_effort", clampReasoningEffort(effort))
|
||||
result, _ := sjson.SetBytes(body, "reasoning_effort", effort)
|
||||
return result, nil
|
||||
}
|
||||
|
||||
@@ -157,7 +114,7 @@ func applyCompatibleOpenAI(body []byte, config thinking.ThinkingConfig) ([]byte,
|
||||
return body, nil
|
||||
}
|
||||
|
||||
result, _ := sjson.SetBytes(body, "reasoning_effort", clampReasoningEffort(effort))
|
||||
result, _ := sjson.SetBytes(body, "reasoning_effort", effort)
|
||||
return result, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -223,14 +223,65 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
functionResponseJSON, _ = sjson.Set(functionResponseJSON, "response.result", responseData)
|
||||
} else if functionResponseResult.IsArray() {
|
||||
frResults := functionResponseResult.Array()
|
||||
if len(frResults) == 1 {
|
||||
functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "response.result", frResults[0].Raw)
|
||||
nonImageCount := 0
|
||||
lastNonImageRaw := ""
|
||||
filteredJSON := "[]"
|
||||
imagePartsJSON := "[]"
|
||||
for _, fr := range frResults {
|
||||
if fr.Get("type").String() == "image" && fr.Get("source.type").String() == "base64" {
|
||||
inlineDataJSON := `{}`
|
||||
if mimeType := fr.Get("source.media_type").String(); mimeType != "" {
|
||||
inlineDataJSON, _ = sjson.Set(inlineDataJSON, "mimeType", mimeType)
|
||||
}
|
||||
if data := fr.Get("source.data").String(); data != "" {
|
||||
inlineDataJSON, _ = sjson.Set(inlineDataJSON, "data", data)
|
||||
}
|
||||
|
||||
imagePartJSON := `{}`
|
||||
imagePartJSON, _ = sjson.SetRaw(imagePartJSON, "inlineData", inlineDataJSON)
|
||||
imagePartsJSON, _ = sjson.SetRaw(imagePartsJSON, "-1", imagePartJSON)
|
||||
continue
|
||||
}
|
||||
|
||||
nonImageCount++
|
||||
lastNonImageRaw = fr.Raw
|
||||
filteredJSON, _ = sjson.SetRaw(filteredJSON, "-1", fr.Raw)
|
||||
}
|
||||
|
||||
if nonImageCount == 1 {
|
||||
functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "response.result", lastNonImageRaw)
|
||||
} else if nonImageCount > 1 {
|
||||
functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "response.result", filteredJSON)
|
||||
} else {
|
||||
functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "response.result", functionResponseResult.Raw)
|
||||
functionResponseJSON, _ = sjson.Set(functionResponseJSON, "response.result", "")
|
||||
}
|
||||
|
||||
// Place image data inside functionResponse.parts as inlineData
|
||||
// instead of as sibling parts in the outer content, to avoid
|
||||
// base64 data bloating the text context.
|
||||
if gjson.Get(imagePartsJSON, "#").Int() > 0 {
|
||||
functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "parts", imagePartsJSON)
|
||||
}
|
||||
|
||||
} else if functionResponseResult.IsObject() {
|
||||
functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "response.result", functionResponseResult.Raw)
|
||||
if functionResponseResult.Get("type").String() == "image" && functionResponseResult.Get("source.type").String() == "base64" {
|
||||
inlineDataJSON := `{}`
|
||||
if mimeType := functionResponseResult.Get("source.media_type").String(); mimeType != "" {
|
||||
inlineDataJSON, _ = sjson.Set(inlineDataJSON, "mimeType", mimeType)
|
||||
}
|
||||
if data := functionResponseResult.Get("source.data").String(); data != "" {
|
||||
inlineDataJSON, _ = sjson.Set(inlineDataJSON, "data", data)
|
||||
}
|
||||
|
||||
imagePartJSON := `{}`
|
||||
imagePartJSON, _ = sjson.SetRaw(imagePartJSON, "inlineData", inlineDataJSON)
|
||||
imagePartsJSON := "[]"
|
||||
imagePartsJSON, _ = sjson.SetRaw(imagePartsJSON, "-1", imagePartJSON)
|
||||
functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "parts", imagePartsJSON)
|
||||
functionResponseJSON, _ = sjson.Set(functionResponseJSON, "response.result", "")
|
||||
} else {
|
||||
functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "response.result", functionResponseResult.Raw)
|
||||
}
|
||||
} else if functionResponseResult.Raw != "" {
|
||||
functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "response.result", functionResponseResult.Raw)
|
||||
} else {
|
||||
@@ -248,7 +299,7 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
if sourceResult.Get("type").String() == "base64" {
|
||||
inlineDataJSON := `{}`
|
||||
if mimeType := sourceResult.Get("media_type").String(); mimeType != "" {
|
||||
inlineDataJSON, _ = sjson.Set(inlineDataJSON, "mime_type", mimeType)
|
||||
inlineDataJSON, _ = sjson.Set(inlineDataJSON, "mimeType", mimeType)
|
||||
}
|
||||
if data := sourceResult.Get("data").String(); data != "" {
|
||||
inlineDataJSON, _ = sjson.Set(inlineDataJSON, "data", data)
|
||||
|
||||
@@ -413,8 +413,8 @@ func TestConvertClaudeRequestToAntigravity_ImageContent(t *testing.T) {
|
||||
if !inlineData.Exists() {
|
||||
t.Error("inlineData should exist")
|
||||
}
|
||||
if inlineData.Get("mime_type").String() != "image/png" {
|
||||
t.Error("mime_type mismatch")
|
||||
if inlineData.Get("mimeType").String() != "image/png" {
|
||||
t.Error("mimeType mismatch")
|
||||
}
|
||||
if !strings.Contains(inlineData.Get("data").String(), "iVBORw0KGgo") {
|
||||
t.Error("data mismatch")
|
||||
@@ -740,6 +740,429 @@ func TestConvertClaudeRequestToAntigravity_ToolResultNullContent(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertClaudeRequestToAntigravity_ToolResultWithImage(t *testing.T) {
|
||||
// tool_result with array content containing text + image should place
|
||||
// image data inside functionResponse.parts as inlineData, not as a
|
||||
// sibling part in the outer content (to avoid base64 context bloat).
|
||||
inputJSON := []byte(`{
|
||||
"model": "claude-3-5-sonnet-20240620",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "Read-123-456",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "File content here"
|
||||
},
|
||||
{
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": "image/png",
|
||||
"data": "iVBORw0KGgoAAAANSUhEUg=="
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}`)
|
||||
|
||||
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5", inputJSON, false)
|
||||
outputStr := string(output)
|
||||
|
||||
if !gjson.Valid(outputStr) {
|
||||
t.Fatalf("Result is not valid JSON:\n%s", outputStr)
|
||||
}
|
||||
|
||||
// Image should be inside functionResponse.parts, not as outer sibling part
|
||||
funcResp := gjson.Get(outputStr, "request.contents.0.parts.0.functionResponse")
|
||||
if !funcResp.Exists() {
|
||||
t.Fatal("functionResponse should exist")
|
||||
}
|
||||
|
||||
// Text content should be in response.result
|
||||
resultText := funcResp.Get("response.result.text").String()
|
||||
if resultText != "File content here" {
|
||||
t.Errorf("Expected response.result.text = 'File content here', got '%s'", resultText)
|
||||
}
|
||||
|
||||
// Image should be in functionResponse.parts[0].inlineData
|
||||
inlineData := funcResp.Get("parts.0.inlineData")
|
||||
if !inlineData.Exists() {
|
||||
t.Fatal("functionResponse.parts[0].inlineData should exist")
|
||||
}
|
||||
if inlineData.Get("mimeType").String() != "image/png" {
|
||||
t.Errorf("Expected mimeType 'image/png', got '%s'", inlineData.Get("mimeType").String())
|
||||
}
|
||||
if !strings.Contains(inlineData.Get("data").String(), "iVBORw0KGgo") {
|
||||
t.Error("data mismatch")
|
||||
}
|
||||
|
||||
// Image should NOT be in outer parts (only functionResponse part should exist)
|
||||
outerParts := gjson.Get(outputStr, "request.contents.0.parts")
|
||||
if outerParts.IsArray() && len(outerParts.Array()) > 1 {
|
||||
t.Errorf("Expected only 1 outer part (functionResponse), got %d", len(outerParts.Array()))
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertClaudeRequestToAntigravity_ToolResultWithSingleImage(t *testing.T) {
|
||||
// tool_result with single image object as content should place
|
||||
// image data inside functionResponse.parts, not as outer sibling part.
|
||||
inputJSON := []byte(`{
|
||||
"model": "claude-3-5-sonnet-20240620",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "Read-789-012",
|
||||
"content": {
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": "image/jpeg",
|
||||
"data": "/9j/4AAQSkZJRgABAQ=="
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}`)
|
||||
|
||||
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5", inputJSON, false)
|
||||
outputStr := string(output)
|
||||
|
||||
if !gjson.Valid(outputStr) {
|
||||
t.Fatalf("Result is not valid JSON:\n%s", outputStr)
|
||||
}
|
||||
|
||||
funcResp := gjson.Get(outputStr, "request.contents.0.parts.0.functionResponse")
|
||||
if !funcResp.Exists() {
|
||||
t.Fatal("functionResponse should exist")
|
||||
}
|
||||
|
||||
// response.result should be empty (image only)
|
||||
if funcResp.Get("response.result").String() != "" {
|
||||
t.Errorf("Expected empty response.result for image-only content, got '%s'", funcResp.Get("response.result").String())
|
||||
}
|
||||
|
||||
// Image should be in functionResponse.parts[0].inlineData
|
||||
inlineData := funcResp.Get("parts.0.inlineData")
|
||||
if !inlineData.Exists() {
|
||||
t.Fatal("functionResponse.parts[0].inlineData should exist")
|
||||
}
|
||||
if inlineData.Get("mimeType").String() != "image/jpeg" {
|
||||
t.Errorf("Expected mimeType 'image/jpeg', got '%s'", inlineData.Get("mimeType").String())
|
||||
}
|
||||
|
||||
// Image should NOT be in outer parts
|
||||
outerParts := gjson.Get(outputStr, "request.contents.0.parts")
|
||||
if outerParts.IsArray() && len(outerParts.Array()) > 1 {
|
||||
t.Errorf("Expected only 1 outer part, got %d", len(outerParts.Array()))
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertClaudeRequestToAntigravity_ToolResultWithMultipleImagesAndTexts(t *testing.T) {
|
||||
// tool_result with array content: 2 text items + 2 images
|
||||
// All images go into functionResponse.parts, texts into response.result array
|
||||
inputJSON := []byte(`{
|
||||
"model": "claude-3-5-sonnet-20240620",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "Multi-001",
|
||||
"content": [
|
||||
{"type": "text", "text": "First text"},
|
||||
{
|
||||
"type": "image",
|
||||
"source": {"type": "base64", "media_type": "image/png", "data": "AAAA"}
|
||||
},
|
||||
{"type": "text", "text": "Second text"},
|
||||
{
|
||||
"type": "image",
|
||||
"source": {"type": "base64", "media_type": "image/jpeg", "data": "BBBB"}
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}`)
|
||||
|
||||
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5", inputJSON, false)
|
||||
outputStr := string(output)
|
||||
|
||||
if !gjson.Valid(outputStr) {
|
||||
t.Fatalf("Result is not valid JSON:\n%s", outputStr)
|
||||
}
|
||||
|
||||
funcResp := gjson.Get(outputStr, "request.contents.0.parts.0.functionResponse")
|
||||
if !funcResp.Exists() {
|
||||
t.Fatal("functionResponse should exist")
|
||||
}
|
||||
|
||||
// Multiple text items => response.result is an array
|
||||
resultArr := funcResp.Get("response.result")
|
||||
if !resultArr.IsArray() {
|
||||
t.Fatalf("Expected response.result to be an array, got: %s", resultArr.Raw)
|
||||
}
|
||||
results := resultArr.Array()
|
||||
if len(results) != 2 {
|
||||
t.Fatalf("Expected 2 result items, got %d", len(results))
|
||||
}
|
||||
|
||||
// Both images should be in functionResponse.parts
|
||||
imgParts := funcResp.Get("parts").Array()
|
||||
if len(imgParts) != 2 {
|
||||
t.Fatalf("Expected 2 image parts in functionResponse.parts, got %d", len(imgParts))
|
||||
}
|
||||
if imgParts[0].Get("inlineData.mimeType").String() != "image/png" {
|
||||
t.Errorf("Expected first image mimeType 'image/png', got '%s'", imgParts[0].Get("inlineData.mimeType").String())
|
||||
}
|
||||
if imgParts[0].Get("inlineData.data").String() != "AAAA" {
|
||||
t.Errorf("Expected first image data 'AAAA', got '%s'", imgParts[0].Get("inlineData.data").String())
|
||||
}
|
||||
if imgParts[1].Get("inlineData.mimeType").String() != "image/jpeg" {
|
||||
t.Errorf("Expected second image mimeType 'image/jpeg', got '%s'", imgParts[1].Get("inlineData.mimeType").String())
|
||||
}
|
||||
if imgParts[1].Get("inlineData.data").String() != "BBBB" {
|
||||
t.Errorf("Expected second image data 'BBBB', got '%s'", imgParts[1].Get("inlineData.data").String())
|
||||
}
|
||||
|
||||
// Only 1 outer part (the functionResponse itself)
|
||||
outerParts := gjson.Get(outputStr, "request.contents.0.parts").Array()
|
||||
if len(outerParts) != 1 {
|
||||
t.Errorf("Expected 1 outer part, got %d", len(outerParts))
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertClaudeRequestToAntigravity_ToolResultWithOnlyMultipleImages(t *testing.T) {
|
||||
// tool_result with only images (no text) — response.result should be empty string
|
||||
inputJSON := []byte(`{
|
||||
"model": "claude-3-5-sonnet-20240620",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "ImgOnly-001",
|
||||
"content": [
|
||||
{
|
||||
"type": "image",
|
||||
"source": {"type": "base64", "media_type": "image/png", "data": "PNG1"}
|
||||
},
|
||||
{
|
||||
"type": "image",
|
||||
"source": {"type": "base64", "media_type": "image/gif", "data": "GIF1"}
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}`)
|
||||
|
||||
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5", inputJSON, false)
|
||||
outputStr := string(output)
|
||||
|
||||
if !gjson.Valid(outputStr) {
|
||||
t.Fatalf("Result is not valid JSON:\n%s", outputStr)
|
||||
}
|
||||
|
||||
funcResp := gjson.Get(outputStr, "request.contents.0.parts.0.functionResponse")
|
||||
if !funcResp.Exists() {
|
||||
t.Fatal("functionResponse should exist")
|
||||
}
|
||||
|
||||
// No text => response.result should be empty string
|
||||
if funcResp.Get("response.result").String() != "" {
|
||||
t.Errorf("Expected empty response.result, got '%s'", funcResp.Get("response.result").String())
|
||||
}
|
||||
|
||||
// Both images in functionResponse.parts
|
||||
imgParts := funcResp.Get("parts").Array()
|
||||
if len(imgParts) != 2 {
|
||||
t.Fatalf("Expected 2 image parts, got %d", len(imgParts))
|
||||
}
|
||||
if imgParts[0].Get("inlineData.mimeType").String() != "image/png" {
|
||||
t.Error("first image mimeType mismatch")
|
||||
}
|
||||
if imgParts[1].Get("inlineData.mimeType").String() != "image/gif" {
|
||||
t.Error("second image mimeType mismatch")
|
||||
}
|
||||
|
||||
// Only 1 outer part
|
||||
outerParts := gjson.Get(outputStr, "request.contents.0.parts").Array()
|
||||
if len(outerParts) != 1 {
|
||||
t.Errorf("Expected 1 outer part, got %d", len(outerParts))
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertClaudeRequestToAntigravity_ToolResultImageNotBase64(t *testing.T) {
|
||||
// image with source.type != "base64" should be treated as non-image (falls through)
|
||||
inputJSON := []byte(`{
|
||||
"model": "claude-3-5-sonnet-20240620",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "NotB64-001",
|
||||
"content": [
|
||||
{"type": "text", "text": "some output"},
|
||||
{
|
||||
"type": "image",
|
||||
"source": {"type": "url", "url": "https://example.com/img.png"}
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}`)
|
||||
|
||||
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5", inputJSON, false)
|
||||
outputStr := string(output)
|
||||
|
||||
if !gjson.Valid(outputStr) {
|
||||
t.Fatalf("Result is not valid JSON:\n%s", outputStr)
|
||||
}
|
||||
|
||||
funcResp := gjson.Get(outputStr, "request.contents.0.parts.0.functionResponse")
|
||||
if !funcResp.Exists() {
|
||||
t.Fatal("functionResponse should exist")
|
||||
}
|
||||
|
||||
// Non-base64 image is treated as non-image, so it goes into the filtered results
|
||||
// along with the text item. Since there are 2 non-image items, result is array.
|
||||
resultArr := funcResp.Get("response.result")
|
||||
if !resultArr.IsArray() {
|
||||
t.Fatalf("Expected response.result to be an array (2 non-image items), got: %s", resultArr.Raw)
|
||||
}
|
||||
results := resultArr.Array()
|
||||
if len(results) != 2 {
|
||||
t.Fatalf("Expected 2 result items, got %d", len(results))
|
||||
}
|
||||
|
||||
// No functionResponse.parts (no base64 images collected)
|
||||
if funcResp.Get("parts").Exists() {
|
||||
t.Error("functionResponse.parts should NOT exist when no base64 images")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertClaudeRequestToAntigravity_ToolResultImageMissingData(t *testing.T) {
|
||||
// image with source.type=base64 but missing data field
|
||||
inputJSON := []byte(`{
|
||||
"model": "claude-3-5-sonnet-20240620",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "NoData-001",
|
||||
"content": [
|
||||
{"type": "text", "text": "output"},
|
||||
{
|
||||
"type": "image",
|
||||
"source": {"type": "base64", "media_type": "image/png"}
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}`)
|
||||
|
||||
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5", inputJSON, false)
|
||||
outputStr := string(output)
|
||||
|
||||
if !gjson.Valid(outputStr) {
|
||||
t.Fatalf("Result is not valid JSON:\n%s", outputStr)
|
||||
}
|
||||
|
||||
funcResp := gjson.Get(outputStr, "request.contents.0.parts.0.functionResponse")
|
||||
if !funcResp.Exists() {
|
||||
t.Fatal("functionResponse should exist")
|
||||
}
|
||||
|
||||
// The image is still classified as base64 image (type check passes),
|
||||
// but data field is missing => inlineData has mimeType but no data
|
||||
imgParts := funcResp.Get("parts").Array()
|
||||
if len(imgParts) != 1 {
|
||||
t.Fatalf("Expected 1 image part, got %d", len(imgParts))
|
||||
}
|
||||
if imgParts[0].Get("inlineData.mimeType").String() != "image/png" {
|
||||
t.Error("mimeType should still be set")
|
||||
}
|
||||
if imgParts[0].Get("inlineData.data").Exists() {
|
||||
t.Error("data should not exist when source.data is missing")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertClaudeRequestToAntigravity_ToolResultImageMissingMediaType(t *testing.T) {
|
||||
// image with source.type=base64 but missing media_type field
|
||||
inputJSON := []byte(`{
|
||||
"model": "claude-3-5-sonnet-20240620",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "NoMime-001",
|
||||
"content": [
|
||||
{"type": "text", "text": "output"},
|
||||
{
|
||||
"type": "image",
|
||||
"source": {"type": "base64", "data": "AAAA"}
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}`)
|
||||
|
||||
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5", inputJSON, false)
|
||||
outputStr := string(output)
|
||||
|
||||
if !gjson.Valid(outputStr) {
|
||||
t.Fatalf("Result is not valid JSON:\n%s", outputStr)
|
||||
}
|
||||
|
||||
funcResp := gjson.Get(outputStr, "request.contents.0.parts.0.functionResponse")
|
||||
if !funcResp.Exists() {
|
||||
t.Fatal("functionResponse should exist")
|
||||
}
|
||||
|
||||
// The image is still classified as base64 image,
|
||||
// but media_type is missing => inlineData has data but no mimeType
|
||||
imgParts := funcResp.Get("parts").Array()
|
||||
if len(imgParts) != 1 {
|
||||
t.Fatalf("Expected 1 image part, got %d", len(imgParts))
|
||||
}
|
||||
if imgParts[0].Get("inlineData.mimeType").Exists() {
|
||||
t.Error("mimeType should not exist when media_type is missing")
|
||||
}
|
||||
if imgParts[0].Get("inlineData.data").String() != "AAAA" {
|
||||
t.Error("data should still be set")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertClaudeRequestToAntigravity_ToolAndThinking_NoExistingSystem(t *testing.T) {
|
||||
// When tools + thinking but no system instruction, should create one with hint
|
||||
inputJSON := []byte(`{
|
||||
|
||||
@@ -93,3 +93,81 @@ func TestConvertGeminiRequestToAntigravity_ParallelFunctionCalls(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestFixCLIToolResponse_PreservesFunctionResponseParts(t *testing.T) {
|
||||
// When functionResponse contains a "parts" field with inlineData (from Claude
|
||||
// translator's image embedding), fixCLIToolResponse should preserve it as-is.
|
||||
// parseFunctionResponseRaw returns response.Raw for valid JSON objects,
|
||||
// so extra fields like "parts" survive the pipeline.
|
||||
input := `{
|
||||
"model": "claude-opus-4-6-thinking",
|
||||
"request": {
|
||||
"contents": [
|
||||
{
|
||||
"role": "model",
|
||||
"parts": [
|
||||
{
|
||||
"functionCall": {"name": "screenshot", "args": {}}
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "function",
|
||||
"parts": [
|
||||
{
|
||||
"functionResponse": {
|
||||
"id": "tool-001",
|
||||
"name": "screenshot",
|
||||
"response": {"result": "Screenshot taken"},
|
||||
"parts": [
|
||||
{"inlineData": {"mimeType": "image/png", "data": "iVBOR"}}
|
||||
]
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
}`
|
||||
|
||||
result, err := fixCLIToolResponse(input)
|
||||
if err != nil {
|
||||
t.Fatalf("fixCLIToolResponse failed: %v", err)
|
||||
}
|
||||
|
||||
// Find the function response content (role=function)
|
||||
contents := gjson.Get(result, "request.contents").Array()
|
||||
var funcContent gjson.Result
|
||||
for _, c := range contents {
|
||||
if c.Get("role").String() == "function" {
|
||||
funcContent = c
|
||||
break
|
||||
}
|
||||
}
|
||||
if !funcContent.Exists() {
|
||||
t.Fatal("function role content should exist in output")
|
||||
}
|
||||
|
||||
// The functionResponse should be preserved with its parts field
|
||||
funcResp := funcContent.Get("parts.0.functionResponse")
|
||||
if !funcResp.Exists() {
|
||||
t.Fatal("functionResponse should exist in output")
|
||||
}
|
||||
|
||||
// Verify the parts field with inlineData is preserved
|
||||
inlineParts := funcResp.Get("parts").Array()
|
||||
if len(inlineParts) != 1 {
|
||||
t.Fatalf("Expected 1 inlineData part in functionResponse.parts, got %d", len(inlineParts))
|
||||
}
|
||||
if inlineParts[0].Get("inlineData.mimeType").String() != "image/png" {
|
||||
t.Errorf("Expected mimeType 'image/png', got '%s'", inlineParts[0].Get("inlineData.mimeType").String())
|
||||
}
|
||||
if inlineParts[0].Get("inlineData.data").String() != "iVBOR" {
|
||||
t.Errorf("Expected data 'iVBOR', got '%s'", inlineParts[0].Get("inlineData.data").String())
|
||||
}
|
||||
|
||||
// Verify response.result is also preserved
|
||||
if funcResp.Get("response.result").String() != "Screenshot taken" {
|
||||
t.Errorf("Expected response.result 'Screenshot taken', got '%s'", funcResp.Get("response.result").String())
|
||||
}
|
||||
}
|
||||
|
||||
+3
-3
@@ -187,7 +187,7 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
if len(pieces) == 2 && len(pieces[1]) > 7 {
|
||||
mime := pieces[0]
|
||||
data := pieces[1][7:]
|
||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.mime_type", mime)
|
||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.mimeType", mime)
|
||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.data", data)
|
||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".thoughtSignature", geminiCLIFunctionThoughtSignature)
|
||||
p++
|
||||
@@ -201,7 +201,7 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
ext = sp[len(sp)-1]
|
||||
}
|
||||
if mimeType, ok := misc.MimeTypes[ext]; ok {
|
||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.mime_type", mimeType)
|
||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.mimeType", mimeType)
|
||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.data", fileData)
|
||||
p++
|
||||
} else {
|
||||
@@ -235,7 +235,7 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
if len(pieces) == 2 && len(pieces[1]) > 7 {
|
||||
mime := pieces[0]
|
||||
data := pieces[1][7:]
|
||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.mime_type", mime)
|
||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.mimeType", mime)
|
||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.data", data)
|
||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".thoughtSignature", geminiCLIFunctionThoughtSignature)
|
||||
p++
|
||||
|
||||
+2
-2
@@ -95,9 +95,9 @@ func ConvertAntigravityResponseToOpenAI(_ context.Context, _ string, originalReq
|
||||
if totalTokenCountResult := usageResult.Get("totalTokenCount"); totalTokenCountResult.Exists() {
|
||||
template, _ = sjson.Set(template, "usage.total_tokens", totalTokenCountResult.Int())
|
||||
}
|
||||
promptTokenCount := usageResult.Get("promptTokenCount").Int() - cachedTokenCount
|
||||
promptTokenCount := usageResult.Get("promptTokenCount").Int()
|
||||
thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int()
|
||||
template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount+thoughtsTokenCount)
|
||||
template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount)
|
||||
if thoughtsTokenCount > 0 {
|
||||
template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount)
|
||||
}
|
||||
|
||||
@@ -199,6 +199,21 @@ func ConvertOpenAIRequestToClaude(modelName string, inputRawJSON []byte, stream
|
||||
msg, _ = sjson.SetRaw(msg, "content.-1", imagePart)
|
||||
}
|
||||
}
|
||||
|
||||
case "file":
|
||||
fileData := part.Get("file.file_data").String()
|
||||
if strings.HasPrefix(fileData, "data:") {
|
||||
semicolonIdx := strings.Index(fileData, ";")
|
||||
commaIdx := strings.Index(fileData, ",")
|
||||
if semicolonIdx != -1 && commaIdx != -1 && commaIdx > semicolonIdx {
|
||||
mediaType := strings.TrimPrefix(fileData[:semicolonIdx], "data:")
|
||||
data := fileData[commaIdx+1:]
|
||||
docPart := `{"type":"document","source":{"type":"base64","media_type":"","data":""}}`
|
||||
docPart, _ = sjson.Set(docPart, "source.media_type", mediaType)
|
||||
docPart, _ = sjson.Set(docPart, "source.data", data)
|
||||
msg, _ = sjson.SetRaw(msg, "content.-1", docPart)
|
||||
}
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
@@ -155,6 +155,7 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte
|
||||
var textAggregate strings.Builder
|
||||
var partsJSON []string
|
||||
hasImage := false
|
||||
hasFile := false
|
||||
if parts := item.Get("content"); parts.Exists() && parts.IsArray() {
|
||||
parts.ForEach(func(_, part gjson.Result) bool {
|
||||
ptype := part.Get("type").String()
|
||||
@@ -207,6 +208,30 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte
|
||||
hasImage = true
|
||||
}
|
||||
}
|
||||
case "input_file":
|
||||
fileData := part.Get("file_data").String()
|
||||
if fileData != "" {
|
||||
mediaType := "application/octet-stream"
|
||||
data := fileData
|
||||
if strings.HasPrefix(fileData, "data:") {
|
||||
trimmed := strings.TrimPrefix(fileData, "data:")
|
||||
mediaAndData := strings.SplitN(trimmed, ";base64,", 2)
|
||||
if len(mediaAndData) == 2 {
|
||||
if mediaAndData[0] != "" {
|
||||
mediaType = mediaAndData[0]
|
||||
}
|
||||
data = mediaAndData[1]
|
||||
}
|
||||
}
|
||||
contentPart := `{"type":"document","source":{"type":"base64","media_type":"","data":""}}`
|
||||
contentPart, _ = sjson.Set(contentPart, "source.media_type", mediaType)
|
||||
contentPart, _ = sjson.Set(contentPart, "source.data", data)
|
||||
partsJSON = append(partsJSON, contentPart)
|
||||
if role == "" {
|
||||
role = "user"
|
||||
}
|
||||
hasFile = true
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
@@ -228,7 +253,7 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte
|
||||
if len(partsJSON) > 0 {
|
||||
msg := `{"role":"","content":[]}`
|
||||
msg, _ = sjson.Set(msg, "role", role)
|
||||
if len(partsJSON) == 1 && !hasImage {
|
||||
if len(partsJSON) == 1 && !hasImage && !hasFile {
|
||||
// Preserve legacy behavior for single text content
|
||||
msg, _ = sjson.Delete(msg, "content")
|
||||
textPart := gjson.Parse(partsJSON[0])
|
||||
|
||||
@@ -180,7 +180,19 @@ func ConvertOpenAIRequestToCodex(modelName string, inputRawJSON []byte, stream b
|
||||
msg, _ = sjson.SetRaw(msg, "content.-1", part)
|
||||
}
|
||||
case "file":
|
||||
// Files are not specified in examples; skip for now
|
||||
if role == "user" {
|
||||
fileData := it.Get("file.file_data").String()
|
||||
filename := it.Get("file.filename").String()
|
||||
if fileData != "" {
|
||||
part := `{}`
|
||||
part, _ = sjson.Set(part, "type", "input_file")
|
||||
part, _ = sjson.Set(part, "file_data", fileData)
|
||||
if filename != "" {
|
||||
part, _ = sjson.Set(part, "filename", filename)
|
||||
}
|
||||
msg, _ = sjson.SetRaw(msg, "content.-1", part)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -26,6 +26,8 @@ func ConvertOpenAIResponsesRequestToCodex(modelName string, inputRawJSON []byte,
|
||||
rawJSON, _ = sjson.DeleteBytes(rawJSON, "temperature")
|
||||
rawJSON, _ = sjson.DeleteBytes(rawJSON, "top_p")
|
||||
rawJSON, _ = sjson.DeleteBytes(rawJSON, "service_tier")
|
||||
rawJSON, _ = sjson.DeleteBytes(rawJSON, "truncation")
|
||||
rawJSON = applyResponsesCompactionCompatibility(rawJSON)
|
||||
|
||||
// Delete the user field as it is not supported by the Codex upstream.
|
||||
rawJSON, _ = sjson.DeleteBytes(rawJSON, "user")
|
||||
@@ -36,6 +38,23 @@ func ConvertOpenAIResponsesRequestToCodex(modelName string, inputRawJSON []byte,
|
||||
return rawJSON
|
||||
}
|
||||
|
||||
// applyResponsesCompactionCompatibility handles OpenAI Responses context_management.compaction
|
||||
// for Codex upstream compatibility.
|
||||
//
|
||||
// Codex /responses currently rejects context_management with:
|
||||
// {"detail":"Unsupported parameter: context_management"}.
|
||||
//
|
||||
// Compatibility strategy:
|
||||
// 1) Remove context_management before forwarding to Codex upstream.
|
||||
func applyResponsesCompactionCompatibility(rawJSON []byte) []byte {
|
||||
if !gjson.GetBytes(rawJSON, "context_management").Exists() {
|
||||
return rawJSON
|
||||
}
|
||||
|
||||
rawJSON, _ = sjson.DeleteBytes(rawJSON, "context_management")
|
||||
return rawJSON
|
||||
}
|
||||
|
||||
// convertSystemRoleToDeveloper traverses the input array and converts any message items
|
||||
// with role "system" to role "developer". This is necessary because Codex API does not
|
||||
// accept "system" role in the input array.
|
||||
|
||||
@@ -280,3 +280,41 @@ func TestUserFieldDeletion(t *testing.T) {
|
||||
t.Errorf("user field should be deleted, but it was found with value: %s", userField.Raw)
|
||||
}
|
||||
}
|
||||
|
||||
func TestContextManagementCompactionCompatibility(t *testing.T) {
|
||||
inputJSON := []byte(`{
|
||||
"model": "gpt-5.2",
|
||||
"context_management": [
|
||||
{
|
||||
"type": "compaction",
|
||||
"compact_threshold": 12000
|
||||
}
|
||||
],
|
||||
"input": [{"role":"user","content":"hello"}]
|
||||
}`)
|
||||
|
||||
output := ConvertOpenAIResponsesRequestToCodex("gpt-5.2", inputJSON, false)
|
||||
outputStr := string(output)
|
||||
|
||||
if gjson.Get(outputStr, "context_management").Exists() {
|
||||
t.Fatalf("context_management should be removed for Codex compatibility")
|
||||
}
|
||||
if gjson.Get(outputStr, "truncation").Exists() {
|
||||
t.Fatalf("truncation should be removed for Codex compatibility")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTruncationRemovedForCodexCompatibility(t *testing.T) {
|
||||
inputJSON := []byte(`{
|
||||
"model": "gpt-5.2",
|
||||
"truncation": "disabled",
|
||||
"input": [{"role":"user","content":"hello"}]
|
||||
}`)
|
||||
|
||||
output := ConvertOpenAIResponsesRequestToCodex("gpt-5.2", inputJSON, false)
|
||||
outputStr := string(output)
|
||||
|
||||
if gjson.Get(outputStr, "truncation").Exists() {
|
||||
t.Fatalf("truncation should be removed for Codex compatibility")
|
||||
}
|
||||
}
|
||||
|
||||
+1
-1
@@ -100,7 +100,7 @@ func ConvertCliResponseToOpenAI(_ context.Context, _ string, originalRequestRawJ
|
||||
}
|
||||
promptTokenCount := usageResult.Get("promptTokenCount").Int()
|
||||
thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int()
|
||||
template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount+thoughtsTokenCount)
|
||||
template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount)
|
||||
if thoughtsTokenCount > 0 {
|
||||
template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount)
|
||||
}
|
||||
|
||||
@@ -100,9 +100,9 @@ func ConvertGeminiResponseToOpenAI(_ context.Context, _ string, originalRequestR
|
||||
if totalTokenCountResult := usageResult.Get("totalTokenCount"); totalTokenCountResult.Exists() {
|
||||
baseTemplate, _ = sjson.Set(baseTemplate, "usage.total_tokens", totalTokenCountResult.Int())
|
||||
}
|
||||
promptTokenCount := usageResult.Get("promptTokenCount").Int() - cachedTokenCount
|
||||
promptTokenCount := usageResult.Get("promptTokenCount").Int()
|
||||
thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int()
|
||||
baseTemplate, _ = sjson.Set(baseTemplate, "usage.prompt_tokens", promptTokenCount+thoughtsTokenCount)
|
||||
baseTemplate, _ = sjson.Set(baseTemplate, "usage.prompt_tokens", promptTokenCount)
|
||||
if thoughtsTokenCount > 0 {
|
||||
baseTemplate, _ = sjson.Set(baseTemplate, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount)
|
||||
}
|
||||
@@ -297,7 +297,7 @@ func ConvertGeminiResponseToOpenAINonStream(_ context.Context, _ string, origina
|
||||
promptTokenCount := usageResult.Get("promptTokenCount").Int()
|
||||
thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int()
|
||||
cachedTokenCount := usageResult.Get("cachedContentTokenCount").Int()
|
||||
template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount+thoughtsTokenCount)
|
||||
template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount)
|
||||
if thoughtsTokenCount > 0 {
|
||||
template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount)
|
||||
}
|
||||
|
||||
@@ -531,8 +531,8 @@ func ConvertGeminiResponseToOpenAIResponses(_ context.Context, modelName string,
|
||||
|
||||
// usage mapping
|
||||
if um := root.Get("usageMetadata"); um.Exists() {
|
||||
// input tokens = prompt + thoughts
|
||||
input := um.Get("promptTokenCount").Int() + um.Get("thoughtsTokenCount").Int()
|
||||
// input tokens = prompt only (thoughts go to output)
|
||||
input := um.Get("promptTokenCount").Int()
|
||||
completed, _ = sjson.Set(completed, "response.usage.input_tokens", input)
|
||||
// cached token details: align with OpenAI "cached_tokens" semantics.
|
||||
completed, _ = sjson.Set(completed, "response.usage.input_tokens_details.cached_tokens", um.Get("cachedContentTokenCount").Int())
|
||||
@@ -737,8 +737,8 @@ func ConvertGeminiResponseToOpenAIResponsesNonStream(_ context.Context, _ string
|
||||
|
||||
// usage mapping
|
||||
if um := root.Get("usageMetadata"); um.Exists() {
|
||||
// input tokens = prompt + thoughts
|
||||
input := um.Get("promptTokenCount").Int() + um.Get("thoughtsTokenCount").Int()
|
||||
// input tokens = prompt only (thoughts go to output)
|
||||
input := um.Get("promptTokenCount").Int()
|
||||
resp, _ = sjson.Set(resp, "usage.input_tokens", input)
|
||||
// cached token details: align with OpenAI "cached_tokens" semantics.
|
||||
resp, _ = sjson.Set(resp, "usage.input_tokens_details.cached_tokens", um.Get("cachedContentTokenCount").Int())
|
||||
|
||||
Reference in New Issue
Block a user