Merge remote-tracking branch 'upstream/dev' into dev
# Conflicts: # internal/runtime/executor/antigravity_executor.go
This commit is contained in:
@@ -58,6 +58,7 @@ func main() {
|
|||||||
// Command-line flags to control the application's behavior.
|
// Command-line flags to control the application's behavior.
|
||||||
var login bool
|
var login bool
|
||||||
var codexLogin bool
|
var codexLogin bool
|
||||||
|
var codexDeviceLogin bool
|
||||||
var claudeLogin bool
|
var claudeLogin bool
|
||||||
var qwenLogin bool
|
var qwenLogin bool
|
||||||
var iflowLogin bool
|
var iflowLogin bool
|
||||||
@@ -76,6 +77,7 @@ func main() {
|
|||||||
// Define command-line flags for different operation modes.
|
// Define command-line flags for different operation modes.
|
||||||
flag.BoolVar(&login, "login", false, "Login Google Account")
|
flag.BoolVar(&login, "login", false, "Login Google Account")
|
||||||
flag.BoolVar(&codexLogin, "codex-login", false, "Login to Codex using OAuth")
|
flag.BoolVar(&codexLogin, "codex-login", false, "Login to Codex using OAuth")
|
||||||
|
flag.BoolVar(&codexDeviceLogin, "codex-device-login", false, "Login to Codex using device code flow")
|
||||||
flag.BoolVar(&claudeLogin, "claude-login", false, "Login to Claude using OAuth")
|
flag.BoolVar(&claudeLogin, "claude-login", false, "Login to Claude using OAuth")
|
||||||
flag.BoolVar(&qwenLogin, "qwen-login", false, "Login to Qwen using OAuth")
|
flag.BoolVar(&qwenLogin, "qwen-login", false, "Login to Qwen using OAuth")
|
||||||
flag.BoolVar(&iflowLogin, "iflow-login", false, "Login to iFlow using OAuth")
|
flag.BoolVar(&iflowLogin, "iflow-login", false, "Login to iFlow using OAuth")
|
||||||
@@ -467,6 +469,9 @@ func main() {
|
|||||||
} else if codexLogin {
|
} else if codexLogin {
|
||||||
// Handle Codex login
|
// Handle Codex login
|
||||||
cmd.DoCodexLogin(cfg, options)
|
cmd.DoCodexLogin(cfg, options)
|
||||||
|
} else if codexDeviceLogin {
|
||||||
|
// Handle Codex device-code login
|
||||||
|
cmd.DoCodexDeviceLogin(cfg, options)
|
||||||
} else if claudeLogin {
|
} else if claudeLogin {
|
||||||
// Handle Claude login
|
// Handle Claude login
|
||||||
cmd.DoClaudeLogin(cfg, options)
|
cmd.DoClaudeLogin(cfg, options)
|
||||||
|
|||||||
@@ -408,6 +408,9 @@ func (h *Handler) buildAuthFileEntry(auth *coreauth.Auth) gin.H {
|
|||||||
if !auth.LastRefreshedAt.IsZero() {
|
if !auth.LastRefreshedAt.IsZero() {
|
||||||
entry["last_refresh"] = auth.LastRefreshedAt
|
entry["last_refresh"] = auth.LastRefreshedAt
|
||||||
}
|
}
|
||||||
|
if !auth.NextRetryAfter.IsZero() {
|
||||||
|
entry["next_retry_after"] = auth.NextRetryAfter
|
||||||
|
}
|
||||||
if path != "" {
|
if path != "" {
|
||||||
entry["path"] = path
|
entry["path"] = path
|
||||||
entry["source"] = "file"
|
entry["source"] = "file"
|
||||||
@@ -947,11 +950,17 @@ func (h *Handler) saveTokenRecord(ctx context.Context, record *coreauth.Auth) (s
|
|||||||
if store == nil {
|
if store == nil {
|
||||||
return "", fmt.Errorf("token store unavailable")
|
return "", fmt.Errorf("token store unavailable")
|
||||||
}
|
}
|
||||||
|
if h.postAuthHook != nil {
|
||||||
|
if err := h.postAuthHook(ctx, record); err != nil {
|
||||||
|
return "", fmt.Errorf("post-auth hook failed: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
return store.Save(ctx, record)
|
return store.Save(ctx, record)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Handler) RequestAnthropicToken(c *gin.Context) {
|
func (h *Handler) RequestAnthropicToken(c *gin.Context) {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
ctx = PopulateAuthContext(ctx, c)
|
||||||
|
|
||||||
fmt.Println("Initializing Claude authentication...")
|
fmt.Println("Initializing Claude authentication...")
|
||||||
|
|
||||||
@@ -1096,6 +1105,7 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) {
|
|||||||
|
|
||||||
func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
ctx = PopulateAuthContext(ctx, c)
|
||||||
proxyHTTPClient := util.SetProxy(&h.cfg.SDKConfig, &http.Client{})
|
proxyHTTPClient := util.SetProxy(&h.cfg.SDKConfig, &http.Client{})
|
||||||
ctx = context.WithValue(ctx, oauth2.HTTPClient, proxyHTTPClient)
|
ctx = context.WithValue(ctx, oauth2.HTTPClient, proxyHTTPClient)
|
||||||
|
|
||||||
@@ -1354,6 +1364,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
|||||||
|
|
||||||
func (h *Handler) RequestCodexToken(c *gin.Context) {
|
func (h *Handler) RequestCodexToken(c *gin.Context) {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
ctx = PopulateAuthContext(ctx, c)
|
||||||
|
|
||||||
fmt.Println("Initializing Codex authentication...")
|
fmt.Println("Initializing Codex authentication...")
|
||||||
|
|
||||||
@@ -1499,6 +1510,7 @@ func (h *Handler) RequestCodexToken(c *gin.Context) {
|
|||||||
|
|
||||||
func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
ctx = PopulateAuthContext(ctx, c)
|
||||||
|
|
||||||
fmt.Println("Initializing Antigravity authentication...")
|
fmt.Println("Initializing Antigravity authentication...")
|
||||||
|
|
||||||
@@ -1663,6 +1675,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
|||||||
|
|
||||||
func (h *Handler) RequestQwenToken(c *gin.Context) {
|
func (h *Handler) RequestQwenToken(c *gin.Context) {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
ctx = PopulateAuthContext(ctx, c)
|
||||||
|
|
||||||
fmt.Println("Initializing Qwen authentication...")
|
fmt.Println("Initializing Qwen authentication...")
|
||||||
|
|
||||||
@@ -1718,6 +1731,7 @@ func (h *Handler) RequestQwenToken(c *gin.Context) {
|
|||||||
|
|
||||||
func (h *Handler) RequestKimiToken(c *gin.Context) {
|
func (h *Handler) RequestKimiToken(c *gin.Context) {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
ctx = PopulateAuthContext(ctx, c)
|
||||||
|
|
||||||
fmt.Println("Initializing Kimi authentication...")
|
fmt.Println("Initializing Kimi authentication...")
|
||||||
|
|
||||||
@@ -1794,6 +1808,7 @@ func (h *Handler) RequestKimiToken(c *gin.Context) {
|
|||||||
|
|
||||||
func (h *Handler) RequestIFlowToken(c *gin.Context) {
|
func (h *Handler) RequestIFlowToken(c *gin.Context) {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
ctx = PopulateAuthContext(ctx, c)
|
||||||
|
|
||||||
fmt.Println("Initializing iFlow authentication...")
|
fmt.Println("Initializing iFlow authentication...")
|
||||||
|
|
||||||
@@ -2412,3 +2427,12 @@ func (h *Handler) GetAuthStatus(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
c.JSON(http.StatusOK, gin.H{"status": "wait"})
|
c.JSON(http.StatusOK, gin.H{"status": "wait"})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// PopulateAuthContext extracts request info and adds it to the context
|
||||||
|
func PopulateAuthContext(ctx context.Context, c *gin.Context) context.Context {
|
||||||
|
info := &coreauth.RequestInfo{
|
||||||
|
Query: c.Request.URL.Query(),
|
||||||
|
Headers: c.Request.Header,
|
||||||
|
}
|
||||||
|
return coreauth.WithRequestInfo(ctx, info)
|
||||||
|
}
|
||||||
|
|||||||
@@ -47,6 +47,7 @@ type Handler struct {
|
|||||||
allowRemoteOverride bool
|
allowRemoteOverride bool
|
||||||
envSecret string
|
envSecret string
|
||||||
logDir string
|
logDir string
|
||||||
|
postAuthHook coreauth.PostAuthHook
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewHandler creates a new management handler instance.
|
// NewHandler creates a new management handler instance.
|
||||||
@@ -128,6 +129,11 @@ func (h *Handler) SetLogDirectory(dir string) {
|
|||||||
h.logDir = dir
|
h.logDir = dir
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetPostAuthHook registers a hook to be called after auth record creation but before persistence.
|
||||||
|
func (h *Handler) SetPostAuthHook(hook coreauth.PostAuthHook) {
|
||||||
|
h.postAuthHook = hook
|
||||||
|
}
|
||||||
|
|
||||||
// Middleware enforces access control for management endpoints.
|
// Middleware enforces access control for management endpoints.
|
||||||
// All requests (local and remote) require a valid management key.
|
// All requests (local and remote) require a valid management key.
|
||||||
// Additionally, remote access requires allow-remote-management=true.
|
// Additionally, remote access requires allow-remote-management=true.
|
||||||
|
|||||||
@@ -51,6 +51,7 @@ type serverOptionConfig struct {
|
|||||||
keepAliveEnabled bool
|
keepAliveEnabled bool
|
||||||
keepAliveTimeout time.Duration
|
keepAliveTimeout time.Duration
|
||||||
keepAliveOnTimeout func()
|
keepAliveOnTimeout func()
|
||||||
|
postAuthHook auth.PostAuthHook
|
||||||
}
|
}
|
||||||
|
|
||||||
// ServerOption customises HTTP server construction.
|
// ServerOption customises HTTP server construction.
|
||||||
@@ -111,6 +112,13 @@ func WithRequestLoggerFactory(factory func(*config.Config, string) logging.Reque
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// WithPostAuthHook registers a hook to be called after auth record creation.
|
||||||
|
func WithPostAuthHook(hook auth.PostAuthHook) ServerOption {
|
||||||
|
return func(cfg *serverOptionConfig) {
|
||||||
|
cfg.postAuthHook = hook
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Server represents the main API server.
|
// Server represents the main API server.
|
||||||
// It encapsulates the Gin engine, HTTP server, handlers, and configuration.
|
// It encapsulates the Gin engine, HTTP server, handlers, and configuration.
|
||||||
type Server struct {
|
type Server struct {
|
||||||
@@ -262,6 +270,9 @@ func NewServer(cfg *config.Config, authManager *auth.Manager, accessManager *sdk
|
|||||||
}
|
}
|
||||||
logDir := logging.ResolveLogDirectory(cfg)
|
logDir := logging.ResolveLogDirectory(cfg)
|
||||||
s.mgmt.SetLogDirectory(logDir)
|
s.mgmt.SetLogDirectory(logDir)
|
||||||
|
if optionState.postAuthHook != nil {
|
||||||
|
s.mgmt.SetPostAuthHook(optionState.postAuthHook)
|
||||||
|
}
|
||||||
s.localPassword = optionState.localPassword
|
s.localPassword = optionState.localPassword
|
||||||
|
|
||||||
// Setup routes
|
// Setup routes
|
||||||
|
|||||||
@@ -36,11 +36,21 @@ type ClaudeTokenStorage struct {
|
|||||||
|
|
||||||
// Expire is the timestamp when the current access token expires.
|
// Expire is the timestamp when the current access token expires.
|
||||||
Expire string `json:"expired"`
|
Expire string `json:"expired"`
|
||||||
|
|
||||||
|
// Metadata holds arbitrary key-value pairs injected via hooks.
|
||||||
|
// It is not exported to JSON directly to allow flattening during serialization.
|
||||||
|
Metadata map[string]any `json:"-"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetMetadata allows external callers to inject metadata into the storage before saving.
|
||||||
|
func (ts *ClaudeTokenStorage) SetMetadata(meta map[string]any) {
|
||||||
|
ts.Metadata = meta
|
||||||
}
|
}
|
||||||
|
|
||||||
// SaveTokenToFile serializes the Claude token storage to a JSON file.
|
// SaveTokenToFile serializes the Claude token storage to a JSON file.
|
||||||
// This method creates the necessary directory structure and writes the token
|
// This method creates the necessary directory structure and writes the token
|
||||||
// data in JSON format to the specified file path for persistent storage.
|
// data in JSON format to the specified file path for persistent storage.
|
||||||
|
// It merges any injected metadata into the top-level JSON object.
|
||||||
//
|
//
|
||||||
// Parameters:
|
// Parameters:
|
||||||
// - authFilePath: The full path where the token file should be saved
|
// - authFilePath: The full path where the token file should be saved
|
||||||
@@ -65,8 +75,14 @@ func (ts *ClaudeTokenStorage) SaveTokenToFile(authFilePath string) error {
|
|||||||
_ = f.Close()
|
_ = f.Close()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
// Merge metadata using helper
|
||||||
|
data, errMerge := misc.MergeMetadata(ts, ts.Metadata)
|
||||||
|
if errMerge != nil {
|
||||||
|
return fmt.Errorf("failed to merge metadata: %w", errMerge)
|
||||||
|
}
|
||||||
|
|
||||||
// Encode and write the token data as JSON
|
// Encode and write the token data as JSON
|
||||||
if err = json.NewEncoder(f).Encode(ts); err != nil {
|
if err = json.NewEncoder(f).Encode(data); err != nil {
|
||||||
return fmt.Errorf("failed to write token to file: %w", err)
|
return fmt.Errorf("failed to write token to file: %w", err)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -71,16 +71,26 @@ func (o *CodexAuth) GenerateAuthURL(state string, pkceCodes *PKCECodes) (string,
|
|||||||
// It performs an HTTP POST request to the OpenAI token endpoint with the provided
|
// It performs an HTTP POST request to the OpenAI token endpoint with the provided
|
||||||
// authorization code and PKCE verifier.
|
// authorization code and PKCE verifier.
|
||||||
func (o *CodexAuth) ExchangeCodeForTokens(ctx context.Context, code string, pkceCodes *PKCECodes) (*CodexAuthBundle, error) {
|
func (o *CodexAuth) ExchangeCodeForTokens(ctx context.Context, code string, pkceCodes *PKCECodes) (*CodexAuthBundle, error) {
|
||||||
|
return o.ExchangeCodeForTokensWithRedirect(ctx, code, RedirectURI, pkceCodes)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExchangeCodeForTokensWithRedirect exchanges an authorization code for tokens using
|
||||||
|
// a caller-provided redirect URI. This supports alternate auth flows such as device
|
||||||
|
// login while preserving the existing token parsing and storage behavior.
|
||||||
|
func (o *CodexAuth) ExchangeCodeForTokensWithRedirect(ctx context.Context, code, redirectURI string, pkceCodes *PKCECodes) (*CodexAuthBundle, error) {
|
||||||
if pkceCodes == nil {
|
if pkceCodes == nil {
|
||||||
return nil, fmt.Errorf("PKCE codes are required for token exchange")
|
return nil, fmt.Errorf("PKCE codes are required for token exchange")
|
||||||
}
|
}
|
||||||
|
if strings.TrimSpace(redirectURI) == "" {
|
||||||
|
return nil, fmt.Errorf("redirect URI is required for token exchange")
|
||||||
|
}
|
||||||
|
|
||||||
// Prepare token exchange request
|
// Prepare token exchange request
|
||||||
data := url.Values{
|
data := url.Values{
|
||||||
"grant_type": {"authorization_code"},
|
"grant_type": {"authorization_code"},
|
||||||
"client_id": {ClientID},
|
"client_id": {ClientID},
|
||||||
"code": {code},
|
"code": {code},
|
||||||
"redirect_uri": {RedirectURI},
|
"redirect_uri": {strings.TrimSpace(redirectURI)},
|
||||||
"code_verifier": {pkceCodes.CodeVerifier},
|
"code_verifier": {pkceCodes.CodeVerifier},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -266,6 +276,10 @@ func (o *CodexAuth) RefreshTokensWithRetry(ctx context.Context, refreshToken str
|
|||||||
if err == nil {
|
if err == nil {
|
||||||
return tokenData, nil
|
return tokenData, nil
|
||||||
}
|
}
|
||||||
|
if isNonRetryableRefreshErr(err) {
|
||||||
|
log.Warnf("Token refresh attempt %d failed with non-retryable error: %v", attempt+1, err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
lastErr = err
|
lastErr = err
|
||||||
log.Warnf("Token refresh attempt %d failed: %v", attempt+1, err)
|
log.Warnf("Token refresh attempt %d failed: %v", attempt+1, err)
|
||||||
@@ -274,6 +288,14 @@ func (o *CodexAuth) RefreshTokensWithRetry(ctx context.Context, refreshToken str
|
|||||||
return nil, fmt.Errorf("token refresh failed after %d attempts: %w", maxRetries, lastErr)
|
return nil, fmt.Errorf("token refresh failed after %d attempts: %w", maxRetries, lastErr)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func isNonRetryableRefreshErr(err error) bool {
|
||||||
|
if err == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
raw := strings.ToLower(err.Error())
|
||||||
|
return strings.Contains(raw, "refresh_token_reused")
|
||||||
|
}
|
||||||
|
|
||||||
// UpdateTokenStorage updates an existing CodexTokenStorage with new token data.
|
// UpdateTokenStorage updates an existing CodexTokenStorage with new token data.
|
||||||
// This is typically called after a successful token refresh to persist the new credentials.
|
// This is typically called after a successful token refresh to persist the new credentials.
|
||||||
func (o *CodexAuth) UpdateTokenStorage(storage *CodexTokenStorage, tokenData *CodexTokenData) {
|
func (o *CodexAuth) UpdateTokenStorage(storage *CodexTokenStorage, tokenData *CodexTokenData) {
|
||||||
|
|||||||
@@ -0,0 +1,44 @@
|
|||||||
|
package codex
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"sync/atomic"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
type roundTripFunc func(*http.Request) (*http.Response, error)
|
||||||
|
|
||||||
|
func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||||
|
return f(req)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRefreshTokensWithRetry_NonRetryableOnlyAttemptsOnce(t *testing.T) {
|
||||||
|
var calls int32
|
||||||
|
auth := &CodexAuth{
|
||||||
|
httpClient: &http.Client{
|
||||||
|
Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) {
|
||||||
|
atomic.AddInt32(&calls, 1)
|
||||||
|
return &http.Response{
|
||||||
|
StatusCode: http.StatusBadRequest,
|
||||||
|
Body: io.NopCloser(strings.NewReader(`{"error":"invalid_grant","code":"refresh_token_reused"}`)),
|
||||||
|
Header: make(http.Header),
|
||||||
|
Request: req,
|
||||||
|
}, nil
|
||||||
|
}),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := auth.RefreshTokensWithRetry(context.Background(), "dummy_refresh_token", 3)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatalf("expected error for non-retryable refresh failure")
|
||||||
|
}
|
||||||
|
if !strings.Contains(strings.ToLower(err.Error()), "refresh_token_reused") {
|
||||||
|
t.Fatalf("expected refresh_token_reused in error, got: %v", err)
|
||||||
|
}
|
||||||
|
if got := atomic.LoadInt32(&calls); got != 1 {
|
||||||
|
t.Fatalf("expected 1 refresh attempt, got %d", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -32,11 +32,21 @@ type CodexTokenStorage struct {
|
|||||||
Type string `json:"type"`
|
Type string `json:"type"`
|
||||||
// Expire is the timestamp when the current access token expires.
|
// Expire is the timestamp when the current access token expires.
|
||||||
Expire string `json:"expired"`
|
Expire string `json:"expired"`
|
||||||
|
|
||||||
|
// Metadata holds arbitrary key-value pairs injected via hooks.
|
||||||
|
// It is not exported to JSON directly to allow flattening during serialization.
|
||||||
|
Metadata map[string]any `json:"-"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetMetadata allows external callers to inject metadata into the storage before saving.
|
||||||
|
func (ts *CodexTokenStorage) SetMetadata(meta map[string]any) {
|
||||||
|
ts.Metadata = meta
|
||||||
}
|
}
|
||||||
|
|
||||||
// SaveTokenToFile serializes the Codex token storage to a JSON file.
|
// SaveTokenToFile serializes the Codex token storage to a JSON file.
|
||||||
// This method creates the necessary directory structure and writes the token
|
// This method creates the necessary directory structure and writes the token
|
||||||
// data in JSON format to the specified file path for persistent storage.
|
// data in JSON format to the specified file path for persistent storage.
|
||||||
|
// It merges any injected metadata into the top-level JSON object.
|
||||||
//
|
//
|
||||||
// Parameters:
|
// Parameters:
|
||||||
// - authFilePath: The full path where the token file should be saved
|
// - authFilePath: The full path where the token file should be saved
|
||||||
@@ -58,7 +68,13 @@ func (ts *CodexTokenStorage) SaveTokenToFile(authFilePath string) error {
|
|||||||
_ = f.Close()
|
_ = f.Close()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
if err = json.NewEncoder(f).Encode(ts); err != nil {
|
// Merge metadata using helper
|
||||||
|
data, errMerge := misc.MergeMetadata(ts, ts.Metadata)
|
||||||
|
if errMerge != nil {
|
||||||
|
return fmt.Errorf("failed to merge metadata: %w", errMerge)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = json.NewEncoder(f).Encode(data); err != nil {
|
||||||
return fmt.Errorf("failed to write token to file: %w", err)
|
return fmt.Errorf("failed to write token to file: %w", err)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -35,11 +35,21 @@ type GeminiTokenStorage struct {
|
|||||||
|
|
||||||
// Type indicates the authentication provider type, always "gemini" for this storage.
|
// Type indicates the authentication provider type, always "gemini" for this storage.
|
||||||
Type string `json:"type"`
|
Type string `json:"type"`
|
||||||
|
|
||||||
|
// Metadata holds arbitrary key-value pairs injected via hooks.
|
||||||
|
// It is not exported to JSON directly to allow flattening during serialization.
|
||||||
|
Metadata map[string]any `json:"-"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetMetadata allows external callers to inject metadata into the storage before saving.
|
||||||
|
func (ts *GeminiTokenStorage) SetMetadata(meta map[string]any) {
|
||||||
|
ts.Metadata = meta
|
||||||
}
|
}
|
||||||
|
|
||||||
// SaveTokenToFile serializes the Gemini token storage to a JSON file.
|
// SaveTokenToFile serializes the Gemini token storage to a JSON file.
|
||||||
// This method creates the necessary directory structure and writes the token
|
// This method creates the necessary directory structure and writes the token
|
||||||
// data in JSON format to the specified file path for persistent storage.
|
// data in JSON format to the specified file path for persistent storage.
|
||||||
|
// It merges any injected metadata into the top-level JSON object.
|
||||||
//
|
//
|
||||||
// Parameters:
|
// Parameters:
|
||||||
// - authFilePath: The full path where the token file should be saved
|
// - authFilePath: The full path where the token file should be saved
|
||||||
@@ -49,6 +59,11 @@ type GeminiTokenStorage struct {
|
|||||||
func (ts *GeminiTokenStorage) SaveTokenToFile(authFilePath string) error {
|
func (ts *GeminiTokenStorage) SaveTokenToFile(authFilePath string) error {
|
||||||
misc.LogSavingCredentials(authFilePath)
|
misc.LogSavingCredentials(authFilePath)
|
||||||
ts.Type = "gemini"
|
ts.Type = "gemini"
|
||||||
|
// Merge metadata using helper
|
||||||
|
data, errMerge := misc.MergeMetadata(ts, ts.Metadata)
|
||||||
|
if errMerge != nil {
|
||||||
|
return fmt.Errorf("failed to merge metadata: %w", errMerge)
|
||||||
|
}
|
||||||
if err := os.MkdirAll(filepath.Dir(authFilePath), 0700); err != nil {
|
if err := os.MkdirAll(filepath.Dir(authFilePath), 0700); err != nil {
|
||||||
return fmt.Errorf("failed to create directory: %v", err)
|
return fmt.Errorf("failed to create directory: %v", err)
|
||||||
}
|
}
|
||||||
@@ -63,7 +78,9 @@ func (ts *GeminiTokenStorage) SaveTokenToFile(authFilePath string) error {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
if err = json.NewEncoder(f).Encode(ts); err != nil {
|
enc := json.NewEncoder(f)
|
||||||
|
enc.SetIndent("", " ")
|
||||||
|
if err := enc.Encode(data); err != nil {
|
||||||
return fmt.Errorf("failed to write token to file: %w", err)
|
return fmt.Errorf("failed to write token to file: %w", err)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -21,6 +21,15 @@ type IFlowTokenStorage struct {
|
|||||||
Scope string `json:"scope"`
|
Scope string `json:"scope"`
|
||||||
Cookie string `json:"cookie"`
|
Cookie string `json:"cookie"`
|
||||||
Type string `json:"type"`
|
Type string `json:"type"`
|
||||||
|
|
||||||
|
// Metadata holds arbitrary key-value pairs injected via hooks.
|
||||||
|
// It is not exported to JSON directly to allow flattening during serialization.
|
||||||
|
Metadata map[string]any `json:"-"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetMetadata allows external callers to inject metadata into the storage before saving.
|
||||||
|
func (ts *IFlowTokenStorage) SetMetadata(meta map[string]any) {
|
||||||
|
ts.Metadata = meta
|
||||||
}
|
}
|
||||||
|
|
||||||
// SaveTokenToFile serialises the token storage to disk.
|
// SaveTokenToFile serialises the token storage to disk.
|
||||||
@@ -37,7 +46,13 @@ func (ts *IFlowTokenStorage) SaveTokenToFile(authFilePath string) error {
|
|||||||
}
|
}
|
||||||
defer func() { _ = f.Close() }()
|
defer func() { _ = f.Close() }()
|
||||||
|
|
||||||
if err = json.NewEncoder(f).Encode(ts); err != nil {
|
// Merge metadata using helper
|
||||||
|
data, errMerge := misc.MergeMetadata(ts, ts.Metadata)
|
||||||
|
if errMerge != nil {
|
||||||
|
return fmt.Errorf("failed to merge metadata: %w", errMerge)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = json.NewEncoder(f).Encode(data); err != nil {
|
||||||
return fmt.Errorf("iflow token: encode token failed: %w", err)
|
return fmt.Errorf("iflow token: encode token failed: %w", err)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -29,6 +29,15 @@ type KimiTokenStorage struct {
|
|||||||
Expired string `json:"expired,omitempty"`
|
Expired string `json:"expired,omitempty"`
|
||||||
// Type indicates the authentication provider type, always "kimi" for this storage.
|
// Type indicates the authentication provider type, always "kimi" for this storage.
|
||||||
Type string `json:"type"`
|
Type string `json:"type"`
|
||||||
|
|
||||||
|
// Metadata holds arbitrary key-value pairs injected via hooks.
|
||||||
|
// It is not exported to JSON directly to allow flattening during serialization.
|
||||||
|
Metadata map[string]any `json:"-"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetMetadata allows external callers to inject metadata into the storage before saving.
|
||||||
|
func (ts *KimiTokenStorage) SetMetadata(meta map[string]any) {
|
||||||
|
ts.Metadata = meta
|
||||||
}
|
}
|
||||||
|
|
||||||
// KimiTokenData holds the raw OAuth token response from Kimi.
|
// KimiTokenData holds the raw OAuth token response from Kimi.
|
||||||
@@ -86,9 +95,15 @@ func (ts *KimiTokenStorage) SaveTokenToFile(authFilePath string) error {
|
|||||||
_ = f.Close()
|
_ = f.Close()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
// Merge metadata using helper
|
||||||
|
data, errMerge := misc.MergeMetadata(ts, ts.Metadata)
|
||||||
|
if errMerge != nil {
|
||||||
|
return fmt.Errorf("failed to merge metadata: %w", errMerge)
|
||||||
|
}
|
||||||
|
|
||||||
encoder := json.NewEncoder(f)
|
encoder := json.NewEncoder(f)
|
||||||
encoder.SetIndent("", " ")
|
encoder.SetIndent("", " ")
|
||||||
if err = encoder.Encode(ts); err != nil {
|
if err = encoder.Encode(data); err != nil {
|
||||||
return fmt.Errorf("failed to write token to file: %w", err)
|
return fmt.Errorf("failed to write token to file: %w", err)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -30,11 +30,21 @@ type QwenTokenStorage struct {
|
|||||||
Type string `json:"type"`
|
Type string `json:"type"`
|
||||||
// Expire is the timestamp when the current access token expires.
|
// Expire is the timestamp when the current access token expires.
|
||||||
Expire string `json:"expired"`
|
Expire string `json:"expired"`
|
||||||
|
|
||||||
|
// Metadata holds arbitrary key-value pairs injected via hooks.
|
||||||
|
// It is not exported to JSON directly to allow flattening during serialization.
|
||||||
|
Metadata map[string]any `json:"-"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetMetadata allows external callers to inject metadata into the storage before saving.
|
||||||
|
func (ts *QwenTokenStorage) SetMetadata(meta map[string]any) {
|
||||||
|
ts.Metadata = meta
|
||||||
}
|
}
|
||||||
|
|
||||||
// SaveTokenToFile serializes the Qwen token storage to a JSON file.
|
// SaveTokenToFile serializes the Qwen token storage to a JSON file.
|
||||||
// This method creates the necessary directory structure and writes the token
|
// This method creates the necessary directory structure and writes the token
|
||||||
// data in JSON format to the specified file path for persistent storage.
|
// data in JSON format to the specified file path for persistent storage.
|
||||||
|
// It merges any injected metadata into the top-level JSON object.
|
||||||
//
|
//
|
||||||
// Parameters:
|
// Parameters:
|
||||||
// - authFilePath: The full path where the token file should be saved
|
// - authFilePath: The full path where the token file should be saved
|
||||||
@@ -56,7 +66,13 @@ func (ts *QwenTokenStorage) SaveTokenToFile(authFilePath string) error {
|
|||||||
_ = f.Close()
|
_ = f.Close()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
if err = json.NewEncoder(f).Encode(ts); err != nil {
|
// Merge metadata using helper
|
||||||
|
data, errMerge := misc.MergeMetadata(ts, ts.Metadata)
|
||||||
|
if errMerge != nil {
|
||||||
|
return fmt.Errorf("failed to merge metadata: %w", errMerge)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = json.NewEncoder(f).Encode(data); err != nil {
|
||||||
return fmt.Errorf("failed to write token to file: %w", err)
|
return fmt.Errorf("failed to write token to file: %w", err)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -0,0 +1,60 @@
|
|||||||
|
package cmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codex"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
|
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
codexLoginModeMetadataKey = "codex_login_mode"
|
||||||
|
codexLoginModeDevice = "device"
|
||||||
|
)
|
||||||
|
|
||||||
|
// DoCodexDeviceLogin triggers the Codex device-code flow while keeping the
|
||||||
|
// existing codex-login OAuth callback flow intact.
|
||||||
|
func DoCodexDeviceLogin(cfg *config.Config, options *LoginOptions) {
|
||||||
|
if options == nil {
|
||||||
|
options = &LoginOptions{}
|
||||||
|
}
|
||||||
|
|
||||||
|
promptFn := options.Prompt
|
||||||
|
if promptFn == nil {
|
||||||
|
promptFn = defaultProjectPrompt()
|
||||||
|
}
|
||||||
|
|
||||||
|
manager := newAuthManager()
|
||||||
|
|
||||||
|
authOpts := &sdkAuth.LoginOptions{
|
||||||
|
NoBrowser: options.NoBrowser,
|
||||||
|
CallbackPort: options.CallbackPort,
|
||||||
|
Metadata: map[string]string{
|
||||||
|
codexLoginModeMetadataKey: codexLoginModeDevice,
|
||||||
|
},
|
||||||
|
Prompt: promptFn,
|
||||||
|
}
|
||||||
|
|
||||||
|
_, savedPath, err := manager.Login(context.Background(), "codex", cfg, authOpts)
|
||||||
|
if err != nil {
|
||||||
|
if authErr, ok := errors.AsType[*codex.AuthenticationError](err); ok {
|
||||||
|
log.Error(codex.GetUserFriendlyMessage(authErr))
|
||||||
|
if authErr.Type == codex.ErrPortInUse.Type {
|
||||||
|
os.Exit(codex.ErrPortInUse.Code)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
fmt.Printf("Codex device authentication failed: %v\n", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if savedPath != "" {
|
||||||
|
fmt.Printf("Authentication saved to %s\n", savedPath)
|
||||||
|
}
|
||||||
|
fmt.Println("Codex device authentication successful!")
|
||||||
|
}
|
||||||
@@ -1,6 +1,7 @@
|
|||||||
package misc
|
package misc
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -24,3 +25,37 @@ func LogSavingCredentials(path string) {
|
|||||||
func LogCredentialSeparator() {
|
func LogCredentialSeparator() {
|
||||||
log.Debug(credentialSeparator)
|
log.Debug(credentialSeparator)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// MergeMetadata serializes the source struct into a map and merges the provided metadata into it.
|
||||||
|
func MergeMetadata(source any, metadata map[string]any) (map[string]any, error) {
|
||||||
|
var data map[string]any
|
||||||
|
|
||||||
|
// Fast path: if source is already a map, just copy it to avoid mutation of original
|
||||||
|
if srcMap, ok := source.(map[string]any); ok {
|
||||||
|
data = make(map[string]any, len(srcMap)+len(metadata))
|
||||||
|
for k, v := range srcMap {
|
||||||
|
data[k] = v
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Slow path: marshal to JSON and back to map to respect JSON tags
|
||||||
|
temp, err := json.Marshal(source)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to marshal source: %w", err)
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(temp, &data); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to unmarshal to map: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Merge extra metadata
|
||||||
|
if metadata != nil {
|
||||||
|
if data == nil {
|
||||||
|
data = make(map[string]any)
|
||||||
|
}
|
||||||
|
for k, v := range metadata {
|
||||||
|
data[k] = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return data, nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -904,19 +904,12 @@ func GetIFlowModels() []*ModelInfo {
|
|||||||
Created int64
|
Created int64
|
||||||
Thinking *ThinkingSupport
|
Thinking *ThinkingSupport
|
||||||
}{
|
}{
|
||||||
{ID: "tstars2.0", DisplayName: "TStars-2.0", Description: "iFlow TStars-2.0 multimodal assistant", Created: 1746489600},
|
|
||||||
{ID: "qwen3-coder-plus", DisplayName: "Qwen3-Coder-Plus", Description: "Qwen3 Coder Plus code generation", Created: 1753228800},
|
{ID: "qwen3-coder-plus", DisplayName: "Qwen3-Coder-Plus", Description: "Qwen3 Coder Plus code generation", Created: 1753228800},
|
||||||
{ID: "qwen3-max", DisplayName: "Qwen3-Max", Description: "Qwen3 flagship model", Created: 1758672000},
|
{ID: "qwen3-max", DisplayName: "Qwen3-Max", Description: "Qwen3 flagship model", Created: 1758672000},
|
||||||
{ID: "qwen3-vl-plus", DisplayName: "Qwen3-VL-Plus", Description: "Qwen3 multimodal vision-language", Created: 1758672000},
|
{ID: "qwen3-vl-plus", DisplayName: "Qwen3-VL-Plus", Description: "Qwen3 multimodal vision-language", Created: 1758672000},
|
||||||
{ID: "qwen3-max-preview", DisplayName: "Qwen3-Max-Preview", Description: "Qwen3 Max preview build", Created: 1757030400, Thinking: iFlowThinkingSupport},
|
{ID: "qwen3-max-preview", DisplayName: "Qwen3-Max-Preview", Description: "Qwen3 Max preview build", Created: 1757030400, Thinking: iFlowThinkingSupport},
|
||||||
{ID: "kimi-k2-0905", DisplayName: "Kimi-K2-Instruct-0905", Description: "Moonshot Kimi K2 instruct 0905", Created: 1757030400},
|
|
||||||
{ID: "glm-4.6", DisplayName: "GLM-4.6", Description: "Zhipu GLM 4.6 general model", Created: 1759190400, Thinking: iFlowThinkingSupport},
|
{ID: "glm-4.6", DisplayName: "GLM-4.6", Description: "Zhipu GLM 4.6 general model", Created: 1759190400, Thinking: iFlowThinkingSupport},
|
||||||
{ID: "glm-4.7", DisplayName: "GLM-4.7", Description: "Zhipu GLM 4.7 general model", Created: 1766448000, Thinking: iFlowThinkingSupport},
|
|
||||||
{ID: "glm-5", DisplayName: "GLM-5", Description: "Zhipu GLM 5 general model", Created: 1770768000, Thinking: iFlowThinkingSupport},
|
|
||||||
{ID: "kimi-k2", DisplayName: "Kimi-K2", Description: "Moonshot Kimi K2 general model", Created: 1752192000},
|
{ID: "kimi-k2", DisplayName: "Kimi-K2", Description: "Moonshot Kimi K2 general model", Created: 1752192000},
|
||||||
{ID: "kimi-k2-thinking", DisplayName: "Kimi-K2-Thinking", Description: "Moonshot Kimi K2 thinking model", Created: 1762387200},
|
|
||||||
{ID: "deepseek-v3.2-chat", DisplayName: "DeepSeek-V3.2", Description: "DeepSeek V3.2 Chat", Created: 1764576000},
|
|
||||||
{ID: "deepseek-v3.2-reasoner", DisplayName: "DeepSeek-V3.2", Description: "DeepSeek V3.2 Reasoner", Created: 1764576000},
|
|
||||||
{ID: "deepseek-v3.2", DisplayName: "DeepSeek-V3.2-Exp", Description: "DeepSeek V3.2 experimental", Created: 1759104000, Thinking: iFlowThinkingSupport},
|
{ID: "deepseek-v3.2", DisplayName: "DeepSeek-V3.2-Exp", Description: "DeepSeek V3.2 experimental", Created: 1759104000, Thinking: iFlowThinkingSupport},
|
||||||
{ID: "deepseek-v3.1", DisplayName: "DeepSeek-V3.1-Terminus", Description: "DeepSeek V3.1 Terminus", Created: 1756339200, Thinking: iFlowThinkingSupport},
|
{ID: "deepseek-v3.1", DisplayName: "DeepSeek-V3.1-Terminus", Description: "DeepSeek V3.1 Terminus", Created: 1756339200, Thinking: iFlowThinkingSupport},
|
||||||
{ID: "deepseek-r1", DisplayName: "DeepSeek-R1", Description: "DeepSeek reasoning model R1", Created: 1737331200},
|
{ID: "deepseek-r1", DisplayName: "DeepSeek-R1", Description: "DeepSeek reasoning model R1", Created: 1737331200},
|
||||||
@@ -925,11 +918,7 @@ func GetIFlowModels() []*ModelInfo {
|
|||||||
{ID: "qwen3-235b-a22b-thinking-2507", DisplayName: "Qwen3-235B-A22B-Thinking", Description: "Qwen3 235B A22B Thinking (2507)", Created: 1753401600},
|
{ID: "qwen3-235b-a22b-thinking-2507", DisplayName: "Qwen3-235B-A22B-Thinking", Description: "Qwen3 235B A22B Thinking (2507)", Created: 1753401600},
|
||||||
{ID: "qwen3-235b-a22b-instruct", DisplayName: "Qwen3-235B-A22B-Instruct", Description: "Qwen3 235B A22B Instruct", Created: 1753401600},
|
{ID: "qwen3-235b-a22b-instruct", DisplayName: "Qwen3-235B-A22B-Instruct", Description: "Qwen3 235B A22B Instruct", Created: 1753401600},
|
||||||
{ID: "qwen3-235b", DisplayName: "Qwen3-235B-A22B", Description: "Qwen3 235B A22B", Created: 1753401600},
|
{ID: "qwen3-235b", DisplayName: "Qwen3-235B-A22B", Description: "Qwen3 235B A22B", Created: 1753401600},
|
||||||
{ID: "minimax-m2", DisplayName: "MiniMax-M2", Description: "MiniMax M2", Created: 1758672000, Thinking: iFlowThinkingSupport},
|
|
||||||
{ID: "minimax-m2.1", DisplayName: "MiniMax-M2.1", Description: "MiniMax M2.1", Created: 1766448000, Thinking: iFlowThinkingSupport},
|
|
||||||
{ID: "minimax-m2.5", DisplayName: "MiniMax-M2.5", Description: "MiniMax M2.5", Created: 1770825600, Thinking: iFlowThinkingSupport},
|
|
||||||
{ID: "iflow-rome-30ba3b", DisplayName: "iFlow-ROME", Description: "iFlow Rome 30BA3B model", Created: 1736899200},
|
{ID: "iflow-rome-30ba3b", DisplayName: "iFlow-ROME", Description: "iFlow Rome 30BA3B model", Created: 1736899200},
|
||||||
{ID: "kimi-k2.5", DisplayName: "Kimi-K2.5", Description: "Moonshot Kimi K2.5", Created: 1769443200, Thinking: iFlowThinkingSupport},
|
|
||||||
}
|
}
|
||||||
models := make([]*ModelInfo, 0, len(entries))
|
models := make([]*ModelInfo, 0, len(entries))
|
||||||
for _, entry := range entries {
|
for _, entry := range entries {
|
||||||
@@ -963,6 +952,7 @@ func GetAntigravityModelConfig() map[string]*AntigravityModelConfig {
|
|||||||
"gemini-2.5-flash-lite": {Thinking: &ThinkingSupport{Min: 0, Max: 24576, ZeroAllowed: true, DynamicAllowed: true}},
|
"gemini-2.5-flash-lite": {Thinking: &ThinkingSupport{Min: 0, Max: 24576, ZeroAllowed: true, DynamicAllowed: true}},
|
||||||
"gemini-3-pro-high": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}}},
|
"gemini-3-pro-high": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}}},
|
||||||
"gemini-3-pro-image": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}}},
|
"gemini-3-pro-image": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}}},
|
||||||
|
"gemini-3.1-pro-high": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}}},
|
||||||
"gemini-3-flash": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"minimal", "low", "medium", "high"}}},
|
"gemini-3-flash": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"minimal", "low", "medium", "high"}}},
|
||||||
"claude-sonnet-4-5-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: true}, MaxCompletionTokens: 64000},
|
"claude-sonnet-4-5-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: true}, MaxCompletionTokens: 64000},
|
||||||
"claude-opus-4-5-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: true}, MaxCompletionTokens: 64000},
|
"claude-opus-4-5-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: true}, MaxCompletionTokens: 64000},
|
||||||
|
|||||||
@@ -55,8 +55,78 @@ const (
|
|||||||
var (
|
var (
|
||||||
randSource = rand.New(rand.NewSource(time.Now().UnixNano()))
|
randSource = rand.New(rand.NewSource(time.Now().UnixNano()))
|
||||||
randSourceMutex sync.Mutex
|
randSourceMutex sync.Mutex
|
||||||
|
// antigravityPrimaryModelsCache keeps the latest non-empty model list fetched
|
||||||
|
// from any antigravity auth. Empty fetches never overwrite this cache.
|
||||||
|
antigravityPrimaryModelsCache struct {
|
||||||
|
mu sync.RWMutex
|
||||||
|
models []*registry.ModelInfo
|
||||||
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func cloneAntigravityModels(models []*registry.ModelInfo) []*registry.ModelInfo {
|
||||||
|
if len(models) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
out := make([]*registry.ModelInfo, 0, len(models))
|
||||||
|
for _, model := range models {
|
||||||
|
if model == nil || strings.TrimSpace(model.ID) == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
out = append(out, cloneAntigravityModelInfo(model))
|
||||||
|
}
|
||||||
|
if len(out) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func cloneAntigravityModelInfo(model *registry.ModelInfo) *registry.ModelInfo {
|
||||||
|
if model == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
clone := *model
|
||||||
|
if len(model.SupportedGenerationMethods) > 0 {
|
||||||
|
clone.SupportedGenerationMethods = append([]string(nil), model.SupportedGenerationMethods...)
|
||||||
|
}
|
||||||
|
if len(model.SupportedParameters) > 0 {
|
||||||
|
clone.SupportedParameters = append([]string(nil), model.SupportedParameters...)
|
||||||
|
}
|
||||||
|
if model.Thinking != nil {
|
||||||
|
thinkingClone := *model.Thinking
|
||||||
|
if len(model.Thinking.Levels) > 0 {
|
||||||
|
thinkingClone.Levels = append([]string(nil), model.Thinking.Levels...)
|
||||||
|
}
|
||||||
|
clone.Thinking = &thinkingClone
|
||||||
|
}
|
||||||
|
return &clone
|
||||||
|
}
|
||||||
|
|
||||||
|
func storeAntigravityPrimaryModels(models []*registry.ModelInfo) bool {
|
||||||
|
cloned := cloneAntigravityModels(models)
|
||||||
|
if len(cloned) == 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
antigravityPrimaryModelsCache.mu.Lock()
|
||||||
|
antigravityPrimaryModelsCache.models = cloned
|
||||||
|
antigravityPrimaryModelsCache.mu.Unlock()
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func loadAntigravityPrimaryModels() []*registry.ModelInfo {
|
||||||
|
antigravityPrimaryModelsCache.mu.RLock()
|
||||||
|
cloned := cloneAntigravityModels(antigravityPrimaryModelsCache.models)
|
||||||
|
antigravityPrimaryModelsCache.mu.RUnlock()
|
||||||
|
return cloned
|
||||||
|
}
|
||||||
|
|
||||||
|
func fallbackAntigravityPrimaryModels() []*registry.ModelInfo {
|
||||||
|
models := loadAntigravityPrimaryModels()
|
||||||
|
if len(models) > 0 {
|
||||||
|
log.Debugf("antigravity executor: using cached primary model list (%d models)", len(models))
|
||||||
|
}
|
||||||
|
return models
|
||||||
|
}
|
||||||
|
|
||||||
// AntigravityExecutor proxies requests to the antigravity upstream.
|
// AntigravityExecutor proxies requests to the antigravity upstream.
|
||||||
type AntigravityExecutor struct {
|
type AntigravityExecutor struct {
|
||||||
cfg *config.Config
|
cfg *config.Config
|
||||||
@@ -1072,7 +1142,7 @@ func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *c
|
|||||||
exec := &AntigravityExecutor{cfg: cfg}
|
exec := &AntigravityExecutor{cfg: cfg}
|
||||||
token, updatedAuth, errToken := exec.ensureAccessToken(ctx, auth)
|
token, updatedAuth, errToken := exec.ensureAccessToken(ctx, auth)
|
||||||
if errToken != nil || token == "" {
|
if errToken != nil || token == "" {
|
||||||
return nil
|
return fallbackAntigravityPrimaryModels()
|
||||||
}
|
}
|
||||||
if updatedAuth != nil {
|
if updatedAuth != nil {
|
||||||
auth = updatedAuth
|
auth = updatedAuth
|
||||||
@@ -1096,7 +1166,7 @@ func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *c
|
|||||||
|
|
||||||
httpReq, errReq := http.NewRequestWithContext(ctx, http.MethodPost, modelsURL, bytes.NewReader(payload))
|
httpReq, errReq := http.NewRequestWithContext(ctx, http.MethodPost, modelsURL, bytes.NewReader(payload))
|
||||||
if errReq != nil {
|
if errReq != nil {
|
||||||
return nil
|
return fallbackAntigravityPrimaryModels()
|
||||||
}
|
}
|
||||||
httpReq.Close = true
|
httpReq.Close = true
|
||||||
httpReq.Header.Set("Content-Type", "application/json")
|
httpReq.Header.Set("Content-Type", "application/json")
|
||||||
@@ -1109,14 +1179,13 @@ func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *c
|
|||||||
httpResp, errDo := httpClient.Do(httpReq)
|
httpResp, errDo := httpClient.Do(httpReq)
|
||||||
if errDo != nil {
|
if errDo != nil {
|
||||||
if errors.Is(errDo, context.Canceled) || errors.Is(errDo, context.DeadlineExceeded) {
|
if errors.Is(errDo, context.Canceled) || errors.Is(errDo, context.DeadlineExceeded) {
|
||||||
return nil
|
return fallbackAntigravityPrimaryModels()
|
||||||
}
|
}
|
||||||
if idx+1 < len(baseURLs) {
|
if idx+1 < len(baseURLs) {
|
||||||
log.Debugf("antigravity executor: models request error on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
log.Debugf("antigravity executor: models request error on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
log.Errorf("antigravity executor: models request failed: %v", errDo)
|
return fallbackAntigravityPrimaryModels()
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
bodyBytes, errRead := io.ReadAll(httpResp.Body)
|
bodyBytes, errRead := io.ReadAll(httpResp.Body)
|
||||||
@@ -1128,21 +1197,27 @@ func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *c
|
|||||||
log.Debugf("antigravity executor: models read error on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
log.Debugf("antigravity executor: models read error on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
log.Errorf("antigravity executor: models read body failed: %v", errRead)
|
return fallbackAntigravityPrimaryModels()
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
if httpResp.StatusCode < http.StatusOK || httpResp.StatusCode >= http.StatusMultipleChoices {
|
if httpResp.StatusCode < http.StatusOK || httpResp.StatusCode >= http.StatusMultipleChoices {
|
||||||
if httpResp.StatusCode == http.StatusTooManyRequests && idx+1 < len(baseURLs) {
|
if httpResp.StatusCode == http.StatusTooManyRequests && idx+1 < len(baseURLs) {
|
||||||
log.Debugf("antigravity executor: models request rate limited on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
log.Debugf("antigravity executor: models request rate limited on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
log.Errorf("antigravity executor: models request error status %d: %s", httpResp.StatusCode, string(bodyBytes))
|
if idx+1 < len(baseURLs) {
|
||||||
return nil
|
log.Debugf("antigravity executor: models request failed with status %d on base url %s, retrying with fallback base url: %s", httpResp.StatusCode, baseURL, baseURLs[idx+1])
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return fallbackAntigravityPrimaryModels()
|
||||||
}
|
}
|
||||||
|
|
||||||
result := gjson.GetBytes(bodyBytes, "models")
|
result := gjson.GetBytes(bodyBytes, "models")
|
||||||
if !result.Exists() {
|
if !result.Exists() {
|
||||||
return nil
|
if idx+1 < len(baseURLs) {
|
||||||
|
log.Debugf("antigravity executor: models field missing on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return fallbackAntigravityPrimaryModels()
|
||||||
}
|
}
|
||||||
|
|
||||||
now := time.Now().Unix()
|
now := time.Now().Unix()
|
||||||
@@ -1210,9 +1285,18 @@ func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *c
|
|||||||
}
|
}
|
||||||
models = append(models, modelInfo)
|
models = append(models, modelInfo)
|
||||||
}
|
}
|
||||||
|
if len(models) == 0 {
|
||||||
|
if idx+1 < len(baseURLs) {
|
||||||
|
log.Debugf("antigravity executor: empty models list on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
log.Debug("antigravity executor: fetched empty model list; retaining cached primary model list")
|
||||||
|
return fallbackAntigravityPrimaryModels()
|
||||||
|
}
|
||||||
|
storeAntigravityPrimaryModels(models)
|
||||||
return models
|
return models
|
||||||
}
|
}
|
||||||
return nil
|
return fallbackAntigravityPrimaryModels()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *AntigravityExecutor) ensureAccessToken(ctx context.Context, auth *cliproxyauth.Auth) (string, *cliproxyauth.Auth, error) {
|
func (e *AntigravityExecutor) ensureAccessToken(ctx context.Context, auth *cliproxyauth.Auth) (string, *cliproxyauth.Auth, error) {
|
||||||
|
|||||||
@@ -0,0 +1,90 @@
|
|||||||
|
package executor
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||||
|
)
|
||||||
|
|
||||||
|
func resetAntigravityPrimaryModelsCacheForTest() {
|
||||||
|
antigravityPrimaryModelsCache.mu.Lock()
|
||||||
|
antigravityPrimaryModelsCache.models = nil
|
||||||
|
antigravityPrimaryModelsCache.mu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStoreAntigravityPrimaryModels_EmptyDoesNotOverwrite(t *testing.T) {
|
||||||
|
resetAntigravityPrimaryModelsCacheForTest()
|
||||||
|
t.Cleanup(resetAntigravityPrimaryModelsCacheForTest)
|
||||||
|
|
||||||
|
seed := []*registry.ModelInfo{
|
||||||
|
{ID: "claude-sonnet-4-5"},
|
||||||
|
{ID: "gemini-2.5-pro"},
|
||||||
|
}
|
||||||
|
if updated := storeAntigravityPrimaryModels(seed); !updated {
|
||||||
|
t.Fatal("expected non-empty model list to update primary cache")
|
||||||
|
}
|
||||||
|
|
||||||
|
if updated := storeAntigravityPrimaryModels(nil); updated {
|
||||||
|
t.Fatal("expected nil model list not to overwrite primary cache")
|
||||||
|
}
|
||||||
|
if updated := storeAntigravityPrimaryModels([]*registry.ModelInfo{}); updated {
|
||||||
|
t.Fatal("expected empty model list not to overwrite primary cache")
|
||||||
|
}
|
||||||
|
|
||||||
|
got := loadAntigravityPrimaryModels()
|
||||||
|
if len(got) != 2 {
|
||||||
|
t.Fatalf("expected cached model count 2, got %d", len(got))
|
||||||
|
}
|
||||||
|
if got[0].ID != "claude-sonnet-4-5" || got[1].ID != "gemini-2.5-pro" {
|
||||||
|
t.Fatalf("unexpected cached model ids: %q, %q", got[0].ID, got[1].ID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadAntigravityPrimaryModels_ReturnsClone(t *testing.T) {
|
||||||
|
resetAntigravityPrimaryModelsCacheForTest()
|
||||||
|
t.Cleanup(resetAntigravityPrimaryModelsCacheForTest)
|
||||||
|
|
||||||
|
if updated := storeAntigravityPrimaryModels([]*registry.ModelInfo{{
|
||||||
|
ID: "gpt-5",
|
||||||
|
DisplayName: "GPT-5",
|
||||||
|
SupportedGenerationMethods: []string{"generateContent"},
|
||||||
|
SupportedParameters: []string{"temperature"},
|
||||||
|
Thinking: ®istry.ThinkingSupport{
|
||||||
|
Levels: []string{"high"},
|
||||||
|
},
|
||||||
|
}}); !updated {
|
||||||
|
t.Fatal("expected model cache update")
|
||||||
|
}
|
||||||
|
|
||||||
|
got := loadAntigravityPrimaryModels()
|
||||||
|
if len(got) != 1 {
|
||||||
|
t.Fatalf("expected one cached model, got %d", len(got))
|
||||||
|
}
|
||||||
|
got[0].ID = "mutated-id"
|
||||||
|
if len(got[0].SupportedGenerationMethods) > 0 {
|
||||||
|
got[0].SupportedGenerationMethods[0] = "mutated-method"
|
||||||
|
}
|
||||||
|
if len(got[0].SupportedParameters) > 0 {
|
||||||
|
got[0].SupportedParameters[0] = "mutated-parameter"
|
||||||
|
}
|
||||||
|
if got[0].Thinking != nil && len(got[0].Thinking.Levels) > 0 {
|
||||||
|
got[0].Thinking.Levels[0] = "mutated-level"
|
||||||
|
}
|
||||||
|
|
||||||
|
again := loadAntigravityPrimaryModels()
|
||||||
|
if len(again) != 1 {
|
||||||
|
t.Fatalf("expected one cached model after mutation, got %d", len(again))
|
||||||
|
}
|
||||||
|
if again[0].ID != "gpt-5" {
|
||||||
|
t.Fatalf("expected cached model id to remain %q, got %q", "gpt-5", again[0].ID)
|
||||||
|
}
|
||||||
|
if len(again[0].SupportedGenerationMethods) == 0 || again[0].SupportedGenerationMethods[0] != "generateContent" {
|
||||||
|
t.Fatalf("expected cached generation methods to be unmutated, got %v", again[0].SupportedGenerationMethods)
|
||||||
|
}
|
||||||
|
if len(again[0].SupportedParameters) == 0 || again[0].SupportedParameters[0] != "temperature" {
|
||||||
|
t.Fatalf("expected cached supported parameters to be unmutated, got %v", again[0].SupportedParameters)
|
||||||
|
}
|
||||||
|
if again[0].Thinking == nil || len(again[0].Thinking.Levels) == 0 || again[0].Thinking.Levels[0] != "high" {
|
||||||
|
t.Fatalf("expected cached model thinking levels to be unmutated, got %v", again[0].Thinking)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -156,7 +156,7 @@ func (e *CodexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
|
|||||||
b, _ := io.ReadAll(httpResp.Body)
|
b, _ := io.ReadAll(httpResp.Body)
|
||||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||||
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||||
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
|
err = newCodexStatusErr(httpResp.StatusCode, b)
|
||||||
return resp, err
|
return resp, err
|
||||||
}
|
}
|
||||||
data, err := io.ReadAll(httpResp.Body)
|
data, err := io.ReadAll(httpResp.Body)
|
||||||
@@ -260,7 +260,7 @@ func (e *CodexExecutor) executeCompact(ctx context.Context, auth *cliproxyauth.A
|
|||||||
b, _ := io.ReadAll(httpResp.Body)
|
b, _ := io.ReadAll(httpResp.Body)
|
||||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||||
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||||
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
|
err = newCodexStatusErr(httpResp.StatusCode, b)
|
||||||
return resp, err
|
return resp, err
|
||||||
}
|
}
|
||||||
data, err := io.ReadAll(httpResp.Body)
|
data, err := io.ReadAll(httpResp.Body)
|
||||||
@@ -358,7 +358,7 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
|
|||||||
}
|
}
|
||||||
appendAPIResponseChunk(ctx, e.cfg, data)
|
appendAPIResponseChunk(ctx, e.cfg, data)
|
||||||
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), data))
|
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), data))
|
||||||
err = statusErr{code: httpResp.StatusCode, msg: string(data)}
|
err = newCodexStatusErr(httpResp.StatusCode, data)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
out := make(chan cliproxyexecutor.StreamChunk)
|
out := make(chan cliproxyexecutor.StreamChunk)
|
||||||
@@ -673,6 +673,35 @@ func applyCodexHeaders(r *http.Request, auth *cliproxyauth.Auth, token string, s
|
|||||||
util.ApplyCustomHeadersFromAttrs(r, attrs)
|
util.ApplyCustomHeadersFromAttrs(r, attrs)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func newCodexStatusErr(statusCode int, body []byte) statusErr {
|
||||||
|
err := statusErr{code: statusCode, msg: string(body)}
|
||||||
|
if retryAfter := parseCodexRetryAfter(statusCode, body, time.Now()); retryAfter != nil {
|
||||||
|
err.retryAfter = retryAfter
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseCodexRetryAfter(statusCode int, errorBody []byte, now time.Time) *time.Duration {
|
||||||
|
if statusCode != http.StatusTooManyRequests || len(errorBody) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(gjson.GetBytes(errorBody, "error.type").String()) != "usage_limit_reached" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if resetsAt := gjson.GetBytes(errorBody, "error.resets_at").Int(); resetsAt > 0 {
|
||||||
|
resetAtTime := time.Unix(resetsAt, 0)
|
||||||
|
if resetAtTime.After(now) {
|
||||||
|
retryAfter := resetAtTime.Sub(now)
|
||||||
|
return &retryAfter
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if resetsInSeconds := gjson.GetBytes(errorBody, "error.resets_in_seconds").Int(); resetsInSeconds > 0 {
|
||||||
|
retryAfter := time.Duration(resetsInSeconds) * time.Second
|
||||||
|
return &retryAfter
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func codexCreds(a *cliproxyauth.Auth) (apiKey, baseURL string) {
|
func codexCreds(a *cliproxyauth.Auth) (apiKey, baseURL string) {
|
||||||
if a == nil {
|
if a == nil {
|
||||||
return "", ""
|
return "", ""
|
||||||
|
|||||||
@@ -0,0 +1,65 @@
|
|||||||
|
package executor
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"strconv"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestParseCodexRetryAfter(t *testing.T) {
|
||||||
|
now := time.Unix(1_700_000_000, 0)
|
||||||
|
|
||||||
|
t.Run("resets_in_seconds", func(t *testing.T) {
|
||||||
|
body := []byte(`{"error":{"type":"usage_limit_reached","resets_in_seconds":123}}`)
|
||||||
|
retryAfter := parseCodexRetryAfter(http.StatusTooManyRequests, body, now)
|
||||||
|
if retryAfter == nil {
|
||||||
|
t.Fatalf("expected retryAfter, got nil")
|
||||||
|
}
|
||||||
|
if *retryAfter != 123*time.Second {
|
||||||
|
t.Fatalf("retryAfter = %v, want %v", *retryAfter, 123*time.Second)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("prefers resets_at", func(t *testing.T) {
|
||||||
|
resetAt := now.Add(5 * time.Minute).Unix()
|
||||||
|
body := []byte(`{"error":{"type":"usage_limit_reached","resets_at":` + itoa(resetAt) + `,"resets_in_seconds":1}}`)
|
||||||
|
retryAfter := parseCodexRetryAfter(http.StatusTooManyRequests, body, now)
|
||||||
|
if retryAfter == nil {
|
||||||
|
t.Fatalf("expected retryAfter, got nil")
|
||||||
|
}
|
||||||
|
if *retryAfter != 5*time.Minute {
|
||||||
|
t.Fatalf("retryAfter = %v, want %v", *retryAfter, 5*time.Minute)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("fallback when resets_at is past", func(t *testing.T) {
|
||||||
|
resetAt := now.Add(-1 * time.Minute).Unix()
|
||||||
|
body := []byte(`{"error":{"type":"usage_limit_reached","resets_at":` + itoa(resetAt) + `,"resets_in_seconds":77}}`)
|
||||||
|
retryAfter := parseCodexRetryAfter(http.StatusTooManyRequests, body, now)
|
||||||
|
if retryAfter == nil {
|
||||||
|
t.Fatalf("expected retryAfter, got nil")
|
||||||
|
}
|
||||||
|
if *retryAfter != 77*time.Second {
|
||||||
|
t.Fatalf("retryAfter = %v, want %v", *retryAfter, 77*time.Second)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("non-429 status code", func(t *testing.T) {
|
||||||
|
body := []byte(`{"error":{"type":"usage_limit_reached","resets_in_seconds":30}}`)
|
||||||
|
if got := parseCodexRetryAfter(http.StatusBadRequest, body, now); got != nil {
|
||||||
|
t.Fatalf("expected nil for non-429, got %v", *got)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("non usage_limit_reached error type", func(t *testing.T) {
|
||||||
|
body := []byte(`{"error":{"type":"server_error","resets_in_seconds":30}}`)
|
||||||
|
if got := parseCodexRetryAfter(http.StatusTooManyRequests, body, now); got != nil {
|
||||||
|
t.Fatalf("expected nil for non-usage_limit_reached, got %v", *got)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func itoa(v int64) string {
|
||||||
|
return strconv.FormatInt(v, 10)
|
||||||
|
}
|
||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
qwenauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/qwen"
|
qwenauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/qwen"
|
||||||
@@ -23,8 +24,150 @@ import (
|
|||||||
|
|
||||||
const (
|
const (
|
||||||
qwenUserAgent = "QwenCode/0.10.3 (darwin; arm64)"
|
qwenUserAgent = "QwenCode/0.10.3 (darwin; arm64)"
|
||||||
|
qwenRateLimitPerMin = 60 // 60 requests per minute per credential
|
||||||
|
qwenRateLimitWindow = time.Minute // sliding window duration
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// qwenBeijingLoc caches the Beijing timezone to avoid repeated LoadLocation syscalls.
|
||||||
|
var qwenBeijingLoc = func() *time.Location {
|
||||||
|
loc, err := time.LoadLocation("Asia/Shanghai")
|
||||||
|
if err != nil || loc == nil {
|
||||||
|
log.Warnf("qwen: failed to load Asia/Shanghai timezone: %v, using fixed UTC+8", err)
|
||||||
|
return time.FixedZone("CST", 8*3600)
|
||||||
|
}
|
||||||
|
return loc
|
||||||
|
}()
|
||||||
|
|
||||||
|
// qwenQuotaCodes is a package-level set of error codes that indicate quota exhaustion.
|
||||||
|
var qwenQuotaCodes = map[string]struct{}{
|
||||||
|
"insufficient_quota": {},
|
||||||
|
"quota_exceeded": {},
|
||||||
|
}
|
||||||
|
|
||||||
|
// qwenRateLimiter tracks request timestamps per credential for rate limiting.
|
||||||
|
// Qwen has a limit of 60 requests per minute per account.
|
||||||
|
var qwenRateLimiter = struct {
|
||||||
|
sync.Mutex
|
||||||
|
requests map[string][]time.Time // authID -> request timestamps
|
||||||
|
}{
|
||||||
|
requests: make(map[string][]time.Time),
|
||||||
|
}
|
||||||
|
|
||||||
|
// redactAuthID returns a redacted version of the auth ID for safe logging.
|
||||||
|
// Keeps a small prefix/suffix to allow correlation across events.
|
||||||
|
func redactAuthID(id string) string {
|
||||||
|
if id == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if len(id) <= 8 {
|
||||||
|
return id
|
||||||
|
}
|
||||||
|
return id[:4] + "..." + id[len(id)-4:]
|
||||||
|
}
|
||||||
|
|
||||||
|
// checkQwenRateLimit checks if the credential has exceeded the rate limit.
|
||||||
|
// Returns nil if allowed, or a statusErr with retryAfter if rate limited.
|
||||||
|
func checkQwenRateLimit(authID string) error {
|
||||||
|
if authID == "" {
|
||||||
|
// Empty authID should not bypass rate limiting in production
|
||||||
|
// Use debug level to avoid log spam for certain auth flows
|
||||||
|
log.Debug("qwen rate limit check: empty authID, skipping rate limit")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
windowStart := now.Add(-qwenRateLimitWindow)
|
||||||
|
|
||||||
|
qwenRateLimiter.Lock()
|
||||||
|
defer qwenRateLimiter.Unlock()
|
||||||
|
|
||||||
|
// Get and filter timestamps within the window
|
||||||
|
timestamps := qwenRateLimiter.requests[authID]
|
||||||
|
var validTimestamps []time.Time
|
||||||
|
for _, ts := range timestamps {
|
||||||
|
if ts.After(windowStart) {
|
||||||
|
validTimestamps = append(validTimestamps, ts)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Always prune expired entries to prevent memory leak
|
||||||
|
// Delete empty entries, otherwise update with pruned slice
|
||||||
|
if len(validTimestamps) == 0 {
|
||||||
|
delete(qwenRateLimiter.requests, authID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if rate limit exceeded
|
||||||
|
if len(validTimestamps) >= qwenRateLimitPerMin {
|
||||||
|
// Calculate when the oldest request will expire
|
||||||
|
oldestInWindow := validTimestamps[0]
|
||||||
|
retryAfter := oldestInWindow.Add(qwenRateLimitWindow).Sub(now)
|
||||||
|
if retryAfter < time.Second {
|
||||||
|
retryAfter = time.Second
|
||||||
|
}
|
||||||
|
retryAfterSec := int(retryAfter.Seconds())
|
||||||
|
return statusErr{
|
||||||
|
code: http.StatusTooManyRequests,
|
||||||
|
msg: fmt.Sprintf(`{"error":{"code":"rate_limit_exceeded","message":"Qwen rate limit: %d requests/minute exceeded, retry after %ds","type":"rate_limit_exceeded"}}`, qwenRateLimitPerMin, retryAfterSec),
|
||||||
|
retryAfter: &retryAfter,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Record this request and update the map with pruned timestamps
|
||||||
|
validTimestamps = append(validTimestamps, now)
|
||||||
|
qwenRateLimiter.requests[authID] = validTimestamps
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// isQwenQuotaError checks if the error response indicates a quota exceeded error.
|
||||||
|
// Qwen returns HTTP 403 with error.code="insufficient_quota" when daily quota is exhausted.
|
||||||
|
func isQwenQuotaError(body []byte) bool {
|
||||||
|
code := strings.ToLower(gjson.GetBytes(body, "error.code").String())
|
||||||
|
errType := strings.ToLower(gjson.GetBytes(body, "error.type").String())
|
||||||
|
|
||||||
|
// Primary check: exact match on error.code or error.type (most reliable)
|
||||||
|
if _, ok := qwenQuotaCodes[code]; ok {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if _, ok := qwenQuotaCodes[errType]; ok {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fallback: check message only if code/type don't match (less reliable)
|
||||||
|
msg := strings.ToLower(gjson.GetBytes(body, "error.message").String())
|
||||||
|
if strings.Contains(msg, "insufficient_quota") || strings.Contains(msg, "quota exceeded") ||
|
||||||
|
strings.Contains(msg, "free allocated quota exceeded") {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// wrapQwenError wraps an HTTP error response, detecting quota errors and mapping them to 429.
|
||||||
|
// Returns the appropriate status code and retryAfter duration for statusErr.
|
||||||
|
// Only checks for quota errors when httpCode is 403 or 429 to avoid false positives.
|
||||||
|
func wrapQwenError(ctx context.Context, httpCode int, body []byte) (errCode int, retryAfter *time.Duration) {
|
||||||
|
errCode = httpCode
|
||||||
|
// Only check quota errors for expected status codes to avoid false positives
|
||||||
|
// Qwen returns 403 for quota errors, 429 for rate limits
|
||||||
|
if (httpCode == http.StatusForbidden || httpCode == http.StatusTooManyRequests) && isQwenQuotaError(body) {
|
||||||
|
errCode = http.StatusTooManyRequests // Map to 429 to trigger quota logic
|
||||||
|
cooldown := timeUntilNextDay()
|
||||||
|
retryAfter = &cooldown
|
||||||
|
logWithRequestID(ctx).Warnf("qwen quota exceeded (http %d -> %d), cooling down until tomorrow (%v)", httpCode, errCode, cooldown)
|
||||||
|
}
|
||||||
|
return errCode, retryAfter
|
||||||
|
}
|
||||||
|
|
||||||
|
// timeUntilNextDay returns duration until midnight Beijing time (UTC+8).
|
||||||
|
// Qwen's daily quota resets at 00:00 Beijing time.
|
||||||
|
func timeUntilNextDay() time.Duration {
|
||||||
|
now := time.Now()
|
||||||
|
nowLocal := now.In(qwenBeijingLoc)
|
||||||
|
tomorrow := time.Date(nowLocal.Year(), nowLocal.Month(), nowLocal.Day()+1, 0, 0, 0, 0, qwenBeijingLoc)
|
||||||
|
return tomorrow.Sub(now)
|
||||||
|
}
|
||||||
|
|
||||||
// QwenExecutor is a stateless executor for Qwen Code using OpenAI-compatible chat completions.
|
// QwenExecutor is a stateless executor for Qwen Code using OpenAI-compatible chat completions.
|
||||||
// If access token is unavailable, it falls back to legacy via ClientAdapter.
|
// If access token is unavailable, it falls back to legacy via ClientAdapter.
|
||||||
type QwenExecutor struct {
|
type QwenExecutor struct {
|
||||||
@@ -67,6 +210,17 @@ func (e *QwenExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
|
|||||||
if opts.Alt == "responses/compact" {
|
if opts.Alt == "responses/compact" {
|
||||||
return resp, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
|
return resp, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check rate limit before proceeding
|
||||||
|
var authID string
|
||||||
|
if auth != nil {
|
||||||
|
authID = auth.ID
|
||||||
|
}
|
||||||
|
if err := checkQwenRateLimit(authID); err != nil {
|
||||||
|
logWithRequestID(ctx).Warnf("qwen rate limit exceeded for credential %s", redactAuthID(authID))
|
||||||
|
return resp, err
|
||||||
|
}
|
||||||
|
|
||||||
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
||||||
|
|
||||||
token, baseURL := qwenCreds(auth)
|
token, baseURL := qwenCreds(auth)
|
||||||
@@ -102,9 +256,8 @@ func (e *QwenExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
|
|||||||
return resp, err
|
return resp, err
|
||||||
}
|
}
|
||||||
applyQwenHeaders(httpReq, token, false)
|
applyQwenHeaders(httpReq, token, false)
|
||||||
var authID, authLabel, authType, authValue string
|
var authLabel, authType, authValue string
|
||||||
if auth != nil {
|
if auth != nil {
|
||||||
authID = auth.ID
|
|
||||||
authLabel = auth.Label
|
authLabel = auth.Label
|
||||||
authType, authValue = auth.AccountInfo()
|
authType, authValue = auth.AccountInfo()
|
||||||
}
|
}
|
||||||
@@ -135,8 +288,10 @@ func (e *QwenExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
|
|||||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||||
b, _ := io.ReadAll(httpResp.Body)
|
b, _ := io.ReadAll(httpResp.Body)
|
||||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||||
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
|
||||||
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
|
errCode, retryAfter := wrapQwenError(ctx, httpResp.StatusCode, b)
|
||||||
|
logWithRequestID(ctx).Debugf("request error, error status: %d (mapped: %d), error message: %s", httpResp.StatusCode, errCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||||
|
err = statusErr{code: errCode, msg: string(b), retryAfter: retryAfter}
|
||||||
return resp, err
|
return resp, err
|
||||||
}
|
}
|
||||||
data, err := io.ReadAll(httpResp.Body)
|
data, err := io.ReadAll(httpResp.Body)
|
||||||
@@ -158,6 +313,17 @@ func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
|
|||||||
if opts.Alt == "responses/compact" {
|
if opts.Alt == "responses/compact" {
|
||||||
return nil, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
|
return nil, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check rate limit before proceeding
|
||||||
|
var authID string
|
||||||
|
if auth != nil {
|
||||||
|
authID = auth.ID
|
||||||
|
}
|
||||||
|
if err := checkQwenRateLimit(authID); err != nil {
|
||||||
|
logWithRequestID(ctx).Warnf("qwen rate limit exceeded for credential %s", redactAuthID(authID))
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
||||||
|
|
||||||
token, baseURL := qwenCreds(auth)
|
token, baseURL := qwenCreds(auth)
|
||||||
@@ -200,9 +366,8 @@ func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
applyQwenHeaders(httpReq, token, true)
|
applyQwenHeaders(httpReq, token, true)
|
||||||
var authID, authLabel, authType, authValue string
|
var authLabel, authType, authValue string
|
||||||
if auth != nil {
|
if auth != nil {
|
||||||
authID = auth.ID
|
|
||||||
authLabel = auth.Label
|
authLabel = auth.Label
|
||||||
authType, authValue = auth.AccountInfo()
|
authType, authValue = auth.AccountInfo()
|
||||||
}
|
}
|
||||||
@@ -228,11 +393,13 @@ func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
|
|||||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||||
b, _ := io.ReadAll(httpResp.Body)
|
b, _ := io.ReadAll(httpResp.Body)
|
||||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||||
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
|
||||||
|
errCode, retryAfter := wrapQwenError(ctx, httpResp.StatusCode, b)
|
||||||
|
logWithRequestID(ctx).Debugf("request error, error status: %d (mapped: %d), error message: %s", httpResp.StatusCode, errCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||||
if errClose := httpResp.Body.Close(); errClose != nil {
|
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||||
log.Errorf("qwen executor: close response body error: %v", errClose)
|
log.Errorf("qwen executor: close response body error: %v", errClose)
|
||||||
}
|
}
|
||||||
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
|
err = statusErr{code: errCode, msg: string(b), retryAfter: retryAfter}
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
out := make(chan cliproxyexecutor.StreamChunk)
|
out := make(chan cliproxyexecutor.StreamChunk)
|
||||||
|
|||||||
@@ -10,53 +10,10 @@ import (
|
|||||||
|
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
"github.com/tidwall/gjson"
|
"github.com/tidwall/gjson"
|
||||||
"github.com/tidwall/sjson"
|
"github.com/tidwall/sjson"
|
||||||
)
|
)
|
||||||
|
|
||||||
// validReasoningEffortLevels contains the standard values accepted by the
|
|
||||||
// OpenAI reasoning_effort field. Provider-specific extensions (xhigh, minimal,
|
|
||||||
// auto) are NOT in this set and must be clamped before use.
|
|
||||||
var validReasoningEffortLevels = map[string]struct{}{
|
|
||||||
"none": {},
|
|
||||||
"low": {},
|
|
||||||
"medium": {},
|
|
||||||
"high": {},
|
|
||||||
}
|
|
||||||
|
|
||||||
// clampReasoningEffort maps any thinking level string to a value that is safe
|
|
||||||
// to send as OpenAI reasoning_effort. Non-standard CPA-internal values are
|
|
||||||
// mapped to the nearest standard equivalent.
|
|
||||||
//
|
|
||||||
// Mapping rules:
|
|
||||||
// - none / low / medium / high → returned as-is (already valid)
|
|
||||||
// - xhigh → "high" (nearest lower standard level)
|
|
||||||
// - minimal → "low" (nearest higher standard level)
|
|
||||||
// - auto → "medium" (reasonable default)
|
|
||||||
// - anything else → "medium" (safe default)
|
|
||||||
func clampReasoningEffort(level string) string {
|
|
||||||
if _, ok := validReasoningEffortLevels[level]; ok {
|
|
||||||
return level
|
|
||||||
}
|
|
||||||
var clamped string
|
|
||||||
switch level {
|
|
||||||
case string(thinking.LevelXHigh):
|
|
||||||
clamped = string(thinking.LevelHigh)
|
|
||||||
case string(thinking.LevelMinimal):
|
|
||||||
clamped = string(thinking.LevelLow)
|
|
||||||
case string(thinking.LevelAuto):
|
|
||||||
clamped = string(thinking.LevelMedium)
|
|
||||||
default:
|
|
||||||
clamped = string(thinking.LevelMedium)
|
|
||||||
}
|
|
||||||
log.WithFields(log.Fields{
|
|
||||||
"original": level,
|
|
||||||
"clamped": clamped,
|
|
||||||
}).Debug("openai: reasoning_effort clamped to nearest valid standard value")
|
|
||||||
return clamped
|
|
||||||
}
|
|
||||||
|
|
||||||
// Applier implements thinking.ProviderApplier for OpenAI models.
|
// Applier implements thinking.ProviderApplier for OpenAI models.
|
||||||
//
|
//
|
||||||
// OpenAI-specific behavior:
|
// OpenAI-specific behavior:
|
||||||
@@ -101,7 +58,7 @@ func (a *Applier) Apply(body []byte, config thinking.ThinkingConfig, modelInfo *
|
|||||||
}
|
}
|
||||||
|
|
||||||
if config.Mode == thinking.ModeLevel {
|
if config.Mode == thinking.ModeLevel {
|
||||||
result, _ := sjson.SetBytes(body, "reasoning_effort", clampReasoningEffort(string(config.Level)))
|
result, _ := sjson.SetBytes(body, "reasoning_effort", string(config.Level))
|
||||||
return result, nil
|
return result, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -122,7 +79,7 @@ func (a *Applier) Apply(body []byte, config thinking.ThinkingConfig, modelInfo *
|
|||||||
return body, nil
|
return body, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
result, _ := sjson.SetBytes(body, "reasoning_effort", clampReasoningEffort(effort))
|
result, _ := sjson.SetBytes(body, "reasoning_effort", effort)
|
||||||
return result, nil
|
return result, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -157,7 +114,7 @@ func applyCompatibleOpenAI(body []byte, config thinking.ThinkingConfig) ([]byte,
|
|||||||
return body, nil
|
return body, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
result, _ := sjson.SetBytes(body, "reasoning_effort", clampReasoningEffort(effort))
|
result, _ := sjson.SetBytes(body, "reasoning_effort", effort)
|
||||||
return result, nil
|
return result, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -223,14 +223,65 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
|||||||
functionResponseJSON, _ = sjson.Set(functionResponseJSON, "response.result", responseData)
|
functionResponseJSON, _ = sjson.Set(functionResponseJSON, "response.result", responseData)
|
||||||
} else if functionResponseResult.IsArray() {
|
} else if functionResponseResult.IsArray() {
|
||||||
frResults := functionResponseResult.Array()
|
frResults := functionResponseResult.Array()
|
||||||
if len(frResults) == 1 {
|
nonImageCount := 0
|
||||||
functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "response.result", frResults[0].Raw)
|
lastNonImageRaw := ""
|
||||||
|
filteredJSON := "[]"
|
||||||
|
imagePartsJSON := "[]"
|
||||||
|
for _, fr := range frResults {
|
||||||
|
if fr.Get("type").String() == "image" && fr.Get("source.type").String() == "base64" {
|
||||||
|
inlineDataJSON := `{}`
|
||||||
|
if mimeType := fr.Get("source.media_type").String(); mimeType != "" {
|
||||||
|
inlineDataJSON, _ = sjson.Set(inlineDataJSON, "mimeType", mimeType)
|
||||||
|
}
|
||||||
|
if data := fr.Get("source.data").String(); data != "" {
|
||||||
|
inlineDataJSON, _ = sjson.Set(inlineDataJSON, "data", data)
|
||||||
|
}
|
||||||
|
|
||||||
|
imagePartJSON := `{}`
|
||||||
|
imagePartJSON, _ = sjson.SetRaw(imagePartJSON, "inlineData", inlineDataJSON)
|
||||||
|
imagePartsJSON, _ = sjson.SetRaw(imagePartsJSON, "-1", imagePartJSON)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
nonImageCount++
|
||||||
|
lastNonImageRaw = fr.Raw
|
||||||
|
filteredJSON, _ = sjson.SetRaw(filteredJSON, "-1", fr.Raw)
|
||||||
|
}
|
||||||
|
|
||||||
|
if nonImageCount == 1 {
|
||||||
|
functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "response.result", lastNonImageRaw)
|
||||||
|
} else if nonImageCount > 1 {
|
||||||
|
functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "response.result", filteredJSON)
|
||||||
} else {
|
} else {
|
||||||
functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "response.result", functionResponseResult.Raw)
|
functionResponseJSON, _ = sjson.Set(functionResponseJSON, "response.result", "")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Place image data inside functionResponse.parts as inlineData
|
||||||
|
// instead of as sibling parts in the outer content, to avoid
|
||||||
|
// base64 data bloating the text context.
|
||||||
|
if gjson.Get(imagePartsJSON, "#").Int() > 0 {
|
||||||
|
functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "parts", imagePartsJSON)
|
||||||
}
|
}
|
||||||
|
|
||||||
} else if functionResponseResult.IsObject() {
|
} else if functionResponseResult.IsObject() {
|
||||||
|
if functionResponseResult.Get("type").String() == "image" && functionResponseResult.Get("source.type").String() == "base64" {
|
||||||
|
inlineDataJSON := `{}`
|
||||||
|
if mimeType := functionResponseResult.Get("source.media_type").String(); mimeType != "" {
|
||||||
|
inlineDataJSON, _ = sjson.Set(inlineDataJSON, "mimeType", mimeType)
|
||||||
|
}
|
||||||
|
if data := functionResponseResult.Get("source.data").String(); data != "" {
|
||||||
|
inlineDataJSON, _ = sjson.Set(inlineDataJSON, "data", data)
|
||||||
|
}
|
||||||
|
|
||||||
|
imagePartJSON := `{}`
|
||||||
|
imagePartJSON, _ = sjson.SetRaw(imagePartJSON, "inlineData", inlineDataJSON)
|
||||||
|
imagePartsJSON := "[]"
|
||||||
|
imagePartsJSON, _ = sjson.SetRaw(imagePartsJSON, "-1", imagePartJSON)
|
||||||
|
functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "parts", imagePartsJSON)
|
||||||
|
functionResponseJSON, _ = sjson.Set(functionResponseJSON, "response.result", "")
|
||||||
|
} else {
|
||||||
functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "response.result", functionResponseResult.Raw)
|
functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "response.result", functionResponseResult.Raw)
|
||||||
|
}
|
||||||
} else if functionResponseResult.Raw != "" {
|
} else if functionResponseResult.Raw != "" {
|
||||||
functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "response.result", functionResponseResult.Raw)
|
functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "response.result", functionResponseResult.Raw)
|
||||||
} else {
|
} else {
|
||||||
@@ -248,7 +299,7 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
|||||||
if sourceResult.Get("type").String() == "base64" {
|
if sourceResult.Get("type").String() == "base64" {
|
||||||
inlineDataJSON := `{}`
|
inlineDataJSON := `{}`
|
||||||
if mimeType := sourceResult.Get("media_type").String(); mimeType != "" {
|
if mimeType := sourceResult.Get("media_type").String(); mimeType != "" {
|
||||||
inlineDataJSON, _ = sjson.Set(inlineDataJSON, "mime_type", mimeType)
|
inlineDataJSON, _ = sjson.Set(inlineDataJSON, "mimeType", mimeType)
|
||||||
}
|
}
|
||||||
if data := sourceResult.Get("data").String(); data != "" {
|
if data := sourceResult.Get("data").String(); data != "" {
|
||||||
inlineDataJSON, _ = sjson.Set(inlineDataJSON, "data", data)
|
inlineDataJSON, _ = sjson.Set(inlineDataJSON, "data", data)
|
||||||
|
|||||||
@@ -413,8 +413,8 @@ func TestConvertClaudeRequestToAntigravity_ImageContent(t *testing.T) {
|
|||||||
if !inlineData.Exists() {
|
if !inlineData.Exists() {
|
||||||
t.Error("inlineData should exist")
|
t.Error("inlineData should exist")
|
||||||
}
|
}
|
||||||
if inlineData.Get("mime_type").String() != "image/png" {
|
if inlineData.Get("mimeType").String() != "image/png" {
|
||||||
t.Error("mime_type mismatch")
|
t.Error("mimeType mismatch")
|
||||||
}
|
}
|
||||||
if !strings.Contains(inlineData.Get("data").String(), "iVBORw0KGgo") {
|
if !strings.Contains(inlineData.Get("data").String(), "iVBORw0KGgo") {
|
||||||
t.Error("data mismatch")
|
t.Error("data mismatch")
|
||||||
@@ -740,6 +740,429 @@ func TestConvertClaudeRequestToAntigravity_ToolResultNullContent(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestConvertClaudeRequestToAntigravity_ToolResultWithImage(t *testing.T) {
|
||||||
|
// tool_result with array content containing text + image should place
|
||||||
|
// image data inside functionResponse.parts as inlineData, not as a
|
||||||
|
// sibling part in the outer content (to avoid base64 context bloat).
|
||||||
|
inputJSON := []byte(`{
|
||||||
|
"model": "claude-3-5-sonnet-20240620",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "tool_result",
|
||||||
|
"tool_use_id": "Read-123-456",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "File content here"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "image",
|
||||||
|
"source": {
|
||||||
|
"type": "base64",
|
||||||
|
"media_type": "image/png",
|
||||||
|
"data": "iVBORw0KGgoAAAANSUhEUg=="
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}`)
|
||||||
|
|
||||||
|
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5", inputJSON, false)
|
||||||
|
outputStr := string(output)
|
||||||
|
|
||||||
|
if !gjson.Valid(outputStr) {
|
||||||
|
t.Fatalf("Result is not valid JSON:\n%s", outputStr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Image should be inside functionResponse.parts, not as outer sibling part
|
||||||
|
funcResp := gjson.Get(outputStr, "request.contents.0.parts.0.functionResponse")
|
||||||
|
if !funcResp.Exists() {
|
||||||
|
t.Fatal("functionResponse should exist")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Text content should be in response.result
|
||||||
|
resultText := funcResp.Get("response.result.text").String()
|
||||||
|
if resultText != "File content here" {
|
||||||
|
t.Errorf("Expected response.result.text = 'File content here', got '%s'", resultText)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Image should be in functionResponse.parts[0].inlineData
|
||||||
|
inlineData := funcResp.Get("parts.0.inlineData")
|
||||||
|
if !inlineData.Exists() {
|
||||||
|
t.Fatal("functionResponse.parts[0].inlineData should exist")
|
||||||
|
}
|
||||||
|
if inlineData.Get("mimeType").String() != "image/png" {
|
||||||
|
t.Errorf("Expected mimeType 'image/png', got '%s'", inlineData.Get("mimeType").String())
|
||||||
|
}
|
||||||
|
if !strings.Contains(inlineData.Get("data").String(), "iVBORw0KGgo") {
|
||||||
|
t.Error("data mismatch")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Image should NOT be in outer parts (only functionResponse part should exist)
|
||||||
|
outerParts := gjson.Get(outputStr, "request.contents.0.parts")
|
||||||
|
if outerParts.IsArray() && len(outerParts.Array()) > 1 {
|
||||||
|
t.Errorf("Expected only 1 outer part (functionResponse), got %d", len(outerParts.Array()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConvertClaudeRequestToAntigravity_ToolResultWithSingleImage(t *testing.T) {
|
||||||
|
// tool_result with single image object as content should place
|
||||||
|
// image data inside functionResponse.parts, not as outer sibling part.
|
||||||
|
inputJSON := []byte(`{
|
||||||
|
"model": "claude-3-5-sonnet-20240620",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "tool_result",
|
||||||
|
"tool_use_id": "Read-789-012",
|
||||||
|
"content": {
|
||||||
|
"type": "image",
|
||||||
|
"source": {
|
||||||
|
"type": "base64",
|
||||||
|
"media_type": "image/jpeg",
|
||||||
|
"data": "/9j/4AAQSkZJRgABAQ=="
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}`)
|
||||||
|
|
||||||
|
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5", inputJSON, false)
|
||||||
|
outputStr := string(output)
|
||||||
|
|
||||||
|
if !gjson.Valid(outputStr) {
|
||||||
|
t.Fatalf("Result is not valid JSON:\n%s", outputStr)
|
||||||
|
}
|
||||||
|
|
||||||
|
funcResp := gjson.Get(outputStr, "request.contents.0.parts.0.functionResponse")
|
||||||
|
if !funcResp.Exists() {
|
||||||
|
t.Fatal("functionResponse should exist")
|
||||||
|
}
|
||||||
|
|
||||||
|
// response.result should be empty (image only)
|
||||||
|
if funcResp.Get("response.result").String() != "" {
|
||||||
|
t.Errorf("Expected empty response.result for image-only content, got '%s'", funcResp.Get("response.result").String())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Image should be in functionResponse.parts[0].inlineData
|
||||||
|
inlineData := funcResp.Get("parts.0.inlineData")
|
||||||
|
if !inlineData.Exists() {
|
||||||
|
t.Fatal("functionResponse.parts[0].inlineData should exist")
|
||||||
|
}
|
||||||
|
if inlineData.Get("mimeType").String() != "image/jpeg" {
|
||||||
|
t.Errorf("Expected mimeType 'image/jpeg', got '%s'", inlineData.Get("mimeType").String())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Image should NOT be in outer parts
|
||||||
|
outerParts := gjson.Get(outputStr, "request.contents.0.parts")
|
||||||
|
if outerParts.IsArray() && len(outerParts.Array()) > 1 {
|
||||||
|
t.Errorf("Expected only 1 outer part, got %d", len(outerParts.Array()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConvertClaudeRequestToAntigravity_ToolResultWithMultipleImagesAndTexts(t *testing.T) {
|
||||||
|
// tool_result with array content: 2 text items + 2 images
|
||||||
|
// All images go into functionResponse.parts, texts into response.result array
|
||||||
|
inputJSON := []byte(`{
|
||||||
|
"model": "claude-3-5-sonnet-20240620",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "tool_result",
|
||||||
|
"tool_use_id": "Multi-001",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": "First text"},
|
||||||
|
{
|
||||||
|
"type": "image",
|
||||||
|
"source": {"type": "base64", "media_type": "image/png", "data": "AAAA"}
|
||||||
|
},
|
||||||
|
{"type": "text", "text": "Second text"},
|
||||||
|
{
|
||||||
|
"type": "image",
|
||||||
|
"source": {"type": "base64", "media_type": "image/jpeg", "data": "BBBB"}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}`)
|
||||||
|
|
||||||
|
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5", inputJSON, false)
|
||||||
|
outputStr := string(output)
|
||||||
|
|
||||||
|
if !gjson.Valid(outputStr) {
|
||||||
|
t.Fatalf("Result is not valid JSON:\n%s", outputStr)
|
||||||
|
}
|
||||||
|
|
||||||
|
funcResp := gjson.Get(outputStr, "request.contents.0.parts.0.functionResponse")
|
||||||
|
if !funcResp.Exists() {
|
||||||
|
t.Fatal("functionResponse should exist")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Multiple text items => response.result is an array
|
||||||
|
resultArr := funcResp.Get("response.result")
|
||||||
|
if !resultArr.IsArray() {
|
||||||
|
t.Fatalf("Expected response.result to be an array, got: %s", resultArr.Raw)
|
||||||
|
}
|
||||||
|
results := resultArr.Array()
|
||||||
|
if len(results) != 2 {
|
||||||
|
t.Fatalf("Expected 2 result items, got %d", len(results))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Both images should be in functionResponse.parts
|
||||||
|
imgParts := funcResp.Get("parts").Array()
|
||||||
|
if len(imgParts) != 2 {
|
||||||
|
t.Fatalf("Expected 2 image parts in functionResponse.parts, got %d", len(imgParts))
|
||||||
|
}
|
||||||
|
if imgParts[0].Get("inlineData.mimeType").String() != "image/png" {
|
||||||
|
t.Errorf("Expected first image mimeType 'image/png', got '%s'", imgParts[0].Get("inlineData.mimeType").String())
|
||||||
|
}
|
||||||
|
if imgParts[0].Get("inlineData.data").String() != "AAAA" {
|
||||||
|
t.Errorf("Expected first image data 'AAAA', got '%s'", imgParts[0].Get("inlineData.data").String())
|
||||||
|
}
|
||||||
|
if imgParts[1].Get("inlineData.mimeType").String() != "image/jpeg" {
|
||||||
|
t.Errorf("Expected second image mimeType 'image/jpeg', got '%s'", imgParts[1].Get("inlineData.mimeType").String())
|
||||||
|
}
|
||||||
|
if imgParts[1].Get("inlineData.data").String() != "BBBB" {
|
||||||
|
t.Errorf("Expected second image data 'BBBB', got '%s'", imgParts[1].Get("inlineData.data").String())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Only 1 outer part (the functionResponse itself)
|
||||||
|
outerParts := gjson.Get(outputStr, "request.contents.0.parts").Array()
|
||||||
|
if len(outerParts) != 1 {
|
||||||
|
t.Errorf("Expected 1 outer part, got %d", len(outerParts))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConvertClaudeRequestToAntigravity_ToolResultWithOnlyMultipleImages(t *testing.T) {
|
||||||
|
// tool_result with only images (no text) — response.result should be empty string
|
||||||
|
inputJSON := []byte(`{
|
||||||
|
"model": "claude-3-5-sonnet-20240620",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "tool_result",
|
||||||
|
"tool_use_id": "ImgOnly-001",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "image",
|
||||||
|
"source": {"type": "base64", "media_type": "image/png", "data": "PNG1"}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "image",
|
||||||
|
"source": {"type": "base64", "media_type": "image/gif", "data": "GIF1"}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}`)
|
||||||
|
|
||||||
|
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5", inputJSON, false)
|
||||||
|
outputStr := string(output)
|
||||||
|
|
||||||
|
if !gjson.Valid(outputStr) {
|
||||||
|
t.Fatalf("Result is not valid JSON:\n%s", outputStr)
|
||||||
|
}
|
||||||
|
|
||||||
|
funcResp := gjson.Get(outputStr, "request.contents.0.parts.0.functionResponse")
|
||||||
|
if !funcResp.Exists() {
|
||||||
|
t.Fatal("functionResponse should exist")
|
||||||
|
}
|
||||||
|
|
||||||
|
// No text => response.result should be empty string
|
||||||
|
if funcResp.Get("response.result").String() != "" {
|
||||||
|
t.Errorf("Expected empty response.result, got '%s'", funcResp.Get("response.result").String())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Both images in functionResponse.parts
|
||||||
|
imgParts := funcResp.Get("parts").Array()
|
||||||
|
if len(imgParts) != 2 {
|
||||||
|
t.Fatalf("Expected 2 image parts, got %d", len(imgParts))
|
||||||
|
}
|
||||||
|
if imgParts[0].Get("inlineData.mimeType").String() != "image/png" {
|
||||||
|
t.Error("first image mimeType mismatch")
|
||||||
|
}
|
||||||
|
if imgParts[1].Get("inlineData.mimeType").String() != "image/gif" {
|
||||||
|
t.Error("second image mimeType mismatch")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Only 1 outer part
|
||||||
|
outerParts := gjson.Get(outputStr, "request.contents.0.parts").Array()
|
||||||
|
if len(outerParts) != 1 {
|
||||||
|
t.Errorf("Expected 1 outer part, got %d", len(outerParts))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConvertClaudeRequestToAntigravity_ToolResultImageNotBase64(t *testing.T) {
|
||||||
|
// image with source.type != "base64" should be treated as non-image (falls through)
|
||||||
|
inputJSON := []byte(`{
|
||||||
|
"model": "claude-3-5-sonnet-20240620",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "tool_result",
|
||||||
|
"tool_use_id": "NotB64-001",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": "some output"},
|
||||||
|
{
|
||||||
|
"type": "image",
|
||||||
|
"source": {"type": "url", "url": "https://example.com/img.png"}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}`)
|
||||||
|
|
||||||
|
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5", inputJSON, false)
|
||||||
|
outputStr := string(output)
|
||||||
|
|
||||||
|
if !gjson.Valid(outputStr) {
|
||||||
|
t.Fatalf("Result is not valid JSON:\n%s", outputStr)
|
||||||
|
}
|
||||||
|
|
||||||
|
funcResp := gjson.Get(outputStr, "request.contents.0.parts.0.functionResponse")
|
||||||
|
if !funcResp.Exists() {
|
||||||
|
t.Fatal("functionResponse should exist")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Non-base64 image is treated as non-image, so it goes into the filtered results
|
||||||
|
// along with the text item. Since there are 2 non-image items, result is array.
|
||||||
|
resultArr := funcResp.Get("response.result")
|
||||||
|
if !resultArr.IsArray() {
|
||||||
|
t.Fatalf("Expected response.result to be an array (2 non-image items), got: %s", resultArr.Raw)
|
||||||
|
}
|
||||||
|
results := resultArr.Array()
|
||||||
|
if len(results) != 2 {
|
||||||
|
t.Fatalf("Expected 2 result items, got %d", len(results))
|
||||||
|
}
|
||||||
|
|
||||||
|
// No functionResponse.parts (no base64 images collected)
|
||||||
|
if funcResp.Get("parts").Exists() {
|
||||||
|
t.Error("functionResponse.parts should NOT exist when no base64 images")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConvertClaudeRequestToAntigravity_ToolResultImageMissingData(t *testing.T) {
|
||||||
|
// image with source.type=base64 but missing data field
|
||||||
|
inputJSON := []byte(`{
|
||||||
|
"model": "claude-3-5-sonnet-20240620",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "tool_result",
|
||||||
|
"tool_use_id": "NoData-001",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": "output"},
|
||||||
|
{
|
||||||
|
"type": "image",
|
||||||
|
"source": {"type": "base64", "media_type": "image/png"}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}`)
|
||||||
|
|
||||||
|
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5", inputJSON, false)
|
||||||
|
outputStr := string(output)
|
||||||
|
|
||||||
|
if !gjson.Valid(outputStr) {
|
||||||
|
t.Fatalf("Result is not valid JSON:\n%s", outputStr)
|
||||||
|
}
|
||||||
|
|
||||||
|
funcResp := gjson.Get(outputStr, "request.contents.0.parts.0.functionResponse")
|
||||||
|
if !funcResp.Exists() {
|
||||||
|
t.Fatal("functionResponse should exist")
|
||||||
|
}
|
||||||
|
|
||||||
|
// The image is still classified as base64 image (type check passes),
|
||||||
|
// but data field is missing => inlineData has mimeType but no data
|
||||||
|
imgParts := funcResp.Get("parts").Array()
|
||||||
|
if len(imgParts) != 1 {
|
||||||
|
t.Fatalf("Expected 1 image part, got %d", len(imgParts))
|
||||||
|
}
|
||||||
|
if imgParts[0].Get("inlineData.mimeType").String() != "image/png" {
|
||||||
|
t.Error("mimeType should still be set")
|
||||||
|
}
|
||||||
|
if imgParts[0].Get("inlineData.data").Exists() {
|
||||||
|
t.Error("data should not exist when source.data is missing")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConvertClaudeRequestToAntigravity_ToolResultImageMissingMediaType(t *testing.T) {
|
||||||
|
// image with source.type=base64 but missing media_type field
|
||||||
|
inputJSON := []byte(`{
|
||||||
|
"model": "claude-3-5-sonnet-20240620",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "tool_result",
|
||||||
|
"tool_use_id": "NoMime-001",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": "output"},
|
||||||
|
{
|
||||||
|
"type": "image",
|
||||||
|
"source": {"type": "base64", "data": "AAAA"}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}`)
|
||||||
|
|
||||||
|
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5", inputJSON, false)
|
||||||
|
outputStr := string(output)
|
||||||
|
|
||||||
|
if !gjson.Valid(outputStr) {
|
||||||
|
t.Fatalf("Result is not valid JSON:\n%s", outputStr)
|
||||||
|
}
|
||||||
|
|
||||||
|
funcResp := gjson.Get(outputStr, "request.contents.0.parts.0.functionResponse")
|
||||||
|
if !funcResp.Exists() {
|
||||||
|
t.Fatal("functionResponse should exist")
|
||||||
|
}
|
||||||
|
|
||||||
|
// The image is still classified as base64 image,
|
||||||
|
// but media_type is missing => inlineData has data but no mimeType
|
||||||
|
imgParts := funcResp.Get("parts").Array()
|
||||||
|
if len(imgParts) != 1 {
|
||||||
|
t.Fatalf("Expected 1 image part, got %d", len(imgParts))
|
||||||
|
}
|
||||||
|
if imgParts[0].Get("inlineData.mimeType").Exists() {
|
||||||
|
t.Error("mimeType should not exist when media_type is missing")
|
||||||
|
}
|
||||||
|
if imgParts[0].Get("inlineData.data").String() != "AAAA" {
|
||||||
|
t.Error("data should still be set")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestConvertClaudeRequestToAntigravity_ToolAndThinking_NoExistingSystem(t *testing.T) {
|
func TestConvertClaudeRequestToAntigravity_ToolAndThinking_NoExistingSystem(t *testing.T) {
|
||||||
// When tools + thinking but no system instruction, should create one with hint
|
// When tools + thinking but no system instruction, should create one with hint
|
||||||
inputJSON := []byte(`{
|
inputJSON := []byte(`{
|
||||||
|
|||||||
@@ -93,3 +93,81 @@ func TestConvertGeminiRequestToAntigravity_ParallelFunctionCalls(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestFixCLIToolResponse_PreservesFunctionResponseParts(t *testing.T) {
|
||||||
|
// When functionResponse contains a "parts" field with inlineData (from Claude
|
||||||
|
// translator's image embedding), fixCLIToolResponse should preserve it as-is.
|
||||||
|
// parseFunctionResponseRaw returns response.Raw for valid JSON objects,
|
||||||
|
// so extra fields like "parts" survive the pipeline.
|
||||||
|
input := `{
|
||||||
|
"model": "claude-opus-4-6-thinking",
|
||||||
|
"request": {
|
||||||
|
"contents": [
|
||||||
|
{
|
||||||
|
"role": "model",
|
||||||
|
"parts": [
|
||||||
|
{
|
||||||
|
"functionCall": {"name": "screenshot", "args": {}}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "function",
|
||||||
|
"parts": [
|
||||||
|
{
|
||||||
|
"functionResponse": {
|
||||||
|
"id": "tool-001",
|
||||||
|
"name": "screenshot",
|
||||||
|
"response": {"result": "Screenshot taken"},
|
||||||
|
"parts": [
|
||||||
|
{"inlineData": {"mimeType": "image/png", "data": "iVBOR"}}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}`
|
||||||
|
|
||||||
|
result, err := fixCLIToolResponse(input)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("fixCLIToolResponse failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find the function response content (role=function)
|
||||||
|
contents := gjson.Get(result, "request.contents").Array()
|
||||||
|
var funcContent gjson.Result
|
||||||
|
for _, c := range contents {
|
||||||
|
if c.Get("role").String() == "function" {
|
||||||
|
funcContent = c
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !funcContent.Exists() {
|
||||||
|
t.Fatal("function role content should exist in output")
|
||||||
|
}
|
||||||
|
|
||||||
|
// The functionResponse should be preserved with its parts field
|
||||||
|
funcResp := funcContent.Get("parts.0.functionResponse")
|
||||||
|
if !funcResp.Exists() {
|
||||||
|
t.Fatal("functionResponse should exist in output")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify the parts field with inlineData is preserved
|
||||||
|
inlineParts := funcResp.Get("parts").Array()
|
||||||
|
if len(inlineParts) != 1 {
|
||||||
|
t.Fatalf("Expected 1 inlineData part in functionResponse.parts, got %d", len(inlineParts))
|
||||||
|
}
|
||||||
|
if inlineParts[0].Get("inlineData.mimeType").String() != "image/png" {
|
||||||
|
t.Errorf("Expected mimeType 'image/png', got '%s'", inlineParts[0].Get("inlineData.mimeType").String())
|
||||||
|
}
|
||||||
|
if inlineParts[0].Get("inlineData.data").String() != "iVBOR" {
|
||||||
|
t.Errorf("Expected data 'iVBOR', got '%s'", inlineParts[0].Get("inlineData.data").String())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify response.result is also preserved
|
||||||
|
if funcResp.Get("response.result").String() != "Screenshot taken" {
|
||||||
|
t.Errorf("Expected response.result 'Screenshot taken', got '%s'", funcResp.Get("response.result").String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
+3
-3
@@ -187,7 +187,7 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
|||||||
if len(pieces) == 2 && len(pieces[1]) > 7 {
|
if len(pieces) == 2 && len(pieces[1]) > 7 {
|
||||||
mime := pieces[0]
|
mime := pieces[0]
|
||||||
data := pieces[1][7:]
|
data := pieces[1][7:]
|
||||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.mime_type", mime)
|
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.mimeType", mime)
|
||||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.data", data)
|
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.data", data)
|
||||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".thoughtSignature", geminiCLIFunctionThoughtSignature)
|
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".thoughtSignature", geminiCLIFunctionThoughtSignature)
|
||||||
p++
|
p++
|
||||||
@@ -201,7 +201,7 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
|||||||
ext = sp[len(sp)-1]
|
ext = sp[len(sp)-1]
|
||||||
}
|
}
|
||||||
if mimeType, ok := misc.MimeTypes[ext]; ok {
|
if mimeType, ok := misc.MimeTypes[ext]; ok {
|
||||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.mime_type", mimeType)
|
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.mimeType", mimeType)
|
||||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.data", fileData)
|
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.data", fileData)
|
||||||
p++
|
p++
|
||||||
} else {
|
} else {
|
||||||
@@ -235,7 +235,7 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
|||||||
if len(pieces) == 2 && len(pieces[1]) > 7 {
|
if len(pieces) == 2 && len(pieces[1]) > 7 {
|
||||||
mime := pieces[0]
|
mime := pieces[0]
|
||||||
data := pieces[1][7:]
|
data := pieces[1][7:]
|
||||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.mime_type", mime)
|
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.mimeType", mime)
|
||||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.data", data)
|
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.data", data)
|
||||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".thoughtSignature", geminiCLIFunctionThoughtSignature)
|
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".thoughtSignature", geminiCLIFunctionThoughtSignature)
|
||||||
p++
|
p++
|
||||||
|
|||||||
+2
-2
@@ -95,9 +95,9 @@ func ConvertAntigravityResponseToOpenAI(_ context.Context, _ string, originalReq
|
|||||||
if totalTokenCountResult := usageResult.Get("totalTokenCount"); totalTokenCountResult.Exists() {
|
if totalTokenCountResult := usageResult.Get("totalTokenCount"); totalTokenCountResult.Exists() {
|
||||||
template, _ = sjson.Set(template, "usage.total_tokens", totalTokenCountResult.Int())
|
template, _ = sjson.Set(template, "usage.total_tokens", totalTokenCountResult.Int())
|
||||||
}
|
}
|
||||||
promptTokenCount := usageResult.Get("promptTokenCount").Int() - cachedTokenCount
|
promptTokenCount := usageResult.Get("promptTokenCount").Int()
|
||||||
thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int()
|
thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int()
|
||||||
template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount+thoughtsTokenCount)
|
template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount)
|
||||||
if thoughtsTokenCount > 0 {
|
if thoughtsTokenCount > 0 {
|
||||||
template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount)
|
template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -199,6 +199,21 @@ func ConvertOpenAIRequestToClaude(modelName string, inputRawJSON []byte, stream
|
|||||||
msg, _ = sjson.SetRaw(msg, "content.-1", imagePart)
|
msg, _ = sjson.SetRaw(msg, "content.-1", imagePart)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
case "file":
|
||||||
|
fileData := part.Get("file.file_data").String()
|
||||||
|
if strings.HasPrefix(fileData, "data:") {
|
||||||
|
semicolonIdx := strings.Index(fileData, ";")
|
||||||
|
commaIdx := strings.Index(fileData, ",")
|
||||||
|
if semicolonIdx != -1 && commaIdx != -1 && commaIdx > semicolonIdx {
|
||||||
|
mediaType := strings.TrimPrefix(fileData[:semicolonIdx], "data:")
|
||||||
|
data := fileData[commaIdx+1:]
|
||||||
|
docPart := `{"type":"document","source":{"type":"base64","media_type":"","data":""}}`
|
||||||
|
docPart, _ = sjson.Set(docPart, "source.media_type", mediaType)
|
||||||
|
docPart, _ = sjson.Set(docPart, "source.data", data)
|
||||||
|
msg, _ = sjson.SetRaw(msg, "content.-1", docPart)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return true
|
return true
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -155,6 +155,7 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte
|
|||||||
var textAggregate strings.Builder
|
var textAggregate strings.Builder
|
||||||
var partsJSON []string
|
var partsJSON []string
|
||||||
hasImage := false
|
hasImage := false
|
||||||
|
hasFile := false
|
||||||
if parts := item.Get("content"); parts.Exists() && parts.IsArray() {
|
if parts := item.Get("content"); parts.Exists() && parts.IsArray() {
|
||||||
parts.ForEach(func(_, part gjson.Result) bool {
|
parts.ForEach(func(_, part gjson.Result) bool {
|
||||||
ptype := part.Get("type").String()
|
ptype := part.Get("type").String()
|
||||||
@@ -207,6 +208,30 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte
|
|||||||
hasImage = true
|
hasImage = true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
case "input_file":
|
||||||
|
fileData := part.Get("file_data").String()
|
||||||
|
if fileData != "" {
|
||||||
|
mediaType := "application/octet-stream"
|
||||||
|
data := fileData
|
||||||
|
if strings.HasPrefix(fileData, "data:") {
|
||||||
|
trimmed := strings.TrimPrefix(fileData, "data:")
|
||||||
|
mediaAndData := strings.SplitN(trimmed, ";base64,", 2)
|
||||||
|
if len(mediaAndData) == 2 {
|
||||||
|
if mediaAndData[0] != "" {
|
||||||
|
mediaType = mediaAndData[0]
|
||||||
|
}
|
||||||
|
data = mediaAndData[1]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
contentPart := `{"type":"document","source":{"type":"base64","media_type":"","data":""}}`
|
||||||
|
contentPart, _ = sjson.Set(contentPart, "source.media_type", mediaType)
|
||||||
|
contentPart, _ = sjson.Set(contentPart, "source.data", data)
|
||||||
|
partsJSON = append(partsJSON, contentPart)
|
||||||
|
if role == "" {
|
||||||
|
role = "user"
|
||||||
|
}
|
||||||
|
hasFile = true
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return true
|
return true
|
||||||
})
|
})
|
||||||
@@ -228,7 +253,7 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte
|
|||||||
if len(partsJSON) > 0 {
|
if len(partsJSON) > 0 {
|
||||||
msg := `{"role":"","content":[]}`
|
msg := `{"role":"","content":[]}`
|
||||||
msg, _ = sjson.Set(msg, "role", role)
|
msg, _ = sjson.Set(msg, "role", role)
|
||||||
if len(partsJSON) == 1 && !hasImage {
|
if len(partsJSON) == 1 && !hasImage && !hasFile {
|
||||||
// Preserve legacy behavior for single text content
|
// Preserve legacy behavior for single text content
|
||||||
msg, _ = sjson.Delete(msg, "content")
|
msg, _ = sjson.Delete(msg, "content")
|
||||||
textPart := gjson.Parse(partsJSON[0])
|
textPart := gjson.Parse(partsJSON[0])
|
||||||
|
|||||||
@@ -180,7 +180,19 @@ func ConvertOpenAIRequestToCodex(modelName string, inputRawJSON []byte, stream b
|
|||||||
msg, _ = sjson.SetRaw(msg, "content.-1", part)
|
msg, _ = sjson.SetRaw(msg, "content.-1", part)
|
||||||
}
|
}
|
||||||
case "file":
|
case "file":
|
||||||
// Files are not specified in examples; skip for now
|
if role == "user" {
|
||||||
|
fileData := it.Get("file.file_data").String()
|
||||||
|
filename := it.Get("file.filename").String()
|
||||||
|
if fileData != "" {
|
||||||
|
part := `{}`
|
||||||
|
part, _ = sjson.Set(part, "type", "input_file")
|
||||||
|
part, _ = sjson.Set(part, "file_data", fileData)
|
||||||
|
if filename != "" {
|
||||||
|
part, _ = sjson.Set(part, "filename", filename)
|
||||||
|
}
|
||||||
|
msg, _ = sjson.SetRaw(msg, "content.-1", part)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -26,6 +26,8 @@ func ConvertOpenAIResponsesRequestToCodex(modelName string, inputRawJSON []byte,
|
|||||||
rawJSON, _ = sjson.DeleteBytes(rawJSON, "temperature")
|
rawJSON, _ = sjson.DeleteBytes(rawJSON, "temperature")
|
||||||
rawJSON, _ = sjson.DeleteBytes(rawJSON, "top_p")
|
rawJSON, _ = sjson.DeleteBytes(rawJSON, "top_p")
|
||||||
rawJSON, _ = sjson.DeleteBytes(rawJSON, "service_tier")
|
rawJSON, _ = sjson.DeleteBytes(rawJSON, "service_tier")
|
||||||
|
rawJSON, _ = sjson.DeleteBytes(rawJSON, "truncation")
|
||||||
|
rawJSON = applyResponsesCompactionCompatibility(rawJSON)
|
||||||
|
|
||||||
// Delete the user field as it is not supported by the Codex upstream.
|
// Delete the user field as it is not supported by the Codex upstream.
|
||||||
rawJSON, _ = sjson.DeleteBytes(rawJSON, "user")
|
rawJSON, _ = sjson.DeleteBytes(rawJSON, "user")
|
||||||
@@ -36,6 +38,23 @@ func ConvertOpenAIResponsesRequestToCodex(modelName string, inputRawJSON []byte,
|
|||||||
return rawJSON
|
return rawJSON
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// applyResponsesCompactionCompatibility handles OpenAI Responses context_management.compaction
|
||||||
|
// for Codex upstream compatibility.
|
||||||
|
//
|
||||||
|
// Codex /responses currently rejects context_management with:
|
||||||
|
// {"detail":"Unsupported parameter: context_management"}.
|
||||||
|
//
|
||||||
|
// Compatibility strategy:
|
||||||
|
// 1) Remove context_management before forwarding to Codex upstream.
|
||||||
|
func applyResponsesCompactionCompatibility(rawJSON []byte) []byte {
|
||||||
|
if !gjson.GetBytes(rawJSON, "context_management").Exists() {
|
||||||
|
return rawJSON
|
||||||
|
}
|
||||||
|
|
||||||
|
rawJSON, _ = sjson.DeleteBytes(rawJSON, "context_management")
|
||||||
|
return rawJSON
|
||||||
|
}
|
||||||
|
|
||||||
// convertSystemRoleToDeveloper traverses the input array and converts any message items
|
// convertSystemRoleToDeveloper traverses the input array and converts any message items
|
||||||
// with role "system" to role "developer". This is necessary because Codex API does not
|
// with role "system" to role "developer". This is necessary because Codex API does not
|
||||||
// accept "system" role in the input array.
|
// accept "system" role in the input array.
|
||||||
|
|||||||
@@ -280,3 +280,41 @@ func TestUserFieldDeletion(t *testing.T) {
|
|||||||
t.Errorf("user field should be deleted, but it was found with value: %s", userField.Raw)
|
t.Errorf("user field should be deleted, but it was found with value: %s", userField.Raw)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestContextManagementCompactionCompatibility(t *testing.T) {
|
||||||
|
inputJSON := []byte(`{
|
||||||
|
"model": "gpt-5.2",
|
||||||
|
"context_management": [
|
||||||
|
{
|
||||||
|
"type": "compaction",
|
||||||
|
"compact_threshold": 12000
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"input": [{"role":"user","content":"hello"}]
|
||||||
|
}`)
|
||||||
|
|
||||||
|
output := ConvertOpenAIResponsesRequestToCodex("gpt-5.2", inputJSON, false)
|
||||||
|
outputStr := string(output)
|
||||||
|
|
||||||
|
if gjson.Get(outputStr, "context_management").Exists() {
|
||||||
|
t.Fatalf("context_management should be removed for Codex compatibility")
|
||||||
|
}
|
||||||
|
if gjson.Get(outputStr, "truncation").Exists() {
|
||||||
|
t.Fatalf("truncation should be removed for Codex compatibility")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTruncationRemovedForCodexCompatibility(t *testing.T) {
|
||||||
|
inputJSON := []byte(`{
|
||||||
|
"model": "gpt-5.2",
|
||||||
|
"truncation": "disabled",
|
||||||
|
"input": [{"role":"user","content":"hello"}]
|
||||||
|
}`)
|
||||||
|
|
||||||
|
output := ConvertOpenAIResponsesRequestToCodex("gpt-5.2", inputJSON, false)
|
||||||
|
outputStr := string(output)
|
||||||
|
|
||||||
|
if gjson.Get(outputStr, "truncation").Exists() {
|
||||||
|
t.Fatalf("truncation should be removed for Codex compatibility")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
+1
-1
@@ -100,7 +100,7 @@ func ConvertCliResponseToOpenAI(_ context.Context, _ string, originalRequestRawJ
|
|||||||
}
|
}
|
||||||
promptTokenCount := usageResult.Get("promptTokenCount").Int()
|
promptTokenCount := usageResult.Get("promptTokenCount").Int()
|
||||||
thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int()
|
thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int()
|
||||||
template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount+thoughtsTokenCount)
|
template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount)
|
||||||
if thoughtsTokenCount > 0 {
|
if thoughtsTokenCount > 0 {
|
||||||
template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount)
|
template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -100,9 +100,9 @@ func ConvertGeminiResponseToOpenAI(_ context.Context, _ string, originalRequestR
|
|||||||
if totalTokenCountResult := usageResult.Get("totalTokenCount"); totalTokenCountResult.Exists() {
|
if totalTokenCountResult := usageResult.Get("totalTokenCount"); totalTokenCountResult.Exists() {
|
||||||
baseTemplate, _ = sjson.Set(baseTemplate, "usage.total_tokens", totalTokenCountResult.Int())
|
baseTemplate, _ = sjson.Set(baseTemplate, "usage.total_tokens", totalTokenCountResult.Int())
|
||||||
}
|
}
|
||||||
promptTokenCount := usageResult.Get("promptTokenCount").Int() - cachedTokenCount
|
promptTokenCount := usageResult.Get("promptTokenCount").Int()
|
||||||
thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int()
|
thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int()
|
||||||
baseTemplate, _ = sjson.Set(baseTemplate, "usage.prompt_tokens", promptTokenCount+thoughtsTokenCount)
|
baseTemplate, _ = sjson.Set(baseTemplate, "usage.prompt_tokens", promptTokenCount)
|
||||||
if thoughtsTokenCount > 0 {
|
if thoughtsTokenCount > 0 {
|
||||||
baseTemplate, _ = sjson.Set(baseTemplate, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount)
|
baseTemplate, _ = sjson.Set(baseTemplate, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount)
|
||||||
}
|
}
|
||||||
@@ -297,7 +297,7 @@ func ConvertGeminiResponseToOpenAINonStream(_ context.Context, _ string, origina
|
|||||||
promptTokenCount := usageResult.Get("promptTokenCount").Int()
|
promptTokenCount := usageResult.Get("promptTokenCount").Int()
|
||||||
thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int()
|
thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int()
|
||||||
cachedTokenCount := usageResult.Get("cachedContentTokenCount").Int()
|
cachedTokenCount := usageResult.Get("cachedContentTokenCount").Int()
|
||||||
template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount+thoughtsTokenCount)
|
template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount)
|
||||||
if thoughtsTokenCount > 0 {
|
if thoughtsTokenCount > 0 {
|
||||||
template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount)
|
template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -531,8 +531,8 @@ func ConvertGeminiResponseToOpenAIResponses(_ context.Context, modelName string,
|
|||||||
|
|
||||||
// usage mapping
|
// usage mapping
|
||||||
if um := root.Get("usageMetadata"); um.Exists() {
|
if um := root.Get("usageMetadata"); um.Exists() {
|
||||||
// input tokens = prompt + thoughts
|
// input tokens = prompt only (thoughts go to output)
|
||||||
input := um.Get("promptTokenCount").Int() + um.Get("thoughtsTokenCount").Int()
|
input := um.Get("promptTokenCount").Int()
|
||||||
completed, _ = sjson.Set(completed, "response.usage.input_tokens", input)
|
completed, _ = sjson.Set(completed, "response.usage.input_tokens", input)
|
||||||
// cached token details: align with OpenAI "cached_tokens" semantics.
|
// cached token details: align with OpenAI "cached_tokens" semantics.
|
||||||
completed, _ = sjson.Set(completed, "response.usage.input_tokens_details.cached_tokens", um.Get("cachedContentTokenCount").Int())
|
completed, _ = sjson.Set(completed, "response.usage.input_tokens_details.cached_tokens", um.Get("cachedContentTokenCount").Int())
|
||||||
@@ -737,8 +737,8 @@ func ConvertGeminiResponseToOpenAIResponsesNonStream(_ context.Context, _ string
|
|||||||
|
|
||||||
// usage mapping
|
// usage mapping
|
||||||
if um := root.Get("usageMetadata"); um.Exists() {
|
if um := root.Get("usageMetadata"); um.Exists() {
|
||||||
// input tokens = prompt + thoughts
|
// input tokens = prompt only (thoughts go to output)
|
||||||
input := um.Get("promptTokenCount").Int() + um.Get("thoughtsTokenCount").Int()
|
input := um.Get("promptTokenCount").Int()
|
||||||
resp, _ = sjson.Set(resp, "usage.input_tokens", input)
|
resp, _ = sjson.Set(resp, "usage.input_tokens", input)
|
||||||
// cached token details: align with OpenAI "cached_tokens" semantics.
|
// cached token details: align with OpenAI "cached_tokens" semantics.
|
||||||
resp, _ = sjson.Set(resp, "usage.input_tokens_details.cached_tokens", um.Get("cachedContentTokenCount").Int())
|
resp, _ = sjson.Set(resp, "usage.input_tokens_details.cached_tokens", um.Get("cachedContentTokenCount").Int())
|
||||||
|
|||||||
@@ -716,6 +716,12 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
if len(chunk.Payload) > 0 {
|
if len(chunk.Payload) > 0 {
|
||||||
|
if handlerType == "openai-response" {
|
||||||
|
if err := validateSSEDataJSON(chunk.Payload); err != nil {
|
||||||
|
_ = sendErr(&interfaces.ErrorMessage{StatusCode: http.StatusBadGateway, Error: err})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
sentPayload = true
|
sentPayload = true
|
||||||
if okSendData := sendData(cloneBytes(chunk.Payload)); !okSendData {
|
if okSendData := sendData(cloneBytes(chunk.Payload)); !okSendData {
|
||||||
return
|
return
|
||||||
@@ -727,6 +733,35 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl
|
|||||||
return dataChan, upstreamHeaders, errChan
|
return dataChan, upstreamHeaders, errChan
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func validateSSEDataJSON(chunk []byte) error {
|
||||||
|
for _, line := range bytes.Split(chunk, []byte("\n")) {
|
||||||
|
line = bytes.TrimSpace(line)
|
||||||
|
if len(line) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if !bytes.HasPrefix(line, []byte("data:")) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
data := bytes.TrimSpace(line[5:])
|
||||||
|
if len(data) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if bytes.Equal(data, []byte("[DONE]")) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if json.Valid(data) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
const max = 512
|
||||||
|
preview := data
|
||||||
|
if len(preview) > max {
|
||||||
|
preview = preview[:max]
|
||||||
|
}
|
||||||
|
return fmt.Errorf("invalid SSE data JSON (len=%d): %q", len(data), preview)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func statusFromError(err error) int {
|
func statusFromError(err error) int {
|
||||||
if err == nil {
|
if err == nil {
|
||||||
return 0
|
return 0
|
||||||
|
|||||||
@@ -134,6 +134,37 @@ type authAwareStreamExecutor struct {
|
|||||||
authIDs []string
|
authIDs []string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type invalidJSONStreamExecutor struct{}
|
||||||
|
|
||||||
|
func (e *invalidJSONStreamExecutor) Identifier() string { return "codex" }
|
||||||
|
|
||||||
|
func (e *invalidJSONStreamExecutor) Execute(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) {
|
||||||
|
return coreexecutor.Response{}, &coreauth.Error{Code: "not_implemented", Message: "Execute not implemented"}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *invalidJSONStreamExecutor) ExecuteStream(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (*coreexecutor.StreamResult, error) {
|
||||||
|
ch := make(chan coreexecutor.StreamChunk, 1)
|
||||||
|
ch <- coreexecutor.StreamChunk{Payload: []byte("event: response.completed\ndata: {\"type\"")}
|
||||||
|
close(ch)
|
||||||
|
return &coreexecutor.StreamResult{Chunks: ch}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *invalidJSONStreamExecutor) Refresh(ctx context.Context, auth *coreauth.Auth) (*coreauth.Auth, error) {
|
||||||
|
return auth, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *invalidJSONStreamExecutor) CountTokens(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) {
|
||||||
|
return coreexecutor.Response{}, &coreauth.Error{Code: "not_implemented", Message: "CountTokens not implemented"}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *invalidJSONStreamExecutor) HttpRequest(ctx context.Context, auth *coreauth.Auth, req *http.Request) (*http.Response, error) {
|
||||||
|
return nil, &coreauth.Error{
|
||||||
|
Code: "not_implemented",
|
||||||
|
Message: "HttpRequest not implemented",
|
||||||
|
HTTPStatus: http.StatusNotImplemented,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (e *authAwareStreamExecutor) Identifier() string { return "codex" }
|
func (e *authAwareStreamExecutor) Identifier() string { return "codex" }
|
||||||
|
|
||||||
func (e *authAwareStreamExecutor) Execute(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) {
|
func (e *authAwareStreamExecutor) Execute(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) {
|
||||||
@@ -524,3 +555,55 @@ func TestExecuteStreamWithAuthManager_SelectedAuthCallbackReceivesAuthID(t *test
|
|||||||
t.Fatalf("selectedAuthID = %q, want %q", selectedAuthID, "auth2")
|
t.Fatalf("selectedAuthID = %q, want %q", selectedAuthID, "auth2")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestExecuteStreamWithAuthManager_ValidatesOpenAIResponsesStreamDataJSON(t *testing.T) {
|
||||||
|
executor := &invalidJSONStreamExecutor{}
|
||||||
|
manager := coreauth.NewManager(nil, nil, nil)
|
||||||
|
manager.RegisterExecutor(executor)
|
||||||
|
|
||||||
|
auth1 := &coreauth.Auth{
|
||||||
|
ID: "auth1",
|
||||||
|
Provider: "codex",
|
||||||
|
Status: coreauth.StatusActive,
|
||||||
|
Metadata: map[string]any{"email": "test1@example.com"},
|
||||||
|
}
|
||||||
|
if _, err := manager.Register(context.Background(), auth1); err != nil {
|
||||||
|
t.Fatalf("manager.Register(auth1): %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
registry.GetGlobalRegistry().RegisterClient(auth1.ID, auth1.Provider, []*registry.ModelInfo{{ID: "test-model"}})
|
||||||
|
t.Cleanup(func() {
|
||||||
|
registry.GetGlobalRegistry().UnregisterClient(auth1.ID)
|
||||||
|
})
|
||||||
|
|
||||||
|
handler := NewBaseAPIHandlers(&sdkconfig.SDKConfig{}, manager)
|
||||||
|
dataChan, _, errChan := handler.ExecuteStreamWithAuthManager(context.Background(), "openai-response", "test-model", []byte(`{"model":"test-model"}`), "")
|
||||||
|
if dataChan == nil || errChan == nil {
|
||||||
|
t.Fatalf("expected non-nil channels")
|
||||||
|
}
|
||||||
|
|
||||||
|
var got []byte
|
||||||
|
for chunk := range dataChan {
|
||||||
|
got = append(got, chunk...)
|
||||||
|
}
|
||||||
|
if len(got) != 0 {
|
||||||
|
t.Fatalf("expected empty payload, got %q", string(got))
|
||||||
|
}
|
||||||
|
|
||||||
|
gotErr := false
|
||||||
|
for msg := range errChan {
|
||||||
|
if msg == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if msg.StatusCode != http.StatusBadGateway {
|
||||||
|
t.Fatalf("expected status %d, got %d", http.StatusBadGateway, msg.StatusCode)
|
||||||
|
}
|
||||||
|
if msg.Error == nil {
|
||||||
|
t.Fatalf("expected error")
|
||||||
|
}
|
||||||
|
gotErr = true
|
||||||
|
}
|
||||||
|
if !gotErr {
|
||||||
|
t.Fatalf("expected terminal error")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -265,8 +265,8 @@ func (h *OpenAIResponsesAPIHandler) forwardResponsesStream(c *gin.Context, flush
|
|||||||
if errMsg.Error != nil && errMsg.Error.Error() != "" {
|
if errMsg.Error != nil && errMsg.Error.Error() != "" {
|
||||||
errText = errMsg.Error.Error()
|
errText = errMsg.Error.Error()
|
||||||
}
|
}
|
||||||
body := handlers.BuildErrorResponseBody(status, errText)
|
chunk := handlers.BuildOpenAIResponsesStreamErrorChunk(status, errText, 0)
|
||||||
_, _ = fmt.Fprintf(c.Writer, "\nevent: error\ndata: %s\n\n", string(body))
|
_, _ = fmt.Fprintf(c.Writer, "\nevent: error\ndata: %s\n\n", string(chunk))
|
||||||
},
|
},
|
||||||
WriteDone: func() {
|
WriteDone: func() {
|
||||||
_, _ = c.Writer.Write([]byte("\n"))
|
_, _ = c.Writer.Write([]byte("\n"))
|
||||||
|
|||||||
@@ -0,0 +1,43 @@
|
|||||||
|
package openai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers"
|
||||||
|
sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestForwardResponsesStreamTerminalErrorUsesResponsesErrorChunk(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
base := handlers.NewBaseAPIHandlers(&sdkconfig.SDKConfig{}, nil)
|
||||||
|
h := NewOpenAIResponsesAPIHandler(base)
|
||||||
|
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(recorder)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||||
|
|
||||||
|
flusher, ok := c.Writer.(http.Flusher)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("expected gin writer to implement http.Flusher")
|
||||||
|
}
|
||||||
|
|
||||||
|
data := make(chan []byte)
|
||||||
|
errs := make(chan *interfaces.ErrorMessage, 1)
|
||||||
|
errs <- &interfaces.ErrorMessage{StatusCode: http.StatusInternalServerError, Error: errors.New("unexpected EOF")}
|
||||||
|
close(errs)
|
||||||
|
|
||||||
|
h.forwardResponsesStream(c, flusher, func(error) {}, data, errs)
|
||||||
|
body := recorder.Body.String()
|
||||||
|
if !strings.Contains(body, `"type":"error"`) {
|
||||||
|
t.Fatalf("expected responses error chunk, got: %q", body)
|
||||||
|
}
|
||||||
|
if strings.Contains(body, `"error":{`) {
|
||||||
|
t.Fatalf("expected streaming error chunk (top-level type), got HTTP error body: %q", body)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,119 @@
|
|||||||
|
package handlers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
type openAIResponsesStreamErrorChunk struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Code string `json:"code"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
SequenceNumber int `json:"sequence_number"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func openAIResponsesStreamErrorCode(status int) string {
|
||||||
|
switch status {
|
||||||
|
case http.StatusUnauthorized:
|
||||||
|
return "invalid_api_key"
|
||||||
|
case http.StatusForbidden:
|
||||||
|
return "insufficient_quota"
|
||||||
|
case http.StatusTooManyRequests:
|
||||||
|
return "rate_limit_exceeded"
|
||||||
|
case http.StatusNotFound:
|
||||||
|
return "model_not_found"
|
||||||
|
case http.StatusRequestTimeout:
|
||||||
|
return "request_timeout"
|
||||||
|
default:
|
||||||
|
if status >= http.StatusInternalServerError {
|
||||||
|
return "internal_server_error"
|
||||||
|
}
|
||||||
|
if status >= http.StatusBadRequest {
|
||||||
|
return "invalid_request_error"
|
||||||
|
}
|
||||||
|
return "unknown_error"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// BuildOpenAIResponsesStreamErrorChunk builds an OpenAI Responses streaming error chunk.
|
||||||
|
//
|
||||||
|
// Important: OpenAI's HTTP error bodies are shaped like {"error":{...}}; those are valid for
|
||||||
|
// non-streaming responses, but streaming clients validate SSE `data:` payloads against a union
|
||||||
|
// of chunks that requires a top-level `type` field.
|
||||||
|
func BuildOpenAIResponsesStreamErrorChunk(status int, errText string, sequenceNumber int) []byte {
|
||||||
|
if status <= 0 {
|
||||||
|
status = http.StatusInternalServerError
|
||||||
|
}
|
||||||
|
if sequenceNumber < 0 {
|
||||||
|
sequenceNumber = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
message := strings.TrimSpace(errText)
|
||||||
|
if message == "" {
|
||||||
|
message = http.StatusText(status)
|
||||||
|
}
|
||||||
|
|
||||||
|
code := openAIResponsesStreamErrorCode(status)
|
||||||
|
|
||||||
|
trimmed := strings.TrimSpace(errText)
|
||||||
|
if trimmed != "" && json.Valid([]byte(trimmed)) {
|
||||||
|
var payload map[string]any
|
||||||
|
if err := json.Unmarshal([]byte(trimmed), &payload); err == nil {
|
||||||
|
if t, ok := payload["type"].(string); ok && strings.TrimSpace(t) == "error" {
|
||||||
|
if m, ok := payload["message"].(string); ok && strings.TrimSpace(m) != "" {
|
||||||
|
message = strings.TrimSpace(m)
|
||||||
|
}
|
||||||
|
if v, ok := payload["code"]; ok && v != nil {
|
||||||
|
if c, ok := v.(string); ok && strings.TrimSpace(c) != "" {
|
||||||
|
code = strings.TrimSpace(c)
|
||||||
|
} else {
|
||||||
|
code = strings.TrimSpace(fmt.Sprint(v))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if v, ok := payload["sequence_number"].(float64); ok && sequenceNumber == 0 {
|
||||||
|
sequenceNumber = int(v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if e, ok := payload["error"].(map[string]any); ok {
|
||||||
|
if m, ok := e["message"].(string); ok && strings.TrimSpace(m) != "" {
|
||||||
|
message = strings.TrimSpace(m)
|
||||||
|
}
|
||||||
|
if v, ok := e["code"]; ok && v != nil {
|
||||||
|
if c, ok := v.(string); ok && strings.TrimSpace(c) != "" {
|
||||||
|
code = strings.TrimSpace(c)
|
||||||
|
} else {
|
||||||
|
code = strings.TrimSpace(fmt.Sprint(v))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.TrimSpace(code) == "" {
|
||||||
|
code = "unknown_error"
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := json.Marshal(openAIResponsesStreamErrorChunk{
|
||||||
|
Type: "error",
|
||||||
|
Code: code,
|
||||||
|
Message: message,
|
||||||
|
SequenceNumber: sequenceNumber,
|
||||||
|
})
|
||||||
|
if err == nil {
|
||||||
|
return data
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extremely defensive fallback.
|
||||||
|
data, _ = json.Marshal(openAIResponsesStreamErrorChunk{
|
||||||
|
Type: "error",
|
||||||
|
Code: "internal_server_error",
|
||||||
|
Message: message,
|
||||||
|
SequenceNumber: sequenceNumber,
|
||||||
|
})
|
||||||
|
if len(data) > 0 {
|
||||||
|
return data
|
||||||
|
}
|
||||||
|
return []byte(`{"type":"error","code":"internal_server_error","message":"internal error","sequence_number":0}`)
|
||||||
|
}
|
||||||
@@ -0,0 +1,48 @@
|
|||||||
|
package handlers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestBuildOpenAIResponsesStreamErrorChunk(t *testing.T) {
|
||||||
|
chunk := BuildOpenAIResponsesStreamErrorChunk(http.StatusInternalServerError, "unexpected EOF", 0)
|
||||||
|
var payload map[string]any
|
||||||
|
if err := json.Unmarshal(chunk, &payload); err != nil {
|
||||||
|
t.Fatalf("unmarshal: %v", err)
|
||||||
|
}
|
||||||
|
if payload["type"] != "error" {
|
||||||
|
t.Fatalf("type = %v, want %q", payload["type"], "error")
|
||||||
|
}
|
||||||
|
if payload["code"] != "internal_server_error" {
|
||||||
|
t.Fatalf("code = %v, want %q", payload["code"], "internal_server_error")
|
||||||
|
}
|
||||||
|
if payload["message"] != "unexpected EOF" {
|
||||||
|
t.Fatalf("message = %v, want %q", payload["message"], "unexpected EOF")
|
||||||
|
}
|
||||||
|
if payload["sequence_number"] != float64(0) {
|
||||||
|
t.Fatalf("sequence_number = %v, want %v", payload["sequence_number"], 0)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildOpenAIResponsesStreamErrorChunkExtractsHTTPErrorBody(t *testing.T) {
|
||||||
|
chunk := BuildOpenAIResponsesStreamErrorChunk(
|
||||||
|
http.StatusInternalServerError,
|
||||||
|
`{"error":{"message":"oops","type":"server_error","code":"internal_server_error"}}`,
|
||||||
|
0,
|
||||||
|
)
|
||||||
|
var payload map[string]any
|
||||||
|
if err := json.Unmarshal(chunk, &payload); err != nil {
|
||||||
|
t.Fatalf("unmarshal: %v", err)
|
||||||
|
}
|
||||||
|
if payload["type"] != "error" {
|
||||||
|
t.Fatalf("type = %v, want %q", payload["type"], "error")
|
||||||
|
}
|
||||||
|
if payload["code"] != "internal_server_error" {
|
||||||
|
t.Fatalf("code = %v, want %q", payload["code"], "internal_server_error")
|
||||||
|
}
|
||||||
|
if payload["message"] != "oops" {
|
||||||
|
t.Fatalf("message = %v, want %q", payload["message"], "oops")
|
||||||
|
}
|
||||||
|
}
|
||||||
+5
-37
@@ -2,8 +2,6 @@ package auth
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"crypto/sha256"
|
|
||||||
"encoding/hex"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -48,6 +46,10 @@ func (a *CodexAuthenticator) Login(ctx context.Context, cfg *config.Config, opts
|
|||||||
opts = &LoginOptions{}
|
opts = &LoginOptions{}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if shouldUseCodexDeviceFlow(opts) {
|
||||||
|
return a.loginWithDeviceFlow(ctx, cfg, opts)
|
||||||
|
}
|
||||||
|
|
||||||
callbackPort := a.CallbackPort
|
callbackPort := a.CallbackPort
|
||||||
if opts.CallbackPort > 0 {
|
if opts.CallbackPort > 0 {
|
||||||
callbackPort = opts.CallbackPort
|
callbackPort = opts.CallbackPort
|
||||||
@@ -186,39 +188,5 @@ waitForCallback:
|
|||||||
return nil, codex.NewAuthenticationError(codex.ErrCodeExchangeFailed, err)
|
return nil, codex.NewAuthenticationError(codex.ErrCodeExchangeFailed, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
tokenStorage := authSvc.CreateTokenStorage(authBundle)
|
return a.buildAuthRecord(authSvc, authBundle)
|
||||||
|
|
||||||
if tokenStorage == nil || tokenStorage.Email == "" {
|
|
||||||
return nil, fmt.Errorf("codex token storage missing account information")
|
|
||||||
}
|
|
||||||
|
|
||||||
planType := ""
|
|
||||||
hashAccountID := ""
|
|
||||||
if tokenStorage.IDToken != "" {
|
|
||||||
if claims, errParse := codex.ParseJWTToken(tokenStorage.IDToken); errParse == nil && claims != nil {
|
|
||||||
planType = strings.TrimSpace(claims.CodexAuthInfo.ChatgptPlanType)
|
|
||||||
accountID := strings.TrimSpace(claims.CodexAuthInfo.ChatgptAccountID)
|
|
||||||
if accountID != "" {
|
|
||||||
digest := sha256.Sum256([]byte(accountID))
|
|
||||||
hashAccountID = hex.EncodeToString(digest[:])[:8]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
fileName := codex.CredentialFileName(tokenStorage.Email, planType, hashAccountID, true)
|
|
||||||
metadata := map[string]any{
|
|
||||||
"email": tokenStorage.Email,
|
|
||||||
}
|
|
||||||
|
|
||||||
fmt.Println("Codex authentication successful")
|
|
||||||
if authBundle.APIKey != "" {
|
|
||||||
fmt.Println("Codex API key obtained and stored")
|
|
||||||
}
|
|
||||||
|
|
||||||
return &coreauth.Auth{
|
|
||||||
ID: fileName,
|
|
||||||
Provider: a.Provider(),
|
|
||||||
FileName: fileName,
|
|
||||||
Storage: tokenStorage,
|
|
||||||
Metadata: metadata,
|
|
||||||
}, nil
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,291 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/hex"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codex"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/browser"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||||
|
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
codexLoginModeMetadataKey = "codex_login_mode"
|
||||||
|
codexLoginModeDevice = "device"
|
||||||
|
codexDeviceUserCodeURL = "https://auth.openai.com/api/accounts/deviceauth/usercode"
|
||||||
|
codexDeviceTokenURL = "https://auth.openai.com/api/accounts/deviceauth/token"
|
||||||
|
codexDeviceVerificationURL = "https://auth.openai.com/codex/device"
|
||||||
|
codexDeviceTokenExchangeRedirectURI = "https://auth.openai.com/deviceauth/callback"
|
||||||
|
codexDeviceTimeout = 15 * time.Minute
|
||||||
|
codexDeviceDefaultPollIntervalSeconds = 5
|
||||||
|
)
|
||||||
|
|
||||||
|
type codexDeviceUserCodeRequest struct {
|
||||||
|
ClientID string `json:"client_id"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type codexDeviceUserCodeResponse struct {
|
||||||
|
DeviceAuthID string `json:"device_auth_id"`
|
||||||
|
UserCode string `json:"user_code"`
|
||||||
|
UserCodeAlt string `json:"usercode"`
|
||||||
|
Interval json.RawMessage `json:"interval"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type codexDeviceTokenRequest struct {
|
||||||
|
DeviceAuthID string `json:"device_auth_id"`
|
||||||
|
UserCode string `json:"user_code"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type codexDeviceTokenResponse struct {
|
||||||
|
AuthorizationCode string `json:"authorization_code"`
|
||||||
|
CodeVerifier string `json:"code_verifier"`
|
||||||
|
CodeChallenge string `json:"code_challenge"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func shouldUseCodexDeviceFlow(opts *LoginOptions) bool {
|
||||||
|
if opts == nil || opts.Metadata == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return strings.EqualFold(strings.TrimSpace(opts.Metadata[codexLoginModeMetadataKey]), codexLoginModeDevice)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *CodexAuthenticator) loginWithDeviceFlow(ctx context.Context, cfg *config.Config, opts *LoginOptions) (*coreauth.Auth, error) {
|
||||||
|
if ctx == nil {
|
||||||
|
ctx = context.Background()
|
||||||
|
}
|
||||||
|
|
||||||
|
httpClient := util.SetProxy(&cfg.SDKConfig, &http.Client{})
|
||||||
|
|
||||||
|
userCodeResp, err := requestCodexDeviceUserCode(ctx, httpClient)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
deviceCode := strings.TrimSpace(userCodeResp.UserCode)
|
||||||
|
if deviceCode == "" {
|
||||||
|
deviceCode = strings.TrimSpace(userCodeResp.UserCodeAlt)
|
||||||
|
}
|
||||||
|
deviceAuthID := strings.TrimSpace(userCodeResp.DeviceAuthID)
|
||||||
|
if deviceCode == "" || deviceAuthID == "" {
|
||||||
|
return nil, fmt.Errorf("codex device flow did not return required fields")
|
||||||
|
}
|
||||||
|
|
||||||
|
pollInterval := parseCodexDevicePollInterval(userCodeResp.Interval)
|
||||||
|
|
||||||
|
fmt.Println("Starting Codex device authentication...")
|
||||||
|
fmt.Printf("Codex device URL: %s\n", codexDeviceVerificationURL)
|
||||||
|
fmt.Printf("Codex device code: %s\n", deviceCode)
|
||||||
|
|
||||||
|
if !opts.NoBrowser {
|
||||||
|
if !browser.IsAvailable() {
|
||||||
|
log.Warn("No browser available; please open the device URL manually")
|
||||||
|
} else if errOpen := browser.OpenURL(codexDeviceVerificationURL); errOpen != nil {
|
||||||
|
log.Warnf("Failed to open browser automatically: %v", errOpen)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
tokenResp, err := pollCodexDeviceToken(ctx, httpClient, deviceAuthID, deviceCode, pollInterval)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
authCode := strings.TrimSpace(tokenResp.AuthorizationCode)
|
||||||
|
codeVerifier := strings.TrimSpace(tokenResp.CodeVerifier)
|
||||||
|
codeChallenge := strings.TrimSpace(tokenResp.CodeChallenge)
|
||||||
|
if authCode == "" || codeVerifier == "" || codeChallenge == "" {
|
||||||
|
return nil, fmt.Errorf("codex device flow token response missing required fields")
|
||||||
|
}
|
||||||
|
|
||||||
|
authSvc := codex.NewCodexAuth(cfg)
|
||||||
|
authBundle, err := authSvc.ExchangeCodeForTokensWithRedirect(
|
||||||
|
ctx,
|
||||||
|
authCode,
|
||||||
|
codexDeviceTokenExchangeRedirectURI,
|
||||||
|
&codex.PKCECodes{
|
||||||
|
CodeVerifier: codeVerifier,
|
||||||
|
CodeChallenge: codeChallenge,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, codex.NewAuthenticationError(codex.ErrCodeExchangeFailed, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return a.buildAuthRecord(authSvc, authBundle)
|
||||||
|
}
|
||||||
|
|
||||||
|
func requestCodexDeviceUserCode(ctx context.Context, client *http.Client) (*codexDeviceUserCodeResponse, error) {
|
||||||
|
body, err := json.Marshal(codexDeviceUserCodeRequest{ClientID: codex.ClientID})
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to encode codex device request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, codexDeviceUserCodeURL, bytes.NewReader(body))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create codex device request: %w", err)
|
||||||
|
}
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
req.Header.Set("Accept", "application/json")
|
||||||
|
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to request codex device code: %w", err)
|
||||||
|
}
|
||||||
|
defer func() { _ = resp.Body.Close() }()
|
||||||
|
|
||||||
|
respBody, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to read codex device code response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !codexDeviceIsSuccessStatus(resp.StatusCode) {
|
||||||
|
trimmed := strings.TrimSpace(string(respBody))
|
||||||
|
if resp.StatusCode == http.StatusNotFound {
|
||||||
|
return nil, fmt.Errorf("codex device endpoint is unavailable (status %d)", resp.StatusCode)
|
||||||
|
}
|
||||||
|
if trimmed == "" {
|
||||||
|
trimmed = "empty response body"
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("codex device code request failed with status %d: %s", resp.StatusCode, trimmed)
|
||||||
|
}
|
||||||
|
|
||||||
|
var parsed codexDeviceUserCodeResponse
|
||||||
|
if err := json.Unmarshal(respBody, &parsed); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to decode codex device code response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &parsed, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func pollCodexDeviceToken(ctx context.Context, client *http.Client, deviceAuthID, userCode string, interval time.Duration) (*codexDeviceTokenResponse, error) {
|
||||||
|
deadline := time.Now().Add(codexDeviceTimeout)
|
||||||
|
|
||||||
|
for {
|
||||||
|
if time.Now().After(deadline) {
|
||||||
|
return nil, fmt.Errorf("codex device authentication timed out after 15 minutes")
|
||||||
|
}
|
||||||
|
|
||||||
|
body, err := json.Marshal(codexDeviceTokenRequest{
|
||||||
|
DeviceAuthID: deviceAuthID,
|
||||||
|
UserCode: userCode,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to encode codex device poll request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, codexDeviceTokenURL, bytes.NewReader(body))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create codex device poll request: %w", err)
|
||||||
|
}
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
req.Header.Set("Accept", "application/json")
|
||||||
|
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to poll codex device token: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
respBody, readErr := io.ReadAll(resp.Body)
|
||||||
|
_ = resp.Body.Close()
|
||||||
|
if readErr != nil {
|
||||||
|
return nil, fmt.Errorf("failed to read codex device poll response: %w", readErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case codexDeviceIsSuccessStatus(resp.StatusCode):
|
||||||
|
var parsed codexDeviceTokenResponse
|
||||||
|
if err := json.Unmarshal(respBody, &parsed); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to decode codex device token response: %w", err)
|
||||||
|
}
|
||||||
|
return &parsed, nil
|
||||||
|
case resp.StatusCode == http.StatusForbidden || resp.StatusCode == http.StatusNotFound:
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return nil, ctx.Err()
|
||||||
|
case <-time.After(interval):
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
trimmed := strings.TrimSpace(string(respBody))
|
||||||
|
if trimmed == "" {
|
||||||
|
trimmed = "empty response body"
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("codex device token polling failed with status %d: %s", resp.StatusCode, trimmed)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseCodexDevicePollInterval(raw json.RawMessage) time.Duration {
|
||||||
|
defaultInterval := time.Duration(codexDeviceDefaultPollIntervalSeconds) * time.Second
|
||||||
|
if len(raw) == 0 {
|
||||||
|
return defaultInterval
|
||||||
|
}
|
||||||
|
|
||||||
|
var asString string
|
||||||
|
if err := json.Unmarshal(raw, &asString); err == nil {
|
||||||
|
if seconds, convErr := strconv.Atoi(strings.TrimSpace(asString)); convErr == nil && seconds > 0 {
|
||||||
|
return time.Duration(seconds) * time.Second
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var asInt int
|
||||||
|
if err := json.Unmarshal(raw, &asInt); err == nil && asInt > 0 {
|
||||||
|
return time.Duration(asInt) * time.Second
|
||||||
|
}
|
||||||
|
|
||||||
|
return defaultInterval
|
||||||
|
}
|
||||||
|
|
||||||
|
func codexDeviceIsSuccessStatus(code int) bool {
|
||||||
|
return code >= 200 && code < 300
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *CodexAuthenticator) buildAuthRecord(authSvc *codex.CodexAuth, authBundle *codex.CodexAuthBundle) (*coreauth.Auth, error) {
|
||||||
|
tokenStorage := authSvc.CreateTokenStorage(authBundle)
|
||||||
|
|
||||||
|
if tokenStorage == nil || tokenStorage.Email == "" {
|
||||||
|
return nil, fmt.Errorf("codex token storage missing account information")
|
||||||
|
}
|
||||||
|
|
||||||
|
planType := ""
|
||||||
|
hashAccountID := ""
|
||||||
|
if tokenStorage.IDToken != "" {
|
||||||
|
if claims, errParse := codex.ParseJWTToken(tokenStorage.IDToken); errParse == nil && claims != nil {
|
||||||
|
planType = strings.TrimSpace(claims.CodexAuthInfo.ChatgptPlanType)
|
||||||
|
accountID := strings.TrimSpace(claims.CodexAuthInfo.ChatgptAccountID)
|
||||||
|
if accountID != "" {
|
||||||
|
digest := sha256.Sum256([]byte(accountID))
|
||||||
|
hashAccountID = hex.EncodeToString(digest[:])[:8]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fileName := codex.CredentialFileName(tokenStorage.Email, planType, hashAccountID, true)
|
||||||
|
metadata := map[string]any{
|
||||||
|
"email": tokenStorage.Email,
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Println("Codex authentication successful")
|
||||||
|
if authBundle.APIKey != "" {
|
||||||
|
fmt.Println("Codex API key obtained and stored")
|
||||||
|
}
|
||||||
|
|
||||||
|
return &coreauth.Auth{
|
||||||
|
ID: fileName,
|
||||||
|
Provider: a.Provider(),
|
||||||
|
FileName: fileName,
|
||||||
|
Storage: tokenStorage,
|
||||||
|
Metadata: metadata,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
@@ -64,8 +64,16 @@ func (s *FileTokenStore) Save(ctx context.Context, auth *cliproxyauth.Auth) (str
|
|||||||
return "", fmt.Errorf("auth filestore: create dir failed: %w", err)
|
return "", fmt.Errorf("auth filestore: create dir failed: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// metadataSetter is a private interface for TokenStorage implementations that support metadata injection.
|
||||||
|
type metadataSetter interface {
|
||||||
|
SetMetadata(map[string]any)
|
||||||
|
}
|
||||||
|
|
||||||
switch {
|
switch {
|
||||||
case auth.Storage != nil:
|
case auth.Storage != nil:
|
||||||
|
if setter, ok := auth.Storage.(metadataSetter); ok {
|
||||||
|
setter.SetMetadata(auth.Metadata)
|
||||||
|
}
|
||||||
if err = auth.Storage.SaveTokenToFile(path); err != nil {
|
if err = auth.Storage.SaveTokenToFile(path); err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -60,6 +60,7 @@ type RefreshEvaluator interface {
|
|||||||
|
|
||||||
const (
|
const (
|
||||||
refreshCheckInterval = 5 * time.Second
|
refreshCheckInterval = 5 * time.Second
|
||||||
|
refreshMaxConcurrency = 16
|
||||||
refreshPendingBackoff = time.Minute
|
refreshPendingBackoff = time.Minute
|
||||||
refreshFailureBackoff = 5 * time.Minute
|
refreshFailureBackoff = 5 * time.Minute
|
||||||
quotaBackoffBase = time.Second
|
quotaBackoffBase = time.Second
|
||||||
@@ -156,6 +157,7 @@ type Manager struct {
|
|||||||
|
|
||||||
// Auto refresh state
|
// Auto refresh state
|
||||||
refreshCancel context.CancelFunc
|
refreshCancel context.CancelFunc
|
||||||
|
refreshSemaphore chan struct{}
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewManager constructs a manager with optional custom selector and hook.
|
// NewManager constructs a manager with optional custom selector and hook.
|
||||||
@@ -173,6 +175,7 @@ func NewManager(store Store, selector Selector, hook Hook) *Manager {
|
|||||||
hook: hook,
|
hook: hook,
|
||||||
auths: make(map[string]*Auth),
|
auths: make(map[string]*Auth),
|
||||||
providerOffsets: make(map[string]int),
|
providerOffsets: make(map[string]int),
|
||||||
|
refreshSemaphore: make(chan struct{}, refreshMaxConcurrency),
|
||||||
}
|
}
|
||||||
// atomic.Value requires non-nil initial value.
|
// atomic.Value requires non-nil initial value.
|
||||||
manager.runtimeConfig.Store(&internalconfig.Config{})
|
manager.runtimeConfig.Store(&internalconfig.Config{})
|
||||||
@@ -1828,9 +1831,7 @@ func (m *Manager) persist(ctx context.Context, auth *Auth) error {
|
|||||||
// every few seconds and triggers refresh operations when required.
|
// every few seconds and triggers refresh operations when required.
|
||||||
// Only one loop is kept alive; starting a new one cancels the previous run.
|
// Only one loop is kept alive; starting a new one cancels the previous run.
|
||||||
func (m *Manager) StartAutoRefresh(parent context.Context, interval time.Duration) {
|
func (m *Manager) StartAutoRefresh(parent context.Context, interval time.Duration) {
|
||||||
if interval <= 0 || interval > refreshCheckInterval {
|
if interval <= 0 {
|
||||||
interval = refreshCheckInterval
|
|
||||||
} else {
|
|
||||||
interval = refreshCheckInterval
|
interval = refreshCheckInterval
|
||||||
}
|
}
|
||||||
if m.refreshCancel != nil {
|
if m.refreshCancel != nil {
|
||||||
@@ -1880,11 +1881,25 @@ func (m *Manager) checkRefreshes(ctx context.Context) {
|
|||||||
if !m.markRefreshPending(a.ID, now) {
|
if !m.markRefreshPending(a.ID, now) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
go m.refreshAuth(ctx, a.ID)
|
go m.refreshAuthWithLimit(ctx, a.ID)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *Manager) refreshAuthWithLimit(ctx context.Context, id string) {
|
||||||
|
if m.refreshSemaphore == nil {
|
||||||
|
m.refreshAuth(ctx, id)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case m.refreshSemaphore <- struct{}{}:
|
||||||
|
defer func() { <-m.refreshSemaphore }()
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
}
|
||||||
|
m.refreshAuth(ctx, id)
|
||||||
|
}
|
||||||
|
|
||||||
func (m *Manager) snapshotAuths() []*Auth {
|
func (m *Manager) snapshotAuths() []*Auth {
|
||||||
m.mu.RLock()
|
m.mu.RLock()
|
||||||
defer m.mu.RUnlock()
|
defer m.mu.RUnlock()
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"math"
|
"math"
|
||||||
|
"math/rand/v2"
|
||||||
"net/http"
|
"net/http"
|
||||||
"sort"
|
"sort"
|
||||||
"strconv"
|
"strconv"
|
||||||
@@ -248,6 +249,9 @@ func getAvailableAuths(auths []*Auth, provider, model string, now time.Time) ([]
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Pick selects the next available auth for the provider in a round-robin manner.
|
// Pick selects the next available auth for the provider in a round-robin manner.
|
||||||
|
// For gemini-cli virtual auths (identified by the gemini_virtual_parent attribute),
|
||||||
|
// a two-level round-robin is used: first cycling across credential groups (parent
|
||||||
|
// accounts), then cycling within each group's project auths.
|
||||||
func (s *RoundRobinSelector) Pick(ctx context.Context, provider, model string, opts cliproxyexecutor.Options, auths []*Auth) (*Auth, error) {
|
func (s *RoundRobinSelector) Pick(ctx context.Context, provider, model string, opts cliproxyexecutor.Options, auths []*Auth) (*Auth, error) {
|
||||||
_ = opts
|
_ = opts
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
@@ -265,21 +269,87 @@ func (s *RoundRobinSelector) Pick(ctx context.Context, provider, model string, o
|
|||||||
if limit <= 0 {
|
if limit <= 0 {
|
||||||
limit = 4096
|
limit = 4096
|
||||||
}
|
}
|
||||||
if _, ok := s.cursors[key]; !ok && len(s.cursors) >= limit {
|
|
||||||
s.cursors = make(map[string]int)
|
|
||||||
}
|
|
||||||
index := s.cursors[key]
|
|
||||||
|
|
||||||
|
// Check if any available auth has gemini_virtual_parent attribute,
|
||||||
|
// indicating gemini-cli virtual auths that should use credential-level polling.
|
||||||
|
groups, parentOrder := groupByVirtualParent(available)
|
||||||
|
if len(parentOrder) > 1 {
|
||||||
|
// Two-level round-robin: first select a credential group, then pick within it.
|
||||||
|
groupKey := key + "::group"
|
||||||
|
s.ensureCursorKey(groupKey, limit)
|
||||||
|
if _, exists := s.cursors[groupKey]; !exists {
|
||||||
|
// Seed with a random initial offset so the starting credential is randomized.
|
||||||
|
s.cursors[groupKey] = rand.IntN(len(parentOrder))
|
||||||
|
}
|
||||||
|
groupIndex := s.cursors[groupKey]
|
||||||
|
if groupIndex >= 2_147_483_640 {
|
||||||
|
groupIndex = 0
|
||||||
|
}
|
||||||
|
s.cursors[groupKey] = groupIndex + 1
|
||||||
|
|
||||||
|
selectedParent := parentOrder[groupIndex%len(parentOrder)]
|
||||||
|
group := groups[selectedParent]
|
||||||
|
|
||||||
|
// Second level: round-robin within the selected credential group.
|
||||||
|
innerKey := key + "::cred:" + selectedParent
|
||||||
|
s.ensureCursorKey(innerKey, limit)
|
||||||
|
innerIndex := s.cursors[innerKey]
|
||||||
|
if innerIndex >= 2_147_483_640 {
|
||||||
|
innerIndex = 0
|
||||||
|
}
|
||||||
|
s.cursors[innerKey] = innerIndex + 1
|
||||||
|
s.mu.Unlock()
|
||||||
|
return group[innerIndex%len(group)], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Flat round-robin for non-grouped auths (original behavior).
|
||||||
|
s.ensureCursorKey(key, limit)
|
||||||
|
index := s.cursors[key]
|
||||||
if index >= 2_147_483_640 {
|
if index >= 2_147_483_640 {
|
||||||
index = 0
|
index = 0
|
||||||
}
|
}
|
||||||
|
|
||||||
s.cursors[key] = index + 1
|
s.cursors[key] = index + 1
|
||||||
s.mu.Unlock()
|
s.mu.Unlock()
|
||||||
// log.Debugf("available: %d, index: %d, key: %d", len(available), index, index%len(available))
|
|
||||||
return available[index%len(available)], nil
|
return available[index%len(available)], nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ensureCursorKey ensures the cursor map has capacity for the given key.
|
||||||
|
// Must be called with s.mu held.
|
||||||
|
func (s *RoundRobinSelector) ensureCursorKey(key string, limit int) {
|
||||||
|
if _, ok := s.cursors[key]; !ok && len(s.cursors) >= limit {
|
||||||
|
s.cursors = make(map[string]int)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// groupByVirtualParent groups auths by their gemini_virtual_parent attribute.
|
||||||
|
// Returns a map of parentID -> auths and a sorted slice of parent IDs for stable iteration.
|
||||||
|
// Only auths with a non-empty gemini_virtual_parent are grouped; if any auth lacks
|
||||||
|
// this attribute, nil/nil is returned so the caller falls back to flat round-robin.
|
||||||
|
func groupByVirtualParent(auths []*Auth) (map[string][]*Auth, []string) {
|
||||||
|
if len(auths) == 0 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
groups := make(map[string][]*Auth)
|
||||||
|
for _, a := range auths {
|
||||||
|
parent := ""
|
||||||
|
if a.Attributes != nil {
|
||||||
|
parent = strings.TrimSpace(a.Attributes["gemini_virtual_parent"])
|
||||||
|
}
|
||||||
|
if parent == "" {
|
||||||
|
// Non-virtual auth present; fall back to flat round-robin.
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
groups[parent] = append(groups[parent], a)
|
||||||
|
}
|
||||||
|
// Collect parent IDs in sorted order for stable cursor indexing.
|
||||||
|
parentOrder := make([]string, 0, len(groups))
|
||||||
|
for p := range groups {
|
||||||
|
parentOrder = append(parentOrder, p)
|
||||||
|
}
|
||||||
|
sort.Strings(parentOrder)
|
||||||
|
return groups, parentOrder
|
||||||
|
}
|
||||||
|
|
||||||
// Pick selects the first available auth for the provider in a deterministic manner.
|
// Pick selects the first available auth for the provider in a deterministic manner.
|
||||||
func (s *FillFirstSelector) Pick(ctx context.Context, provider, model string, opts cliproxyexecutor.Options, auths []*Auth) (*Auth, error) {
|
func (s *FillFirstSelector) Pick(ctx context.Context, provider, model string, opts cliproxyexecutor.Options, auths []*Auth) (*Auth, error) {
|
||||||
_ = opts
|
_ = opts
|
||||||
|
|||||||
@@ -402,3 +402,128 @@ func TestRoundRobinSelectorPick_CursorKeyCap(t *testing.T) {
|
|||||||
t.Fatalf("selector.cursors missing key %q", "gemini:m3")
|
t.Fatalf("selector.cursors missing key %q", "gemini:m3")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestRoundRobinSelectorPick_GeminiCLICredentialGrouping(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
selector := &RoundRobinSelector{}
|
||||||
|
|
||||||
|
// Simulate two gemini-cli credentials, each with multiple projects:
|
||||||
|
// Credential A (parent = "cred-a.json") has 3 projects
|
||||||
|
// Credential B (parent = "cred-b.json") has 2 projects
|
||||||
|
auths := []*Auth{
|
||||||
|
{ID: "cred-a.json::proj-a1", Attributes: map[string]string{"gemini_virtual_parent": "cred-a.json"}},
|
||||||
|
{ID: "cred-a.json::proj-a2", Attributes: map[string]string{"gemini_virtual_parent": "cred-a.json"}},
|
||||||
|
{ID: "cred-a.json::proj-a3", Attributes: map[string]string{"gemini_virtual_parent": "cred-a.json"}},
|
||||||
|
{ID: "cred-b.json::proj-b1", Attributes: map[string]string{"gemini_virtual_parent": "cred-b.json"}},
|
||||||
|
{ID: "cred-b.json::proj-b2", Attributes: map[string]string{"gemini_virtual_parent": "cred-b.json"}},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Two-level round-robin: consecutive picks must alternate between credentials.
|
||||||
|
// Credential group order is randomized, but within each call the group cursor
|
||||||
|
// advances by 1, so consecutive picks should cycle through different parents.
|
||||||
|
picks := make([]string, 6)
|
||||||
|
parents := make([]string, 6)
|
||||||
|
for i := 0; i < 6; i++ {
|
||||||
|
got, err := selector.Pick(context.Background(), "gemini-cli", "gemini-2.5-pro", cliproxyexecutor.Options{}, auths)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Pick() #%d error = %v", i, err)
|
||||||
|
}
|
||||||
|
if got == nil {
|
||||||
|
t.Fatalf("Pick() #%d auth = nil", i)
|
||||||
|
}
|
||||||
|
picks[i] = got.ID
|
||||||
|
parents[i] = got.Attributes["gemini_virtual_parent"]
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify property: consecutive picks must alternate between credential groups.
|
||||||
|
for i := 1; i < len(parents); i++ {
|
||||||
|
if parents[i] == parents[i-1] {
|
||||||
|
t.Fatalf("Pick() #%d and #%d both from same parent %q (IDs: %q, %q); expected alternating credentials",
|
||||||
|
i-1, i, parents[i], picks[i-1], picks[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify property: each credential's projects are picked in sequence (round-robin within group).
|
||||||
|
credPicks := map[string][]string{}
|
||||||
|
for i, id := range picks {
|
||||||
|
credPicks[parents[i]] = append(credPicks[parents[i]], id)
|
||||||
|
}
|
||||||
|
for parent, ids := range credPicks {
|
||||||
|
for i := 1; i < len(ids); i++ {
|
||||||
|
if ids[i] == ids[i-1] {
|
||||||
|
t.Fatalf("Credential %q picked same project %q twice in a row", parent, ids[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRoundRobinSelectorPick_SingleParentFallsBackToFlat(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
selector := &RoundRobinSelector{}
|
||||||
|
|
||||||
|
// All auths from the same parent - should fall back to flat round-robin
|
||||||
|
// because there's only one credential group (no benefit from two-level).
|
||||||
|
auths := []*Auth{
|
||||||
|
{ID: "cred-a.json::proj-a1", Attributes: map[string]string{"gemini_virtual_parent": "cred-a.json"}},
|
||||||
|
{ID: "cred-a.json::proj-a2", Attributes: map[string]string{"gemini_virtual_parent": "cred-a.json"}},
|
||||||
|
{ID: "cred-a.json::proj-a3", Attributes: map[string]string{"gemini_virtual_parent": "cred-a.json"}},
|
||||||
|
}
|
||||||
|
|
||||||
|
// With single parent group, parentOrder has length 1, so it uses flat round-robin.
|
||||||
|
// Sorted by ID: proj-a1, proj-a2, proj-a3
|
||||||
|
want := []string{
|
||||||
|
"cred-a.json::proj-a1",
|
||||||
|
"cred-a.json::proj-a2",
|
||||||
|
"cred-a.json::proj-a3",
|
||||||
|
"cred-a.json::proj-a1",
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, expectedID := range want {
|
||||||
|
got, err := selector.Pick(context.Background(), "gemini-cli", "gemini-2.5-pro", cliproxyexecutor.Options{}, auths)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Pick() #%d error = %v", i, err)
|
||||||
|
}
|
||||||
|
if got == nil {
|
||||||
|
t.Fatalf("Pick() #%d auth = nil", i)
|
||||||
|
}
|
||||||
|
if got.ID != expectedID {
|
||||||
|
t.Fatalf("Pick() #%d auth.ID = %q, want %q", i, got.ID, expectedID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRoundRobinSelectorPick_MixedVirtualAndNonVirtualFallsBackToFlat(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
selector := &RoundRobinSelector{}
|
||||||
|
|
||||||
|
// Mix of virtual and non-virtual auths (e.g., a regular gemini-cli auth without projects
|
||||||
|
// alongside virtual ones). Should fall back to flat round-robin.
|
||||||
|
auths := []*Auth{
|
||||||
|
{ID: "cred-a.json::proj-a1", Attributes: map[string]string{"gemini_virtual_parent": "cred-a.json"}},
|
||||||
|
{ID: "cred-regular.json"}, // no gemini_virtual_parent
|
||||||
|
}
|
||||||
|
|
||||||
|
// groupByVirtualParent returns nil when any auth lacks the attribute,
|
||||||
|
// so flat round-robin is used. Sorted by ID: cred-a.json::proj-a1, cred-regular.json
|
||||||
|
want := []string{
|
||||||
|
"cred-a.json::proj-a1",
|
||||||
|
"cred-regular.json",
|
||||||
|
"cred-a.json::proj-a1",
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, expectedID := range want {
|
||||||
|
got, err := selector.Pick(context.Background(), "gemini-cli", "", cliproxyexecutor.Options{}, auths)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Pick() #%d error = %v", i, err)
|
||||||
|
}
|
||||||
|
if got == nil {
|
||||||
|
t.Fatalf("Pick() #%d auth = nil", i)
|
||||||
|
}
|
||||||
|
if got.ID != expectedID {
|
||||||
|
t.Fatalf("Pick() #%d auth.ID = %q, want %q", i, got.ID, expectedID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,9 +1,12 @@
|
|||||||
package auth
|
package auth
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"crypto/sha256"
|
"crypto/sha256"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
@@ -12,6 +15,33 @@ import (
|
|||||||
baseauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth"
|
baseauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// PostAuthHook defines a function that is called after an Auth record is created
|
||||||
|
// but before it is persisted to storage. This allows for modification of the
|
||||||
|
// Auth record (e.g., injecting metadata) based on external context.
|
||||||
|
type PostAuthHook func(context.Context, *Auth) error
|
||||||
|
|
||||||
|
// RequestInfo holds information extracted from the HTTP request.
|
||||||
|
// It is injected into the context passed to PostAuthHook.
|
||||||
|
type RequestInfo struct {
|
||||||
|
Query url.Values
|
||||||
|
Headers http.Header
|
||||||
|
}
|
||||||
|
|
||||||
|
type requestInfoKey struct{}
|
||||||
|
|
||||||
|
// WithRequestInfo returns a new context with the given RequestInfo attached.
|
||||||
|
func WithRequestInfo(ctx context.Context, info *RequestInfo) context.Context {
|
||||||
|
return context.WithValue(ctx, requestInfoKey{}, info)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetRequestInfo retrieves the RequestInfo from the context, if present.
|
||||||
|
func GetRequestInfo(ctx context.Context) *RequestInfo {
|
||||||
|
if val, ok := ctx.Value(requestInfoKey{}).(*RequestInfo); ok {
|
||||||
|
return val
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// Auth encapsulates the runtime state and metadata associated with a single credential.
|
// Auth encapsulates the runtime state and metadata associated with a single credential.
|
||||||
type Auth struct {
|
type Auth struct {
|
||||||
// ID uniquely identifies the auth record across restarts.
|
// ID uniquely identifies the auth record across restarts.
|
||||||
|
|||||||
@@ -153,6 +153,16 @@ func (b *Builder) WithLocalManagementPassword(password string) *Builder {
|
|||||||
return b
|
return b
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// WithPostAuthHook registers a hook to be called after an Auth record is created
|
||||||
|
// but before it is persisted to storage.
|
||||||
|
func (b *Builder) WithPostAuthHook(hook coreauth.PostAuthHook) *Builder {
|
||||||
|
if hook == nil {
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
b.serverOptions = append(b.serverOptions, api.WithPostAuthHook(hook))
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
// Build validates inputs, applies defaults, and returns a ready-to-run service.
|
// Build validates inputs, applies defaults, and returns a ready-to-run service.
|
||||||
func (b *Builder) Build() (*Service, error) {
|
func (b *Builder) Build() (*Service, error) {
|
||||||
if b.cfg == nil {
|
if b.cfg == nil {
|
||||||
|
|||||||
@@ -925,6 +925,9 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) {
|
|||||||
key = strings.ToLower(strings.TrimSpace(a.Provider))
|
key = strings.ToLower(strings.TrimSpace(a.Provider))
|
||||||
}
|
}
|
||||||
GlobalModelRegistry().RegisterClient(a.ID, key, applyModelPrefixes(models, a.Prefix, s.cfg != nil && s.cfg.ForceModelPrefix))
|
GlobalModelRegistry().RegisterClient(a.ID, key, applyModelPrefixes(models, a.Prefix, s.cfg != nil && s.cfg.ForceModelPrefix))
|
||||||
|
if provider == "antigravity" {
|
||||||
|
s.backfillAntigravityModels(a, models)
|
||||||
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1069,6 +1072,56 @@ func (s *Service) oauthExcludedModels(provider, authKind string) []string {
|
|||||||
return cfg.OAuthExcludedModels[providerKey]
|
return cfg.OAuthExcludedModels[providerKey]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Service) backfillAntigravityModels(source *coreauth.Auth, primaryModels []*ModelInfo) {
|
||||||
|
if s == nil || s.coreManager == nil || len(primaryModels) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
sourceID := ""
|
||||||
|
if source != nil {
|
||||||
|
sourceID = strings.TrimSpace(source.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
reg := registry.GetGlobalRegistry()
|
||||||
|
for _, candidate := range s.coreManager.List() {
|
||||||
|
if candidate == nil || candidate.Disabled {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
candidateID := strings.TrimSpace(candidate.ID)
|
||||||
|
if candidateID == "" || candidateID == sourceID {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if !strings.EqualFold(strings.TrimSpace(candidate.Provider), "antigravity") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if len(reg.GetModelsForClient(candidateID)) > 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
authKind := strings.ToLower(strings.TrimSpace(candidate.Attributes["auth_kind"]))
|
||||||
|
if authKind == "" {
|
||||||
|
if kind, _ := candidate.AccountInfo(); strings.EqualFold(kind, "api_key") {
|
||||||
|
authKind = "apikey"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
excluded := s.oauthExcludedModels("antigravity", authKind)
|
||||||
|
if candidate.Attributes != nil {
|
||||||
|
if val, ok := candidate.Attributes["excluded_models"]; ok && strings.TrimSpace(val) != "" {
|
||||||
|
excluded = strings.Split(val, ",")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
models := applyExcludedModels(primaryModels, excluded)
|
||||||
|
models = applyOAuthModelAlias(s.cfg, "antigravity", authKind, models)
|
||||||
|
if len(models) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
reg.RegisterClient(candidateID, "antigravity", applyModelPrefixes(models, candidate.Prefix, s.cfg != nil && s.cfg.ForceModelPrefix))
|
||||||
|
log.Debugf("antigravity models backfilled for auth %s using primary model list", candidateID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func applyExcludedModels(models []*ModelInfo, excluded []string) []*ModelInfo {
|
func applyExcludedModels(models []*ModelInfo, excluded []string) []*ModelInfo {
|
||||||
if len(models) == 0 || len(excluded) == 0 {
|
if len(models) == 0 || len(excluded) == 0 {
|
||||||
return models
|
return models
|
||||||
|
|||||||
@@ -0,0 +1,135 @@
|
|||||||
|
package cliproxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||||
|
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestBackfillAntigravityModels_RegistersMissingAuth(t *testing.T) {
|
||||||
|
source := &coreauth.Auth{
|
||||||
|
ID: "ag-backfill-source",
|
||||||
|
Provider: "antigravity",
|
||||||
|
Status: coreauth.StatusActive,
|
||||||
|
Attributes: map[string]string{
|
||||||
|
"auth_kind": "oauth",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
target := &coreauth.Auth{
|
||||||
|
ID: "ag-backfill-target",
|
||||||
|
Provider: "antigravity",
|
||||||
|
Status: coreauth.StatusActive,
|
||||||
|
Attributes: map[string]string{
|
||||||
|
"auth_kind": "oauth",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
manager := coreauth.NewManager(nil, nil, nil)
|
||||||
|
if _, err := manager.Register(context.Background(), source); err != nil {
|
||||||
|
t.Fatalf("register source auth: %v", err)
|
||||||
|
}
|
||||||
|
if _, err := manager.Register(context.Background(), target); err != nil {
|
||||||
|
t.Fatalf("register target auth: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
service := &Service{
|
||||||
|
cfg: &config.Config{},
|
||||||
|
coreManager: manager,
|
||||||
|
}
|
||||||
|
|
||||||
|
reg := registry.GetGlobalRegistry()
|
||||||
|
reg.UnregisterClient(source.ID)
|
||||||
|
reg.UnregisterClient(target.ID)
|
||||||
|
t.Cleanup(func() {
|
||||||
|
reg.UnregisterClient(source.ID)
|
||||||
|
reg.UnregisterClient(target.ID)
|
||||||
|
})
|
||||||
|
|
||||||
|
primary := []*ModelInfo{
|
||||||
|
{ID: "claude-sonnet-4-5"},
|
||||||
|
{ID: "gemini-2.5-pro"},
|
||||||
|
}
|
||||||
|
reg.RegisterClient(source.ID, "antigravity", primary)
|
||||||
|
|
||||||
|
service.backfillAntigravityModels(source, primary)
|
||||||
|
|
||||||
|
got := reg.GetModelsForClient(target.ID)
|
||||||
|
if len(got) != 2 {
|
||||||
|
t.Fatalf("expected target auth to be backfilled with 2 models, got %d", len(got))
|
||||||
|
}
|
||||||
|
|
||||||
|
ids := make(map[string]struct{}, len(got))
|
||||||
|
for _, model := range got {
|
||||||
|
if model == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
ids[strings.ToLower(strings.TrimSpace(model.ID))] = struct{}{}
|
||||||
|
}
|
||||||
|
if _, ok := ids["claude-sonnet-4-5"]; !ok {
|
||||||
|
t.Fatal("expected backfilled model claude-sonnet-4-5")
|
||||||
|
}
|
||||||
|
if _, ok := ids["gemini-2.5-pro"]; !ok {
|
||||||
|
t.Fatal("expected backfilled model gemini-2.5-pro")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBackfillAntigravityModels_RespectsExcludedModels(t *testing.T) {
|
||||||
|
source := &coreauth.Auth{
|
||||||
|
ID: "ag-backfill-source-excluded",
|
||||||
|
Provider: "antigravity",
|
||||||
|
Status: coreauth.StatusActive,
|
||||||
|
Attributes: map[string]string{
|
||||||
|
"auth_kind": "oauth",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
target := &coreauth.Auth{
|
||||||
|
ID: "ag-backfill-target-excluded",
|
||||||
|
Provider: "antigravity",
|
||||||
|
Status: coreauth.StatusActive,
|
||||||
|
Attributes: map[string]string{
|
||||||
|
"auth_kind": "oauth",
|
||||||
|
"excluded_models": "gemini-2.5-pro",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
manager := coreauth.NewManager(nil, nil, nil)
|
||||||
|
if _, err := manager.Register(context.Background(), source); err != nil {
|
||||||
|
t.Fatalf("register source auth: %v", err)
|
||||||
|
}
|
||||||
|
if _, err := manager.Register(context.Background(), target); err != nil {
|
||||||
|
t.Fatalf("register target auth: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
service := &Service{
|
||||||
|
cfg: &config.Config{},
|
||||||
|
coreManager: manager,
|
||||||
|
}
|
||||||
|
|
||||||
|
reg := registry.GetGlobalRegistry()
|
||||||
|
reg.UnregisterClient(source.ID)
|
||||||
|
reg.UnregisterClient(target.ID)
|
||||||
|
t.Cleanup(func() {
|
||||||
|
reg.UnregisterClient(source.ID)
|
||||||
|
reg.UnregisterClient(target.ID)
|
||||||
|
})
|
||||||
|
|
||||||
|
primary := []*ModelInfo{
|
||||||
|
{ID: "claude-sonnet-4-5"},
|
||||||
|
{ID: "gemini-2.5-pro"},
|
||||||
|
}
|
||||||
|
reg.RegisterClient(source.ID, "antigravity", primary)
|
||||||
|
|
||||||
|
service.backfillAntigravityModels(source, primary)
|
||||||
|
|
||||||
|
got := reg.GetModelsForClient(target.ID)
|
||||||
|
if len(got) != 1 {
|
||||||
|
t.Fatalf("expected 1 model after exclusion, got %d", len(got))
|
||||||
|
}
|
||||||
|
if got[0] == nil || !strings.EqualFold(strings.TrimSpace(got[0].ID), "claude-sonnet-4-5") {
|
||||||
|
t.Fatalf("expected remaining model %q, got %+v", "claude-sonnet-4-5", got[0])
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user