diff --git a/internal/api/buffered_conn.go b/internal/api/buffered_conn.go new file mode 100644 index 00000000..5eb55f96 --- /dev/null +++ b/internal/api/buffered_conn.go @@ -0,0 +1,32 @@ +package api + +import ( + "bufio" + "crypto/tls" + "net" +) + +type bufferedConn struct { + net.Conn + reader *bufio.Reader +} + +func (c *bufferedConn) Read(p []byte) (int, error) { + if c == nil { + return 0, net.ErrClosed + } + if c.reader == nil { + return c.Conn.Read(p) + } + return c.reader.Read(p) +} + +func (c *bufferedConn) ConnectionState() tls.ConnectionState { + if c == nil || c.Conn == nil { + return tls.ConnectionState{} + } + if stater, ok := c.Conn.(interface{ ConnectionState() tls.ConnectionState }); ok { + return stater.ConnectionState() + } + return tls.ConnectionState{} +} diff --git a/internal/api/handlers/management/handler.go b/internal/api/handlers/management/handler.go index 30cc9738..ee96ed79 100644 --- a/internal/api/handlers/management/handler.go +++ b/internal/api/handlers/management/handler.go @@ -152,9 +152,6 @@ func (h *Handler) SetPostAuthHook(hook coreauth.PostAuthHook) { // All requests (local and remote) require a valid management key. // Additionally, remote access requires allow-remote-management=true. func (h *Handler) Middleware() gin.HandlerFunc { - const maxFailures = 5 - const banDuration = 30 * time.Minute - return func(c *gin.Context) { c.Header("X-CPA-VERSION", buildinfo.Version) c.Header("X-CPA-COMMIT", buildinfo.Commit) @@ -162,64 +159,6 @@ func (h *Handler) Middleware() gin.HandlerFunc { clientIP := c.ClientIP() localClient := clientIP == "127.0.0.1" || clientIP == "::1" - cfg := h.cfg - var ( - allowRemote bool - secretHash string - ) - if cfg != nil { - allowRemote = cfg.RemoteManagement.AllowRemote - secretHash = cfg.RemoteManagement.SecretKey - } - if h.allowRemoteOverride { - allowRemote = true - } - envSecret := h.envSecret - - fail := func() {} - if !localClient { - h.attemptsMu.Lock() - ai := h.failedAttempts[clientIP] - if ai != nil { - if !ai.blockedUntil.IsZero() { - if time.Now().Before(ai.blockedUntil) { - remaining := time.Until(ai.blockedUntil).Round(time.Second) - h.attemptsMu.Unlock() - c.AbortWithStatusJSON(http.StatusForbidden, gin.H{"error": fmt.Sprintf("IP banned due to too many failed attempts. Try again in %s", remaining)}) - return - } - // Ban expired, reset state - ai.blockedUntil = time.Time{} - ai.count = 0 - } - } - h.attemptsMu.Unlock() - - if !allowRemote { - c.AbortWithStatusJSON(http.StatusForbidden, gin.H{"error": "remote management disabled"}) - return - } - - fail = func() { - h.attemptsMu.Lock() - aip := h.failedAttempts[clientIP] - if aip == nil { - aip = &attemptInfo{} - h.failedAttempts[clientIP] = aip - } - aip.count++ - aip.lastActivity = time.Now() - if aip.count >= maxFailures { - aip.blockedUntil = time.Now().Add(banDuration) - aip.count = 0 - } - h.attemptsMu.Unlock() - } - } - if secretHash == "" && envSecret == "" { - c.AbortWithStatusJSON(http.StatusForbidden, gin.H{"error": "remote management key not set"}) - return - } // Accept either Authorization: Bearer or X-Management-Key var provided string @@ -235,44 +174,98 @@ func (h *Handler) Middleware() gin.HandlerFunc { provided = c.GetHeader("X-Management-Key") } - if provided == "" { - if !localClient { - fail() - } - c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "missing management key"}) + allowed, statusCode, errMsg := h.AuthenticateManagementKey(clientIP, localClient, provided) + if !allowed { + c.AbortWithStatusJSON(statusCode, gin.H{"error": errMsg}) return } + c.Next() + } +} - if localClient { - if lp := h.localPassword; lp != "" { - if subtle.ConstantTimeCompare([]byte(provided), []byte(lp)) == 1 { - c.Next() - return +// AuthenticateManagementKey verifies the provided management key for the given client. +// It mirrors the behaviour of Middleware() so non-HTTP callers can reuse the same logic. +func (h *Handler) AuthenticateManagementKey(clientIP string, localClient bool, provided string) (bool, int, string) { + const maxFailures = 5 + const banDuration = 30 * time.Minute + + if h == nil { + return false, http.StatusForbidden, "remote management disabled" + } + + cfg := h.cfg + var ( + allowRemote bool + secretHash string + ) + if cfg != nil { + allowRemote = cfg.RemoteManagement.AllowRemote + secretHash = cfg.RemoteManagement.SecretKey + } + if h.allowRemoteOverride { + allowRemote = true + } + envSecret := h.envSecret + + fail := func() {} + if !localClient { + h.attemptsMu.Lock() + ai := h.failedAttempts[clientIP] + if ai != nil { + if !ai.blockedUntil.IsZero() { + if time.Now().Before(ai.blockedUntil) { + remaining := time.Until(ai.blockedUntil).Round(time.Second) + h.attemptsMu.Unlock() + return false, http.StatusForbidden, fmt.Sprintf("IP banned due to too many failed attempts. Try again in %s", remaining) } + // Ban expired, reset state + ai.blockedUntil = time.Time{} + ai.count = 0 } } + h.attemptsMu.Unlock() - if envSecret != "" && subtle.ConstantTimeCompare([]byte(provided), []byte(envSecret)) == 1 { - if !localClient { - h.attemptsMu.Lock() - if ai := h.failedAttempts[clientIP]; ai != nil { - ai.count = 0 - ai.blockedUntil = time.Time{} - } - h.attemptsMu.Unlock() - } - c.Next() - return + if !allowRemote { + return false, http.StatusForbidden, "remote management disabled" } - if secretHash == "" || bcrypt.CompareHashAndPassword([]byte(secretHash), []byte(provided)) != nil { - if !localClient { - fail() + fail = func() { + h.attemptsMu.Lock() + aip := h.failedAttempts[clientIP] + if aip == nil { + aip = &attemptInfo{} + h.failedAttempts[clientIP] = aip } - c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "invalid management key"}) - return + aip.count++ + aip.lastActivity = time.Now() + if aip.count >= maxFailures { + aip.blockedUntil = time.Now().Add(banDuration) + aip.count = 0 + } + h.attemptsMu.Unlock() } + } + if secretHash == "" && envSecret == "" { + return false, http.StatusForbidden, "remote management key not set" + } + + if provided == "" { + if !localClient { + fail() + } + return false, http.StatusUnauthorized, "missing management key" + } + + if localClient { + if lp := h.localPassword; lp != "" { + if subtle.ConstantTimeCompare([]byte(provided), []byte(lp)) == 1 { + return true, 0, "" + } + } + } + + if envSecret != "" && subtle.ConstantTimeCompare([]byte(provided), []byte(envSecret)) == 1 { if !localClient { h.attemptsMu.Lock() if ai := h.failedAttempts[clientIP]; ai != nil { @@ -281,9 +274,26 @@ func (h *Handler) Middleware() gin.HandlerFunc { } h.attemptsMu.Unlock() } - - c.Next() + return true, 0, "" } + + if secretHash == "" || bcrypt.CompareHashAndPassword([]byte(secretHash), []byte(provided)) != nil { + if !localClient { + fail() + } + return false, http.StatusUnauthorized, "invalid management key" + } + + if !localClient { + h.attemptsMu.Lock() + if ai := h.failedAttempts[clientIP]; ai != nil { + ai.count = 0 + ai.blockedUntil = time.Time{} + } + h.attemptsMu.Unlock() + } + + return true, 0, "" } // persist saves the current in-memory config to disk. diff --git a/internal/api/mux_listener.go b/internal/api/mux_listener.go new file mode 100644 index 00000000..d9a0c9f4 --- /dev/null +++ b/internal/api/mux_listener.go @@ -0,0 +1,68 @@ +package api + +import ( + "net" + "sync" +) + +type muxListener struct { + addr net.Addr + connCh chan net.Conn + closeCh chan struct{} + once sync.Once +} + +func newMuxListener(addr net.Addr, buffer int) *muxListener { + if buffer <= 0 { + buffer = 1 + } + return &muxListener{ + addr: addr, + connCh: make(chan net.Conn, buffer), + closeCh: make(chan struct{}), + } +} + +func (l *muxListener) Put(conn net.Conn) error { + if conn == nil { + return nil + } + select { + case <-l.closeCh: + return net.ErrClosed + case l.connCh <- conn: + return nil + } +} + +func (l *muxListener) Accept() (net.Conn, error) { + select { + case <-l.closeCh: + return nil, net.ErrClosed + case conn := <-l.connCh: + if conn == nil { + return nil, net.ErrClosed + } + return conn, nil + } +} + +func (l *muxListener) Close() error { + if l == nil { + return nil + } + l.once.Do(func() { + close(l.closeCh) + }) + return nil +} + +func (l *muxListener) Addr() net.Addr { + if l == nil { + return &net.TCPAddr{} + } + if l.addr == nil { + return &net.TCPAddr{} + } + return l.addr +} diff --git a/internal/api/protocol_multiplexer.go b/internal/api/protocol_multiplexer.go new file mode 100644 index 00000000..14068dc5 --- /dev/null +++ b/internal/api/protocol_multiplexer.go @@ -0,0 +1,109 @@ +package api + +import ( + "bufio" + "crypto/tls" + "errors" + "net" + "net/http" + "strings" + + log "github.com/sirupsen/logrus" +) + +func normalizeHTTPServeError(err error) error { + if err == nil { + return nil + } + if errors.Is(err, net.ErrClosed) { + return nil + } + if errors.Is(err, http.ErrServerClosed) { + return nil + } + return err +} + +func normalizeListenerError(err error) error { + if err == nil { + return nil + } + if errors.Is(err, net.ErrClosed) { + return nil + } + return err +} + +func (s *Server) acceptMuxConnections(listener net.Listener, httpListener *muxListener) error { + if s == nil || listener == nil { + return net.ErrClosed + } + + for { + conn, errAccept := listener.Accept() + if errAccept != nil { + return errAccept + } + if conn == nil { + continue + } + + tlsConn, ok := conn.(*tls.Conn) + if ok { + if errHandshake := tlsConn.Handshake(); errHandshake != nil { + if errClose := conn.Close(); errClose != nil { + log.Errorf("failed to close connection after TLS handshake error: %v", errClose) + } + continue + } + proto := strings.TrimSpace(tlsConn.ConnectionState().NegotiatedProtocol) + if proto == "h2" || proto == "http/1.1" { + if httpListener == nil { + if errClose := conn.Close(); errClose != nil { + log.Errorf("failed to close connection: %v", errClose) + } + continue + } + if errPut := httpListener.Put(tlsConn); errPut != nil { + if errClose := conn.Close(); errClose != nil { + log.Errorf("failed to close connection after HTTP routing failure: %v", errClose) + } + } + continue + } + } + + reader := bufio.NewReader(conn) + prefix, errPeek := reader.Peek(1) + if errPeek != nil { + if errClose := conn.Close(); errClose != nil { + log.Errorf("failed to close connection after protocol peek failure: %v", errClose) + } + continue + } + + if isRedisRESPPrefix(prefix[0]) { + if !s.managementRoutesEnabled.Load() { + if errClose := conn.Close(); errClose != nil { + log.Errorf("failed to close redis connection while management is disabled: %v", errClose) + } + continue + } + go s.handleRedisConnection(conn, reader) + continue + } + + if httpListener == nil { + if errClose := conn.Close(); errClose != nil { + log.Errorf("failed to close connection without HTTP listener: %v", errClose) + } + continue + } + + if errPut := httpListener.Put(&bufferedConn{Conn: conn, reader: reader}); errPut != nil { + if errClose := conn.Close(); errClose != nil { + log.Errorf("failed to close connection after HTTP routing failure: %v", errClose) + } + } + } +} diff --git a/internal/api/redis_queue_protocol.go b/internal/api/redis_queue_protocol.go new file mode 100644 index 00000000..053a99c7 --- /dev/null +++ b/internal/api/redis_queue_protocol.go @@ -0,0 +1,317 @@ +package api + +import ( + "bufio" + "errors" + "fmt" + "io" + "net" + "strconv" + "strings" + + "github.com/router-for-me/CLIProxyAPI/v6/internal/redisqueue" + log "github.com/sirupsen/logrus" +) + +func isRedisRESPPrefix(prefix byte) bool { + switch prefix { + case '*', '$', '+', '-', ':': + return true + default: + return false + } +} + +func (s *Server) handleRedisConnection(conn net.Conn, reader *bufio.Reader) { + if s == nil || conn == nil || reader == nil { + return + } + + clientIP, localClient := resolveRemoteIP(conn.RemoteAddr()) + authed := false + writer := bufio.NewWriter(conn) + defer func() { + if errClose := conn.Close(); errClose != nil { + log.Errorf("redis connection close error: %v", errClose) + } + }() + + flush := func() bool { + if errFlush := writer.Flush(); errFlush != nil { + log.Errorf("redis protocol flush error: %v", errFlush) + return false + } + return true + } + + for { + if !s.managementRoutesEnabled.Load() { + return + } + + args, err := readRESPArray(reader) + if err != nil { + if !errors.Is(err, io.EOF) { + _ = writeRedisError(writer, "ERR "+err.Error()) + _ = writer.Flush() + } + return + } + if len(args) == 0 { + _ = writeRedisError(writer, "ERR empty command") + if !flush() { + return + } + continue + } + + cmd := strings.ToUpper(strings.TrimSpace(args[0])) + switch cmd { + case "AUTH": + password, ok := parseAuthPassword(args) + if !ok { + _ = writeRedisError(writer, "ERR wrong number of arguments for 'auth' command") + if !flush() { + return + } + continue + } + if s.mgmt == nil { + _ = writeRedisError(writer, "ERR remote management disabled") + if !flush() { + return + } + continue + } + allowed, _, errMsg := s.mgmt.AuthenticateManagementKey(clientIP, localClient, password) + if !allowed { + _ = writeRedisError(writer, "ERR "+errMsg) + if !flush() { + return + } + continue + } + authed = true + _ = writeRedisSimpleString(writer, "OK") + if !flush() { + return + } + case "LPOP", "RPOP": + if !authed { + _ = writeRedisError(writer, "NOAUTH Authentication required.") + if !flush() { + return + } + continue + } + count, hasCount, ok := parsePopCount(args) + if !ok { + _ = writeRedisError(writer, "ERR wrong number of arguments for '"+strings.ToLower(cmd)+"' command") + if !flush() { + return + } + continue + } + if count <= 0 { + _ = writeRedisError(writer, "ERR value is not an integer or out of range") + if !flush() { + return + } + continue + } + items := redisqueue.PopOldest(count) + if hasCount { + _ = writeRedisArrayOfBulkStrings(writer, items) + if !flush() { + return + } + continue + } + if len(items) == 0 { + _ = writeRedisNilBulkString(writer) + if !flush() { + return + } + continue + } + _ = writeRedisBulkString(writer, items[0]) + if !flush() { + return + } + default: + _ = writeRedisError(writer, fmt.Sprintf("ERR unknown command '%s'", strings.ToLower(cmd))) + if !flush() { + return + } + } + } +} + +func resolveRemoteIP(addr net.Addr) (ip string, localClient bool) { + if addr == nil { + return "", false + } + host := addr.String() + if h, _, err := net.SplitHostPort(host); err == nil { + host = h + } + host = strings.TrimSpace(host) + localClient = host == "127.0.0.1" || host == "::1" + return host, localClient +} + +func parseAuthPassword(args []string) (string, bool) { + switch len(args) { + case 2: + return args[1], true + case 3: + // Support AUTH by ignoring username for compatibility. + return args[2], true + default: + return "", false + } +} + +func parsePopCount(args []string) (count int, hasCount bool, ok bool) { + if len(args) != 2 && len(args) != 3 { + return 0, false, false + } + if len(args) == 2 { + return 1, false, true + } + parsed, err := strconv.Atoi(strings.TrimSpace(args[2])) + if err != nil { + return 0, true, true + } + return parsed, true, true +} + +func readRESPArray(reader *bufio.Reader) ([]string, error) { + prefix, err := reader.ReadByte() + if err != nil { + return nil, err + } + if prefix != '*' { + return nil, fmt.Errorf("protocol error") + } + line, err := readRESPLine(reader) + if err != nil { + return nil, err + } + count, err := strconv.Atoi(line) + if err != nil || count < 0 { + return nil, fmt.Errorf("protocol error") + } + args := make([]string, 0, count) + for i := 0; i < count; i++ { + value, err := readRESPString(reader) + if err != nil { + return nil, err + } + args = append(args, value) + } + return args, nil +} + +func readRESPString(reader *bufio.Reader) (string, error) { + prefix, err := reader.ReadByte() + if err != nil { + return "", err + } + switch prefix { + case '$': + return readRESPBulkString(reader) + case '+', ':': + return readRESPLine(reader) + default: + return "", fmt.Errorf("protocol error") + } +} + +func readRESPBulkString(reader *bufio.Reader) (string, error) { + line, err := readRESPLine(reader) + if err != nil { + return "", err + } + length, err := strconv.Atoi(line) + if err != nil { + return "", fmt.Errorf("protocol error") + } + if length < 0 { + return "", nil + } + buf := make([]byte, length+2) + if _, err := io.ReadFull(reader, buf); err != nil { + return "", err + } + if length+2 < 2 || buf[length] != '\r' || buf[length+1] != '\n' { + return "", fmt.Errorf("protocol error") + } + return string(buf[:length]), nil +} + +func readRESPLine(reader *bufio.Reader) (string, error) { + line, err := reader.ReadString('\n') + if err != nil { + return "", err + } + line = strings.TrimSuffix(line, "\n") + line = strings.TrimSuffix(line, "\r") + return line, nil +} + +func writeRedisSimpleString(writer *bufio.Writer, value string) error { + if writer == nil { + return net.ErrClosed + } + _, err := writer.WriteString("+" + value + "\r\n") + return err +} + +func writeRedisError(writer *bufio.Writer, message string) error { + if writer == nil { + return net.ErrClosed + } + _, err := writer.WriteString("-" + message + "\r\n") + return err +} + +func writeRedisNilBulkString(writer *bufio.Writer) error { + if writer == nil { + return net.ErrClosed + } + _, err := writer.WriteString("$-1\r\n") + return err +} + +func writeRedisBulkString(writer *bufio.Writer, payload []byte) error { + if writer == nil { + return net.ErrClosed + } + if payload == nil { + return writeRedisNilBulkString(writer) + } + if _, err := writer.WriteString("$" + strconv.Itoa(len(payload)) + "\r\n"); err != nil { + return err + } + if _, err := writer.Write(payload); err != nil { + return err + } + _, err := writer.WriteString("\r\n") + return err +} + +func writeRedisArrayOfBulkStrings(writer *bufio.Writer, items [][]byte) error { + if writer == nil { + return net.ErrClosed + } + if _, err := writer.WriteString("*" + strconv.Itoa(len(items)) + "\r\n"); err != nil { + return err + } + for i := range items { + if err := writeRedisBulkString(writer, items[i]); err != nil { + return err + } + } + return nil +} diff --git a/internal/api/redis_queue_protocol_integration_test.go b/internal/api/redis_queue_protocol_integration_test.go new file mode 100644 index 00000000..18ab0279 --- /dev/null +++ b/internal/api/redis_queue_protocol_integration_test.go @@ -0,0 +1,304 @@ +package api + +import ( + "bufio" + "bytes" + "errors" + "fmt" + "io" + "net" + "strconv" + "strings" + "testing" + "time" + + "github.com/router-for-me/CLIProxyAPI/v6/internal/redisqueue" +) + +func startRedisMuxListener(t *testing.T, server *Server) (addr string, stop func()) { + t.Helper() + + listener, errListen := net.Listen("tcp", "127.0.0.1:0") + if errListen != nil { + t.Fatalf("failed to listen: %v", errListen) + } + + errCh := make(chan error, 1) + go func() { + errCh <- server.acceptMuxConnections(listener, nil) + }() + + stop = func() { + _ = listener.Close() + select { + case err := <-errCh: + if err != nil && !errors.Is(err, net.ErrClosed) { + t.Errorf("accept loop returned unexpected error: %v", err) + } + case <-time.After(2 * time.Second): + t.Errorf("timeout waiting for accept loop to exit") + } + } + + return listener.Addr().String(), stop +} + +func writeTestRESPCommand(conn net.Conn, args ...string) error { + if conn == nil { + return net.ErrClosed + } + if len(args) == 0 { + return nil + } + + var buf bytes.Buffer + fmt.Fprintf(&buf, "*%d\r\n", len(args)) + for _, arg := range args { + fmt.Fprintf(&buf, "$%d\r\n%s\r\n", len(arg), arg) + } + _, err := conn.Write(buf.Bytes()) + return err +} + +func readTestRESPLine(r *bufio.Reader) (string, error) { + line, err := r.ReadString('\n') + if err != nil { + return "", err + } + if !strings.HasSuffix(line, "\r\n") { + return "", fmt.Errorf("invalid RESP line terminator: %q", line) + } + return strings.TrimSuffix(line, "\r\n"), nil +} + +func readTestRESPSimpleString(r *bufio.Reader) (string, error) { + prefix, err := r.ReadByte() + if err != nil { + return "", err + } + if prefix != '+' { + return "", fmt.Errorf("expected simple string prefix '+', got %q", prefix) + } + return readTestRESPLine(r) +} + +func readTestRESPError(r *bufio.Reader) (string, error) { + prefix, err := r.ReadByte() + if err != nil { + return "", err + } + if prefix != '-' { + return "", fmt.Errorf("expected error prefix '-', got %q", prefix) + } + return readTestRESPLine(r) +} + +func readTestRESPBulkString(r *bufio.Reader) ([]byte, error) { + prefix, err := r.ReadByte() + if err != nil { + return nil, err + } + if prefix != '$' { + return nil, fmt.Errorf("expected bulk string prefix '$', got %q", prefix) + } + + line, err := readTestRESPLine(r) + if err != nil { + return nil, err + } + length, err := strconv.Atoi(line) + if err != nil { + return nil, fmt.Errorf("invalid bulk string length %q: %v", line, err) + } + if length == -1 { + return nil, nil + } + if length < -1 { + return nil, fmt.Errorf("invalid bulk string length %d", length) + } + + payload := make([]byte, length+2) + if _, err := io.ReadFull(r, payload); err != nil { + return nil, err + } + if payload[length] != '\r' || payload[length+1] != '\n' { + return nil, fmt.Errorf("invalid bulk string terminator") + } + return payload[:length], nil +} + +func readRESPArrayOfBulkStrings(r *bufio.Reader) ([][]byte, error) { + prefix, err := r.ReadByte() + if err != nil { + return nil, err + } + if prefix != '*' { + return nil, fmt.Errorf("expected array prefix '*', got %q", prefix) + } + + line, err := readTestRESPLine(r) + if err != nil { + return nil, err + } + count, err := strconv.Atoi(line) + if err != nil { + return nil, fmt.Errorf("invalid array length %q: %v", line, err) + } + if count < 0 { + return nil, fmt.Errorf("invalid array length %d", count) + } + + out := make([][]byte, 0, count) + for i := 0; i < count; i++ { + item, err := readTestRESPBulkString(r) + if err != nil { + return nil, err + } + out = append(out, item) + } + return out, nil +} + +func TestRedisProtocol_ManagementDisabled_RejectsConnection(t *testing.T) { + t.Setenv("MANAGEMENT_PASSWORD", "") + redisqueue.SetEnabled(false) + + server := newTestServer(t) + if server.managementRoutesEnabled.Load() { + t.Fatalf("expected managementRoutesEnabled to be false") + } + + addr, stop := startRedisMuxListener(t, server) + t.Cleanup(stop) + + conn, errDial := net.DialTimeout("tcp", addr, time.Second) + if errDial != nil { + t.Fatalf("failed to dial redis listener: %v", errDial) + } + t.Cleanup(func() { _ = conn.Close() }) + + _ = conn.SetDeadline(time.Now().Add(2 * time.Second)) + if errWrite := writeTestRESPCommand(conn, "PING"); errWrite != nil { + t.Fatalf("failed to write RESP command: %v", errWrite) + } + + buf := make([]byte, 1) + _, errRead := conn.Read(buf) + if errRead == nil { + t.Fatalf("expected connection to be closed when management is disabled") + } + if ne, ok := errRead.(net.Error); ok && ne.Timeout() { + t.Fatalf("expected connection to be closed when management is disabled, got timeout: %v", errRead) + } +} + +func TestRedisProtocol_AUTH_And_PopContracts(t *testing.T) { + const managementPassword = "test-management-password" + + t.Setenv("MANAGEMENT_PASSWORD", managementPassword) + redisqueue.SetEnabled(false) + t.Cleanup(func() { redisqueue.SetEnabled(false) }) + + server := newTestServer(t) + if !server.managementRoutesEnabled.Load() { + t.Fatalf("expected managementRoutesEnabled to be true") + } + + addr, stop := startRedisMuxListener(t, server) + t.Cleanup(stop) + + conn, errDial := net.DialTimeout("tcp", addr, time.Second) + if errDial != nil { + t.Fatalf("failed to dial redis listener: %v", errDial) + } + t.Cleanup(func() { _ = conn.Close() }) + + reader := bufio.NewReader(conn) + + _ = conn.SetDeadline(time.Now().Add(5 * time.Second)) + + if errWrite := writeTestRESPCommand(conn, "AUTH", "test-key"); errWrite != nil { + t.Fatalf("failed to write AUTH command: %v", errWrite) + } + if msg, err := readTestRESPError(reader); err != nil { + t.Fatalf("failed to read AUTH error: %v", err) + } else if msg != "ERR invalid management key" { + t.Fatalf("unexpected AUTH error: %q", msg) + } + + if errWrite := writeTestRESPCommand(conn, "LPOP", "queue"); errWrite != nil { + t.Fatalf("failed to write LPOP command: %v", errWrite) + } + if msg, err := readTestRESPError(reader); err != nil { + t.Fatalf("failed to read LPOP NOAUTH error: %v", err) + } else if msg != "NOAUTH Authentication required." { + t.Fatalf("unexpected LPOP NOAUTH error: %q", msg) + } + + if errWrite := writeTestRESPCommand(conn, "AUTH", managementPassword); errWrite != nil { + t.Fatalf("failed to write AUTH command: %v", errWrite) + } + if msg, err := readTestRESPSimpleString(reader); err != nil { + t.Fatalf("failed to read AUTH response: %v", err) + } else if msg != "OK" { + t.Fatalf("unexpected AUTH response: %q", msg) + } + + if !redisqueue.Enabled() { + t.Fatalf("expected redisqueue to be enabled") + } + redisqueue.Enqueue([]byte("a")) + redisqueue.Enqueue([]byte("b")) + redisqueue.Enqueue([]byte("c")) + + if errWrite := writeTestRESPCommand(conn, "RPOP", "queue"); errWrite != nil { + t.Fatalf("failed to write RPOP command: %v", errWrite) + } + if item, err := readTestRESPBulkString(reader); err != nil { + t.Fatalf("failed to read RPOP response: %v", err) + } else if string(item) != "a" { + t.Fatalf("unexpected RPOP item: %q", string(item)) + } + + if errWrite := writeTestRESPCommand(conn, "LPOP", "queue"); errWrite != nil { + t.Fatalf("failed to write LPOP command: %v", errWrite) + } + if item, err := readTestRESPBulkString(reader); err != nil { + t.Fatalf("failed to read LPOP response: %v", err) + } else if string(item) != "b" { + t.Fatalf("unexpected LPOP item: %q", string(item)) + } + + if errWrite := writeTestRESPCommand(conn, "RPOP", "queue", "10"); errWrite != nil { + t.Fatalf("failed to write RPOP count command: %v", errWrite) + } + items, errItems := readRESPArrayOfBulkStrings(reader) + if errItems != nil { + t.Fatalf("failed to read RPOP count response: %v", errItems) + } + if len(items) != 1 || string(items[0]) != "c" { + t.Fatalf("unexpected RPOP count items: %#v", items) + } + + if errWrite := writeTestRESPCommand(conn, "LPOP", "queue"); errWrite != nil { + t.Fatalf("failed to write LPOP empty command: %v", errWrite) + } + item, errItem := readTestRESPBulkString(reader) + if errItem != nil { + t.Fatalf("failed to read LPOP empty response: %v", errItem) + } + if item != nil { + t.Fatalf("expected nil bulk string for empty queue, got %q", string(item)) + } + + if errWrite := writeTestRESPCommand(conn, "RPOP", "queue", "2"); errWrite != nil { + t.Fatalf("failed to write RPOP empty count command: %v", errWrite) + } + emptyItems, errEmpty := readRESPArrayOfBulkStrings(reader) + if errEmpty != nil { + t.Fatalf("failed to read RPOP empty count response: %v", errEmpty) + } + if len(emptyItems) != 0 { + t.Fatalf("expected empty array for empty queue with count, got %#v", emptyItems) + } +} diff --git a/internal/api/server.go b/internal/api/server.go index 32ae3164..e70883b0 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -7,8 +7,10 @@ package api import ( "context" "crypto/subtle" + "crypto/tls" "errors" "fmt" + "net" "net/http" "os" "path/filepath" @@ -28,6 +30,7 @@ import ( "github.com/router-for-me/CLIProxyAPI/v6/internal/config" "github.com/router-for-me/CLIProxyAPI/v6/internal/logging" "github.com/router-for-me/CLIProxyAPI/v6/internal/managementasset" + "github.com/router-for-me/CLIProxyAPI/v6/internal/redisqueue" "github.com/router-for-me/CLIProxyAPI/v6/internal/usage" "github.com/router-for-me/CLIProxyAPI/v6/internal/util" sdkaccess "github.com/router-for-me/CLIProxyAPI/v6/sdk/access" @@ -38,6 +41,7 @@ import ( sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth" "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" log "github.com/sirupsen/logrus" + "golang.org/x/net/http2" "gopkg.in/yaml.v3" ) @@ -127,6 +131,12 @@ type Server struct { // server is the underlying HTTP server. server *http.Server + // muxBaseListener is the shared TCP listener used to serve both HTTP and Redis protocol traffic. + muxBaseListener net.Listener + + // muxHTTPListener receives HTTP connections selected by the multiplexer. + muxHTTPListener *muxListener + // handlers contains the API handlers for processing requests. handlers *handlers.BaseAPIHandler @@ -299,6 +309,7 @@ func NewServer(cfg *config.Config, authManager *auth.Manager, accessManager *sdk // or when a local management password is provided (e.g. TUI mode). hasManagementSecret := cfg.RemoteManagement.SecretKey != "" || envManagementSecret || s.localPassword != "" s.managementRoutesEnabled.Store(hasManagementSecret) + redisqueue.SetEnabled(hasManagementSecret) if hasManagementSecret { s.registerManagementRoutes() } @@ -797,26 +808,98 @@ func (s *Server) Start() error { return fmt.Errorf("failed to start HTTP server: server not initialized") } + addr := s.server.Addr + listener, errListen := net.Listen("tcp", addr) + if errListen != nil { + return fmt.Errorf("failed to start HTTP server: %v", errListen) + } + useTLS := s.cfg != nil && s.cfg.TLS.Enable if useTLS { - cert := strings.TrimSpace(s.cfg.TLS.Cert) - key := strings.TrimSpace(s.cfg.TLS.Key) - if cert == "" || key == "" { + certPath := strings.TrimSpace(s.cfg.TLS.Cert) + keyPath := strings.TrimSpace(s.cfg.TLS.Key) + if certPath == "" || keyPath == "" { + if errClose := listener.Close(); errClose != nil { + log.Errorf("failed to close listener after TLS validation failure: %v", errClose) + } return fmt.Errorf("failed to start HTTPS server: tls.cert or tls.key is empty") } - log.Debugf("Starting API server on %s with TLS", s.server.Addr) - if errServeTLS := s.server.ListenAndServeTLS(cert, key); errServeTLS != nil && !errors.Is(errServeTLS, http.ErrServerClosed) { - return fmt.Errorf("failed to start HTTPS server: %v", errServeTLS) + certPair, errLoad := tls.LoadX509KeyPair(certPath, keyPath) + if errLoad != nil { + if errClose := listener.Close(); errClose != nil { + log.Errorf("failed to close listener after TLS key pair load failure: %v", errClose) + } + return fmt.Errorf("failed to start HTTPS server: %v", errLoad) + } + + tlsConfig := &tls.Config{ + Certificates: []tls.Certificate{certPair}, + NextProtos: []string{"h2", "http/1.1"}, + } + s.server.TLSConfig = tlsConfig + if errHTTP2 := http2.ConfigureServer(s.server, &http2.Server{}); errHTTP2 != nil { + log.Warnf("failed to configure HTTP/2: %v", errHTTP2) + } + listener = tls.NewListener(listener, tlsConfig) + log.Debugf("Starting API server on %s with TLS", addr) + } else { + log.Debugf("Starting API server on %s", addr) + } + + httpListener := newMuxListener(listener.Addr(), 1024) + s.muxBaseListener = listener + s.muxHTTPListener = httpListener + + httpErrCh := make(chan error, 1) + acceptErrCh := make(chan error, 1) + + go func() { + httpErrCh <- s.server.Serve(httpListener) + }() + go func() { + acceptErrCh <- s.acceptMuxConnections(listener, httpListener) + }() + + select { + case errServe := <-httpErrCh: + if s.muxBaseListener != nil { + if errClose := s.muxBaseListener.Close(); errClose != nil && !errors.Is(errClose, net.ErrClosed) { + log.Debugf("failed to close shared listener after HTTP serve exit: %v", errClose) + } + } + if s.muxHTTPListener != nil { + _ = s.muxHTTPListener.Close() + } + errAccept := <-acceptErrCh + errServe = normalizeHTTPServeError(errServe) + errAccept = normalizeListenerError(errAccept) + if errServe != nil { + return fmt.Errorf("failed to start HTTP server: %v", errServe) + } + if errAccept != nil { + return fmt.Errorf("failed to start HTTP server: %v", errAccept) + } + return nil + case errAccept := <-acceptErrCh: + if s.muxHTTPListener != nil { + _ = s.muxHTTPListener.Close() + } + if s.muxBaseListener != nil { + if errClose := s.muxBaseListener.Close(); errClose != nil && !errors.Is(errClose, net.ErrClosed) { + log.Debugf("failed to close shared listener after accept loop exit: %v", errClose) + } + } + errServe := <-httpErrCh + errServe = normalizeHTTPServeError(errServe) + errAccept = normalizeListenerError(errAccept) + if errAccept != nil { + return fmt.Errorf("failed to start HTTP server: %v", errAccept) + } + if errServe != nil { + return fmt.Errorf("failed to start HTTP server: %v", errServe) } return nil } - - log.Debugf("Starting API server on %s", s.server.Addr) - if errServe := s.server.ListenAndServe(); errServe != nil && !errors.Is(errServe, http.ErrServerClosed) { - return fmt.Errorf("failed to start HTTP server: %v", errServe) - } - - return nil } // Stop gracefully shuts down the API server without interrupting any @@ -837,6 +920,15 @@ func (s *Server) Stop(ctx context.Context) error { } } + if s.muxHTTPListener != nil { + _ = s.muxHTTPListener.Close() + } + if s.muxBaseListener != nil { + if errClose := s.muxBaseListener.Close(); errClose != nil && !errors.Is(errClose, net.ErrClosed) { + log.Debugf("failed to close shared listener: %v", errClose) + } + } + // Shutdown the HTTP server. if err := s.server.Shutdown(ctx); err != nil { return fmt.Errorf("failed to shutdown HTTP server: %v", err) @@ -963,6 +1055,7 @@ func (s *Server) UpdateClients(cfg *config.Config) { s.managementRoutesEnabled.Store(!newSecretEmpty) } } + redisqueue.SetEnabled(s.managementRoutesEnabled.Load()) s.applyAccessConfig(oldCfg, cfg) s.cfg = cfg diff --git a/internal/redisqueue/plugin.go b/internal/redisqueue/plugin.go new file mode 100644 index 00000000..a805e5da --- /dev/null +++ b/internal/redisqueue/plugin.go @@ -0,0 +1,145 @@ +package redisqueue + +import ( + "context" + "encoding/json" + "net/http" + "strings" + "time" + + "github.com/gin-gonic/gin" + internallogging "github.com/router-for-me/CLIProxyAPI/v6/internal/logging" + internalusage "github.com/router-for-me/CLIProxyAPI/v6/internal/usage" + coreusage "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/usage" +) + +func init() { + coreusage.RegisterPlugin(&usageQueuePlugin{}) +} + +type usageQueuePlugin struct{} + +func (p *usageQueuePlugin) HandleUsage(ctx context.Context, record coreusage.Record) { + if p == nil { + return + } + if !Enabled() || !internalusage.StatisticsEnabled() { + return + } + + timestamp := record.RequestedAt + if timestamp.IsZero() { + timestamp = time.Now() + } + + modelName := strings.TrimSpace(record.Model) + if modelName == "" { + modelName = "unknown" + } + provider := strings.TrimSpace(record.Provider) + if provider == "" { + provider = "unknown" + } + authType := strings.TrimSpace(record.AuthType) + if authType == "" { + authType = "unknown" + } + apiKey := strings.TrimSpace(record.APIKey) + requestID := strings.TrimSpace(internallogging.GetRequestID(ctx)) + if requestID == "" { + if ginCtx, ok := ctx.Value("gin").(*gin.Context); ok && ginCtx != nil { + requestID = strings.TrimSpace(internallogging.GetGinRequestID(ginCtx)) + } + } + + tokens := internalusage.TokenStats{ + InputTokens: record.Detail.InputTokens, + OutputTokens: record.Detail.OutputTokens, + ReasoningTokens: record.Detail.ReasoningTokens, + CachedTokens: record.Detail.CachedTokens, + TotalTokens: record.Detail.TotalTokens, + } + if tokens.TotalTokens == 0 { + tokens.TotalTokens = tokens.InputTokens + tokens.OutputTokens + tokens.ReasoningTokens + } + if tokens.TotalTokens == 0 { + tokens.TotalTokens = tokens.InputTokens + tokens.OutputTokens + tokens.ReasoningTokens + tokens.CachedTokens + } + + failed := record.Failed + if !failed { + failed = !resolveSuccess(ctx) + } + + detail := internalusage.RequestDetail{ + Timestamp: timestamp, + LatencyMs: record.Latency.Milliseconds(), + Source: record.Source, + AuthIndex: record.AuthIndex, + Tokens: tokens, + Failed: failed, + } + + payload, err := json.Marshal(queuedUsageDetail{ + RequestDetail: detail, + Provider: provider, + Model: modelName, + Endpoint: resolveEndpoint(ctx), + AuthType: authType, + APIKey: apiKey, + RequestID: requestID, + }) + if err != nil { + return + } + Enqueue(payload) +} + +type queuedUsageDetail struct { + internalusage.RequestDetail + Provider string `json:"provider"` + Model string `json:"model"` + Endpoint string `json:"endpoint"` + AuthType string `json:"auth_type"` + APIKey string `json:"api_key"` + RequestID string `json:"request_id"` +} + +func resolveSuccess(ctx context.Context) bool { + if ctx == nil { + return true + } + ginCtx, ok := ctx.Value("gin").(*gin.Context) + if !ok || ginCtx == nil { + return true + } + status := ginCtx.Writer.Status() + if status == 0 { + return true + } + return status < http.StatusBadRequest +} + +func resolveEndpoint(ctx context.Context) string { + if ctx == nil { + return "" + } + ginCtx, ok := ctx.Value("gin").(*gin.Context) + if !ok || ginCtx == nil || ginCtx.Request == nil { + return "" + } + + path := strings.TrimSpace(ginCtx.FullPath()) + if path == "" && ginCtx.Request.URL != nil { + path = strings.TrimSpace(ginCtx.Request.URL.Path) + } + if path == "" { + return "" + } + + method := strings.TrimSpace(ginCtx.Request.Method) + if method == "" { + return path + } + return method + " " + path +} diff --git a/internal/redisqueue/plugin_test.go b/internal/redisqueue/plugin_test.go new file mode 100644 index 00000000..907b8aee --- /dev/null +++ b/internal/redisqueue/plugin_test.go @@ -0,0 +1,160 @@ +package redisqueue + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/gin-gonic/gin" + internallogging "github.com/router-for-me/CLIProxyAPI/v6/internal/logging" + internalusage "github.com/router-for-me/CLIProxyAPI/v6/internal/usage" + coreusage "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/usage" +) + +func TestUsageQueuePluginPayloadIncludesStableFieldsAndSuccess(t *testing.T) { + withEnabledQueue(t, func() { + ginCtx := newTestGinContext(t, http.MethodPost, "/v1/chat/completions", http.StatusOK) + internallogging.SetGinRequestID(ginCtx, "gin-request-id-ignored") + ctx := context.WithValue(internallogging.WithRequestID(context.Background(), "ctx-request-id"), "gin", ginCtx) + + plugin := &usageQueuePlugin{} + plugin.HandleUsage(ctx, coreusage.Record{ + Provider: "openai", + Model: "gpt-5.4", + APIKey: "test-key", + AuthIndex: "0", + AuthType: "apikey", + Source: "user@example.com", + RequestedAt: time.Date(2026, 4, 25, 0, 0, 0, 0, time.UTC), + Latency: 1500 * time.Millisecond, + Detail: coreusage.Detail{ + InputTokens: 10, + OutputTokens: 20, + TotalTokens: 30, + }, + }) + + payload := popSinglePayload(t) + requireStringField(t, payload, "provider", "openai") + requireStringField(t, payload, "model", "gpt-5.4") + requireStringField(t, payload, "endpoint", "POST /v1/chat/completions") + requireStringField(t, payload, "auth_type", "apikey") + requireStringField(t, payload, "request_id", "ctx-request-id") + requireBoolField(t, payload, "failed", false) + }) +} + +func TestUsageQueuePluginPayloadIncludesStableFieldsAndFailureAndGinRequestID(t *testing.T) { + withEnabledQueue(t, func() { + ginCtx := newTestGinContext(t, http.MethodGet, "/v1/responses", http.StatusInternalServerError) + internallogging.SetGinRequestID(ginCtx, "gin-request-id") + ctx := context.WithValue(context.Background(), "gin", ginCtx) + + plugin := &usageQueuePlugin{} + plugin.HandleUsage(ctx, coreusage.Record{ + Provider: "openai", + Model: "gpt-5.4-mini", + APIKey: "test-key", + AuthIndex: "0", + AuthType: "apikey", + Source: "user@example.com", + RequestedAt: time.Date(2026, 4, 25, 0, 0, 0, 0, time.UTC), + Latency: 2500 * time.Millisecond, + Detail: coreusage.Detail{ + InputTokens: 10, + OutputTokens: 20, + TotalTokens: 30, + }, + }) + + payload := popSinglePayload(t) + requireStringField(t, payload, "provider", "openai") + requireStringField(t, payload, "model", "gpt-5.4-mini") + requireStringField(t, payload, "endpoint", "GET /v1/responses") + requireStringField(t, payload, "auth_type", "apikey") + requireStringField(t, payload, "request_id", "gin-request-id") + requireBoolField(t, payload, "failed", true) + }) +} + +func withEnabledQueue(t *testing.T, fn func()) { + t.Helper() + + prevQueueEnabled := Enabled() + prevStatsEnabled := internalusage.StatisticsEnabled() + + SetEnabled(false) + SetEnabled(true) + internalusage.SetStatisticsEnabled(true) + + defer func() { + SetEnabled(false) + SetEnabled(prevQueueEnabled) + internalusage.SetStatisticsEnabled(prevStatsEnabled) + }() + + fn() +} + +func newTestGinContext(t *testing.T, method, path string, status int) *gin.Context { + t.Helper() + + gin.SetMode(gin.TestMode) + recorder := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(recorder) + ginCtx.Request = httptest.NewRequest(method, "http://example.com"+path, nil) + if status != 0 { + ginCtx.Status(status) + } + return ginCtx +} + +func popSinglePayload(t *testing.T) map[string]json.RawMessage { + t.Helper() + + items := PopOldest(10) + if len(items) != 1 { + t.Fatalf("PopOldest() items = %d, want 1", len(items)) + } + + var payload map[string]json.RawMessage + if err := json.Unmarshal(items[0], &payload); err != nil { + t.Fatalf("unmarshal payload: %v", err) + } + return payload +} + +func requireStringField(t *testing.T, payload map[string]json.RawMessage, key, want string) { + t.Helper() + + raw, ok := payload[key] + if !ok { + t.Fatalf("payload missing %q", key) + } + var got string + if err := json.Unmarshal(raw, &got); err != nil { + t.Fatalf("unmarshal %q: %v", key, err) + } + if got != want { + t.Fatalf("%s = %q, want %q", key, got, want) + } +} + +func requireBoolField(t *testing.T, payload map[string]json.RawMessage, key string, want bool) { + t.Helper() + + raw, ok := payload[key] + if !ok { + t.Fatalf("payload missing %q", key) + } + var got bool + if err := json.Unmarshal(raw, &got); err != nil { + t.Fatalf("unmarshal %q: %v", key, err) + } + if got != want { + t.Fatalf("%s = %t, want %t", key, got, want) + } +} diff --git a/internal/redisqueue/queue.go b/internal/redisqueue/queue.go new file mode 100644 index 00000000..8a4b6742 --- /dev/null +++ b/internal/redisqueue/queue.go @@ -0,0 +1,133 @@ +package redisqueue + +import ( + "sync" + "sync/atomic" + "time" +) + +const retentionWindow = time.Minute + +type queueItem struct { + enqueuedAt time.Time + payload []byte +} + +type queue struct { + mu sync.Mutex + items []queueItem + head int +} + +var ( + enabled atomic.Bool + global queue +) + +func SetEnabled(value bool) { + enabled.Store(value) + if !value { + global.clear() + } +} + +func Enabled() bool { + return enabled.Load() +} + +func Enqueue(payload []byte) { + if !Enabled() { + return + } + if len(payload) == 0 { + return + } + global.enqueue(payload) +} + +func PopOldest(count int) [][]byte { + if !Enabled() { + return nil + } + if count <= 0 { + return nil + } + return global.popOldest(count) +} + +func (q *queue) clear() { + q.mu.Lock() + defer q.mu.Unlock() + q.items = nil + q.head = 0 +} + +func (q *queue) enqueue(payload []byte) { + now := time.Now() + + q.mu.Lock() + defer q.mu.Unlock() + + q.pruneLocked(now) + q.items = append(q.items, queueItem{ + enqueuedAt: now, + payload: append([]byte(nil), payload...), + }) + q.maybeCompactLocked() +} + +func (q *queue) popOldest(count int) [][]byte { + now := time.Now() + + q.mu.Lock() + defer q.mu.Unlock() + + q.pruneLocked(now) + available := len(q.items) - q.head + if available <= 0 { + q.items = nil + q.head = 0 + return nil + } + if count > available { + count = available + } + + out := make([][]byte, 0, count) + for i := 0; i < count; i++ { + item := q.items[q.head+i] + out = append(out, item.payload) + } + q.head += count + q.maybeCompactLocked() + return out +} + +func (q *queue) pruneLocked(now time.Time) { + if q.head >= len(q.items) { + q.items = nil + q.head = 0 + return + } + + cutoff := now.Add(-retentionWindow) + for q.head < len(q.items) && q.items[q.head].enqueuedAt.Before(cutoff) { + q.head++ + } +} + +func (q *queue) maybeCompactLocked() { + if q.head == 0 { + return + } + if q.head >= len(q.items) { + q.items = nil + q.head = 0 + return + } + if q.head < 1024 && q.head*2 < len(q.items) { + return + } + q.items = append([]queueItem(nil), q.items[q.head:]...) + q.head = 0 +} diff --git a/internal/runtime/executor/helps/usage_helpers.go b/internal/runtime/executor/helps/usage_helpers.go index 8da8fd1e..97c1c611 100644 --- a/internal/runtime/executor/helps/usage_helpers.go +++ b/internal/runtime/executor/helps/usage_helpers.go @@ -20,6 +20,7 @@ type UsageReporter struct { model string authID string authIndex string + authType string apiKey string source string requestedAt time.Time @@ -34,6 +35,7 @@ func NewUsageReporter(ctx context.Context, provider, model string, auth *cliprox requestedAt: time.Now(), apiKey: apiKey, source: resolveUsageSource(auth, apiKey), + authType: resolveUsageAuthType(auth), } if auth != nil { reporter.authID = auth.ID @@ -98,6 +100,7 @@ func (r *UsageReporter) buildRecord(detail usage.Detail, failed bool) usage.Reco APIKey: r.apiKey, AuthID: r.authID, AuthIndex: r.authIndex, + AuthType: r.authType, RequestedAt: r.requestedAt, Latency: r.latency(), Failed: failed, @@ -181,6 +184,18 @@ func resolveUsageSource(auth *cliproxyauth.Auth, ctxAPIKey string) string { return "" } +func resolveUsageAuthType(auth *cliproxyauth.Auth) string { + if auth == nil { + return "" + } + kind, _ := auth.AccountInfo() + kind = strings.TrimSpace(kind) + if kind == "api_key" { + return "apikey" + } + return kind +} + func ParseCodexUsage(data []byte) (usage.Detail, bool) { usageNode := gjson.ParseBytes(data).Get("response.usage") if !usageNode.Exists() { diff --git a/sdk/cliproxy/service.go b/sdk/cliproxy/service.go index fa0d8a0a..c5458b48 100644 --- a/sdk/cliproxy/service.go +++ b/sdk/cliproxy/service.go @@ -13,6 +13,7 @@ import ( "time" "github.com/router-for-me/CLIProxyAPI/v6/internal/api" + _ "github.com/router-for-me/CLIProxyAPI/v6/internal/redisqueue" "github.com/router-for-me/CLIProxyAPI/v6/internal/registry" "github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/executor" _ "github.com/router-for-me/CLIProxyAPI/v6/internal/usage" diff --git a/sdk/cliproxy/usage/manager.go b/sdk/cliproxy/usage/manager.go index 8d24f51f..c3d95f66 100644 --- a/sdk/cliproxy/usage/manager.go +++ b/sdk/cliproxy/usage/manager.go @@ -15,6 +15,7 @@ type Record struct { APIKey string AuthID string AuthIndex string + AuthType string Source string RequestedAt time.Time Latency time.Duration