refactor(runtime): move executor utilities to helps package and update references
This commit is contained in:
@@ -0,0 +1,68 @@
|
||||
package helps
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type CodexCache struct {
|
||||
ID string
|
||||
Expire time.Time
|
||||
}
|
||||
|
||||
// codexCacheMap stores prompt cache IDs keyed by model+user_id.
|
||||
// Protected by codexCacheMu. Entries expire after 1 hour.
|
||||
var (
|
||||
codexCacheMap = make(map[string]CodexCache)
|
||||
codexCacheMu sync.RWMutex
|
||||
)
|
||||
|
||||
// codexCacheCleanupInterval controls how often expired entries are purged.
|
||||
const codexCacheCleanupInterval = 15 * time.Minute
|
||||
|
||||
// codexCacheCleanupOnce ensures the background cleanup goroutine starts only once.
|
||||
var codexCacheCleanupOnce sync.Once
|
||||
|
||||
// startCodexCacheCleanup launches a background goroutine that periodically
|
||||
// removes expired entries from codexCacheMap to prevent memory leaks.
|
||||
func startCodexCacheCleanup() {
|
||||
go func() {
|
||||
ticker := time.NewTicker(codexCacheCleanupInterval)
|
||||
defer ticker.Stop()
|
||||
for range ticker.C {
|
||||
purgeExpiredCodexCache()
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// purgeExpiredCodexCache removes entries that have expired.
|
||||
func purgeExpiredCodexCache() {
|
||||
now := time.Now()
|
||||
codexCacheMu.Lock()
|
||||
defer codexCacheMu.Unlock()
|
||||
for key, cache := range codexCacheMap {
|
||||
if cache.Expire.Before(now) {
|
||||
delete(codexCacheMap, key)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GetCodexCache retrieves a cached entry, returning ok=false if not found or expired.
|
||||
func GetCodexCache(key string) (CodexCache, bool) {
|
||||
codexCacheCleanupOnce.Do(startCodexCacheCleanup)
|
||||
codexCacheMu.RLock()
|
||||
cache, ok := codexCacheMap[key]
|
||||
codexCacheMu.RUnlock()
|
||||
if !ok || cache.Expire.Before(time.Now()) {
|
||||
return CodexCache{}, false
|
||||
}
|
||||
return cache, true
|
||||
}
|
||||
|
||||
// SetCodexCache stores a cache entry.
|
||||
func SetCodexCache(key string, cache CodexCache) {
|
||||
codexCacheCleanupOnce.Do(startCodexCacheCleanup)
|
||||
codexCacheMu.Lock()
|
||||
codexCacheMap[key] = cache
|
||||
codexCacheMu.Unlock()
|
||||
}
|
||||
@@ -0,0 +1,389 @@
|
||||
package helps
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultClaudeFingerprintUserAgent = "claude-cli/2.1.63 (external, cli)"
|
||||
defaultClaudeFingerprintPackageVersion = "0.74.0"
|
||||
defaultClaudeFingerprintRuntimeVersion = "v24.3.0"
|
||||
defaultClaudeFingerprintOS = "MacOS"
|
||||
defaultClaudeFingerprintArch = "arm64"
|
||||
claudeDeviceProfileTTL = 7 * 24 * time.Hour
|
||||
claudeDeviceProfileCleanupPeriod = time.Hour
|
||||
)
|
||||
|
||||
var (
|
||||
claudeCLIVersionPattern = regexp.MustCompile(`^claude-cli/(\d+)\.(\d+)\.(\d+)`)
|
||||
|
||||
claudeDeviceProfileCache = make(map[string]claudeDeviceProfileCacheEntry)
|
||||
claudeDeviceProfileCacheMu sync.RWMutex
|
||||
claudeDeviceProfileCacheCleanupOnce sync.Once
|
||||
|
||||
ClaudeDeviceProfileBeforeCandidateStore func(ClaudeDeviceProfile)
|
||||
)
|
||||
|
||||
type claudeCLIVersion struct {
|
||||
major int
|
||||
minor int
|
||||
patch int
|
||||
}
|
||||
|
||||
func (v claudeCLIVersion) Compare(other claudeCLIVersion) int {
|
||||
switch {
|
||||
case v.major != other.major:
|
||||
if v.major > other.major {
|
||||
return 1
|
||||
}
|
||||
return -1
|
||||
case v.minor != other.minor:
|
||||
if v.minor > other.minor {
|
||||
return 1
|
||||
}
|
||||
return -1
|
||||
case v.patch != other.patch:
|
||||
if v.patch > other.patch {
|
||||
return 1
|
||||
}
|
||||
return -1
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
||||
type ClaudeDeviceProfile struct {
|
||||
UserAgent string
|
||||
PackageVersion string
|
||||
RuntimeVersion string
|
||||
OS string
|
||||
Arch string
|
||||
version claudeCLIVersion
|
||||
hasVersion bool
|
||||
}
|
||||
|
||||
type claudeDeviceProfileCacheEntry struct {
|
||||
profile ClaudeDeviceProfile
|
||||
expire time.Time
|
||||
}
|
||||
|
||||
func ClaudeDeviceProfileStabilizationEnabled(cfg *config.Config) bool {
|
||||
if cfg == nil || cfg.ClaudeHeaderDefaults.StabilizeDeviceProfile == nil {
|
||||
return false
|
||||
}
|
||||
return *cfg.ClaudeHeaderDefaults.StabilizeDeviceProfile
|
||||
}
|
||||
|
||||
func ResetClaudeDeviceProfileCache() {
|
||||
claudeDeviceProfileCacheMu.Lock()
|
||||
claudeDeviceProfileCache = make(map[string]claudeDeviceProfileCacheEntry)
|
||||
claudeDeviceProfileCacheMu.Unlock()
|
||||
}
|
||||
|
||||
func defaultClaudeDeviceProfile(cfg *config.Config) ClaudeDeviceProfile {
|
||||
hdrDefault := func(cfgVal, fallback string) string {
|
||||
if strings.TrimSpace(cfgVal) != "" {
|
||||
return strings.TrimSpace(cfgVal)
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
|
||||
var hd config.ClaudeHeaderDefaults
|
||||
if cfg != nil {
|
||||
hd = cfg.ClaudeHeaderDefaults
|
||||
}
|
||||
|
||||
profile := ClaudeDeviceProfile{
|
||||
UserAgent: hdrDefault(hd.UserAgent, defaultClaudeFingerprintUserAgent),
|
||||
PackageVersion: hdrDefault(hd.PackageVersion, defaultClaudeFingerprintPackageVersion),
|
||||
RuntimeVersion: hdrDefault(hd.RuntimeVersion, defaultClaudeFingerprintRuntimeVersion),
|
||||
OS: hdrDefault(hd.OS, defaultClaudeFingerprintOS),
|
||||
Arch: hdrDefault(hd.Arch, defaultClaudeFingerprintArch),
|
||||
}
|
||||
if version, ok := parseClaudeCLIVersion(profile.UserAgent); ok {
|
||||
profile.version = version
|
||||
profile.hasVersion = true
|
||||
}
|
||||
return profile
|
||||
}
|
||||
|
||||
// mapStainlessOS maps runtime.GOOS to Stainless SDK OS names.
|
||||
func mapStainlessOS() string {
|
||||
switch runtime.GOOS {
|
||||
case "darwin":
|
||||
return "MacOS"
|
||||
case "windows":
|
||||
return "Windows"
|
||||
case "linux":
|
||||
return "Linux"
|
||||
case "freebsd":
|
||||
return "FreeBSD"
|
||||
default:
|
||||
return "Other::" + runtime.GOOS
|
||||
}
|
||||
}
|
||||
|
||||
// mapStainlessArch maps runtime.GOARCH to Stainless SDK architecture names.
|
||||
func mapStainlessArch() string {
|
||||
switch runtime.GOARCH {
|
||||
case "amd64":
|
||||
return "x64"
|
||||
case "arm64":
|
||||
return "arm64"
|
||||
case "386":
|
||||
return "x86"
|
||||
default:
|
||||
return "other::" + runtime.GOARCH
|
||||
}
|
||||
}
|
||||
|
||||
func parseClaudeCLIVersion(userAgent string) (claudeCLIVersion, bool) {
|
||||
matches := claudeCLIVersionPattern.FindStringSubmatch(strings.TrimSpace(userAgent))
|
||||
if len(matches) != 4 {
|
||||
return claudeCLIVersion{}, false
|
||||
}
|
||||
major, err := strconv.Atoi(matches[1])
|
||||
if err != nil {
|
||||
return claudeCLIVersion{}, false
|
||||
}
|
||||
minor, err := strconv.Atoi(matches[2])
|
||||
if err != nil {
|
||||
return claudeCLIVersion{}, false
|
||||
}
|
||||
patch, err := strconv.Atoi(matches[3])
|
||||
if err != nil {
|
||||
return claudeCLIVersion{}, false
|
||||
}
|
||||
return claudeCLIVersion{major: major, minor: minor, patch: patch}, true
|
||||
}
|
||||
|
||||
func shouldUpgradeClaudeDeviceProfile(candidate, current ClaudeDeviceProfile) bool {
|
||||
if candidate.UserAgent == "" || !candidate.hasVersion {
|
||||
return false
|
||||
}
|
||||
if current.UserAgent == "" || !current.hasVersion {
|
||||
return true
|
||||
}
|
||||
return candidate.version.Compare(current.version) > 0
|
||||
}
|
||||
|
||||
func pinClaudeDeviceProfilePlatform(profile, baseline ClaudeDeviceProfile) ClaudeDeviceProfile {
|
||||
profile.OS = baseline.OS
|
||||
profile.Arch = baseline.Arch
|
||||
return profile
|
||||
}
|
||||
|
||||
// normalizeClaudeDeviceProfile keeps stabilized profiles pinned to the current
|
||||
// baseline platform and enforces the baseline software fingerprint as a floor.
|
||||
func normalizeClaudeDeviceProfile(profile, baseline ClaudeDeviceProfile) ClaudeDeviceProfile {
|
||||
profile = pinClaudeDeviceProfilePlatform(profile, baseline)
|
||||
if profile.UserAgent == "" || !profile.hasVersion || shouldUpgradeClaudeDeviceProfile(baseline, profile) {
|
||||
profile.UserAgent = baseline.UserAgent
|
||||
profile.PackageVersion = baseline.PackageVersion
|
||||
profile.RuntimeVersion = baseline.RuntimeVersion
|
||||
profile.version = baseline.version
|
||||
profile.hasVersion = baseline.hasVersion
|
||||
}
|
||||
return profile
|
||||
}
|
||||
|
||||
func extractClaudeDeviceProfile(headers http.Header, cfg *config.Config) (ClaudeDeviceProfile, bool) {
|
||||
if headers == nil {
|
||||
return ClaudeDeviceProfile{}, false
|
||||
}
|
||||
|
||||
userAgent := strings.TrimSpace(headers.Get("User-Agent"))
|
||||
version, ok := parseClaudeCLIVersion(userAgent)
|
||||
if !ok {
|
||||
return ClaudeDeviceProfile{}, false
|
||||
}
|
||||
|
||||
baseline := defaultClaudeDeviceProfile(cfg)
|
||||
profile := ClaudeDeviceProfile{
|
||||
UserAgent: userAgent,
|
||||
PackageVersion: firstNonEmptyHeader(headers, "X-Stainless-Package-Version", baseline.PackageVersion),
|
||||
RuntimeVersion: firstNonEmptyHeader(headers, "X-Stainless-Runtime-Version", baseline.RuntimeVersion),
|
||||
OS: firstNonEmptyHeader(headers, "X-Stainless-Os", baseline.OS),
|
||||
Arch: firstNonEmptyHeader(headers, "X-Stainless-Arch", baseline.Arch),
|
||||
version: version,
|
||||
hasVersion: true,
|
||||
}
|
||||
return profile, true
|
||||
}
|
||||
|
||||
func firstNonEmptyHeader(headers http.Header, name, fallback string) string {
|
||||
if headers == nil {
|
||||
return fallback
|
||||
}
|
||||
if value := strings.TrimSpace(headers.Get(name)); value != "" {
|
||||
return value
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
|
||||
func claudeDeviceProfileScopeKey(auth *cliproxyauth.Auth, apiKey string) string {
|
||||
switch {
|
||||
case auth != nil && strings.TrimSpace(auth.ID) != "":
|
||||
return "auth:" + strings.TrimSpace(auth.ID)
|
||||
case strings.TrimSpace(apiKey) != "":
|
||||
return "api_key:" + strings.TrimSpace(apiKey)
|
||||
default:
|
||||
return "global"
|
||||
}
|
||||
}
|
||||
|
||||
func claudeDeviceProfileCacheKey(auth *cliproxyauth.Auth, apiKey string) string {
|
||||
sum := sha256.Sum256([]byte(claudeDeviceProfileScopeKey(auth, apiKey)))
|
||||
return hex.EncodeToString(sum[:])
|
||||
}
|
||||
|
||||
func startClaudeDeviceProfileCacheCleanup() {
|
||||
go func() {
|
||||
ticker := time.NewTicker(claudeDeviceProfileCleanupPeriod)
|
||||
defer ticker.Stop()
|
||||
for range ticker.C {
|
||||
purgeExpiredClaudeDeviceProfiles()
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func purgeExpiredClaudeDeviceProfiles() {
|
||||
now := time.Now()
|
||||
claudeDeviceProfileCacheMu.Lock()
|
||||
for key, entry := range claudeDeviceProfileCache {
|
||||
if !entry.expire.After(now) {
|
||||
delete(claudeDeviceProfileCache, key)
|
||||
}
|
||||
}
|
||||
claudeDeviceProfileCacheMu.Unlock()
|
||||
}
|
||||
|
||||
func ResolveClaudeDeviceProfile(auth *cliproxyauth.Auth, apiKey string, headers http.Header, cfg *config.Config) ClaudeDeviceProfile {
|
||||
claudeDeviceProfileCacheCleanupOnce.Do(startClaudeDeviceProfileCacheCleanup)
|
||||
|
||||
cacheKey := claudeDeviceProfileCacheKey(auth, apiKey)
|
||||
now := time.Now()
|
||||
baseline := defaultClaudeDeviceProfile(cfg)
|
||||
candidate, hasCandidate := extractClaudeDeviceProfile(headers, cfg)
|
||||
if hasCandidate {
|
||||
candidate = pinClaudeDeviceProfilePlatform(candidate, baseline)
|
||||
}
|
||||
if hasCandidate && !shouldUpgradeClaudeDeviceProfile(candidate, baseline) {
|
||||
hasCandidate = false
|
||||
}
|
||||
|
||||
claudeDeviceProfileCacheMu.RLock()
|
||||
entry, hasCached := claudeDeviceProfileCache[cacheKey]
|
||||
cachedValid := hasCached && entry.expire.After(now) && entry.profile.UserAgent != ""
|
||||
claudeDeviceProfileCacheMu.RUnlock()
|
||||
|
||||
if hasCandidate {
|
||||
if ClaudeDeviceProfileBeforeCandidateStore != nil {
|
||||
ClaudeDeviceProfileBeforeCandidateStore(candidate)
|
||||
}
|
||||
|
||||
claudeDeviceProfileCacheMu.Lock()
|
||||
entry, hasCached = claudeDeviceProfileCache[cacheKey]
|
||||
cachedValid = hasCached && entry.expire.After(now) && entry.profile.UserAgent != ""
|
||||
if cachedValid {
|
||||
entry.profile = normalizeClaudeDeviceProfile(entry.profile, baseline)
|
||||
}
|
||||
if cachedValid && !shouldUpgradeClaudeDeviceProfile(candidate, entry.profile) {
|
||||
entry.expire = now.Add(claudeDeviceProfileTTL)
|
||||
claudeDeviceProfileCache[cacheKey] = entry
|
||||
claudeDeviceProfileCacheMu.Unlock()
|
||||
return entry.profile
|
||||
}
|
||||
|
||||
claudeDeviceProfileCache[cacheKey] = claudeDeviceProfileCacheEntry{
|
||||
profile: candidate,
|
||||
expire: now.Add(claudeDeviceProfileTTL),
|
||||
}
|
||||
claudeDeviceProfileCacheMu.Unlock()
|
||||
return candidate
|
||||
}
|
||||
|
||||
if cachedValid {
|
||||
claudeDeviceProfileCacheMu.Lock()
|
||||
entry = claudeDeviceProfileCache[cacheKey]
|
||||
if entry.expire.After(now) && entry.profile.UserAgent != "" {
|
||||
entry.profile = normalizeClaudeDeviceProfile(entry.profile, baseline)
|
||||
entry.expire = now.Add(claudeDeviceProfileTTL)
|
||||
claudeDeviceProfileCache[cacheKey] = entry
|
||||
claudeDeviceProfileCacheMu.Unlock()
|
||||
return entry.profile
|
||||
}
|
||||
claudeDeviceProfileCacheMu.Unlock()
|
||||
}
|
||||
|
||||
return baseline
|
||||
}
|
||||
|
||||
func ApplyClaudeDeviceProfileHeaders(r *http.Request, profile ClaudeDeviceProfile) {
|
||||
if r == nil {
|
||||
return
|
||||
}
|
||||
for _, headerName := range []string{
|
||||
"User-Agent",
|
||||
"X-Stainless-Package-Version",
|
||||
"X-Stainless-Runtime-Version",
|
||||
"X-Stainless-Os",
|
||||
"X-Stainless-Arch",
|
||||
} {
|
||||
r.Header.Del(headerName)
|
||||
}
|
||||
r.Header.Set("User-Agent", profile.UserAgent)
|
||||
r.Header.Set("X-Stainless-Package-Version", profile.PackageVersion)
|
||||
r.Header.Set("X-Stainless-Runtime-Version", profile.RuntimeVersion)
|
||||
r.Header.Set("X-Stainless-Os", profile.OS)
|
||||
r.Header.Set("X-Stainless-Arch", profile.Arch)
|
||||
}
|
||||
|
||||
func ApplyClaudeLegacyDeviceHeaders(r *http.Request, ginHeaders http.Header, cfg *config.Config) {
|
||||
if r == nil {
|
||||
return
|
||||
}
|
||||
profile := defaultClaudeDeviceProfile(cfg)
|
||||
miscEnsure := func(name, fallback string) {
|
||||
if strings.TrimSpace(r.Header.Get(name)) != "" {
|
||||
return
|
||||
}
|
||||
if strings.TrimSpace(ginHeaders.Get(name)) != "" {
|
||||
r.Header.Set(name, strings.TrimSpace(ginHeaders.Get(name)))
|
||||
return
|
||||
}
|
||||
r.Header.Set(name, fallback)
|
||||
}
|
||||
|
||||
miscEnsure("X-Stainless-Runtime-Version", profile.RuntimeVersion)
|
||||
miscEnsure("X-Stainless-Package-Version", profile.PackageVersion)
|
||||
miscEnsure("X-Stainless-Os", mapStainlessOS())
|
||||
miscEnsure("X-Stainless-Arch", mapStainlessArch())
|
||||
|
||||
// Legacy mode preserves per-auth custom header overrides. By the time we get
|
||||
// here, ApplyCustomHeadersFromAttrs has already populated r.Header.
|
||||
if strings.TrimSpace(r.Header.Get("User-Agent")) != "" {
|
||||
return
|
||||
}
|
||||
|
||||
clientUA := ""
|
||||
if ginHeaders != nil {
|
||||
clientUA = strings.TrimSpace(ginHeaders.Get("User-Agent"))
|
||||
}
|
||||
if isClaudeCodeClient(clientUA) {
|
||||
r.Header.Set("User-Agent", clientUA)
|
||||
return
|
||||
}
|
||||
r.Header.Set("User-Agent", profile.UserAgent)
|
||||
}
|
||||
@@ -0,0 +1,176 @@
|
||||
package helps
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
"sort"
|
||||
"strings"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
// zeroWidthSpace is the Unicode zero-width space character used for obfuscation.
|
||||
const zeroWidthSpace = "\u200B"
|
||||
|
||||
// SensitiveWordMatcher holds the compiled regex for matching sensitive words.
|
||||
type SensitiveWordMatcher struct {
|
||||
regex *regexp.Regexp
|
||||
}
|
||||
|
||||
// BuildSensitiveWordMatcher compiles a regex from the word list.
|
||||
// Words are sorted by length (longest first) for proper matching.
|
||||
func BuildSensitiveWordMatcher(words []string) *SensitiveWordMatcher {
|
||||
if len(words) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Filter and normalize words
|
||||
var validWords []string
|
||||
for _, w := range words {
|
||||
w = strings.TrimSpace(w)
|
||||
if utf8.RuneCountInString(w) >= 2 && !strings.Contains(w, zeroWidthSpace) {
|
||||
validWords = append(validWords, w)
|
||||
}
|
||||
}
|
||||
|
||||
if len(validWords) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Sort by length (longest first) for proper matching
|
||||
sort.Slice(validWords, func(i, j int) bool {
|
||||
return len(validWords[i]) > len(validWords[j])
|
||||
})
|
||||
|
||||
// Escape and join
|
||||
escaped := make([]string, len(validWords))
|
||||
for i, w := range validWords {
|
||||
escaped[i] = regexp.QuoteMeta(w)
|
||||
}
|
||||
|
||||
pattern := "(?i)" + strings.Join(escaped, "|")
|
||||
re, err := regexp.Compile(pattern)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &SensitiveWordMatcher{regex: re}
|
||||
}
|
||||
|
||||
// obfuscateWord inserts a zero-width space after the first grapheme.
|
||||
func obfuscateWord(word string) string {
|
||||
if strings.Contains(word, zeroWidthSpace) {
|
||||
return word
|
||||
}
|
||||
|
||||
// Get first rune
|
||||
r, size := utf8.DecodeRuneInString(word)
|
||||
if r == utf8.RuneError || size >= len(word) {
|
||||
return word
|
||||
}
|
||||
|
||||
return string(r) + zeroWidthSpace + word[size:]
|
||||
}
|
||||
|
||||
// obfuscateText replaces all sensitive words in the text.
|
||||
func (m *SensitiveWordMatcher) obfuscateText(text string) string {
|
||||
if m == nil || m.regex == nil {
|
||||
return text
|
||||
}
|
||||
return m.regex.ReplaceAllStringFunc(text, obfuscateWord)
|
||||
}
|
||||
|
||||
// ObfuscateSensitiveWords processes the payload and obfuscates sensitive words
|
||||
// in system blocks and message content.
|
||||
func ObfuscateSensitiveWords(payload []byte, matcher *SensitiveWordMatcher) []byte {
|
||||
if matcher == nil || matcher.regex == nil {
|
||||
return payload
|
||||
}
|
||||
|
||||
// Obfuscate in system blocks
|
||||
payload = obfuscateSystemBlocks(payload, matcher)
|
||||
|
||||
// Obfuscate in messages
|
||||
payload = obfuscateMessages(payload, matcher)
|
||||
|
||||
return payload
|
||||
}
|
||||
|
||||
// obfuscateSystemBlocks obfuscates sensitive words in system blocks.
|
||||
func obfuscateSystemBlocks(payload []byte, matcher *SensitiveWordMatcher) []byte {
|
||||
system := gjson.GetBytes(payload, "system")
|
||||
if !system.Exists() {
|
||||
return payload
|
||||
}
|
||||
|
||||
if system.IsArray() {
|
||||
modified := false
|
||||
system.ForEach(func(key, value gjson.Result) bool {
|
||||
if value.Get("type").String() == "text" {
|
||||
text := value.Get("text").String()
|
||||
obfuscated := matcher.obfuscateText(text)
|
||||
if obfuscated != text {
|
||||
path := "system." + key.String() + ".text"
|
||||
payload, _ = sjson.SetBytes(payload, path, obfuscated)
|
||||
modified = true
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
if modified {
|
||||
return payload
|
||||
}
|
||||
} else if system.Type == gjson.String {
|
||||
text := system.String()
|
||||
obfuscated := matcher.obfuscateText(text)
|
||||
if obfuscated != text {
|
||||
payload, _ = sjson.SetBytes(payload, "system", obfuscated)
|
||||
}
|
||||
}
|
||||
|
||||
return payload
|
||||
}
|
||||
|
||||
// obfuscateMessages obfuscates sensitive words in message content.
|
||||
func obfuscateMessages(payload []byte, matcher *SensitiveWordMatcher) []byte {
|
||||
messages := gjson.GetBytes(payload, "messages")
|
||||
if !messages.Exists() || !messages.IsArray() {
|
||||
return payload
|
||||
}
|
||||
|
||||
messages.ForEach(func(msgKey, msg gjson.Result) bool {
|
||||
content := msg.Get("content")
|
||||
if !content.Exists() {
|
||||
return true
|
||||
}
|
||||
|
||||
msgPath := "messages." + msgKey.String()
|
||||
|
||||
if content.Type == gjson.String {
|
||||
// Simple string content
|
||||
text := content.String()
|
||||
obfuscated := matcher.obfuscateText(text)
|
||||
if obfuscated != text {
|
||||
payload, _ = sjson.SetBytes(payload, msgPath+".content", obfuscated)
|
||||
}
|
||||
} else if content.IsArray() {
|
||||
// Array of content blocks
|
||||
content.ForEach(func(blockKey, block gjson.Result) bool {
|
||||
if block.Get("type").String() == "text" {
|
||||
text := block.Get("text").String()
|
||||
obfuscated := matcher.obfuscateText(text)
|
||||
if obfuscated != text {
|
||||
path := msgPath + ".content." + blockKey.String() + ".text"
|
||||
payload, _ = sjson.SetBytes(payload, path, obfuscated)
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
return true
|
||||
})
|
||||
|
||||
return payload
|
||||
}
|
||||
@@ -0,0 +1,56 @@
|
||||
package helps
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// userIDPattern matches Claude Code format: user_[64-hex]_account_[uuid]_session_[uuid]
|
||||
var userIDPattern = regexp.MustCompile(`^user_[a-fA-F0-9]{64}_account_[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}_session_[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$`)
|
||||
|
||||
// generateFakeUserID generates a fake user ID in Claude Code format.
|
||||
// Format: user_[64-hex-chars]_account_[UUID-v4]_session_[UUID-v4]
|
||||
func generateFakeUserID() string {
|
||||
hexBytes := make([]byte, 32)
|
||||
_, _ = rand.Read(hexBytes)
|
||||
hexPart := hex.EncodeToString(hexBytes)
|
||||
accountUUID := uuid.New().String()
|
||||
sessionUUID := uuid.New().String()
|
||||
return "user_" + hexPart + "_account_" + accountUUID + "_session_" + sessionUUID
|
||||
}
|
||||
|
||||
// isValidUserID checks if a user ID matches Claude Code format.
|
||||
func isValidUserID(userID string) bool {
|
||||
return userIDPattern.MatchString(userID)
|
||||
}
|
||||
|
||||
func GenerateFakeUserID() string {
|
||||
return generateFakeUserID()
|
||||
}
|
||||
|
||||
func IsValidUserID(userID string) bool {
|
||||
return isValidUserID(userID)
|
||||
}
|
||||
|
||||
// ShouldCloak determines if request should be cloaked based on config and client User-Agent.
|
||||
// Returns true if cloaking should be applied.
|
||||
func ShouldCloak(cloakMode string, userAgent string) bool {
|
||||
switch strings.ToLower(cloakMode) {
|
||||
case "always":
|
||||
return true
|
||||
case "never":
|
||||
return false
|
||||
default: // "auto" or empty
|
||||
// If client is Claude Code, don't cloak
|
||||
return !strings.HasPrefix(userAgent, "claude-cli")
|
||||
}
|
||||
}
|
||||
|
||||
// isClaudeCodeClient checks if the User-Agent indicates a Claude Code client.
|
||||
func isClaudeCodeClient(userAgent string) bool {
|
||||
return strings.HasPrefix(userAgent, "claude-cli")
|
||||
}
|
||||
@@ -0,0 +1,391 @@
|
||||
package helps
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"html"
|
||||
"net/http"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
const (
|
||||
apiAttemptsKey = "API_UPSTREAM_ATTEMPTS"
|
||||
apiRequestKey = "API_REQUEST"
|
||||
apiResponseKey = "API_RESPONSE"
|
||||
)
|
||||
|
||||
// UpstreamRequestLog captures the outbound upstream request details for logging.
|
||||
type UpstreamRequestLog struct {
|
||||
URL string
|
||||
Method string
|
||||
Headers http.Header
|
||||
Body []byte
|
||||
Provider string
|
||||
AuthID string
|
||||
AuthLabel string
|
||||
AuthType string
|
||||
AuthValue string
|
||||
}
|
||||
|
||||
type upstreamAttempt struct {
|
||||
index int
|
||||
request string
|
||||
response *strings.Builder
|
||||
responseIntroWritten bool
|
||||
statusWritten bool
|
||||
headersWritten bool
|
||||
bodyStarted bool
|
||||
bodyHasContent bool
|
||||
errorWritten bool
|
||||
}
|
||||
|
||||
// RecordAPIRequest stores the upstream request metadata in Gin context for request logging.
|
||||
func RecordAPIRequest(ctx context.Context, cfg *config.Config, info UpstreamRequestLog) {
|
||||
if cfg == nil || !cfg.RequestLog {
|
||||
return
|
||||
}
|
||||
ginCtx := ginContextFrom(ctx)
|
||||
if ginCtx == nil {
|
||||
return
|
||||
}
|
||||
|
||||
attempts := getAttempts(ginCtx)
|
||||
index := len(attempts) + 1
|
||||
|
||||
builder := &strings.Builder{}
|
||||
builder.WriteString(fmt.Sprintf("=== API REQUEST %d ===\n", index))
|
||||
builder.WriteString(fmt.Sprintf("Timestamp: %s\n", time.Now().Format(time.RFC3339Nano)))
|
||||
if info.URL != "" {
|
||||
builder.WriteString(fmt.Sprintf("Upstream URL: %s\n", info.URL))
|
||||
} else {
|
||||
builder.WriteString("Upstream URL: <unknown>\n")
|
||||
}
|
||||
if info.Method != "" {
|
||||
builder.WriteString(fmt.Sprintf("HTTP Method: %s\n", info.Method))
|
||||
}
|
||||
if auth := formatAuthInfo(info); auth != "" {
|
||||
builder.WriteString(fmt.Sprintf("Auth: %s\n", auth))
|
||||
}
|
||||
builder.WriteString("\nHeaders:\n")
|
||||
writeHeaders(builder, info.Headers)
|
||||
builder.WriteString("\nBody:\n")
|
||||
if len(info.Body) > 0 {
|
||||
builder.WriteString(string(info.Body))
|
||||
} else {
|
||||
builder.WriteString("<empty>")
|
||||
}
|
||||
builder.WriteString("\n\n")
|
||||
|
||||
attempt := &upstreamAttempt{
|
||||
index: index,
|
||||
request: builder.String(),
|
||||
response: &strings.Builder{},
|
||||
}
|
||||
attempts = append(attempts, attempt)
|
||||
ginCtx.Set(apiAttemptsKey, attempts)
|
||||
updateAggregatedRequest(ginCtx, attempts)
|
||||
}
|
||||
|
||||
// RecordAPIResponseMetadata captures upstream response status/header information for the latest attempt.
|
||||
func RecordAPIResponseMetadata(ctx context.Context, cfg *config.Config, status int, headers http.Header) {
|
||||
if cfg == nil || !cfg.RequestLog {
|
||||
return
|
||||
}
|
||||
ginCtx := ginContextFrom(ctx)
|
||||
if ginCtx == nil {
|
||||
return
|
||||
}
|
||||
attempts, attempt := ensureAttempt(ginCtx)
|
||||
ensureResponseIntro(attempt)
|
||||
|
||||
if status > 0 && !attempt.statusWritten {
|
||||
attempt.response.WriteString(fmt.Sprintf("Status: %d\n", status))
|
||||
attempt.statusWritten = true
|
||||
}
|
||||
if !attempt.headersWritten {
|
||||
attempt.response.WriteString("Headers:\n")
|
||||
writeHeaders(attempt.response, headers)
|
||||
attempt.headersWritten = true
|
||||
attempt.response.WriteString("\n")
|
||||
}
|
||||
|
||||
updateAggregatedResponse(ginCtx, attempts)
|
||||
}
|
||||
|
||||
// RecordAPIResponseError adds an error entry for the latest attempt when no HTTP response is available.
|
||||
func RecordAPIResponseError(ctx context.Context, cfg *config.Config, err error) {
|
||||
if cfg == nil || !cfg.RequestLog || err == nil {
|
||||
return
|
||||
}
|
||||
ginCtx := ginContextFrom(ctx)
|
||||
if ginCtx == nil {
|
||||
return
|
||||
}
|
||||
attempts, attempt := ensureAttempt(ginCtx)
|
||||
ensureResponseIntro(attempt)
|
||||
|
||||
if attempt.bodyStarted && !attempt.bodyHasContent {
|
||||
// Ensure body does not stay empty marker if error arrives first.
|
||||
attempt.bodyStarted = false
|
||||
}
|
||||
if attempt.errorWritten {
|
||||
attempt.response.WriteString("\n")
|
||||
}
|
||||
attempt.response.WriteString(fmt.Sprintf("Error: %s\n", err.Error()))
|
||||
attempt.errorWritten = true
|
||||
|
||||
updateAggregatedResponse(ginCtx, attempts)
|
||||
}
|
||||
|
||||
// AppendAPIResponseChunk appends an upstream response chunk to Gin context for request logging.
|
||||
func AppendAPIResponseChunk(ctx context.Context, cfg *config.Config, chunk []byte) {
|
||||
if cfg == nil || !cfg.RequestLog {
|
||||
return
|
||||
}
|
||||
data := bytes.TrimSpace(chunk)
|
||||
if len(data) == 0 {
|
||||
return
|
||||
}
|
||||
ginCtx := ginContextFrom(ctx)
|
||||
if ginCtx == nil {
|
||||
return
|
||||
}
|
||||
attempts, attempt := ensureAttempt(ginCtx)
|
||||
ensureResponseIntro(attempt)
|
||||
|
||||
if !attempt.headersWritten {
|
||||
attempt.response.WriteString("Headers:\n")
|
||||
writeHeaders(attempt.response, nil)
|
||||
attempt.headersWritten = true
|
||||
attempt.response.WriteString("\n")
|
||||
}
|
||||
if !attempt.bodyStarted {
|
||||
attempt.response.WriteString("Body:\n")
|
||||
attempt.bodyStarted = true
|
||||
}
|
||||
if attempt.bodyHasContent {
|
||||
attempt.response.WriteString("\n\n")
|
||||
}
|
||||
attempt.response.WriteString(string(data))
|
||||
attempt.bodyHasContent = true
|
||||
|
||||
updateAggregatedResponse(ginCtx, attempts)
|
||||
}
|
||||
|
||||
func ginContextFrom(ctx context.Context) *gin.Context {
|
||||
ginCtx, _ := ctx.Value("gin").(*gin.Context)
|
||||
return ginCtx
|
||||
}
|
||||
|
||||
func getAttempts(ginCtx *gin.Context) []*upstreamAttempt {
|
||||
if ginCtx == nil {
|
||||
return nil
|
||||
}
|
||||
if value, exists := ginCtx.Get(apiAttemptsKey); exists {
|
||||
if attempts, ok := value.([]*upstreamAttempt); ok {
|
||||
return attempts
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func ensureAttempt(ginCtx *gin.Context) ([]*upstreamAttempt, *upstreamAttempt) {
|
||||
attempts := getAttempts(ginCtx)
|
||||
if len(attempts) == 0 {
|
||||
attempt := &upstreamAttempt{
|
||||
index: 1,
|
||||
request: "=== API REQUEST 1 ===\n<missing>\n\n",
|
||||
response: &strings.Builder{},
|
||||
}
|
||||
attempts = []*upstreamAttempt{attempt}
|
||||
ginCtx.Set(apiAttemptsKey, attempts)
|
||||
updateAggregatedRequest(ginCtx, attempts)
|
||||
}
|
||||
return attempts, attempts[len(attempts)-1]
|
||||
}
|
||||
|
||||
func ensureResponseIntro(attempt *upstreamAttempt) {
|
||||
if attempt == nil || attempt.response == nil || attempt.responseIntroWritten {
|
||||
return
|
||||
}
|
||||
attempt.response.WriteString(fmt.Sprintf("=== API RESPONSE %d ===\n", attempt.index))
|
||||
attempt.response.WriteString(fmt.Sprintf("Timestamp: %s\n", time.Now().Format(time.RFC3339Nano)))
|
||||
attempt.response.WriteString("\n")
|
||||
attempt.responseIntroWritten = true
|
||||
}
|
||||
|
||||
func updateAggregatedRequest(ginCtx *gin.Context, attempts []*upstreamAttempt) {
|
||||
if ginCtx == nil {
|
||||
return
|
||||
}
|
||||
var builder strings.Builder
|
||||
for _, attempt := range attempts {
|
||||
builder.WriteString(attempt.request)
|
||||
}
|
||||
ginCtx.Set(apiRequestKey, []byte(builder.String()))
|
||||
}
|
||||
|
||||
func updateAggregatedResponse(ginCtx *gin.Context, attempts []*upstreamAttempt) {
|
||||
if ginCtx == nil {
|
||||
return
|
||||
}
|
||||
var builder strings.Builder
|
||||
for idx, attempt := range attempts {
|
||||
if attempt == nil || attempt.response == nil {
|
||||
continue
|
||||
}
|
||||
responseText := attempt.response.String()
|
||||
if responseText == "" {
|
||||
continue
|
||||
}
|
||||
builder.WriteString(responseText)
|
||||
if !strings.HasSuffix(responseText, "\n") {
|
||||
builder.WriteString("\n")
|
||||
}
|
||||
if idx < len(attempts)-1 {
|
||||
builder.WriteString("\n")
|
||||
}
|
||||
}
|
||||
ginCtx.Set(apiResponseKey, []byte(builder.String()))
|
||||
}
|
||||
|
||||
func writeHeaders(builder *strings.Builder, headers http.Header) {
|
||||
if builder == nil {
|
||||
return
|
||||
}
|
||||
if len(headers) == 0 {
|
||||
builder.WriteString("<none>\n")
|
||||
return
|
||||
}
|
||||
keys := make([]string, 0, len(headers))
|
||||
for key := range headers {
|
||||
keys = append(keys, key)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
for _, key := range keys {
|
||||
values := headers[key]
|
||||
if len(values) == 0 {
|
||||
builder.WriteString(fmt.Sprintf("%s:\n", key))
|
||||
continue
|
||||
}
|
||||
for _, value := range values {
|
||||
masked := util.MaskSensitiveHeaderValue(key, value)
|
||||
builder.WriteString(fmt.Sprintf("%s: %s\n", key, masked))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func formatAuthInfo(info UpstreamRequestLog) string {
|
||||
var parts []string
|
||||
if trimmed := strings.TrimSpace(info.Provider); trimmed != "" {
|
||||
parts = append(parts, fmt.Sprintf("provider=%s", trimmed))
|
||||
}
|
||||
if trimmed := strings.TrimSpace(info.AuthID); trimmed != "" {
|
||||
parts = append(parts, fmt.Sprintf("auth_id=%s", trimmed))
|
||||
}
|
||||
if trimmed := strings.TrimSpace(info.AuthLabel); trimmed != "" {
|
||||
parts = append(parts, fmt.Sprintf("label=%s", trimmed))
|
||||
}
|
||||
|
||||
authType := strings.ToLower(strings.TrimSpace(info.AuthType))
|
||||
authValue := strings.TrimSpace(info.AuthValue)
|
||||
switch authType {
|
||||
case "api_key":
|
||||
if authValue != "" {
|
||||
parts = append(parts, fmt.Sprintf("type=api_key value=%s", util.HideAPIKey(authValue)))
|
||||
} else {
|
||||
parts = append(parts, "type=api_key")
|
||||
}
|
||||
case "oauth":
|
||||
parts = append(parts, "type=oauth")
|
||||
default:
|
||||
if authType != "" {
|
||||
if authValue != "" {
|
||||
parts = append(parts, fmt.Sprintf("type=%s value=%s", authType, authValue))
|
||||
} else {
|
||||
parts = append(parts, fmt.Sprintf("type=%s", authType))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return strings.Join(parts, ", ")
|
||||
}
|
||||
|
||||
func SummarizeErrorBody(contentType string, body []byte) string {
|
||||
isHTML := strings.Contains(strings.ToLower(contentType), "text/html")
|
||||
if !isHTML {
|
||||
trimmed := bytes.TrimSpace(bytes.ToLower(body))
|
||||
if bytes.HasPrefix(trimmed, []byte("<!doctype html")) || bytes.HasPrefix(trimmed, []byte("<html")) {
|
||||
isHTML = true
|
||||
}
|
||||
}
|
||||
if isHTML {
|
||||
if title := extractHTMLTitle(body); title != "" {
|
||||
return title
|
||||
}
|
||||
return "[html body omitted]"
|
||||
}
|
||||
|
||||
// Try to extract error message from JSON response
|
||||
if message := extractJSONErrorMessage(body); message != "" {
|
||||
return message
|
||||
}
|
||||
|
||||
return string(body)
|
||||
}
|
||||
|
||||
func extractHTMLTitle(body []byte) string {
|
||||
lower := bytes.ToLower(body)
|
||||
start := bytes.Index(lower, []byte("<title"))
|
||||
if start == -1 {
|
||||
return ""
|
||||
}
|
||||
gt := bytes.IndexByte(lower[start:], '>')
|
||||
if gt == -1 {
|
||||
return ""
|
||||
}
|
||||
start += gt + 1
|
||||
end := bytes.Index(lower[start:], []byte("</title>"))
|
||||
if end == -1 {
|
||||
return ""
|
||||
}
|
||||
title := string(body[start : start+end])
|
||||
title = html.UnescapeString(title)
|
||||
title = strings.TrimSpace(title)
|
||||
if title == "" {
|
||||
return ""
|
||||
}
|
||||
return strings.Join(strings.Fields(title), " ")
|
||||
}
|
||||
|
||||
// extractJSONErrorMessage attempts to extract error.message from JSON error responses
|
||||
func extractJSONErrorMessage(body []byte) string {
|
||||
result := gjson.GetBytes(body, "error.message")
|
||||
if result.Exists() && result.String() != "" {
|
||||
return result.String()
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// logWithRequestID returns a logrus Entry with request_id field populated from context.
|
||||
// If no request ID is found in context, it returns the standard logger.
|
||||
func LogWithRequestID(ctx context.Context) *log.Entry {
|
||||
if ctx == nil {
|
||||
return log.NewEntry(log.StandardLogger())
|
||||
}
|
||||
requestID := logging.GetRequestID(ctx)
|
||||
if requestID == "" {
|
||||
return log.NewEntry(log.StandardLogger())
|
||||
}
|
||||
return log.WithField("request_id", requestID)
|
||||
}
|
||||
@@ -0,0 +1,319 @@
|
||||
package helps
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
// ApplyPayloadConfigWithRoot behaves like applyPayloadConfig but treats all parameter
|
||||
// paths as relative to the provided root path (for example, "request" for Gemini CLI)
|
||||
// and restricts matches to the given protocol when supplied. Defaults are checked
|
||||
// against the original payload when provided. requestedModel carries the client-visible
|
||||
// model name before alias resolution so payload rules can target aliases precisely.
|
||||
func ApplyPayloadConfigWithRoot(cfg *config.Config, model, protocol, root string, payload, original []byte, requestedModel string) []byte {
|
||||
if cfg == nil || len(payload) == 0 {
|
||||
return payload
|
||||
}
|
||||
rules := cfg.Payload
|
||||
if len(rules.Default) == 0 && len(rules.DefaultRaw) == 0 && len(rules.Override) == 0 && len(rules.OverrideRaw) == 0 && len(rules.Filter) == 0 {
|
||||
return payload
|
||||
}
|
||||
model = strings.TrimSpace(model)
|
||||
requestedModel = strings.TrimSpace(requestedModel)
|
||||
if model == "" && requestedModel == "" {
|
||||
return payload
|
||||
}
|
||||
candidates := payloadModelCandidates(model, requestedModel)
|
||||
out := payload
|
||||
source := original
|
||||
if len(source) == 0 {
|
||||
source = payload
|
||||
}
|
||||
appliedDefaults := make(map[string]struct{})
|
||||
// Apply default rules: first write wins per field across all matching rules.
|
||||
for i := range rules.Default {
|
||||
rule := &rules.Default[i]
|
||||
if !payloadModelRulesMatch(rule.Models, protocol, candidates) {
|
||||
continue
|
||||
}
|
||||
for path, value := range rule.Params {
|
||||
fullPath := buildPayloadPath(root, path)
|
||||
if fullPath == "" {
|
||||
continue
|
||||
}
|
||||
if gjson.GetBytes(source, fullPath).Exists() {
|
||||
continue
|
||||
}
|
||||
if _, ok := appliedDefaults[fullPath]; ok {
|
||||
continue
|
||||
}
|
||||
updated, errSet := sjson.SetBytes(out, fullPath, value)
|
||||
if errSet != nil {
|
||||
continue
|
||||
}
|
||||
out = updated
|
||||
appliedDefaults[fullPath] = struct{}{}
|
||||
}
|
||||
}
|
||||
// Apply default raw rules: first write wins per field across all matching rules.
|
||||
for i := range rules.DefaultRaw {
|
||||
rule := &rules.DefaultRaw[i]
|
||||
if !payloadModelRulesMatch(rule.Models, protocol, candidates) {
|
||||
continue
|
||||
}
|
||||
for path, value := range rule.Params {
|
||||
fullPath := buildPayloadPath(root, path)
|
||||
if fullPath == "" {
|
||||
continue
|
||||
}
|
||||
if gjson.GetBytes(source, fullPath).Exists() {
|
||||
continue
|
||||
}
|
||||
if _, ok := appliedDefaults[fullPath]; ok {
|
||||
continue
|
||||
}
|
||||
rawValue, ok := payloadRawValue(value)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
updated, errSet := sjson.SetRawBytes(out, fullPath, rawValue)
|
||||
if errSet != nil {
|
||||
continue
|
||||
}
|
||||
out = updated
|
||||
appliedDefaults[fullPath] = struct{}{}
|
||||
}
|
||||
}
|
||||
// Apply override rules: last write wins per field across all matching rules.
|
||||
for i := range rules.Override {
|
||||
rule := &rules.Override[i]
|
||||
if !payloadModelRulesMatch(rule.Models, protocol, candidates) {
|
||||
continue
|
||||
}
|
||||
for path, value := range rule.Params {
|
||||
fullPath := buildPayloadPath(root, path)
|
||||
if fullPath == "" {
|
||||
continue
|
||||
}
|
||||
updated, errSet := sjson.SetBytes(out, fullPath, value)
|
||||
if errSet != nil {
|
||||
continue
|
||||
}
|
||||
out = updated
|
||||
}
|
||||
}
|
||||
// Apply override raw rules: last write wins per field across all matching rules.
|
||||
for i := range rules.OverrideRaw {
|
||||
rule := &rules.OverrideRaw[i]
|
||||
if !payloadModelRulesMatch(rule.Models, protocol, candidates) {
|
||||
continue
|
||||
}
|
||||
for path, value := range rule.Params {
|
||||
fullPath := buildPayloadPath(root, path)
|
||||
if fullPath == "" {
|
||||
continue
|
||||
}
|
||||
rawValue, ok := payloadRawValue(value)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
updated, errSet := sjson.SetRawBytes(out, fullPath, rawValue)
|
||||
if errSet != nil {
|
||||
continue
|
||||
}
|
||||
out = updated
|
||||
}
|
||||
}
|
||||
// Apply filter rules: remove matching paths from payload.
|
||||
for i := range rules.Filter {
|
||||
rule := &rules.Filter[i]
|
||||
if !payloadModelRulesMatch(rule.Models, protocol, candidates) {
|
||||
continue
|
||||
}
|
||||
for _, path := range rule.Params {
|
||||
fullPath := buildPayloadPath(root, path)
|
||||
if fullPath == "" {
|
||||
continue
|
||||
}
|
||||
updated, errDel := sjson.DeleteBytes(out, fullPath)
|
||||
if errDel != nil {
|
||||
continue
|
||||
}
|
||||
out = updated
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func payloadModelRulesMatch(rules []config.PayloadModelRule, protocol string, models []string) bool {
|
||||
if len(rules) == 0 || len(models) == 0 {
|
||||
return false
|
||||
}
|
||||
for _, model := range models {
|
||||
for _, entry := range rules {
|
||||
name := strings.TrimSpace(entry.Name)
|
||||
if name == "" {
|
||||
continue
|
||||
}
|
||||
if ep := strings.TrimSpace(entry.Protocol); ep != "" && protocol != "" && !strings.EqualFold(ep, protocol) {
|
||||
continue
|
||||
}
|
||||
if matchModelPattern(name, model) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func payloadModelCandidates(model, requestedModel string) []string {
|
||||
model = strings.TrimSpace(model)
|
||||
requestedModel = strings.TrimSpace(requestedModel)
|
||||
if model == "" && requestedModel == "" {
|
||||
return nil
|
||||
}
|
||||
candidates := make([]string, 0, 3)
|
||||
seen := make(map[string]struct{}, 3)
|
||||
addCandidate := func(value string) {
|
||||
value = strings.TrimSpace(value)
|
||||
if value == "" {
|
||||
return
|
||||
}
|
||||
key := strings.ToLower(value)
|
||||
if _, ok := seen[key]; ok {
|
||||
return
|
||||
}
|
||||
seen[key] = struct{}{}
|
||||
candidates = append(candidates, value)
|
||||
}
|
||||
if model != "" {
|
||||
addCandidate(model)
|
||||
}
|
||||
if requestedModel != "" {
|
||||
parsed := thinking.ParseSuffix(requestedModel)
|
||||
base := strings.TrimSpace(parsed.ModelName)
|
||||
if base != "" {
|
||||
addCandidate(base)
|
||||
}
|
||||
if parsed.HasSuffix {
|
||||
addCandidate(requestedModel)
|
||||
}
|
||||
}
|
||||
return candidates
|
||||
}
|
||||
|
||||
// buildPayloadPath combines an optional root path with a relative parameter path.
|
||||
// When root is empty, the parameter path is used as-is. When root is non-empty,
|
||||
// the parameter path is treated as relative to root.
|
||||
func buildPayloadPath(root, path string) string {
|
||||
r := strings.TrimSpace(root)
|
||||
p := strings.TrimSpace(path)
|
||||
if r == "" {
|
||||
return p
|
||||
}
|
||||
if p == "" {
|
||||
return r
|
||||
}
|
||||
if strings.HasPrefix(p, ".") {
|
||||
p = p[1:]
|
||||
}
|
||||
return r + "." + p
|
||||
}
|
||||
|
||||
func payloadRawValue(value any) ([]byte, bool) {
|
||||
if value == nil {
|
||||
return nil, false
|
||||
}
|
||||
switch typed := value.(type) {
|
||||
case string:
|
||||
return []byte(typed), true
|
||||
case []byte:
|
||||
return typed, true
|
||||
default:
|
||||
raw, errMarshal := json.Marshal(typed)
|
||||
if errMarshal != nil {
|
||||
return nil, false
|
||||
}
|
||||
return raw, true
|
||||
}
|
||||
}
|
||||
|
||||
func PayloadRequestedModel(opts cliproxyexecutor.Options, fallback string) string {
|
||||
fallback = strings.TrimSpace(fallback)
|
||||
if len(opts.Metadata) == 0 {
|
||||
return fallback
|
||||
}
|
||||
raw, ok := opts.Metadata[cliproxyexecutor.RequestedModelMetadataKey]
|
||||
if !ok || raw == nil {
|
||||
return fallback
|
||||
}
|
||||
switch v := raw.(type) {
|
||||
case string:
|
||||
if strings.TrimSpace(v) == "" {
|
||||
return fallback
|
||||
}
|
||||
return strings.TrimSpace(v)
|
||||
case []byte:
|
||||
if len(v) == 0 {
|
||||
return fallback
|
||||
}
|
||||
trimmed := strings.TrimSpace(string(v))
|
||||
if trimmed == "" {
|
||||
return fallback
|
||||
}
|
||||
return trimmed
|
||||
default:
|
||||
return fallback
|
||||
}
|
||||
}
|
||||
|
||||
// matchModelPattern performs simple wildcard matching where '*' matches zero or more characters.
|
||||
// Examples:
|
||||
//
|
||||
// "*-5" matches "gpt-5"
|
||||
// "gpt-*" matches "gpt-5" and "gpt-4"
|
||||
// "gemini-*-pro" matches "gemini-2.5-pro" and "gemini-3-pro".
|
||||
func matchModelPattern(pattern, model string) bool {
|
||||
pattern = strings.TrimSpace(pattern)
|
||||
model = strings.TrimSpace(model)
|
||||
if pattern == "" {
|
||||
return false
|
||||
}
|
||||
if pattern == "*" {
|
||||
return true
|
||||
}
|
||||
// Iterative glob-style matcher supporting only '*' wildcard.
|
||||
pi, si := 0, 0
|
||||
starIdx := -1
|
||||
matchIdx := 0
|
||||
for si < len(model) {
|
||||
if pi < len(pattern) && (pattern[pi] == model[si]) {
|
||||
pi++
|
||||
si++
|
||||
continue
|
||||
}
|
||||
if pi < len(pattern) && pattern[pi] == '*' {
|
||||
starIdx = pi
|
||||
matchIdx = si
|
||||
pi++
|
||||
continue
|
||||
}
|
||||
if starIdx != -1 {
|
||||
pi = starIdx + 1
|
||||
matchIdx++
|
||||
si = matchIdx
|
||||
continue
|
||||
}
|
||||
return false
|
||||
}
|
||||
for pi < len(pattern) && pattern[pi] == '*' {
|
||||
pi++
|
||||
}
|
||||
return pi == len(pattern)
|
||||
}
|
||||
@@ -0,0 +1,79 @@
|
||||
package helps
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/proxyutil"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// NewProxyAwareHTTPClient creates an HTTP client with proper proxy configuration priority:
|
||||
// 1. Use auth.ProxyURL if configured (highest priority)
|
||||
// 2. Use cfg.ProxyURL if auth proxy is not configured
|
||||
// 3. Use RoundTripper from context if neither are configured
|
||||
//
|
||||
// Parameters:
|
||||
// - ctx: The context containing optional RoundTripper
|
||||
// - cfg: The application configuration
|
||||
// - auth: The authentication information
|
||||
// - timeout: The client timeout (0 means no timeout)
|
||||
//
|
||||
// Returns:
|
||||
// - *http.Client: An HTTP client with configured proxy or transport
|
||||
func NewProxyAwareHTTPClient(ctx context.Context, cfg *config.Config, auth *cliproxyauth.Auth, timeout time.Duration) *http.Client {
|
||||
httpClient := &http.Client{}
|
||||
if timeout > 0 {
|
||||
httpClient.Timeout = timeout
|
||||
}
|
||||
|
||||
// Priority 1: Use auth.ProxyURL if configured
|
||||
var proxyURL string
|
||||
if auth != nil {
|
||||
proxyURL = strings.TrimSpace(auth.ProxyURL)
|
||||
}
|
||||
|
||||
// Priority 2: Use cfg.ProxyURL if auth proxy is not configured
|
||||
if proxyURL == "" && cfg != nil {
|
||||
proxyURL = strings.TrimSpace(cfg.ProxyURL)
|
||||
}
|
||||
|
||||
// If we have a proxy URL configured, set up the transport
|
||||
if proxyURL != "" {
|
||||
transport := buildProxyTransport(proxyURL)
|
||||
if transport != nil {
|
||||
httpClient.Transport = transport
|
||||
return httpClient
|
||||
}
|
||||
// If proxy setup failed, log and fall through to context RoundTripper
|
||||
log.Debugf("failed to setup proxy from URL: %s, falling back to context transport", proxyURL)
|
||||
}
|
||||
|
||||
// Priority 3: Use RoundTripper from context (typically from RoundTripperFor)
|
||||
if rt, ok := ctx.Value("cliproxy.roundtripper").(http.RoundTripper); ok && rt != nil {
|
||||
httpClient.Transport = rt
|
||||
}
|
||||
|
||||
return httpClient
|
||||
}
|
||||
|
||||
// buildProxyTransport creates an HTTP transport configured for the given proxy URL.
|
||||
// It supports SOCKS5, HTTP, and HTTPS proxy protocols.
|
||||
//
|
||||
// Parameters:
|
||||
// - proxyURL: The proxy URL string (e.g., "socks5://user:pass@host:port", "http://host:port")
|
||||
//
|
||||
// Returns:
|
||||
// - *http.Transport: A configured transport, or nil if the proxy URL is invalid
|
||||
func buildProxyTransport(proxyURL string) *http.Transport {
|
||||
transport, _, errBuild := proxyutil.BuildHTTPTransport(proxyURL)
|
||||
if errBuild != nil {
|
||||
log.Errorf("%v", errBuild)
|
||||
return nil
|
||||
}
|
||||
return transport
|
||||
}
|
||||
@@ -0,0 +1,30 @@
|
||||
package helps
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
|
||||
)
|
||||
|
||||
func TestNewProxyAwareHTTPClientDirectBypassesGlobalProxy(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
client := NewProxyAwareHTTPClient(
|
||||
context.Background(),
|
||||
&config.Config{SDKConfig: sdkconfig.SDKConfig{ProxyURL: "http://global-proxy.example.com:8080"}},
|
||||
&cliproxyauth.Auth{ProxyURL: "direct"},
|
||||
0,
|
||||
)
|
||||
|
||||
transport, ok := client.Transport.(*http.Transport)
|
||||
if !ok {
|
||||
t.Fatalf("transport type = %T, want *http.Transport", client.Transport)
|
||||
}
|
||||
if transport.Proxy != nil {
|
||||
t.Fatal("expected direct transport to disable proxy function")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,12 @@
|
||||
package helps
|
||||
|
||||
import (
|
||||
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/thinking/provider/antigravity"
|
||||
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/thinking/provider/claude"
|
||||
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/thinking/provider/codex"
|
||||
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/thinking/provider/gemini"
|
||||
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/thinking/provider/geminicli"
|
||||
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/thinking/provider/iflow"
|
||||
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/thinking/provider/kimi"
|
||||
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/thinking/provider/openai"
|
||||
)
|
||||
@@ -0,0 +1,236 @@
|
||||
package helps
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tiktoken-go/tokenizer"
|
||||
)
|
||||
|
||||
// TokenizerForModel returns a tokenizer codec suitable for an OpenAI-style model id.
|
||||
func TokenizerForModel(model string) (tokenizer.Codec, error) {
|
||||
sanitized := strings.ToLower(strings.TrimSpace(model))
|
||||
switch {
|
||||
case sanitized == "":
|
||||
return tokenizer.Get(tokenizer.Cl100kBase)
|
||||
case strings.HasPrefix(sanitized, "gpt-5"):
|
||||
return tokenizer.ForModel(tokenizer.GPT5)
|
||||
case strings.HasPrefix(sanitized, "gpt-5.1"):
|
||||
return tokenizer.ForModel(tokenizer.GPT5)
|
||||
case strings.HasPrefix(sanitized, "gpt-4.1"):
|
||||
return tokenizer.ForModel(tokenizer.GPT41)
|
||||
case strings.HasPrefix(sanitized, "gpt-4o"):
|
||||
return tokenizer.ForModel(tokenizer.GPT4o)
|
||||
case strings.HasPrefix(sanitized, "gpt-4"):
|
||||
return tokenizer.ForModel(tokenizer.GPT4)
|
||||
case strings.HasPrefix(sanitized, "gpt-3.5"), strings.HasPrefix(sanitized, "gpt-3"):
|
||||
return tokenizer.ForModel(tokenizer.GPT35Turbo)
|
||||
case strings.HasPrefix(sanitized, "o1"):
|
||||
return tokenizer.ForModel(tokenizer.O1)
|
||||
case strings.HasPrefix(sanitized, "o3"):
|
||||
return tokenizer.ForModel(tokenizer.O3)
|
||||
case strings.HasPrefix(sanitized, "o4"):
|
||||
return tokenizer.ForModel(tokenizer.O4Mini)
|
||||
default:
|
||||
return tokenizer.Get(tokenizer.O200kBase)
|
||||
}
|
||||
}
|
||||
|
||||
// CountOpenAIChatTokens approximates prompt tokens for OpenAI chat completions payloads.
|
||||
func CountOpenAIChatTokens(enc tokenizer.Codec, payload []byte) (int64, error) {
|
||||
if enc == nil {
|
||||
return 0, fmt.Errorf("encoder is nil")
|
||||
}
|
||||
if len(payload) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
root := gjson.ParseBytes(payload)
|
||||
segments := make([]string, 0, 32)
|
||||
|
||||
collectOpenAIMessages(root.Get("messages"), &segments)
|
||||
collectOpenAITools(root.Get("tools"), &segments)
|
||||
collectOpenAIFunctions(root.Get("functions"), &segments)
|
||||
collectOpenAIToolChoice(root.Get("tool_choice"), &segments)
|
||||
collectOpenAIResponseFormat(root.Get("response_format"), &segments)
|
||||
addIfNotEmpty(&segments, root.Get("input").String())
|
||||
addIfNotEmpty(&segments, root.Get("prompt").String())
|
||||
|
||||
joined := strings.TrimSpace(strings.Join(segments, "\n"))
|
||||
if joined == "" {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
count, err := enc.Count(joined)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return int64(count), nil
|
||||
}
|
||||
|
||||
// BuildOpenAIUsageJSON returns a minimal usage structure understood by downstream translators.
|
||||
func BuildOpenAIUsageJSON(count int64) []byte {
|
||||
return []byte(fmt.Sprintf(`{"usage":{"prompt_tokens":%d,"completion_tokens":0,"total_tokens":%d}}`, count, count))
|
||||
}
|
||||
|
||||
func collectOpenAIMessages(messages gjson.Result, segments *[]string) {
|
||||
if !messages.Exists() || !messages.IsArray() {
|
||||
return
|
||||
}
|
||||
messages.ForEach(func(_, message gjson.Result) bool {
|
||||
addIfNotEmpty(segments, message.Get("role").String())
|
||||
addIfNotEmpty(segments, message.Get("name").String())
|
||||
collectOpenAIContent(message.Get("content"), segments)
|
||||
collectOpenAIToolCalls(message.Get("tool_calls"), segments)
|
||||
collectOpenAIFunctionCall(message.Get("function_call"), segments)
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
func collectOpenAIContent(content gjson.Result, segments *[]string) {
|
||||
if !content.Exists() {
|
||||
return
|
||||
}
|
||||
if content.Type == gjson.String {
|
||||
addIfNotEmpty(segments, content.String())
|
||||
return
|
||||
}
|
||||
if content.IsArray() {
|
||||
content.ForEach(func(_, part gjson.Result) bool {
|
||||
partType := part.Get("type").String()
|
||||
switch partType {
|
||||
case "text", "input_text", "output_text":
|
||||
addIfNotEmpty(segments, part.Get("text").String())
|
||||
case "image_url":
|
||||
addIfNotEmpty(segments, part.Get("image_url.url").String())
|
||||
case "input_audio", "output_audio", "audio":
|
||||
addIfNotEmpty(segments, part.Get("id").String())
|
||||
case "tool_result":
|
||||
addIfNotEmpty(segments, part.Get("name").String())
|
||||
collectOpenAIContent(part.Get("content"), segments)
|
||||
default:
|
||||
if part.IsArray() {
|
||||
collectOpenAIContent(part, segments)
|
||||
return true
|
||||
}
|
||||
if part.Type == gjson.JSON {
|
||||
addIfNotEmpty(segments, part.Raw)
|
||||
return true
|
||||
}
|
||||
addIfNotEmpty(segments, part.String())
|
||||
}
|
||||
return true
|
||||
})
|
||||
return
|
||||
}
|
||||
if content.Type == gjson.JSON {
|
||||
addIfNotEmpty(segments, content.Raw)
|
||||
}
|
||||
}
|
||||
|
||||
func collectOpenAIToolCalls(calls gjson.Result, segments *[]string) {
|
||||
if !calls.Exists() || !calls.IsArray() {
|
||||
return
|
||||
}
|
||||
calls.ForEach(func(_, call gjson.Result) bool {
|
||||
addIfNotEmpty(segments, call.Get("id").String())
|
||||
addIfNotEmpty(segments, call.Get("type").String())
|
||||
function := call.Get("function")
|
||||
if function.Exists() {
|
||||
addIfNotEmpty(segments, function.Get("name").String())
|
||||
addIfNotEmpty(segments, function.Get("description").String())
|
||||
addIfNotEmpty(segments, function.Get("arguments").String())
|
||||
if params := function.Get("parameters"); params.Exists() {
|
||||
addIfNotEmpty(segments, params.Raw)
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
func collectOpenAIFunctionCall(call gjson.Result, segments *[]string) {
|
||||
if !call.Exists() {
|
||||
return
|
||||
}
|
||||
addIfNotEmpty(segments, call.Get("name").String())
|
||||
addIfNotEmpty(segments, call.Get("arguments").String())
|
||||
}
|
||||
|
||||
func collectOpenAITools(tools gjson.Result, segments *[]string) {
|
||||
if !tools.Exists() {
|
||||
return
|
||||
}
|
||||
if tools.IsArray() {
|
||||
tools.ForEach(func(_, tool gjson.Result) bool {
|
||||
appendToolPayload(tool, segments)
|
||||
return true
|
||||
})
|
||||
return
|
||||
}
|
||||
appendToolPayload(tools, segments)
|
||||
}
|
||||
|
||||
func collectOpenAIFunctions(functions gjson.Result, segments *[]string) {
|
||||
if !functions.Exists() || !functions.IsArray() {
|
||||
return
|
||||
}
|
||||
functions.ForEach(func(_, function gjson.Result) bool {
|
||||
addIfNotEmpty(segments, function.Get("name").String())
|
||||
addIfNotEmpty(segments, function.Get("description").String())
|
||||
if params := function.Get("parameters"); params.Exists() {
|
||||
addIfNotEmpty(segments, params.Raw)
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
func collectOpenAIToolChoice(choice gjson.Result, segments *[]string) {
|
||||
if !choice.Exists() {
|
||||
return
|
||||
}
|
||||
if choice.Type == gjson.String {
|
||||
addIfNotEmpty(segments, choice.String())
|
||||
return
|
||||
}
|
||||
addIfNotEmpty(segments, choice.Raw)
|
||||
}
|
||||
|
||||
func collectOpenAIResponseFormat(format gjson.Result, segments *[]string) {
|
||||
if !format.Exists() {
|
||||
return
|
||||
}
|
||||
addIfNotEmpty(segments, format.Get("type").String())
|
||||
addIfNotEmpty(segments, format.Get("name").String())
|
||||
if schema := format.Get("json_schema"); schema.Exists() {
|
||||
addIfNotEmpty(segments, schema.Raw)
|
||||
}
|
||||
if schema := format.Get("schema"); schema.Exists() {
|
||||
addIfNotEmpty(segments, schema.Raw)
|
||||
}
|
||||
}
|
||||
|
||||
func appendToolPayload(tool gjson.Result, segments *[]string) {
|
||||
if !tool.Exists() {
|
||||
return
|
||||
}
|
||||
addIfNotEmpty(segments, tool.Get("type").String())
|
||||
addIfNotEmpty(segments, tool.Get("name").String())
|
||||
addIfNotEmpty(segments, tool.Get("description").String())
|
||||
if function := tool.Get("function"); function.Exists() {
|
||||
addIfNotEmpty(segments, function.Get("name").String())
|
||||
addIfNotEmpty(segments, function.Get("description").String())
|
||||
if params := function.Get("parameters"); params.Exists() {
|
||||
addIfNotEmpty(segments, params.Raw)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func addIfNotEmpty(segments *[]string, value string) {
|
||||
if segments == nil {
|
||||
return
|
||||
}
|
||||
if trimmed := strings.TrimSpace(value); trimmed != "" {
|
||||
*segments = append(*segments, trimmed)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,577 @@
|
||||
package helps
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/usage"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
type UsageReporter struct {
|
||||
provider string
|
||||
model string
|
||||
authID string
|
||||
authIndex string
|
||||
apiKey string
|
||||
source string
|
||||
requestedAt time.Time
|
||||
once sync.Once
|
||||
}
|
||||
|
||||
func NewUsageReporter(ctx context.Context, provider, model string, auth *cliproxyauth.Auth) *UsageReporter {
|
||||
apiKey := APIKeyFromContext(ctx)
|
||||
reporter := &UsageReporter{
|
||||
provider: provider,
|
||||
model: model,
|
||||
requestedAt: time.Now(),
|
||||
apiKey: apiKey,
|
||||
source: resolveUsageSource(auth, apiKey),
|
||||
}
|
||||
if auth != nil {
|
||||
reporter.authID = auth.ID
|
||||
reporter.authIndex = auth.EnsureIndex()
|
||||
}
|
||||
return reporter
|
||||
}
|
||||
|
||||
func (r *UsageReporter) Publish(ctx context.Context, detail usage.Detail) {
|
||||
r.publishWithOutcome(ctx, detail, false)
|
||||
}
|
||||
|
||||
func (r *UsageReporter) PublishFailure(ctx context.Context) {
|
||||
r.publishWithOutcome(ctx, usage.Detail{}, true)
|
||||
}
|
||||
|
||||
func (r *UsageReporter) TrackFailure(ctx context.Context, errPtr *error) {
|
||||
if r == nil || errPtr == nil {
|
||||
return
|
||||
}
|
||||
if *errPtr != nil {
|
||||
r.PublishFailure(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *UsageReporter) publishWithOutcome(ctx context.Context, detail usage.Detail, failed bool) {
|
||||
if r == nil {
|
||||
return
|
||||
}
|
||||
if detail.TotalTokens == 0 {
|
||||
total := detail.InputTokens + detail.OutputTokens + detail.ReasoningTokens
|
||||
if total > 0 {
|
||||
detail.TotalTokens = total
|
||||
}
|
||||
}
|
||||
if detail.InputTokens == 0 && detail.OutputTokens == 0 && detail.ReasoningTokens == 0 && detail.CachedTokens == 0 && detail.TotalTokens == 0 && !failed {
|
||||
return
|
||||
}
|
||||
r.once.Do(func() {
|
||||
usage.PublishRecord(ctx, r.buildRecord(detail, failed))
|
||||
})
|
||||
}
|
||||
|
||||
// ensurePublished guarantees that a usage record is emitted exactly once.
|
||||
// It is safe to call multiple times; only the first call wins due to once.Do.
|
||||
// This is used to ensure request counting even when upstream responses do not
|
||||
// include any usage fields (tokens), especially for streaming paths.
|
||||
func (r *UsageReporter) EnsurePublished(ctx context.Context) {
|
||||
if r == nil {
|
||||
return
|
||||
}
|
||||
r.once.Do(func() {
|
||||
usage.PublishRecord(ctx, r.buildRecord(usage.Detail{}, false))
|
||||
})
|
||||
}
|
||||
|
||||
func (r *UsageReporter) buildRecord(detail usage.Detail, failed bool) usage.Record {
|
||||
if r == nil {
|
||||
return usage.Record{Detail: detail, Failed: failed}
|
||||
}
|
||||
return usage.Record{
|
||||
Provider: r.provider,
|
||||
Model: r.model,
|
||||
Source: r.source,
|
||||
APIKey: r.apiKey,
|
||||
AuthID: r.authID,
|
||||
AuthIndex: r.authIndex,
|
||||
RequestedAt: r.requestedAt,
|
||||
Latency: r.latency(),
|
||||
Failed: failed,
|
||||
Detail: detail,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *UsageReporter) latency() time.Duration {
|
||||
if r == nil || r.requestedAt.IsZero() {
|
||||
return 0
|
||||
}
|
||||
latency := time.Since(r.requestedAt)
|
||||
if latency < 0 {
|
||||
return 0
|
||||
}
|
||||
return latency
|
||||
}
|
||||
|
||||
func APIKeyFromContext(ctx context.Context) string {
|
||||
if ctx == nil {
|
||||
return ""
|
||||
}
|
||||
ginCtx, ok := ctx.Value("gin").(*gin.Context)
|
||||
if !ok || ginCtx == nil {
|
||||
return ""
|
||||
}
|
||||
if v, exists := ginCtx.Get("apiKey"); exists {
|
||||
switch value := v.(type) {
|
||||
case string:
|
||||
return value
|
||||
case fmt.Stringer:
|
||||
return value.String()
|
||||
default:
|
||||
return fmt.Sprintf("%v", value)
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func resolveUsageSource(auth *cliproxyauth.Auth, ctxAPIKey string) string {
|
||||
if auth != nil {
|
||||
provider := strings.TrimSpace(auth.Provider)
|
||||
if strings.EqualFold(provider, "gemini-cli") {
|
||||
if id := strings.TrimSpace(auth.ID); id != "" {
|
||||
return id
|
||||
}
|
||||
}
|
||||
if strings.EqualFold(provider, "vertex") {
|
||||
if auth.Metadata != nil {
|
||||
if projectID, ok := auth.Metadata["project_id"].(string); ok {
|
||||
if trimmed := strings.TrimSpace(projectID); trimmed != "" {
|
||||
return trimmed
|
||||
}
|
||||
}
|
||||
if project, ok := auth.Metadata["project"].(string); ok {
|
||||
if trimmed := strings.TrimSpace(project); trimmed != "" {
|
||||
return trimmed
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if _, value := auth.AccountInfo(); value != "" {
|
||||
return strings.TrimSpace(value)
|
||||
}
|
||||
if auth.Metadata != nil {
|
||||
if email, ok := auth.Metadata["email"].(string); ok {
|
||||
if trimmed := strings.TrimSpace(email); trimmed != "" {
|
||||
return trimmed
|
||||
}
|
||||
}
|
||||
}
|
||||
if auth.Attributes != nil {
|
||||
if key := strings.TrimSpace(auth.Attributes["api_key"]); key != "" {
|
||||
return key
|
||||
}
|
||||
}
|
||||
}
|
||||
if trimmed := strings.TrimSpace(ctxAPIKey); trimmed != "" {
|
||||
return trimmed
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func ParseCodexUsage(data []byte) (usage.Detail, bool) {
|
||||
usageNode := gjson.ParseBytes(data).Get("response.usage")
|
||||
if !usageNode.Exists() {
|
||||
return usage.Detail{}, false
|
||||
}
|
||||
detail := usage.Detail{
|
||||
InputTokens: usageNode.Get("input_tokens").Int(),
|
||||
OutputTokens: usageNode.Get("output_tokens").Int(),
|
||||
TotalTokens: usageNode.Get("total_tokens").Int(),
|
||||
}
|
||||
if cached := usageNode.Get("input_tokens_details.cached_tokens"); cached.Exists() {
|
||||
detail.CachedTokens = cached.Int()
|
||||
}
|
||||
if reasoning := usageNode.Get("output_tokens_details.reasoning_tokens"); reasoning.Exists() {
|
||||
detail.ReasoningTokens = reasoning.Int()
|
||||
}
|
||||
return detail, true
|
||||
}
|
||||
|
||||
func ParseOpenAIUsage(data []byte) usage.Detail {
|
||||
usageNode := gjson.ParseBytes(data).Get("usage")
|
||||
if !usageNode.Exists() {
|
||||
return usage.Detail{}
|
||||
}
|
||||
inputNode := usageNode.Get("prompt_tokens")
|
||||
if !inputNode.Exists() {
|
||||
inputNode = usageNode.Get("input_tokens")
|
||||
}
|
||||
outputNode := usageNode.Get("completion_tokens")
|
||||
if !outputNode.Exists() {
|
||||
outputNode = usageNode.Get("output_tokens")
|
||||
}
|
||||
detail := usage.Detail{
|
||||
InputTokens: inputNode.Int(),
|
||||
OutputTokens: outputNode.Int(),
|
||||
TotalTokens: usageNode.Get("total_tokens").Int(),
|
||||
}
|
||||
cached := usageNode.Get("prompt_tokens_details.cached_tokens")
|
||||
if !cached.Exists() {
|
||||
cached = usageNode.Get("input_tokens_details.cached_tokens")
|
||||
}
|
||||
if cached.Exists() {
|
||||
detail.CachedTokens = cached.Int()
|
||||
}
|
||||
reasoning := usageNode.Get("completion_tokens_details.reasoning_tokens")
|
||||
if !reasoning.Exists() {
|
||||
reasoning = usageNode.Get("output_tokens_details.reasoning_tokens")
|
||||
}
|
||||
if reasoning.Exists() {
|
||||
detail.ReasoningTokens = reasoning.Int()
|
||||
}
|
||||
return detail
|
||||
}
|
||||
|
||||
func ParseOpenAIStreamUsage(line []byte) (usage.Detail, bool) {
|
||||
payload := jsonPayload(line)
|
||||
if len(payload) == 0 || !gjson.ValidBytes(payload) {
|
||||
return usage.Detail{}, false
|
||||
}
|
||||
usageNode := gjson.GetBytes(payload, "usage")
|
||||
if !usageNode.Exists() {
|
||||
return usage.Detail{}, false
|
||||
}
|
||||
detail := usage.Detail{
|
||||
InputTokens: usageNode.Get("prompt_tokens").Int(),
|
||||
OutputTokens: usageNode.Get("completion_tokens").Int(),
|
||||
TotalTokens: usageNode.Get("total_tokens").Int(),
|
||||
}
|
||||
if cached := usageNode.Get("prompt_tokens_details.cached_tokens"); cached.Exists() {
|
||||
detail.CachedTokens = cached.Int()
|
||||
}
|
||||
if reasoning := usageNode.Get("completion_tokens_details.reasoning_tokens"); reasoning.Exists() {
|
||||
detail.ReasoningTokens = reasoning.Int()
|
||||
}
|
||||
return detail, true
|
||||
}
|
||||
|
||||
func ParseClaudeUsage(data []byte) usage.Detail {
|
||||
usageNode := gjson.ParseBytes(data).Get("usage")
|
||||
if !usageNode.Exists() {
|
||||
return usage.Detail{}
|
||||
}
|
||||
detail := usage.Detail{
|
||||
InputTokens: usageNode.Get("input_tokens").Int(),
|
||||
OutputTokens: usageNode.Get("output_tokens").Int(),
|
||||
CachedTokens: usageNode.Get("cache_read_input_tokens").Int(),
|
||||
}
|
||||
if detail.CachedTokens == 0 {
|
||||
// fall back to creation tokens when read tokens are absent
|
||||
detail.CachedTokens = usageNode.Get("cache_creation_input_tokens").Int()
|
||||
}
|
||||
detail.TotalTokens = detail.InputTokens + detail.OutputTokens
|
||||
return detail
|
||||
}
|
||||
|
||||
func ParseClaudeStreamUsage(line []byte) (usage.Detail, bool) {
|
||||
payload := jsonPayload(line)
|
||||
if len(payload) == 0 || !gjson.ValidBytes(payload) {
|
||||
return usage.Detail{}, false
|
||||
}
|
||||
usageNode := gjson.GetBytes(payload, "usage")
|
||||
if !usageNode.Exists() {
|
||||
return usage.Detail{}, false
|
||||
}
|
||||
detail := usage.Detail{
|
||||
InputTokens: usageNode.Get("input_tokens").Int(),
|
||||
OutputTokens: usageNode.Get("output_tokens").Int(),
|
||||
CachedTokens: usageNode.Get("cache_read_input_tokens").Int(),
|
||||
}
|
||||
if detail.CachedTokens == 0 {
|
||||
detail.CachedTokens = usageNode.Get("cache_creation_input_tokens").Int()
|
||||
}
|
||||
detail.TotalTokens = detail.InputTokens + detail.OutputTokens
|
||||
return detail, true
|
||||
}
|
||||
|
||||
func parseGeminiFamilyUsageDetail(node gjson.Result) usage.Detail {
|
||||
detail := usage.Detail{
|
||||
InputTokens: node.Get("promptTokenCount").Int(),
|
||||
OutputTokens: node.Get("candidatesTokenCount").Int(),
|
||||
ReasoningTokens: node.Get("thoughtsTokenCount").Int(),
|
||||
TotalTokens: node.Get("totalTokenCount").Int(),
|
||||
CachedTokens: node.Get("cachedContentTokenCount").Int(),
|
||||
}
|
||||
if detail.TotalTokens == 0 {
|
||||
detail.TotalTokens = detail.InputTokens + detail.OutputTokens + detail.ReasoningTokens
|
||||
}
|
||||
return detail
|
||||
}
|
||||
|
||||
func ParseGeminiCLIUsage(data []byte) usage.Detail {
|
||||
usageNode := gjson.ParseBytes(data)
|
||||
node := usageNode.Get("response.usageMetadata")
|
||||
if !node.Exists() {
|
||||
node = usageNode.Get("response.usage_metadata")
|
||||
}
|
||||
if !node.Exists() {
|
||||
return usage.Detail{}
|
||||
}
|
||||
return parseGeminiFamilyUsageDetail(node)
|
||||
}
|
||||
|
||||
func ParseGeminiUsage(data []byte) usage.Detail {
|
||||
usageNode := gjson.ParseBytes(data)
|
||||
node := usageNode.Get("usageMetadata")
|
||||
if !node.Exists() {
|
||||
node = usageNode.Get("usage_metadata")
|
||||
}
|
||||
if !node.Exists() {
|
||||
return usage.Detail{}
|
||||
}
|
||||
return parseGeminiFamilyUsageDetail(node)
|
||||
}
|
||||
|
||||
func ParseGeminiStreamUsage(line []byte) (usage.Detail, bool) {
|
||||
payload := jsonPayload(line)
|
||||
if len(payload) == 0 || !gjson.ValidBytes(payload) {
|
||||
return usage.Detail{}, false
|
||||
}
|
||||
node := gjson.GetBytes(payload, "usageMetadata")
|
||||
if !node.Exists() {
|
||||
node = gjson.GetBytes(payload, "usage_metadata")
|
||||
}
|
||||
if !node.Exists() {
|
||||
return usage.Detail{}, false
|
||||
}
|
||||
return parseGeminiFamilyUsageDetail(node), true
|
||||
}
|
||||
|
||||
func ParseGeminiCLIStreamUsage(line []byte) (usage.Detail, bool) {
|
||||
payload := jsonPayload(line)
|
||||
if len(payload) == 0 || !gjson.ValidBytes(payload) {
|
||||
return usage.Detail{}, false
|
||||
}
|
||||
node := gjson.GetBytes(payload, "response.usageMetadata")
|
||||
if !node.Exists() {
|
||||
node = gjson.GetBytes(payload, "usage_metadata")
|
||||
}
|
||||
if !node.Exists() {
|
||||
return usage.Detail{}, false
|
||||
}
|
||||
return parseGeminiFamilyUsageDetail(node), true
|
||||
}
|
||||
|
||||
func ParseAntigravityUsage(data []byte) usage.Detail {
|
||||
usageNode := gjson.ParseBytes(data)
|
||||
node := usageNode.Get("response.usageMetadata")
|
||||
if !node.Exists() {
|
||||
node = usageNode.Get("usageMetadata")
|
||||
}
|
||||
if !node.Exists() {
|
||||
node = usageNode.Get("usage_metadata")
|
||||
}
|
||||
if !node.Exists() {
|
||||
return usage.Detail{}
|
||||
}
|
||||
return parseGeminiFamilyUsageDetail(node)
|
||||
}
|
||||
|
||||
func ParseAntigravityStreamUsage(line []byte) (usage.Detail, bool) {
|
||||
payload := jsonPayload(line)
|
||||
if len(payload) == 0 || !gjson.ValidBytes(payload) {
|
||||
return usage.Detail{}, false
|
||||
}
|
||||
node := gjson.GetBytes(payload, "response.usageMetadata")
|
||||
if !node.Exists() {
|
||||
node = gjson.GetBytes(payload, "usageMetadata")
|
||||
}
|
||||
if !node.Exists() {
|
||||
node = gjson.GetBytes(payload, "usage_metadata")
|
||||
}
|
||||
if !node.Exists() {
|
||||
return usage.Detail{}, false
|
||||
}
|
||||
return parseGeminiFamilyUsageDetail(node), true
|
||||
}
|
||||
|
||||
var stopChunkWithoutUsage sync.Map
|
||||
|
||||
func rememberStopWithoutUsage(traceID string) {
|
||||
stopChunkWithoutUsage.Store(traceID, struct{}{})
|
||||
time.AfterFunc(10*time.Minute, func() { stopChunkWithoutUsage.Delete(traceID) })
|
||||
}
|
||||
|
||||
// FilterSSEUsageMetadata removes usageMetadata from SSE events that are not
|
||||
// terminal (finishReason != "stop"). Stop chunks are left untouched. This
|
||||
// function is shared between aistudio and antigravity executors.
|
||||
func FilterSSEUsageMetadata(payload []byte) []byte {
|
||||
if len(payload) == 0 {
|
||||
return payload
|
||||
}
|
||||
|
||||
lines := bytes.Split(payload, []byte("\n"))
|
||||
modified := false
|
||||
foundData := false
|
||||
for idx, line := range lines {
|
||||
trimmed := bytes.TrimSpace(line)
|
||||
if len(trimmed) == 0 || !bytes.HasPrefix(trimmed, []byte("data:")) {
|
||||
continue
|
||||
}
|
||||
foundData = true
|
||||
dataIdx := bytes.Index(line, []byte("data:"))
|
||||
if dataIdx < 0 {
|
||||
continue
|
||||
}
|
||||
rawJSON := bytes.TrimSpace(line[dataIdx+5:])
|
||||
traceID := gjson.GetBytes(rawJSON, "traceId").String()
|
||||
if isStopChunkWithoutUsage(rawJSON) && traceID != "" {
|
||||
rememberStopWithoutUsage(traceID)
|
||||
continue
|
||||
}
|
||||
if traceID != "" {
|
||||
if _, ok := stopChunkWithoutUsage.Load(traceID); ok && hasUsageMetadata(rawJSON) {
|
||||
stopChunkWithoutUsage.Delete(traceID)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
cleaned, changed := StripUsageMetadataFromJSON(rawJSON)
|
||||
if !changed {
|
||||
continue
|
||||
}
|
||||
var rebuilt []byte
|
||||
rebuilt = append(rebuilt, line[:dataIdx]...)
|
||||
rebuilt = append(rebuilt, []byte("data:")...)
|
||||
if len(cleaned) > 0 {
|
||||
rebuilt = append(rebuilt, ' ')
|
||||
rebuilt = append(rebuilt, cleaned...)
|
||||
}
|
||||
lines[idx] = rebuilt
|
||||
modified = true
|
||||
}
|
||||
if !modified {
|
||||
if !foundData {
|
||||
// Handle payloads that are raw JSON without SSE data: prefix.
|
||||
trimmed := bytes.TrimSpace(payload)
|
||||
cleaned, changed := StripUsageMetadataFromJSON(trimmed)
|
||||
if !changed {
|
||||
return payload
|
||||
}
|
||||
return cleaned
|
||||
}
|
||||
return payload
|
||||
}
|
||||
return bytes.Join(lines, []byte("\n"))
|
||||
}
|
||||
|
||||
// StripUsageMetadataFromJSON drops usageMetadata unless finishReason is present (terminal).
|
||||
// It handles both formats:
|
||||
// - Aistudio: candidates.0.finishReason
|
||||
// - Antigravity: response.candidates.0.finishReason
|
||||
func StripUsageMetadataFromJSON(rawJSON []byte) ([]byte, bool) {
|
||||
jsonBytes := bytes.TrimSpace(rawJSON)
|
||||
if len(jsonBytes) == 0 || !gjson.ValidBytes(jsonBytes) {
|
||||
return rawJSON, false
|
||||
}
|
||||
|
||||
// Check for finishReason in both aistudio and antigravity formats
|
||||
finishReason := gjson.GetBytes(jsonBytes, "candidates.0.finishReason")
|
||||
if !finishReason.Exists() {
|
||||
finishReason = gjson.GetBytes(jsonBytes, "response.candidates.0.finishReason")
|
||||
}
|
||||
terminalReason := finishReason.Exists() && strings.TrimSpace(finishReason.String()) != ""
|
||||
|
||||
usageMetadata := gjson.GetBytes(jsonBytes, "usageMetadata")
|
||||
if !usageMetadata.Exists() {
|
||||
usageMetadata = gjson.GetBytes(jsonBytes, "response.usageMetadata")
|
||||
}
|
||||
|
||||
// Terminal chunk: keep as-is.
|
||||
if terminalReason {
|
||||
return rawJSON, false
|
||||
}
|
||||
|
||||
// Nothing to strip
|
||||
if !usageMetadata.Exists() {
|
||||
return rawJSON, false
|
||||
}
|
||||
|
||||
// Remove usageMetadata from both possible locations
|
||||
cleaned := jsonBytes
|
||||
var changed bool
|
||||
|
||||
if usageMetadata = gjson.GetBytes(cleaned, "usageMetadata"); usageMetadata.Exists() {
|
||||
// Rename usageMetadata to cpaUsageMetadata in the message_start event of Claude
|
||||
cleaned, _ = sjson.SetRawBytes(cleaned, "cpaUsageMetadata", []byte(usageMetadata.Raw))
|
||||
cleaned, _ = sjson.DeleteBytes(cleaned, "usageMetadata")
|
||||
changed = true
|
||||
}
|
||||
|
||||
if usageMetadata = gjson.GetBytes(cleaned, "response.usageMetadata"); usageMetadata.Exists() {
|
||||
// Rename usageMetadata to cpaUsageMetadata in the message_start event of Claude
|
||||
cleaned, _ = sjson.SetRawBytes(cleaned, "response.cpaUsageMetadata", []byte(usageMetadata.Raw))
|
||||
cleaned, _ = sjson.DeleteBytes(cleaned, "response.usageMetadata")
|
||||
changed = true
|
||||
}
|
||||
|
||||
return cleaned, changed
|
||||
}
|
||||
|
||||
func hasUsageMetadata(jsonBytes []byte) bool {
|
||||
if len(jsonBytes) == 0 || !gjson.ValidBytes(jsonBytes) {
|
||||
return false
|
||||
}
|
||||
if gjson.GetBytes(jsonBytes, "usageMetadata").Exists() {
|
||||
return true
|
||||
}
|
||||
if gjson.GetBytes(jsonBytes, "response.usageMetadata").Exists() {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func isStopChunkWithoutUsage(jsonBytes []byte) bool {
|
||||
if len(jsonBytes) == 0 || !gjson.ValidBytes(jsonBytes) {
|
||||
return false
|
||||
}
|
||||
finishReason := gjson.GetBytes(jsonBytes, "candidates.0.finishReason")
|
||||
if !finishReason.Exists() {
|
||||
finishReason = gjson.GetBytes(jsonBytes, "response.candidates.0.finishReason")
|
||||
}
|
||||
trimmed := strings.TrimSpace(finishReason.String())
|
||||
if !finishReason.Exists() || trimmed == "" {
|
||||
return false
|
||||
}
|
||||
return !hasUsageMetadata(jsonBytes)
|
||||
}
|
||||
|
||||
func JSONPayload(line []byte) []byte {
|
||||
return jsonPayload(line)
|
||||
}
|
||||
|
||||
func jsonPayload(line []byte) []byte {
|
||||
trimmed := bytes.TrimSpace(line)
|
||||
if len(trimmed) == 0 {
|
||||
return nil
|
||||
}
|
||||
if bytes.Equal(trimmed, []byte("[DONE]")) {
|
||||
return nil
|
||||
}
|
||||
if bytes.HasPrefix(trimmed, []byte("event:")) {
|
||||
return nil
|
||||
}
|
||||
if bytes.HasPrefix(trimmed, []byte("data:")) {
|
||||
trimmed = bytes.TrimSpace(trimmed[len("data:"):])
|
||||
}
|
||||
if len(trimmed) == 0 || trimmed[0] != '{' {
|
||||
return nil
|
||||
}
|
||||
return trimmed
|
||||
}
|
||||
@@ -0,0 +1,64 @@
|
||||
package helps
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/usage"
|
||||
)
|
||||
|
||||
func TestParseOpenAIUsageChatCompletions(t *testing.T) {
|
||||
data := []byte(`{"usage":{"prompt_tokens":1,"completion_tokens":2,"total_tokens":3,"prompt_tokens_details":{"cached_tokens":4},"completion_tokens_details":{"reasoning_tokens":5}}}`)
|
||||
detail := ParseOpenAIUsage(data)
|
||||
if detail.InputTokens != 1 {
|
||||
t.Fatalf("input tokens = %d, want %d", detail.InputTokens, 1)
|
||||
}
|
||||
if detail.OutputTokens != 2 {
|
||||
t.Fatalf("output tokens = %d, want %d", detail.OutputTokens, 2)
|
||||
}
|
||||
if detail.TotalTokens != 3 {
|
||||
t.Fatalf("total tokens = %d, want %d", detail.TotalTokens, 3)
|
||||
}
|
||||
if detail.CachedTokens != 4 {
|
||||
t.Fatalf("cached tokens = %d, want %d", detail.CachedTokens, 4)
|
||||
}
|
||||
if detail.ReasoningTokens != 5 {
|
||||
t.Fatalf("reasoning tokens = %d, want %d", detail.ReasoningTokens, 5)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseOpenAIUsageResponses(t *testing.T) {
|
||||
data := []byte(`{"usage":{"input_tokens":10,"output_tokens":20,"total_tokens":30,"input_tokens_details":{"cached_tokens":7},"output_tokens_details":{"reasoning_tokens":9}}}`)
|
||||
detail := ParseOpenAIUsage(data)
|
||||
if detail.InputTokens != 10 {
|
||||
t.Fatalf("input tokens = %d, want %d", detail.InputTokens, 10)
|
||||
}
|
||||
if detail.OutputTokens != 20 {
|
||||
t.Fatalf("output tokens = %d, want %d", detail.OutputTokens, 20)
|
||||
}
|
||||
if detail.TotalTokens != 30 {
|
||||
t.Fatalf("total tokens = %d, want %d", detail.TotalTokens, 30)
|
||||
}
|
||||
if detail.CachedTokens != 7 {
|
||||
t.Fatalf("cached tokens = %d, want %d", detail.CachedTokens, 7)
|
||||
}
|
||||
if detail.ReasoningTokens != 9 {
|
||||
t.Fatalf("reasoning tokens = %d, want %d", detail.ReasoningTokens, 9)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUsageReporterBuildRecordIncludesLatency(t *testing.T) {
|
||||
reporter := &UsageReporter{
|
||||
provider: "openai",
|
||||
model: "gpt-5.4",
|
||||
requestedAt: time.Now().Add(-1500 * time.Millisecond),
|
||||
}
|
||||
|
||||
record := reporter.buildRecord(usage.Detail{TotalTokens: 3}, false)
|
||||
if record.Latency < time.Second {
|
||||
t.Fatalf("latency = %v, want >= 1s", record.Latency)
|
||||
}
|
||||
if record.Latency > 3*time.Second {
|
||||
t.Fatalf("latency = %v, want <= 3s", record.Latency)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,89 @@
|
||||
package helps
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type userIDCacheEntry struct {
|
||||
value string
|
||||
expire time.Time
|
||||
}
|
||||
|
||||
var (
|
||||
userIDCache = make(map[string]userIDCacheEntry)
|
||||
userIDCacheMu sync.RWMutex
|
||||
userIDCacheCleanupOnce sync.Once
|
||||
)
|
||||
|
||||
const (
|
||||
userIDTTL = time.Hour
|
||||
userIDCacheCleanupPeriod = 15 * time.Minute
|
||||
)
|
||||
|
||||
func startUserIDCacheCleanup() {
|
||||
go func() {
|
||||
ticker := time.NewTicker(userIDCacheCleanupPeriod)
|
||||
defer ticker.Stop()
|
||||
for range ticker.C {
|
||||
purgeExpiredUserIDs()
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func purgeExpiredUserIDs() {
|
||||
now := time.Now()
|
||||
userIDCacheMu.Lock()
|
||||
for key, entry := range userIDCache {
|
||||
if !entry.expire.After(now) {
|
||||
delete(userIDCache, key)
|
||||
}
|
||||
}
|
||||
userIDCacheMu.Unlock()
|
||||
}
|
||||
|
||||
func userIDCacheKey(apiKey string) string {
|
||||
sum := sha256.Sum256([]byte(apiKey))
|
||||
return hex.EncodeToString(sum[:])
|
||||
}
|
||||
|
||||
func CachedUserID(apiKey string) string {
|
||||
if apiKey == "" {
|
||||
return generateFakeUserID()
|
||||
}
|
||||
|
||||
userIDCacheCleanupOnce.Do(startUserIDCacheCleanup)
|
||||
|
||||
key := userIDCacheKey(apiKey)
|
||||
now := time.Now()
|
||||
|
||||
userIDCacheMu.RLock()
|
||||
entry, ok := userIDCache[key]
|
||||
valid := ok && entry.value != "" && entry.expire.After(now) && isValidUserID(entry.value)
|
||||
userIDCacheMu.RUnlock()
|
||||
if valid {
|
||||
userIDCacheMu.Lock()
|
||||
entry = userIDCache[key]
|
||||
if entry.value != "" && entry.expire.After(now) && isValidUserID(entry.value) {
|
||||
entry.expire = now.Add(userIDTTL)
|
||||
userIDCache[key] = entry
|
||||
userIDCacheMu.Unlock()
|
||||
return entry.value
|
||||
}
|
||||
userIDCacheMu.Unlock()
|
||||
}
|
||||
|
||||
newID := generateFakeUserID()
|
||||
|
||||
userIDCacheMu.Lock()
|
||||
entry, ok = userIDCache[key]
|
||||
if !ok || entry.value == "" || !entry.expire.After(now) || !isValidUserID(entry.value) {
|
||||
entry.value = newID
|
||||
}
|
||||
entry.expire = now.Add(userIDTTL)
|
||||
userIDCache[key] = entry
|
||||
userIDCacheMu.Unlock()
|
||||
return entry.value
|
||||
}
|
||||
@@ -0,0 +1,86 @@
|
||||
package helps
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func resetUserIDCache() {
|
||||
userIDCacheMu.Lock()
|
||||
userIDCache = make(map[string]userIDCacheEntry)
|
||||
userIDCacheMu.Unlock()
|
||||
}
|
||||
|
||||
func TestCachedUserID_ReusesWithinTTL(t *testing.T) {
|
||||
resetUserIDCache()
|
||||
|
||||
first := CachedUserID("api-key-1")
|
||||
second := CachedUserID("api-key-1")
|
||||
|
||||
if first == "" {
|
||||
t.Fatal("expected generated user_id to be non-empty")
|
||||
}
|
||||
if first != second {
|
||||
t.Fatalf("expected cached user_id to be reused, got %q and %q", first, second)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCachedUserID_ExpiresAfterTTL(t *testing.T) {
|
||||
resetUserIDCache()
|
||||
|
||||
expiredID := CachedUserID("api-key-expired")
|
||||
cacheKey := userIDCacheKey("api-key-expired")
|
||||
userIDCacheMu.Lock()
|
||||
userIDCache[cacheKey] = userIDCacheEntry{
|
||||
value: expiredID,
|
||||
expire: time.Now().Add(-time.Minute),
|
||||
}
|
||||
userIDCacheMu.Unlock()
|
||||
|
||||
newID := CachedUserID("api-key-expired")
|
||||
if newID == expiredID {
|
||||
t.Fatalf("expected expired user_id to be replaced, got %q", newID)
|
||||
}
|
||||
if newID == "" {
|
||||
t.Fatal("expected regenerated user_id to be non-empty")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCachedUserID_IsScopedByAPIKey(t *testing.T) {
|
||||
resetUserIDCache()
|
||||
|
||||
first := CachedUserID("api-key-1")
|
||||
second := CachedUserID("api-key-2")
|
||||
|
||||
if first == second {
|
||||
t.Fatalf("expected different API keys to have different user_ids, got %q", first)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCachedUserID_RenewsTTLOnHit(t *testing.T) {
|
||||
resetUserIDCache()
|
||||
|
||||
key := "api-key-renew"
|
||||
id := CachedUserID(key)
|
||||
cacheKey := userIDCacheKey(key)
|
||||
|
||||
soon := time.Now()
|
||||
userIDCacheMu.Lock()
|
||||
userIDCache[cacheKey] = userIDCacheEntry{
|
||||
value: id,
|
||||
expire: soon.Add(2 * time.Second),
|
||||
}
|
||||
userIDCacheMu.Unlock()
|
||||
|
||||
if refreshed := CachedUserID(key); refreshed != id {
|
||||
t.Fatalf("expected cached user_id to be reused before expiry, got %q", refreshed)
|
||||
}
|
||||
|
||||
userIDCacheMu.RLock()
|
||||
entry := userIDCache[cacheKey]
|
||||
userIDCacheMu.RUnlock()
|
||||
|
||||
if entry.expire.Sub(soon) < 30*time.Minute {
|
||||
t.Fatalf("expected TTL to renew, got %v remaining", entry.expire.Sub(soon))
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user