diff --git a/internal/api/handlers/management/auth_files.go b/internal/api/handlers/management/auth_files.go index d7e79897..d9ecefe5 100644 --- a/internal/api/handlers/management/auth_files.go +++ b/internal/api/handlers/management/auth_files.go @@ -333,6 +333,9 @@ func (h *Handler) listAuthFilesFromDisk(c *gin.Context) { emailValue := gjson.GetBytes(data, "email").String() fileData["type"] = typeValue fileData["email"] = emailValue + if projectID := strings.TrimSpace(gjson.GetBytes(data, "project_id").String()); projectID != "" { + fileData["project_id"] = projectID + } if pv := gjson.GetBytes(data, "priority"); pv.Exists() { switch pv.Type { case gjson.Number: @@ -394,6 +397,9 @@ func (h *Handler) buildAuthFileEntry(auth *coreauth.Auth) gin.H { if email := authEmail(auth); email != "" { entry["email"] = email } + if projectID := authProjectID(auth); projectID != "" { + entry["project_id"] = projectID + } if accountType, account := auth.AccountInfo(); accountType != "" || account != "" { if accountType != "" { entry["account_type"] = accountType @@ -468,6 +474,28 @@ func (h *Handler) buildAuthFileEntry(auth *coreauth.Auth) gin.H { return entry } +func authProjectID(auth *coreauth.Auth) string { + if auth == nil { + return "" + } + if auth.Metadata != nil { + if v, ok := auth.Metadata["project_id"].(string); ok { + if projectID := strings.TrimSpace(v); projectID != "" { + return projectID + } + } + } + if auth.Attributes != nil { + if projectID := strings.TrimSpace(auth.Attributes["project_id"]); projectID != "" { + return projectID + } + if projectID := strings.TrimSpace(auth.Attributes["gemini_virtual_project"]); projectID != "" { + return projectID + } + } + return "" +} + func extractCodexIDTokenClaims(auth *coreauth.Auth) gin.H { if auth == nil || auth.Metadata == nil { return nil diff --git a/internal/api/handlers/management/auth_files_project_id_test.go b/internal/api/handlers/management/auth_files_project_id_test.go new file mode 100644 index 00000000..e9634f5a --- /dev/null +++ b/internal/api/handlers/management/auth_files_project_id_test.go @@ -0,0 +1,103 @@ +package management + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "testing" + + "github.com/gin-gonic/gin" + "github.com/router-for-me/CLIProxyAPI/v7/internal/config" + coreauth "github.com/router-for-me/CLIProxyAPI/v7/sdk/cliproxy/auth" +) + +func TestListAuthFiles_IncludesProjectIDFromManager(t *testing.T) { + t.Setenv("MANAGEMENT_PASSWORD", "") + gin.SetMode(gin.TestMode) + + authDir := t.TempDir() + fileName := "gemini-user@example.com-project-a.json" + filePath := filepath.Join(authDir, fileName) + if errWrite := os.WriteFile(filePath, []byte(`{"type":"gemini","email":"user@example.com","project_id":"project-a"}`), 0o600); errWrite != nil { + t.Fatalf("failed to write auth file: %v", errWrite) + } + + manager := coreauth.NewManager(nil, nil, nil) + record := &coreauth.Auth{ + ID: fileName, + FileName: fileName, + Provider: "gemini-cli", + Status: coreauth.StatusActive, + Attributes: map[string]string{ + "path": filePath, + }, + Metadata: map[string]any{ + "type": "gemini", + "email": "user@example.com", + "project_id": "project-a", + }, + } + if _, errRegister := manager.Register(context.Background(), record); errRegister != nil { + t.Fatalf("failed to register auth record: %v", errRegister) + } + + h := NewHandlerWithoutConfigFilePath(&config.Config{AuthDir: authDir}, manager) + h.tokenStore = &memoryAuthStore{} + + entry := firstAuthFileEntry(t, h) + if got := entry["project_id"]; got != "project-a" { + t.Fatalf("expected project_id %q, got %#v", "project-a", got) + } +} + +func TestListAuthFilesFromDisk_IncludesProjectID(t *testing.T) { + t.Setenv("MANAGEMENT_PASSWORD", "") + gin.SetMode(gin.TestMode) + + authDir := t.TempDir() + filePath := filepath.Join(authDir, "gemini-user@example.com-project-a.json") + if errWrite := os.WriteFile(filePath, []byte(`{"type":"gemini","email":"user@example.com","project_id":"project-a"}`), 0o600); errWrite != nil { + t.Fatalf("failed to write auth file: %v", errWrite) + } + + h := NewHandlerWithoutConfigFilePath(&config.Config{AuthDir: authDir}, nil) + + entry := firstAuthFileEntry(t, h) + if got := entry["project_id"]; got != "project-a" { + t.Fatalf("expected project_id %q, got %#v", "project-a", got) + } +} + +func firstAuthFileEntry(t *testing.T, h *Handler) map[string]any { + t.Helper() + + rec := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(rec) + ginCtx.Request = httptest.NewRequest(http.MethodGet, "/v0/management/auth-files", nil) + + h.ListAuthFiles(ginCtx) + + if rec.Code != http.StatusOK { + t.Fatalf("expected list status %d, got %d with body %s", http.StatusOK, rec.Code, rec.Body.String()) + } + + var payload map[string]any + if errUnmarshal := json.Unmarshal(rec.Body.Bytes(), &payload); errUnmarshal != nil { + t.Fatalf("failed to decode list payload: %v", errUnmarshal) + } + filesRaw, ok := payload["files"].([]any) + if !ok { + t.Fatalf("expected files array, payload: %#v", payload) + } + if len(filesRaw) != 1 { + t.Fatalf("expected 1 auth entry, got %d", len(filesRaw)) + } + fileEntry, ok := filesRaw[0].(map[string]any) + if !ok { + t.Fatalf("expected file entry object, got %#v", filesRaw[0]) + } + return fileEntry +} diff --git a/internal/api/redis_queue_protocol.go b/internal/api/redis_queue_protocol.go index 6f3622d7..f9d412d9 100644 --- a/internal/api/redis_queue_protocol.go +++ b/internal/api/redis_queue_protocol.go @@ -14,6 +14,13 @@ import ( log "github.com/sirupsen/logrus" ) +const redisUsageChannel = "usage" + +type redisSubscriptionCommand struct { + args []string + err error +} + func isRedisRESPPrefix(prefix byte) bool { switch prefix { case '*', '$', '+', '-', ':': @@ -131,6 +138,41 @@ func (s *Server) handleRedisConnection(conn net.Conn, reader *bufio.Reader) { if !flush() { return } + case "SUBSCRIBE": + if !authed { + _ = writeRedisError(writer, "NOAUTH Authentication required.") + if !flush() { + return + } + continue + } + channel, ok := parseSubscribeChannel(args) + if !ok { + _ = writeRedisError(writer, "ERR wrong number of arguments for 'subscribe' command") + if !flush() { + return + } + continue + } + if !strings.EqualFold(channel, redisUsageChannel) { + _ = writeRedisError(writer, fmt.Sprintf("ERR unsupported channel '%s'", channel)) + if !flush() { + return + } + continue + } + messages, unsubscribe := redisqueue.SubscribeUsage() + if errWrite := writeRedisPubSubSubscribe(writer, redisUsageChannel, 1); errWrite != nil { + unsubscribe() + log.Errorf("redis protocol subscribe response error: %v", errWrite) + return + } + if !flush() { + unsubscribe() + return + } + s.streamRedisUsageSubscription(reader, writer, messages, unsubscribe) + return case "LPOP", "RPOP": if !authed { _ = writeRedisError(writer, "NOAUTH Authentication required.") @@ -182,6 +224,101 @@ func (s *Server) handleRedisConnection(conn net.Conn, reader *bufio.Reader) { } } +func (s *Server) streamRedisUsageSubscription(reader *bufio.Reader, writer *bufio.Writer, messages <-chan []byte, unsubscribe func()) { + if unsubscribe == nil { + return + } + defer unsubscribe() + + done := make(chan struct{}) + defer close(done) + + commands := make(chan redisSubscriptionCommand, 1) + go readRedisSubscriptionCommands(reader, commands, done) + + for { + select { + case msg, ok := <-messages: + if !ok { + return + } + if errWrite := writeRedisPubSubMessage(writer, redisUsageChannel, msg); errWrite != nil { + log.Errorf("redis protocol publish message error: %v", errWrite) + return + } + if errFlush := writer.Flush(); errFlush != nil { + log.Errorf("redis protocol flush error: %v", errFlush) + return + } + case command, ok := <-commands: + if !ok { + return + } + keepOpen := handleRedisSubscriptionCommand(writer, command) + if errFlush := writer.Flush(); errFlush != nil { + log.Errorf("redis protocol flush error: %v", errFlush) + return + } + if !keepOpen { + return + } + } + } +} + +func readRedisSubscriptionCommands(reader *bufio.Reader, commands chan<- redisSubscriptionCommand, done <-chan struct{}) { + defer close(commands) + + for { + args, err := readRESPArray(reader) + if err != nil { + if !errors.Is(err, io.EOF) { + select { + case commands <- redisSubscriptionCommand{err: err}: + case <-done: + } + } + return + } + select { + case commands <- redisSubscriptionCommand{args: args}: + case <-done: + return + } + } +} + +func handleRedisSubscriptionCommand(writer *bufio.Writer, command redisSubscriptionCommand) bool { + if command.err != nil { + _ = writeRedisError(writer, "ERR "+command.err.Error()) + return false + } + if len(command.args) == 0 { + _ = writeRedisError(writer, "ERR empty command") + return true + } + + cmd := strings.ToUpper(strings.TrimSpace(command.args[0])) + switch cmd { + case "PING": + payload := []byte(nil) + if len(command.args) > 1 { + payload = []byte(command.args[1]) + } + _ = writeRedisPubSubPong(writer, payload) + return true + case "UNSUBSCRIBE": + _ = writeRedisPubSubUnsubscribe(writer, redisUsageChannel, 0) + return false + case "QUIT": + _ = writeRedisSimpleString(writer, "OK") + return false + default: + _ = writeRedisError(writer, fmt.Sprintf("ERR unknown command '%s'", strings.ToLower(cmd))) + return true + } +} + func resolveRemoteIP(addr net.Addr) (ip string, localClient bool) { if addr == nil { return "", false @@ -232,6 +369,13 @@ func parseAuthPassword(args []string) (string, bool) { } } +func parseSubscribeChannel(args []string) (string, bool) { + if len(args) != 2 { + return "", false + } + return strings.TrimSpace(args[1]), true +} + func parsePopCount(args []string) (count int, hasCount bool, ok bool) { if len(args) != 2 && len(args) != 3 { return 0, false, false @@ -375,3 +519,68 @@ func writeRedisArrayOfBulkStrings(writer *bufio.Writer, items [][]byte) error { } return nil } + +func writeRedisInteger(writer *bufio.Writer, value int) error { + if writer == nil { + return net.ErrClosed + } + _, err := writer.WriteString(":" + strconv.Itoa(value) + "\r\n") + return err +} + +func writeRedisArrayHeader(writer *bufio.Writer, count int) error { + if writer == nil { + return net.ErrClosed + } + _, err := writer.WriteString("*" + strconv.Itoa(count) + "\r\n") + return err +} + +func writeRedisPubSubSubscribe(writer *bufio.Writer, channel string, count int) error { + if err := writeRedisArrayHeader(writer, 3); err != nil { + return err + } + if err := writeRedisBulkString(writer, []byte("subscribe")); err != nil { + return err + } + if err := writeRedisBulkString(writer, []byte(channel)); err != nil { + return err + } + return writeRedisInteger(writer, count) +} + +func writeRedisPubSubUnsubscribe(writer *bufio.Writer, channel string, count int) error { + if err := writeRedisArrayHeader(writer, 3); err != nil { + return err + } + if err := writeRedisBulkString(writer, []byte("unsubscribe")); err != nil { + return err + } + if err := writeRedisBulkString(writer, []byte(channel)); err != nil { + return err + } + return writeRedisInteger(writer, count) +} + +func writeRedisPubSubMessage(writer *bufio.Writer, channel string, payload []byte) error { + if err := writeRedisArrayHeader(writer, 3); err != nil { + return err + } + if err := writeRedisBulkString(writer, []byte("message")); err != nil { + return err + } + if err := writeRedisBulkString(writer, []byte(channel)); err != nil { + return err + } + return writeRedisBulkString(writer, payload) +} + +func writeRedisPubSubPong(writer *bufio.Writer, payload []byte) error { + if err := writeRedisArrayHeader(writer, 2); err != nil { + return err + } + if err := writeRedisBulkString(writer, []byte("pong")); err != nil { + return err + } + return writeRedisBulkString(writer, payload) +} diff --git a/internal/api/redis_queue_protocol_integration_test.go b/internal/api/redis_queue_protocol_integration_test.go index 1586d37c..8547e040 100644 --- a/internal/api/redis_queue_protocol_integration_test.go +++ b/internal/api/redis_queue_protocol_integration_test.go @@ -3,10 +3,13 @@ package api import ( "bufio" "bytes" + "encoding/json" "errors" "fmt" "io" "net" + "net/http" + "net/http/httptest" "strconv" "strings" "testing" @@ -171,6 +174,105 @@ func readRESPArrayOfBulkStrings(r *bufio.Reader) ([][]byte, error) { return out, nil } +func readTestRESPInteger(r *bufio.Reader) (int, error) { + prefix, err := r.ReadByte() + if err != nil { + return 0, err + } + if prefix != ':' { + return 0, fmt.Errorf("expected integer prefix ':', got %q", prefix) + } + + line, err := readTestRESPLine(r) + if err != nil { + return 0, err + } + value, err := strconv.Atoi(line) + if err != nil { + return 0, fmt.Errorf("invalid integer %q: %v", line, err) + } + return value, nil +} + +func readTestRESPArrayHeader(r *bufio.Reader) (int, error) { + prefix, err := r.ReadByte() + if err != nil { + return 0, err + } + if prefix != '*' { + return 0, fmt.Errorf("expected array prefix '*', got %q", prefix) + } + + line, err := readTestRESPLine(r) + if err != nil { + return 0, err + } + count, err := strconv.Atoi(line) + if err != nil { + return 0, fmt.Errorf("invalid array length %q: %v", line, err) + } + if count < 0 { + return 0, fmt.Errorf("invalid array length %d", count) + } + return count, nil +} + +func readTestRESPPubSubSubscribe(r *bufio.Reader) (string, int, error) { + count, err := readTestRESPArrayHeader(r) + if err != nil { + return "", 0, err + } + if count != 3 { + return "", 0, fmt.Errorf("subscribe array length = %d, want 3", count) + } + + kind, err := readTestRESPBulkString(r) + if err != nil { + return "", 0, err + } + if string(kind) != "subscribe" { + return "", 0, fmt.Errorf("pubsub kind = %q, want subscribe", string(kind)) + } + + channel, err := readTestRESPBulkString(r) + if err != nil { + return "", 0, err + } + subscriptions, err := readTestRESPInteger(r) + if err != nil { + return "", 0, err + } + return string(channel), subscriptions, nil +} + +func readTestRESPPubSubMessage(r *bufio.Reader) (string, []byte, error) { + count, err := readTestRESPArrayHeader(r) + if err != nil { + return "", nil, err + } + if count != 3 { + return "", nil, fmt.Errorf("message array length = %d, want 3", count) + } + + kind, err := readTestRESPBulkString(r) + if err != nil { + return "", nil, err + } + if string(kind) != "message" { + return "", nil, fmt.Errorf("pubsub kind = %q, want message", string(kind)) + } + + channel, err := readTestRESPBulkString(r) + if err != nil { + return "", nil, err + } + payload, err := readTestRESPBulkString(r) + if err != nil { + return "", nil, err + } + return string(channel), payload, nil +} + func TestRedisProtocol_ManagementDisabled_RejectsConnection(t *testing.T) { t.Setenv("MANAGEMENT_PASSWORD", "") redisqueue.SetEnabled(false) @@ -352,6 +454,127 @@ func TestRedisProtocol_AUTH_And_PopContracts(t *testing.T) { } } +func TestRedisProtocol_SubscribeUsageBroadcastsAndSkipsQueue(t *testing.T) { + const managementPassword = "test-management-password" + + t.Setenv("MANAGEMENT_PASSWORD", managementPassword) + redisqueue.SetEnabled(false) + t.Cleanup(func() { redisqueue.SetEnabled(false) }) + + server := newTestServer(t) + if !server.managementRoutesEnabled.Load() { + t.Fatalf("expected managementRoutesEnabled to be true") + } + + addr, stop := startRedisMuxListener(t, server) + t.Cleanup(stop) + + firstConn, errDialFirst := net.DialTimeout("tcp", addr, time.Second) + if errDialFirst != nil { + t.Fatalf("failed to dial first redis listener: %v", errDialFirst) + } + t.Cleanup(func() { _ = firstConn.Close() }) + firstReader := bufio.NewReader(firstConn) + _ = firstConn.SetDeadline(time.Now().Add(5 * time.Second)) + + if errWrite := writeTestRESPCommand(firstConn, "AUTH", managementPassword); errWrite != nil { + t.Fatalf("failed to write first AUTH command: %v", errWrite) + } + if msg, err := readTestRESPSimpleString(firstReader); err != nil { + t.Fatalf("failed to read first AUTH response: %v", err) + } else if msg != "OK" { + t.Fatalf("unexpected first AUTH response: %q", msg) + } + if errWrite := writeTestRESPCommand(firstConn, "SUBSCRIBE", "usage"); errWrite != nil { + t.Fatalf("failed to write first SUBSCRIBE command: %v", errWrite) + } + if channel, count, err := readTestRESPPubSubSubscribe(firstReader); err != nil { + t.Fatalf("failed to read first SUBSCRIBE response: %v", err) + } else if channel != "usage" || count != 1 { + t.Fatalf("unexpected first SUBSCRIBE response channel=%q count=%d", channel, count) + } + + secondConn, errDialSecond := net.DialTimeout("tcp", addr, time.Second) + if errDialSecond != nil { + t.Fatalf("failed to dial second redis listener: %v", errDialSecond) + } + t.Cleanup(func() { _ = secondConn.Close() }) + secondReader := bufio.NewReader(secondConn) + _ = secondConn.SetDeadline(time.Now().Add(5 * time.Second)) + + if errWrite := writeTestRESPCommand(secondConn, "AUTH", managementPassword); errWrite != nil { + t.Fatalf("failed to write second AUTH command: %v", errWrite) + } + if msg, err := readTestRESPSimpleString(secondReader); err != nil { + t.Fatalf("failed to read second AUTH response: %v", err) + } else if msg != "OK" { + t.Fatalf("unexpected second AUTH response: %q", msg) + } + if errWrite := writeTestRESPCommand(secondConn, "SUBSCRIBE", "usage"); errWrite != nil { + t.Fatalf("failed to write second SUBSCRIBE command: %v", errWrite) + } + if channel, count, err := readTestRESPPubSubSubscribe(secondReader); err != nil { + t.Fatalf("failed to read second SUBSCRIBE response: %v", err) + } else if channel != "usage" || count != 1 { + t.Fatalf("unexpected second SUBSCRIBE response channel=%q count=%d", channel, count) + } + + redisqueue.Enqueue([]byte(`{"id":1}`)) + + if channel, payload, err := readTestRESPPubSubMessage(firstReader); err != nil { + t.Fatalf("failed to read first pubsub message: %v", err) + } else if channel != "usage" || string(payload) != `{"id":1}` { + t.Fatalf("unexpected first pubsub message channel=%q payload=%q", channel, string(payload)) + } + if channel, payload, err := readTestRESPPubSubMessage(secondReader); err != nil { + t.Fatalf("failed to read second pubsub message: %v", err) + } else if channel != "usage" || string(payload) != `{"id":1}` { + t.Fatalf("unexpected second pubsub message channel=%q payload=%q", channel, string(payload)) + } + + popConn, errDialPop := net.DialTimeout("tcp", addr, time.Second) + if errDialPop != nil { + t.Fatalf("failed to dial pop redis listener: %v", errDialPop) + } + t.Cleanup(func() { _ = popConn.Close() }) + popReader := bufio.NewReader(popConn) + _ = popConn.SetDeadline(time.Now().Add(5 * time.Second)) + + if errWrite := writeTestRESPCommand(popConn, "AUTH", managementPassword); errWrite != nil { + t.Fatalf("failed to write pop AUTH command: %v", errWrite) + } + if msg, err := readTestRESPSimpleString(popReader); err != nil { + t.Fatalf("failed to read pop AUTH response: %v", err) + } else if msg != "OK" { + t.Fatalf("unexpected pop AUTH response: %q", msg) + } + if errWrite := writeTestRESPCommand(popConn, "LPOP", "usage"); errWrite != nil { + t.Fatalf("failed to write pop LPOP command: %v", errWrite) + } + item, errItem := readTestRESPBulkString(popReader) + if errItem != nil { + t.Fatalf("failed to read pop LPOP response: %v", errItem) + } + if item != nil { + t.Fatalf("expected subscribed usage to skip queue, got %q", string(item)) + } + + managementReq := httptest.NewRequest(http.MethodGet, "/v0/management/usage-queue?count=1", nil) + managementReq.Header.Set("Authorization", "Bearer "+managementPassword) + managementRR := httptest.NewRecorder() + server.engine.ServeHTTP(managementRR, managementReq) + if managementRR.Code != http.StatusOK { + t.Fatalf("management usage status = %d, want %d body=%s", managementRR.Code, http.StatusOK, managementRR.Body.String()) + } + var managementPayload []json.RawMessage + if errUnmarshal := json.Unmarshal(managementRR.Body.Bytes(), &managementPayload); errUnmarshal != nil { + t.Fatalf("unmarshal management usage response: %v", errUnmarshal) + } + if len(managementPayload) != 0 { + t.Fatalf("expected management usage queue to be empty, got %s", managementRR.Body.String()) + } +} + func TestRedisProtocol_IPBan_MirrorsManagementPolicy(t *testing.T) { const managementPassword = "test-management-password" diff --git a/internal/redisqueue/queue.go b/internal/redisqueue/queue.go index 2fea5839..6a2a594e 100644 --- a/internal/redisqueue/queue.go +++ b/internal/redisqueue/queue.go @@ -9,6 +9,7 @@ import ( const ( defaultRetentionSeconds int64 = 60 maxRetentionSeconds int64 = 3600 + usageSubscriberBuffer = 256 ) type queueItem struct { @@ -17,9 +18,11 @@ type queueItem struct { } type queue struct { - mu sync.Mutex - items []queueItem - head int + mu sync.Mutex + items []queueItem + head int + subscribers map[uint64]chan []byte + nextSubscriberID uint64 } var ( @@ -60,6 +63,9 @@ func Enqueue(payload []byte) { if len(payload) == 0 { return } + if global.publishToSubscribers(payload) { + return + } global.enqueue(payload) } @@ -73,11 +79,25 @@ func PopOldest(count int) [][]byte { return global.popOldest(count) } +func SubscribeUsage() (<-chan []byte, func()) { + return global.subscribeUsage() +} + func (q *queue) clear() { q.mu.Lock() - defer q.mu.Unlock() + + subscribers := make([]chan []byte, 0, len(q.subscribers)) + for _, subscriber := range q.subscribers { + subscribers = append(subscribers, subscriber) + } q.items = nil q.head = 0 + q.subscribers = nil + q.mu.Unlock() + + for _, subscriber := range subscribers { + close(subscriber) + } } func (q *queue) enqueue(payload []byte) { @@ -94,6 +114,61 @@ func (q *queue) enqueue(payload []byte) { q.maybeCompactLocked() } +func (q *queue) publishToSubscribers(payload []byte) bool { + q.mu.Lock() + defer q.mu.Unlock() + + if len(q.subscribers) == 0 { + return false + } + + for id, subscriber := range q.subscribers { + cloned := append([]byte(nil), payload...) + select { + case subscriber <- cloned: + default: + delete(q.subscribers, id) + close(subscriber) + } + } + + return true +} + +func (q *queue) subscribeUsage() (<-chan []byte, func()) { + subscriber := make(chan []byte, usageSubscriberBuffer) + + q.mu.Lock() + if q.subscribers == nil { + q.subscribers = make(map[uint64]chan []byte) + } + q.nextSubscriberID++ + id := q.nextSubscriberID + q.subscribers[id] = subscriber + q.mu.Unlock() + + var once sync.Once + unsubscribe := func() { + once.Do(func() { + q.unsubscribeUsage(id) + }) + } + return subscriber, unsubscribe +} + +func (q *queue) unsubscribeUsage(id uint64) { + q.mu.Lock() + subscriber, ok := q.subscribers[id] + if ok { + delete(q.subscribers, id) + } + q.mu.Unlock() + + if ok { + close(subscriber) + } +} + func (q *queue) popOldest(count int) [][]byte { now := time.Now() diff --git a/internal/redisqueue/queue_test.go b/internal/redisqueue/queue_test.go new file mode 100644 index 00000000..f40c8826 --- /dev/null +++ b/internal/redisqueue/queue_test.go @@ -0,0 +1,67 @@ +package redisqueue + +import ( + "testing" + "time" +) + +func TestEnqueueBroadcastsToUsageSubscribersAndSkipsQueue(t *testing.T) { + withEnabledQueue(t, func() { + first, unsubscribeFirst := SubscribeUsage() + defer unsubscribeFirst() + second, unsubscribeSecond := SubscribeUsage() + defer unsubscribeSecond() + + Enqueue([]byte("usage-record")) + + requireUsageSubscriberPayload(t, first, "usage-record") + requireUsageSubscriberPayload(t, second, "usage-record") + + if items := PopOldest(1); len(items) != 0 { + t.Fatalf("PopOldest() items = %q, want empty after subscriber broadcast", items) + } + + unsubscribeFirst() + unsubscribeSecond() + + Enqueue([]byte("queued-record")) + items := PopOldest(1) + if len(items) != 1 || string(items[0]) != "queued-record" { + t.Fatalf("PopOldest() items = %q, want queued record after unsubscribe", items) + } + }) +} + +func TestSetEnabledFalseClosesUsageSubscribers(t *testing.T) { + withEnabledQueue(t, func() { + subscriber, unsubscribe := SubscribeUsage() + defer unsubscribe() + + SetEnabled(false) + + select { + case _, ok := <-subscriber: + if ok { + t.Fatalf("subscriber channel remained open after SetEnabled(false)") + } + case <-time.After(time.Second): + t.Fatalf("timeout waiting for subscriber close") + } + }) +} + +func requireUsageSubscriberPayload(t *testing.T, subscriber <-chan []byte, want string) { + t.Helper() + + select { + case got, ok := <-subscriber: + if !ok { + t.Fatalf("subscriber closed before receiving %q", want) + } + if string(got) != want { + t.Fatalf("subscriber payload = %q, want %q", string(got), want) + } + case <-time.After(time.Second): + t.Fatalf("timeout waiting for subscriber payload %q", want) + } +}