feat(redis): implement Pub/Sub support for usage tracking
- Added Redis Pub/Sub capability to broadcast usage updates to subscribed clients. - Enhanced `redisqueue` with subscriber management and message broadcasting. - Updated tests to validate Pub/Sub message handling, subscription behavior, and fallback to the queue after unsubscribing. - Integrated `project_id` parsing into auth-files logic to include project identifiers in metadata.
This commit is contained in:
@@ -333,6 +333,9 @@ func (h *Handler) listAuthFilesFromDisk(c *gin.Context) {
|
|||||||
emailValue := gjson.GetBytes(data, "email").String()
|
emailValue := gjson.GetBytes(data, "email").String()
|
||||||
fileData["type"] = typeValue
|
fileData["type"] = typeValue
|
||||||
fileData["email"] = emailValue
|
fileData["email"] = emailValue
|
||||||
|
if projectID := strings.TrimSpace(gjson.GetBytes(data, "project_id").String()); projectID != "" {
|
||||||
|
fileData["project_id"] = projectID
|
||||||
|
}
|
||||||
if pv := gjson.GetBytes(data, "priority"); pv.Exists() {
|
if pv := gjson.GetBytes(data, "priority"); pv.Exists() {
|
||||||
switch pv.Type {
|
switch pv.Type {
|
||||||
case gjson.Number:
|
case gjson.Number:
|
||||||
@@ -394,6 +397,9 @@ func (h *Handler) buildAuthFileEntry(auth *coreauth.Auth) gin.H {
|
|||||||
if email := authEmail(auth); email != "" {
|
if email := authEmail(auth); email != "" {
|
||||||
entry["email"] = email
|
entry["email"] = email
|
||||||
}
|
}
|
||||||
|
if projectID := authProjectID(auth); projectID != "" {
|
||||||
|
entry["project_id"] = projectID
|
||||||
|
}
|
||||||
if accountType, account := auth.AccountInfo(); accountType != "" || account != "" {
|
if accountType, account := auth.AccountInfo(); accountType != "" || account != "" {
|
||||||
if accountType != "" {
|
if accountType != "" {
|
||||||
entry["account_type"] = accountType
|
entry["account_type"] = accountType
|
||||||
@@ -468,6 +474,28 @@ func (h *Handler) buildAuthFileEntry(auth *coreauth.Auth) gin.H {
|
|||||||
return entry
|
return entry
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func authProjectID(auth *coreauth.Auth) string {
|
||||||
|
if auth == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if auth.Metadata != nil {
|
||||||
|
if v, ok := auth.Metadata["project_id"].(string); ok {
|
||||||
|
if projectID := strings.TrimSpace(v); projectID != "" {
|
||||||
|
return projectID
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if auth.Attributes != nil {
|
||||||
|
if projectID := strings.TrimSpace(auth.Attributes["project_id"]); projectID != "" {
|
||||||
|
return projectID
|
||||||
|
}
|
||||||
|
if projectID := strings.TrimSpace(auth.Attributes["gemini_virtual_project"]); projectID != "" {
|
||||||
|
return projectID
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
func extractCodexIDTokenClaims(auth *coreauth.Auth) gin.H {
|
func extractCodexIDTokenClaims(auth *coreauth.Auth) gin.H {
|
||||||
if auth == nil || auth.Metadata == nil {
|
if auth == nil || auth.Metadata == nil {
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -0,0 +1,103 @@
|
|||||||
|
package management
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v7/internal/config"
|
||||||
|
coreauth "github.com/router-for-me/CLIProxyAPI/v7/sdk/cliproxy/auth"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestListAuthFiles_IncludesProjectIDFromManager(t *testing.T) {
|
||||||
|
t.Setenv("MANAGEMENT_PASSWORD", "")
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
authDir := t.TempDir()
|
||||||
|
fileName := "gemini-user@example.com-project-a.json"
|
||||||
|
filePath := filepath.Join(authDir, fileName)
|
||||||
|
if errWrite := os.WriteFile(filePath, []byte(`{"type":"gemini","email":"user@example.com","project_id":"project-a"}`), 0o600); errWrite != nil {
|
||||||
|
t.Fatalf("failed to write auth file: %v", errWrite)
|
||||||
|
}
|
||||||
|
|
||||||
|
manager := coreauth.NewManager(nil, nil, nil)
|
||||||
|
record := &coreauth.Auth{
|
||||||
|
ID: fileName,
|
||||||
|
FileName: fileName,
|
||||||
|
Provider: "gemini-cli",
|
||||||
|
Status: coreauth.StatusActive,
|
||||||
|
Attributes: map[string]string{
|
||||||
|
"path": filePath,
|
||||||
|
},
|
||||||
|
Metadata: map[string]any{
|
||||||
|
"type": "gemini",
|
||||||
|
"email": "user@example.com",
|
||||||
|
"project_id": "project-a",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
if _, errRegister := manager.Register(context.Background(), record); errRegister != nil {
|
||||||
|
t.Fatalf("failed to register auth record: %v", errRegister)
|
||||||
|
}
|
||||||
|
|
||||||
|
h := NewHandlerWithoutConfigFilePath(&config.Config{AuthDir: authDir}, manager)
|
||||||
|
h.tokenStore = &memoryAuthStore{}
|
||||||
|
|
||||||
|
entry := firstAuthFileEntry(t, h)
|
||||||
|
if got := entry["project_id"]; got != "project-a" {
|
||||||
|
t.Fatalf("expected project_id %q, got %#v", "project-a", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestListAuthFilesFromDisk_IncludesProjectID(t *testing.T) {
|
||||||
|
t.Setenv("MANAGEMENT_PASSWORD", "")
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
authDir := t.TempDir()
|
||||||
|
filePath := filepath.Join(authDir, "gemini-user@example.com-project-a.json")
|
||||||
|
if errWrite := os.WriteFile(filePath, []byte(`{"type":"gemini","email":"user@example.com","project_id":"project-a"}`), 0o600); errWrite != nil {
|
||||||
|
t.Fatalf("failed to write auth file: %v", errWrite)
|
||||||
|
}
|
||||||
|
|
||||||
|
h := NewHandlerWithoutConfigFilePath(&config.Config{AuthDir: authDir}, nil)
|
||||||
|
|
||||||
|
entry := firstAuthFileEntry(t, h)
|
||||||
|
if got := entry["project_id"]; got != "project-a" {
|
||||||
|
t.Fatalf("expected project_id %q, got %#v", "project-a", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func firstAuthFileEntry(t *testing.T, h *Handler) map[string]any {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
ginCtx, _ := gin.CreateTestContext(rec)
|
||||||
|
ginCtx.Request = httptest.NewRequest(http.MethodGet, "/v0/management/auth-files", nil)
|
||||||
|
|
||||||
|
h.ListAuthFiles(ginCtx)
|
||||||
|
|
||||||
|
if rec.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected list status %d, got %d with body %s", http.StatusOK, rec.Code, rec.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
var payload map[string]any
|
||||||
|
if errUnmarshal := json.Unmarshal(rec.Body.Bytes(), &payload); errUnmarshal != nil {
|
||||||
|
t.Fatalf("failed to decode list payload: %v", errUnmarshal)
|
||||||
|
}
|
||||||
|
filesRaw, ok := payload["files"].([]any)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("expected files array, payload: %#v", payload)
|
||||||
|
}
|
||||||
|
if len(filesRaw) != 1 {
|
||||||
|
t.Fatalf("expected 1 auth entry, got %d", len(filesRaw))
|
||||||
|
}
|
||||||
|
fileEntry, ok := filesRaw[0].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("expected file entry object, got %#v", filesRaw[0])
|
||||||
|
}
|
||||||
|
return fileEntry
|
||||||
|
}
|
||||||
@@ -14,6 +14,13 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const redisUsageChannel = "usage"
|
||||||
|
|
||||||
|
type redisSubscriptionCommand struct {
|
||||||
|
args []string
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
func isRedisRESPPrefix(prefix byte) bool {
|
func isRedisRESPPrefix(prefix byte) bool {
|
||||||
switch prefix {
|
switch prefix {
|
||||||
case '*', '$', '+', '-', ':':
|
case '*', '$', '+', '-', ':':
|
||||||
@@ -131,6 +138,41 @@ func (s *Server) handleRedisConnection(conn net.Conn, reader *bufio.Reader) {
|
|||||||
if !flush() {
|
if !flush() {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
case "SUBSCRIBE":
|
||||||
|
if !authed {
|
||||||
|
_ = writeRedisError(writer, "NOAUTH Authentication required.")
|
||||||
|
if !flush() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
channel, ok := parseSubscribeChannel(args)
|
||||||
|
if !ok {
|
||||||
|
_ = writeRedisError(writer, "ERR wrong number of arguments for 'subscribe' command")
|
||||||
|
if !flush() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if !strings.EqualFold(channel, redisUsageChannel) {
|
||||||
|
_ = writeRedisError(writer, fmt.Sprintf("ERR unsupported channel '%s'", channel))
|
||||||
|
if !flush() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
messages, unsubscribe := redisqueue.SubscribeUsage()
|
||||||
|
if errWrite := writeRedisPubSubSubscribe(writer, redisUsageChannel, 1); errWrite != nil {
|
||||||
|
unsubscribe()
|
||||||
|
log.Errorf("redis protocol subscribe response error: %v", errWrite)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !flush() {
|
||||||
|
unsubscribe()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.streamRedisUsageSubscription(reader, writer, messages, unsubscribe)
|
||||||
|
return
|
||||||
case "LPOP", "RPOP":
|
case "LPOP", "RPOP":
|
||||||
if !authed {
|
if !authed {
|
||||||
_ = writeRedisError(writer, "NOAUTH Authentication required.")
|
_ = writeRedisError(writer, "NOAUTH Authentication required.")
|
||||||
@@ -182,6 +224,101 @@ func (s *Server) handleRedisConnection(conn net.Conn, reader *bufio.Reader) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Server) streamRedisUsageSubscription(reader *bufio.Reader, writer *bufio.Writer, messages <-chan []byte, unsubscribe func()) {
|
||||||
|
if unsubscribe == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer unsubscribe()
|
||||||
|
|
||||||
|
done := make(chan struct{})
|
||||||
|
defer close(done)
|
||||||
|
|
||||||
|
commands := make(chan redisSubscriptionCommand, 1)
|
||||||
|
go readRedisSubscriptionCommands(reader, commands, done)
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case msg, ok := <-messages:
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if errWrite := writeRedisPubSubMessage(writer, redisUsageChannel, msg); errWrite != nil {
|
||||||
|
log.Errorf("redis protocol publish message error: %v", errWrite)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if errFlush := writer.Flush(); errFlush != nil {
|
||||||
|
log.Errorf("redis protocol flush error: %v", errFlush)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
case command, ok := <-commands:
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
keepOpen := handleRedisSubscriptionCommand(writer, command)
|
||||||
|
if errFlush := writer.Flush(); errFlush != nil {
|
||||||
|
log.Errorf("redis protocol flush error: %v", errFlush)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !keepOpen {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func readRedisSubscriptionCommands(reader *bufio.Reader, commands chan<- redisSubscriptionCommand, done <-chan struct{}) {
|
||||||
|
defer close(commands)
|
||||||
|
|
||||||
|
for {
|
||||||
|
args, err := readRESPArray(reader)
|
||||||
|
if err != nil {
|
||||||
|
if !errors.Is(err, io.EOF) {
|
||||||
|
select {
|
||||||
|
case commands <- redisSubscriptionCommand{err: err}:
|
||||||
|
case <-done:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case commands <- redisSubscriptionCommand{args: args}:
|
||||||
|
case <-done:
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func handleRedisSubscriptionCommand(writer *bufio.Writer, command redisSubscriptionCommand) bool {
|
||||||
|
if command.err != nil {
|
||||||
|
_ = writeRedisError(writer, "ERR "+command.err.Error())
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if len(command.args) == 0 {
|
||||||
|
_ = writeRedisError(writer, "ERR empty command")
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd := strings.ToUpper(strings.TrimSpace(command.args[0]))
|
||||||
|
switch cmd {
|
||||||
|
case "PING":
|
||||||
|
payload := []byte(nil)
|
||||||
|
if len(command.args) > 1 {
|
||||||
|
payload = []byte(command.args[1])
|
||||||
|
}
|
||||||
|
_ = writeRedisPubSubPong(writer, payload)
|
||||||
|
return true
|
||||||
|
case "UNSUBSCRIBE":
|
||||||
|
_ = writeRedisPubSubUnsubscribe(writer, redisUsageChannel, 0)
|
||||||
|
return false
|
||||||
|
case "QUIT":
|
||||||
|
_ = writeRedisSimpleString(writer, "OK")
|
||||||
|
return false
|
||||||
|
default:
|
||||||
|
_ = writeRedisError(writer, fmt.Sprintf("ERR unknown command '%s'", strings.ToLower(cmd)))
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func resolveRemoteIP(addr net.Addr) (ip string, localClient bool) {
|
func resolveRemoteIP(addr net.Addr) (ip string, localClient bool) {
|
||||||
if addr == nil {
|
if addr == nil {
|
||||||
return "", false
|
return "", false
|
||||||
@@ -232,6 +369,13 @@ func parseAuthPassword(args []string) (string, bool) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func parseSubscribeChannel(args []string) (string, bool) {
|
||||||
|
if len(args) != 2 {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
return strings.TrimSpace(args[1]), true
|
||||||
|
}
|
||||||
|
|
||||||
func parsePopCount(args []string) (count int, hasCount bool, ok bool) {
|
func parsePopCount(args []string) (count int, hasCount bool, ok bool) {
|
||||||
if len(args) != 2 && len(args) != 3 {
|
if len(args) != 2 && len(args) != 3 {
|
||||||
return 0, false, false
|
return 0, false, false
|
||||||
@@ -375,3 +519,68 @@ func writeRedisArrayOfBulkStrings(writer *bufio.Writer, items [][]byte) error {
|
|||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func writeRedisInteger(writer *bufio.Writer, value int) error {
|
||||||
|
if writer == nil {
|
||||||
|
return net.ErrClosed
|
||||||
|
}
|
||||||
|
_, err := writer.WriteString(":" + strconv.Itoa(value) + "\r\n")
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func writeRedisArrayHeader(writer *bufio.Writer, count int) error {
|
||||||
|
if writer == nil {
|
||||||
|
return net.ErrClosed
|
||||||
|
}
|
||||||
|
_, err := writer.WriteString("*" + strconv.Itoa(count) + "\r\n")
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func writeRedisPubSubSubscribe(writer *bufio.Writer, channel string, count int) error {
|
||||||
|
if err := writeRedisArrayHeader(writer, 3); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := writeRedisBulkString(writer, []byte("subscribe")); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := writeRedisBulkString(writer, []byte(channel)); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return writeRedisInteger(writer, count)
|
||||||
|
}
|
||||||
|
|
||||||
|
func writeRedisPubSubUnsubscribe(writer *bufio.Writer, channel string, count int) error {
|
||||||
|
if err := writeRedisArrayHeader(writer, 3); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := writeRedisBulkString(writer, []byte("unsubscribe")); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := writeRedisBulkString(writer, []byte(channel)); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return writeRedisInteger(writer, count)
|
||||||
|
}
|
||||||
|
|
||||||
|
func writeRedisPubSubMessage(writer *bufio.Writer, channel string, payload []byte) error {
|
||||||
|
if err := writeRedisArrayHeader(writer, 3); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := writeRedisBulkString(writer, []byte("message")); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := writeRedisBulkString(writer, []byte(channel)); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return writeRedisBulkString(writer, payload)
|
||||||
|
}
|
||||||
|
|
||||||
|
func writeRedisPubSubPong(writer *bufio.Writer, payload []byte) error {
|
||||||
|
if err := writeRedisArrayHeader(writer, 2); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := writeRedisBulkString(writer, []byte("pong")); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return writeRedisBulkString(writer, payload)
|
||||||
|
}
|
||||||
|
|||||||
@@ -3,10 +3,13 @@ package api
|
|||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
@@ -171,6 +174,105 @@ func readRESPArrayOfBulkStrings(r *bufio.Reader) ([][]byte, error) {
|
|||||||
return out, nil
|
return out, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func readTestRESPInteger(r *bufio.Reader) (int, error) {
|
||||||
|
prefix, err := r.ReadByte()
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
if prefix != ':' {
|
||||||
|
return 0, fmt.Errorf("expected integer prefix ':', got %q", prefix)
|
||||||
|
}
|
||||||
|
|
||||||
|
line, err := readTestRESPLine(r)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
value, err := strconv.Atoi(line)
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("invalid integer %q: %v", line, err)
|
||||||
|
}
|
||||||
|
return value, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func readTestRESPArrayHeader(r *bufio.Reader) (int, error) {
|
||||||
|
prefix, err := r.ReadByte()
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
if prefix != '*' {
|
||||||
|
return 0, fmt.Errorf("expected array prefix '*', got %q", prefix)
|
||||||
|
}
|
||||||
|
|
||||||
|
line, err := readTestRESPLine(r)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
count, err := strconv.Atoi(line)
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("invalid array length %q: %v", line, err)
|
||||||
|
}
|
||||||
|
if count < 0 {
|
||||||
|
return 0, fmt.Errorf("invalid array length %d", count)
|
||||||
|
}
|
||||||
|
return count, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func readTestRESPPubSubSubscribe(r *bufio.Reader) (string, int, error) {
|
||||||
|
count, err := readTestRESPArrayHeader(r)
|
||||||
|
if err != nil {
|
||||||
|
return "", 0, err
|
||||||
|
}
|
||||||
|
if count != 3 {
|
||||||
|
return "", 0, fmt.Errorf("subscribe array length = %d, want 3", count)
|
||||||
|
}
|
||||||
|
|
||||||
|
kind, err := readTestRESPBulkString(r)
|
||||||
|
if err != nil {
|
||||||
|
return "", 0, err
|
||||||
|
}
|
||||||
|
if string(kind) != "subscribe" {
|
||||||
|
return "", 0, fmt.Errorf("pubsub kind = %q, want subscribe", string(kind))
|
||||||
|
}
|
||||||
|
|
||||||
|
channel, err := readTestRESPBulkString(r)
|
||||||
|
if err != nil {
|
||||||
|
return "", 0, err
|
||||||
|
}
|
||||||
|
subscriptions, err := readTestRESPInteger(r)
|
||||||
|
if err != nil {
|
||||||
|
return "", 0, err
|
||||||
|
}
|
||||||
|
return string(channel), subscriptions, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func readTestRESPPubSubMessage(r *bufio.Reader) (string, []byte, error) {
|
||||||
|
count, err := readTestRESPArrayHeader(r)
|
||||||
|
if err != nil {
|
||||||
|
return "", nil, err
|
||||||
|
}
|
||||||
|
if count != 3 {
|
||||||
|
return "", nil, fmt.Errorf("message array length = %d, want 3", count)
|
||||||
|
}
|
||||||
|
|
||||||
|
kind, err := readTestRESPBulkString(r)
|
||||||
|
if err != nil {
|
||||||
|
return "", nil, err
|
||||||
|
}
|
||||||
|
if string(kind) != "message" {
|
||||||
|
return "", nil, fmt.Errorf("pubsub kind = %q, want message", string(kind))
|
||||||
|
}
|
||||||
|
|
||||||
|
channel, err := readTestRESPBulkString(r)
|
||||||
|
if err != nil {
|
||||||
|
return "", nil, err
|
||||||
|
}
|
||||||
|
payload, err := readTestRESPBulkString(r)
|
||||||
|
if err != nil {
|
||||||
|
return "", nil, err
|
||||||
|
}
|
||||||
|
return string(channel), payload, nil
|
||||||
|
}
|
||||||
|
|
||||||
func TestRedisProtocol_ManagementDisabled_RejectsConnection(t *testing.T) {
|
func TestRedisProtocol_ManagementDisabled_RejectsConnection(t *testing.T) {
|
||||||
t.Setenv("MANAGEMENT_PASSWORD", "")
|
t.Setenv("MANAGEMENT_PASSWORD", "")
|
||||||
redisqueue.SetEnabled(false)
|
redisqueue.SetEnabled(false)
|
||||||
@@ -352,6 +454,127 @@ func TestRedisProtocol_AUTH_And_PopContracts(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestRedisProtocol_SubscribeUsageBroadcastsAndSkipsQueue(t *testing.T) {
|
||||||
|
const managementPassword = "test-management-password"
|
||||||
|
|
||||||
|
t.Setenv("MANAGEMENT_PASSWORD", managementPassword)
|
||||||
|
redisqueue.SetEnabled(false)
|
||||||
|
t.Cleanup(func() { redisqueue.SetEnabled(false) })
|
||||||
|
|
||||||
|
server := newTestServer(t)
|
||||||
|
if !server.managementRoutesEnabled.Load() {
|
||||||
|
t.Fatalf("expected managementRoutesEnabled to be true")
|
||||||
|
}
|
||||||
|
|
||||||
|
addr, stop := startRedisMuxListener(t, server)
|
||||||
|
t.Cleanup(stop)
|
||||||
|
|
||||||
|
firstConn, errDialFirst := net.DialTimeout("tcp", addr, time.Second)
|
||||||
|
if errDialFirst != nil {
|
||||||
|
t.Fatalf("failed to dial first redis listener: %v", errDialFirst)
|
||||||
|
}
|
||||||
|
t.Cleanup(func() { _ = firstConn.Close() })
|
||||||
|
firstReader := bufio.NewReader(firstConn)
|
||||||
|
_ = firstConn.SetDeadline(time.Now().Add(5 * time.Second))
|
||||||
|
|
||||||
|
if errWrite := writeTestRESPCommand(firstConn, "AUTH", managementPassword); errWrite != nil {
|
||||||
|
t.Fatalf("failed to write first AUTH command: %v", errWrite)
|
||||||
|
}
|
||||||
|
if msg, err := readTestRESPSimpleString(firstReader); err != nil {
|
||||||
|
t.Fatalf("failed to read first AUTH response: %v", err)
|
||||||
|
} else if msg != "OK" {
|
||||||
|
t.Fatalf("unexpected first AUTH response: %q", msg)
|
||||||
|
}
|
||||||
|
if errWrite := writeTestRESPCommand(firstConn, "SUBSCRIBE", "usage"); errWrite != nil {
|
||||||
|
t.Fatalf("failed to write first SUBSCRIBE command: %v", errWrite)
|
||||||
|
}
|
||||||
|
if channel, count, err := readTestRESPPubSubSubscribe(firstReader); err != nil {
|
||||||
|
t.Fatalf("failed to read first SUBSCRIBE response: %v", err)
|
||||||
|
} else if channel != "usage" || count != 1 {
|
||||||
|
t.Fatalf("unexpected first SUBSCRIBE response channel=%q count=%d", channel, count)
|
||||||
|
}
|
||||||
|
|
||||||
|
secondConn, errDialSecond := net.DialTimeout("tcp", addr, time.Second)
|
||||||
|
if errDialSecond != nil {
|
||||||
|
t.Fatalf("failed to dial second redis listener: %v", errDialSecond)
|
||||||
|
}
|
||||||
|
t.Cleanup(func() { _ = secondConn.Close() })
|
||||||
|
secondReader := bufio.NewReader(secondConn)
|
||||||
|
_ = secondConn.SetDeadline(time.Now().Add(5 * time.Second))
|
||||||
|
|
||||||
|
if errWrite := writeTestRESPCommand(secondConn, "AUTH", managementPassword); errWrite != nil {
|
||||||
|
t.Fatalf("failed to write second AUTH command: %v", errWrite)
|
||||||
|
}
|
||||||
|
if msg, err := readTestRESPSimpleString(secondReader); err != nil {
|
||||||
|
t.Fatalf("failed to read second AUTH response: %v", err)
|
||||||
|
} else if msg != "OK" {
|
||||||
|
t.Fatalf("unexpected second AUTH response: %q", msg)
|
||||||
|
}
|
||||||
|
if errWrite := writeTestRESPCommand(secondConn, "SUBSCRIBE", "usage"); errWrite != nil {
|
||||||
|
t.Fatalf("failed to write second SUBSCRIBE command: %v", errWrite)
|
||||||
|
}
|
||||||
|
if channel, count, err := readTestRESPPubSubSubscribe(secondReader); err != nil {
|
||||||
|
t.Fatalf("failed to read second SUBSCRIBE response: %v", err)
|
||||||
|
} else if channel != "usage" || count != 1 {
|
||||||
|
t.Fatalf("unexpected second SUBSCRIBE response channel=%q count=%d", channel, count)
|
||||||
|
}
|
||||||
|
|
||||||
|
redisqueue.Enqueue([]byte(`{"id":1}`))
|
||||||
|
|
||||||
|
if channel, payload, err := readTestRESPPubSubMessage(firstReader); err != nil {
|
||||||
|
t.Fatalf("failed to read first pubsub message: %v", err)
|
||||||
|
} else if channel != "usage" || string(payload) != `{"id":1}` {
|
||||||
|
t.Fatalf("unexpected first pubsub message channel=%q payload=%q", channel, string(payload))
|
||||||
|
}
|
||||||
|
if channel, payload, err := readTestRESPPubSubMessage(secondReader); err != nil {
|
||||||
|
t.Fatalf("failed to read second pubsub message: %v", err)
|
||||||
|
} else if channel != "usage" || string(payload) != `{"id":1}` {
|
||||||
|
t.Fatalf("unexpected second pubsub message channel=%q payload=%q", channel, string(payload))
|
||||||
|
}
|
||||||
|
|
||||||
|
popConn, errDialPop := net.DialTimeout("tcp", addr, time.Second)
|
||||||
|
if errDialPop != nil {
|
||||||
|
t.Fatalf("failed to dial pop redis listener: %v", errDialPop)
|
||||||
|
}
|
||||||
|
t.Cleanup(func() { _ = popConn.Close() })
|
||||||
|
popReader := bufio.NewReader(popConn)
|
||||||
|
_ = popConn.SetDeadline(time.Now().Add(5 * time.Second))
|
||||||
|
|
||||||
|
if errWrite := writeTestRESPCommand(popConn, "AUTH", managementPassword); errWrite != nil {
|
||||||
|
t.Fatalf("failed to write pop AUTH command: %v", errWrite)
|
||||||
|
}
|
||||||
|
if msg, err := readTestRESPSimpleString(popReader); err != nil {
|
||||||
|
t.Fatalf("failed to read pop AUTH response: %v", err)
|
||||||
|
} else if msg != "OK" {
|
||||||
|
t.Fatalf("unexpected pop AUTH response: %q", msg)
|
||||||
|
}
|
||||||
|
if errWrite := writeTestRESPCommand(popConn, "LPOP", "usage"); errWrite != nil {
|
||||||
|
t.Fatalf("failed to write pop LPOP command: %v", errWrite)
|
||||||
|
}
|
||||||
|
item, errItem := readTestRESPBulkString(popReader)
|
||||||
|
if errItem != nil {
|
||||||
|
t.Fatalf("failed to read pop LPOP response: %v", errItem)
|
||||||
|
}
|
||||||
|
if item != nil {
|
||||||
|
t.Fatalf("expected subscribed usage to skip queue, got %q", string(item))
|
||||||
|
}
|
||||||
|
|
||||||
|
managementReq := httptest.NewRequest(http.MethodGet, "/v0/management/usage-queue?count=1", nil)
|
||||||
|
managementReq.Header.Set("Authorization", "Bearer "+managementPassword)
|
||||||
|
managementRR := httptest.NewRecorder()
|
||||||
|
server.engine.ServeHTTP(managementRR, managementReq)
|
||||||
|
if managementRR.Code != http.StatusOK {
|
||||||
|
t.Fatalf("management usage status = %d, want %d body=%s", managementRR.Code, http.StatusOK, managementRR.Body.String())
|
||||||
|
}
|
||||||
|
var managementPayload []json.RawMessage
|
||||||
|
if errUnmarshal := json.Unmarshal(managementRR.Body.Bytes(), &managementPayload); errUnmarshal != nil {
|
||||||
|
t.Fatalf("unmarshal management usage response: %v", errUnmarshal)
|
||||||
|
}
|
||||||
|
if len(managementPayload) != 0 {
|
||||||
|
t.Fatalf("expected management usage queue to be empty, got %s", managementRR.Body.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestRedisProtocol_IPBan_MirrorsManagementPolicy(t *testing.T) {
|
func TestRedisProtocol_IPBan_MirrorsManagementPolicy(t *testing.T) {
|
||||||
const managementPassword = "test-management-password"
|
const managementPassword = "test-management-password"
|
||||||
|
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import (
|
|||||||
const (
|
const (
|
||||||
defaultRetentionSeconds int64 = 60
|
defaultRetentionSeconds int64 = 60
|
||||||
maxRetentionSeconds int64 = 3600
|
maxRetentionSeconds int64 = 3600
|
||||||
|
usageSubscriberBuffer = 256
|
||||||
)
|
)
|
||||||
|
|
||||||
type queueItem struct {
|
type queueItem struct {
|
||||||
@@ -17,9 +18,11 @@ type queueItem struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type queue struct {
|
type queue struct {
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
items []queueItem
|
items []queueItem
|
||||||
head int
|
head int
|
||||||
|
subscribers map[uint64]chan []byte
|
||||||
|
nextSubscriberID uint64
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -60,6 +63,9 @@ func Enqueue(payload []byte) {
|
|||||||
if len(payload) == 0 {
|
if len(payload) == 0 {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if global.publishToSubscribers(payload) {
|
||||||
|
return
|
||||||
|
}
|
||||||
global.enqueue(payload)
|
global.enqueue(payload)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -73,11 +79,25 @@ func PopOldest(count int) [][]byte {
|
|||||||
return global.popOldest(count)
|
return global.popOldest(count)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func SubscribeUsage() (<-chan []byte, func()) {
|
||||||
|
return global.subscribeUsage()
|
||||||
|
}
|
||||||
|
|
||||||
func (q *queue) clear() {
|
func (q *queue) clear() {
|
||||||
q.mu.Lock()
|
q.mu.Lock()
|
||||||
defer q.mu.Unlock()
|
|
||||||
|
subscribers := make([]chan []byte, 0, len(q.subscribers))
|
||||||
|
for _, subscriber := range q.subscribers {
|
||||||
|
subscribers = append(subscribers, subscriber)
|
||||||
|
}
|
||||||
q.items = nil
|
q.items = nil
|
||||||
q.head = 0
|
q.head = 0
|
||||||
|
q.subscribers = nil
|
||||||
|
q.mu.Unlock()
|
||||||
|
|
||||||
|
for _, subscriber := range subscribers {
|
||||||
|
close(subscriber)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (q *queue) enqueue(payload []byte) {
|
func (q *queue) enqueue(payload []byte) {
|
||||||
@@ -94,6 +114,61 @@ func (q *queue) enqueue(payload []byte) {
|
|||||||
q.maybeCompactLocked()
|
q.maybeCompactLocked()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (q *queue) publishToSubscribers(payload []byte) bool {
|
||||||
|
q.mu.Lock()
|
||||||
|
defer q.mu.Unlock()
|
||||||
|
|
||||||
|
if len(q.subscribers) == 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
for id, subscriber := range q.subscribers {
|
||||||
|
cloned := append([]byte(nil), payload...)
|
||||||
|
select {
|
||||||
|
case subscriber <- cloned:
|
||||||
|
default:
|
||||||
|
delete(q.subscribers, id)
|
||||||
|
close(subscriber)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *queue) subscribeUsage() (<-chan []byte, func()) {
|
||||||
|
subscriber := make(chan []byte, usageSubscriberBuffer)
|
||||||
|
|
||||||
|
q.mu.Lock()
|
||||||
|
if q.subscribers == nil {
|
||||||
|
q.subscribers = make(map[uint64]chan []byte)
|
||||||
|
}
|
||||||
|
q.nextSubscriberID++
|
||||||
|
id := q.nextSubscriberID
|
||||||
|
q.subscribers[id] = subscriber
|
||||||
|
q.mu.Unlock()
|
||||||
|
|
||||||
|
var once sync.Once
|
||||||
|
unsubscribe := func() {
|
||||||
|
once.Do(func() {
|
||||||
|
q.unsubscribeUsage(id)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return subscriber, unsubscribe
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *queue) unsubscribeUsage(id uint64) {
|
||||||
|
q.mu.Lock()
|
||||||
|
subscriber, ok := q.subscribers[id]
|
||||||
|
if ok {
|
||||||
|
delete(q.subscribers, id)
|
||||||
|
}
|
||||||
|
q.mu.Unlock()
|
||||||
|
|
||||||
|
if ok {
|
||||||
|
close(subscriber)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (q *queue) popOldest(count int) [][]byte {
|
func (q *queue) popOldest(count int) [][]byte {
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,67 @@
|
|||||||
|
package redisqueue
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestEnqueueBroadcastsToUsageSubscribersAndSkipsQueue(t *testing.T) {
|
||||||
|
withEnabledQueue(t, func() {
|
||||||
|
first, unsubscribeFirst := SubscribeUsage()
|
||||||
|
defer unsubscribeFirst()
|
||||||
|
second, unsubscribeSecond := SubscribeUsage()
|
||||||
|
defer unsubscribeSecond()
|
||||||
|
|
||||||
|
Enqueue([]byte("usage-record"))
|
||||||
|
|
||||||
|
requireUsageSubscriberPayload(t, first, "usage-record")
|
||||||
|
requireUsageSubscriberPayload(t, second, "usage-record")
|
||||||
|
|
||||||
|
if items := PopOldest(1); len(items) != 0 {
|
||||||
|
t.Fatalf("PopOldest() items = %q, want empty after subscriber broadcast", items)
|
||||||
|
}
|
||||||
|
|
||||||
|
unsubscribeFirst()
|
||||||
|
unsubscribeSecond()
|
||||||
|
|
||||||
|
Enqueue([]byte("queued-record"))
|
||||||
|
items := PopOldest(1)
|
||||||
|
if len(items) != 1 || string(items[0]) != "queued-record" {
|
||||||
|
t.Fatalf("PopOldest() items = %q, want queued record after unsubscribe", items)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSetEnabledFalseClosesUsageSubscribers(t *testing.T) {
|
||||||
|
withEnabledQueue(t, func() {
|
||||||
|
subscriber, unsubscribe := SubscribeUsage()
|
||||||
|
defer unsubscribe()
|
||||||
|
|
||||||
|
SetEnabled(false)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case _, ok := <-subscriber:
|
||||||
|
if ok {
|
||||||
|
t.Fatalf("subscriber channel remained open after SetEnabled(false)")
|
||||||
|
}
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Fatalf("timeout waiting for subscriber close")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func requireUsageSubscriberPayload(t *testing.T, subscriber <-chan []byte, want string) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case got, ok := <-subscriber:
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("subscriber closed before receiving %q", want)
|
||||||
|
}
|
||||||
|
if string(got) != want {
|
||||||
|
t.Fatalf("subscriber payload = %q, want %q", string(got), want)
|
||||||
|
}
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Fatalf("timeout waiting for subscriber payload %q", want)
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user