test: remove unused Redis protocol tests and helpers

- Removed obsolete Redis protocol test cases and helper functions that were no longer relevant due to recent architecture changes.
- Streamlined remaining test files to align with updated Redis handling and connection management logic.
This commit is contained in:
Luis Pater
2026-05-19 23:12:57 +08:00
parent b9589e8ed6
commit 99fa530967
14 changed files with 72 additions and 1429 deletions
+4
View File
@@ -218,6 +218,10 @@ OmniRoute 是一个面向多供应商大语言模型的 AI 网关:它提供兼
一个公开的 CLIProxyAPI 兼容二开版本和配套管理面板,尽量保持与上游一致的使用方式,同时恢复内置使用量统计,并补充缓存命中率、首字响应时间、TPS 记录和面向 Docker 自托管的安装说明。 一个公开的 CLIProxyAPI 兼容二开版本和配套管理面板,尽量保持与上游一致的使用方式,同时恢复内置使用量统计,并补充缓存命中率、首字响应时间、TPS 记录和面向 Docker 自托管的安装说明。
### [Codex Switch](https://github.com/9ycrooked/CodexSwitch)
这是一个使用 Tauri 2 + Vue 3 构建的工具,用于管理多个 OpenAI Codex 桌面账户。它可以在已保存的 ChatGPT/Codex 认证配置之间切换,实时查看 5 小时和每周配额使用情况,验证 token 健康状态,查看当前账户详情,并在无需手动复制的情况下导入或保存 auth.json 文件。
> [!NOTE] > [!NOTE]
> 如果你开发了 CLIProxyAPI 的移植或衍生项目,请提交 PR 将其添加到此列表中。 > 如果你开发了 CLIProxyAPI 的移植或衍生项目,请提交 PR 将其添加到此列表中。
+4
View File
@@ -217,6 +217,10 @@ OmniRouteはマルチプロバイダーLLM向けのAIゲートウェイです:
上流に近い使い方を維持する公開CLIProxyAPI互換フォーク兼管理パネルです。内蔵の使用量統計を復元し、キャッシュヒット率、初回バイト待ち時間、TPSの記録、Docker向けのセルフホスト手順を追加しています。 上流に近い使い方を維持する公開CLIProxyAPI互換フォーク兼管理パネルです。内蔵の使用量統計を復元し、キャッシュヒット率、初回バイト待ち時間、TPSの記録、Docker向けのセルフホスト手順を追加しています。
### [Codex Switch](https://github.com/9ycrooked/CodexSwitch)
Tauri 2 + Vue 3で構築された、複数のOpenAI Codexデスクトップアカウントを管理するためのツールです。保存済みのChatGPT/Codex認証プロファイルを切り替え、5時間および週次クォータ使用量をリアルタイムで確認し、tokenの状態を検証し、現在のアカウント詳細を表示し、手動コピーなしでauth.jsonファイルをインポートまたは保存できます。
> [!NOTE] > [!NOTE]
> CLIProxyAPIの移植版またはそれに触発されたプロジェクトを開発した場合は、PRを送ってこのリストに追加してください。 > CLIProxyAPIの移植版またはそれに触発されたプロジェクトを開発した場合は、PRを送ってこのリストに追加してください。
-77
View File
@@ -1,77 +0,0 @@
package main
import "testing"
func TestParseHomeFlagConfigHostPort(t *testing.T) {
cfg, err := parseHomeFlagConfig("home.example.com:8327", "secret")
if err != nil {
t.Fatalf("parseHomeFlagConfig() error = %v", err)
}
if !cfg.Enabled {
t.Fatal("Enabled = false, want true")
}
if cfg.Host != "home.example.com" {
t.Fatalf("Host = %q, want home.example.com", cfg.Host)
}
if cfg.Port != 8327 {
t.Fatalf("Port = %d, want 8327", cfg.Port)
}
if cfg.Password != "secret" {
t.Fatalf("Password = %q, want secret", cfg.Password)
}
if cfg.TLS.Enable {
t.Fatal("TLS.Enable = true, want false")
}
}
func TestParseHomeFlagConfigRediss(t *testing.T) {
cfg, err := parseHomeFlagConfig("rediss://:url-secret@home.example.com:444?server-name=home.example.com&skip_verify=true&ca-cert=C%3A%2Fcerts%2Fca.pem", "")
if err != nil {
t.Fatalf("parseHomeFlagConfig() error = %v", err)
}
if cfg.Host != "home.example.com" {
t.Fatalf("Host = %q, want home.example.com", cfg.Host)
}
if cfg.Port != 444 {
t.Fatalf("Port = %d, want 444", cfg.Port)
}
if cfg.Password != "url-secret" {
t.Fatalf("Password = %q, want url-secret", cfg.Password)
}
if !cfg.TLS.Enable {
t.Fatal("TLS.Enable = false, want true")
}
if cfg.TLS.ServerName != "home.example.com" {
t.Fatalf("TLS.ServerName = %q, want home.example.com", cfg.TLS.ServerName)
}
if !cfg.TLS.InsecureSkipVerify {
t.Fatal("TLS.InsecureSkipVerify = false, want true")
}
if cfg.TLS.CACert != "C:/certs/ca.pem" {
t.Fatalf("TLS.CACert = %q, want C:/certs/ca.pem", cfg.TLS.CACert)
}
}
func TestParseHomeFlagConfigPasswordFlagOverridesURLPassword(t *testing.T) {
cfg, err := parseHomeFlagConfig("rediss://:url-secret@home.example.com:444", "flag-secret")
if err != nil {
t.Fatalf("parseHomeFlagConfig() error = %v", err)
}
if cfg.Password != "flag-secret" {
t.Fatalf("Password = %q, want flag-secret", cfg.Password)
}
}
func TestParseHomeFlagConfigDisableClusterDiscovery(t *testing.T) {
cfg, err := parseHomeFlagConfig("redis://home.example.com:8327?disable-cluster-discovery=true", "")
if err != nil {
t.Fatalf("parseHomeFlagConfig() error = %v", err)
}
if !cfg.DisableClusterDiscovery {
t.Fatal("DisableClusterDiscovery = false, want true")
}
}
+1 -179
View File
@@ -10,11 +10,9 @@ import (
"fmt" "fmt"
"io" "io"
"io/fs" "io/fs"
"net"
"net/url" "net/url"
"os" "os"
"path/filepath" "path/filepath"
"strconv"
"strings" "strings"
"time" "time"
@@ -53,120 +51,6 @@ func init() {
buildinfo.BuildDate = BuildDate buildinfo.BuildDate = BuildDate
} }
func parseHomeFlagConfig(rawAddr string, password string) (config.HomeConfig, error) {
rawAddr = strings.TrimSpace(rawAddr)
if rawAddr == "" {
return config.HomeConfig{}, fmt.Errorf("address is empty")
}
if strings.Contains(rawAddr, "://") {
return parseHomeURLConfig(rawAddr, password)
}
host, portStr, errSplit := net.SplitHostPort(rawAddr)
if errSplit != nil {
return config.HomeConfig{}, fmt.Errorf("expected host:port, redis://host:port, or rediss://host:port: %w", errSplit)
}
host = strings.TrimSpace(host)
if host == "" {
return config.HomeConfig{}, fmt.Errorf("host is empty")
}
port, errPort := parseHomePort(portStr)
if errPort != nil {
return config.HomeConfig{}, errPort
}
return config.HomeConfig{
Enabled: true,
Host: host,
Port: port,
Password: password,
}, nil
}
func parseHomeURLConfig(rawAddr string, password string) (config.HomeConfig, error) {
parsed, errParse := url.Parse(rawAddr)
if errParse != nil {
return config.HomeConfig{}, fmt.Errorf("parse URL: %w", errParse)
}
scheme := strings.ToLower(strings.TrimSpace(parsed.Scheme))
if scheme != "redis" && scheme != "rediss" {
return config.HomeConfig{}, fmt.Errorf("unsupported URL scheme %q", parsed.Scheme)
}
host := strings.TrimSpace(parsed.Hostname())
if host == "" {
return config.HomeConfig{}, fmt.Errorf("host is empty")
}
port, errPort := parseHomePort(parsed.Port())
if errPort != nil {
return config.HomeConfig{}, errPort
}
if password == "" && parsed.User != nil {
if urlPassword, ok := parsed.User.Password(); ok {
password = urlPassword
}
}
homeCfg := config.HomeConfig{
Enabled: true,
Host: host,
Port: port,
Password: password,
}
query := parsed.Query()
homeCfg.DisableClusterDiscovery = parseHomeBoolQuery(query, "disable-cluster-discovery", "disable_cluster_discovery")
if scheme == "rediss" {
homeCfg.TLS.Enable = true
homeCfg.TLS.ServerName = strings.TrimSpace(firstHomeQueryValue(query, "server-name", "server_name"))
homeCfg.TLS.InsecureSkipVerify = parseHomeBoolQuery(query, "insecure-skip-verify", "insecure_skip_verify", "skip_verify")
homeCfg.TLS.CACert = strings.TrimSpace(firstHomeQueryValue(query, "ca-cert", "ca_cert"))
}
return homeCfg, nil
}
func parseHomePort(rawPort string) (int, error) {
rawPort = strings.TrimSpace(rawPort)
if rawPort == "" {
return 0, fmt.Errorf("port is empty")
}
port, errPort := strconv.Atoi(rawPort)
if errPort != nil || port <= 0 || port > 65535 {
return 0, fmt.Errorf("invalid port %q", rawPort)
}
return port, nil
}
func firstHomeQueryValue(values url.Values, keys ...string) string {
for _, key := range keys {
if value := values.Get(key); value != "" {
return value
}
}
return ""
}
func parseHomeBoolQuery(values url.Values, keys ...string) bool {
for _, key := range keys {
value := strings.TrimSpace(values.Get(key))
if value == "" {
continue
}
parsed, errParse := strconv.ParseBool(value)
return errParse == nil && parsed
}
return false
}
// main is the entry point of the application. // main is the entry point of the application.
// It parses command-line flags, loads configuration, and starts the appropriate // It parses command-line flags, loads configuration, and starts the appropriate
// service based on the provided flags (login, codex-login, or server mode). // service based on the provided flags (login, codex-login, or server mode).
@@ -188,8 +72,6 @@ func main() {
var vertexImportPrefix string var vertexImportPrefix string
var configPath string var configPath string
var password string var password string
var homeAddr string
var homePassword string
var homeJWT string var homeJWT string
var homeDisableClusterDiscovery bool var homeDisableClusterDiscovery bool
var tuiMode bool var tuiMode bool
@@ -211,10 +93,8 @@ func main() {
flag.StringVar(&vertexImport, "vertex-import", "", "Import Vertex service account key JSON file") flag.StringVar(&vertexImport, "vertex-import", "", "Import Vertex service account key JSON file")
flag.StringVar(&vertexImportPrefix, "vertex-import-prefix", "", "Prefix for Vertex model namespacing (use with -vertex-import)") flag.StringVar(&vertexImportPrefix, "vertex-import-prefix", "", "Prefix for Vertex model namespacing (use with -vertex-import)")
flag.StringVar(&password, "password", "", "") flag.StringVar(&password, "password", "", "")
flag.StringVar(&homeAddr, "home", "", "Home control plane address in host:port, redis://host:port, or rediss://host:port format (loads config from home and skips local config file)")
flag.StringVar(&homePassword, "home-password", "", "Home control plane password (Redis AUTH)")
flag.StringVar(&homeJWT, "home-jwt", "", "Home control plane JWT for mTLS certificate bootstrap and connection") flag.StringVar(&homeJWT, "home-jwt", "", "Home control plane JWT for mTLS certificate bootstrap and connection")
flag.BoolVar(&homeDisableClusterDiscovery, "home-disable-cluster-discovery", false, "Disable Home CLUSTER NODES discovery and keep using the configured -home address") flag.BoolVar(&homeDisableClusterDiscovery, "home-disable-cluster-discovery", false, "Disable Home CLUSTER NODES discovery and keep using the configured -home-jwt address")
flag.BoolVar(&tuiMode, "tui", false, "Start with terminal management UI") flag.BoolVar(&tuiMode, "tui", false, "Start with terminal management UI")
flag.BoolVar(&standalone, "standalone", false, "In TUI mode, start an embedded local server") flag.BoolVar(&standalone, "standalone", false, "In TUI mode, start an embedded local server")
flag.BoolVar(&localModel, "local-model", false, "Use embedded model catalog only, skip remote model fetching") flag.BoolVar(&localModel, "local-model", false, "Use embedded model catalog only, skip remote model fetching")
@@ -302,17 +182,6 @@ func main() {
} }
writableBase := util.WritablePath() writableBase := util.WritablePath()
// Allow env var fallback for home flags so they can be configured without command args.
if strings.TrimSpace(homeAddr) == "" {
if v, ok := lookupEnv("HOME_ADDR", "home_addr"); ok {
homeAddr = v
}
}
if strings.TrimSpace(homePassword) == "" {
if v, ok := lookupEnv("HOME_PASSWORD", "home_password"); ok {
homePassword = v
}
}
if strings.TrimSpace(homeJWT) == "" { if strings.TrimSpace(homeJWT) == "" {
if v, ok := lookupEnv("HOME_JWT", "home_jwt"); ok { if v, ok := lookupEnv("HOME_JWT", "home_jwt"); ok {
homeJWT = v homeJWT = v
@@ -426,53 +295,6 @@ func main() {
configFilePath = filepath.Join(wd, "config.yaml") configFilePath = filepath.Join(wd, "config.yaml")
} }
// Local stores are intentionally disabled when config is loaded from home.
usePostgresStore = false
useObjectStore = false
useGitStore = false
} else if strings.TrimSpace(homeAddr) != "" {
configLoadedFromHome = true
trimmedHomePassword := strings.TrimSpace(homePassword)
homeCfg, errHomeCfg := parseHomeFlagConfig(homeAddr, trimmedHomePassword)
if errHomeCfg != nil {
log.Errorf("invalid -home address %q: %v", homeAddr, errHomeCfg)
return
}
if homeDisableClusterDiscovery {
homeCfg.DisableClusterDiscovery = true
}
homeClient := home.New(homeCfg)
defer homeClient.Close()
ctxHome, cancelHome := context.WithTimeout(context.Background(), 30*time.Second)
raw, errGetConfig := homeClient.GetConfig(ctxHome)
cancelHome()
if errGetConfig != nil {
log.Errorf("failed to fetch config from home: %v", errGetConfig)
return
}
parsed, errParseConfig := config.ParseConfigBytes(raw)
if errParseConfig != nil {
log.Errorf("failed to parse config payload from home: %v", errParseConfig)
return
}
if parsed == nil {
parsed = &config.Config{}
}
parsed.Home = homeCfg
parsed.Port = 8317 // Default to 8317 for home mode, can be overridden by home config
parsed.UsageStatisticsEnabled = true
cfg = parsed
// Keep a non-empty config path for downstream components (log paths, management assets, etc),
// but do not require the file to exist when loading config from home.
if strings.TrimSpace(configPath) != "" {
configFilePath = configPath
} else {
configFilePath = filepath.Join(wd, "config.yaml")
}
// Local stores are intentionally disabled when config is loaded from home. // Local stores are intentionally disabled when config is loaded from home.
usePostgresStore = false usePostgresStore = false
useObjectStore = false useObjectStore = false
+2 -22
View File
@@ -11,26 +11,6 @@ tls:
cert: "" cert: ""
key: "" key: ""
# Optional "home" control plane integration over Redis protocol.
home:
enabled: false
host: "127.0.0.1"
port: 6379
password: ""
# Keep CPA pinned to the configured home address instead of switching to CLUSTER NODES entries.
# Useful when Home is behind NAT, Docker networking, or a reverse proxy.
disable-cluster-discovery: false
# Optional TLS for the outbound Redis connection to the home control plane.
# Enable this when connecting through rediss:// or an SSL stream proxy.
tls:
enable: false
# Optional SNI/certificate name override. Leave empty to use the configured home host.
server-name: ""
# Trust a private CA bundle in addition to system roots.
ca-cert: ""
# Only for testing self-signed endpoints; disables certificate verification.
insecure-skip-verify: false
# Management API settings # Management API settings
remote-management: remote-management:
# Whether to allow remote (non-localhost) management access. # Whether to allow remote (non-localhost) management access.
@@ -86,8 +66,8 @@ error-logs-max-files: 10
# When false, disable in-memory usage statistics aggregation # When false, disable in-memory usage statistics aggregation
usage-statistics-enabled: false usage-statistics-enabled: false
# How long (in seconds) Redis usage queue items are retained in memory for the RESP interface (LPOP/RPOP). # How long (in seconds) usage queue items are retained in memory for the Management API.
# Note: the in-process Redis RESP usage output is disabled when home.enabled is true. # The local Redis RESP usage output is disabled.
# Default: 60. Max: 3600. # Default: 60. Max: 3600.
redis-usage-queue-retention-seconds: 60 redis-usage-queue-retention-seconds: 60
+1 -13
View File
@@ -103,20 +103,8 @@ func (s *Server) routeMuxConnection(conn net.Conn, httpListener *muxListener) {
} }
if isRedisRESPPrefix(prefix[0]) { if isRedisRESPPrefix(prefix[0]) {
if s.cfg != nil && s.cfg.Home.Enabled {
if errClose := conn.Close(); errClose != nil {
log.Errorf("failed to close redis connection while home mode is enabled: %v", errClose)
}
return
}
if !s.managementRoutesEnabled.Load() {
if errClose := conn.Close(); errClose != nil {
log.Errorf("failed to close redis connection while management is disabled: %v", errClose)
}
return
}
_ = conn.SetReadDeadline(time.Time{}) _ = conn.SetReadDeadline(time.Time{})
s.handleRedisConnection(conn, reader) s.handleRedisConnection(conn)
return return
} }
+5 -548
View File
@@ -2,25 +2,11 @@ package api
import ( import (
"bufio" "bufio"
"errors"
"fmt"
"io"
"net" "net"
"net/http"
"strconv"
"strings"
"github.com/router-for-me/CLIProxyAPI/v7/internal/redisqueue"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
const redisUsageChannel = "usage"
type redisSubscriptionCommand struct {
args []string
err error
}
func isRedisRESPPrefix(prefix byte) bool { func isRedisRESPPrefix(prefix byte) bool {
switch prefix { switch prefix {
case '*', '$', '+', '-', ':': case '*', '$', '+', '-', ':':
@@ -30,13 +16,11 @@ func isRedisRESPPrefix(prefix byte) bool {
} }
} }
func (s *Server) handleRedisConnection(conn net.Conn, reader *bufio.Reader) { func (s *Server) handleRedisConnection(conn net.Conn) {
if s == nil || conn == nil || reader == nil { if s == nil || conn == nil {
return return
} }
clientIP, localClient := resolveRemoteIP(conn.RemoteAddr())
authed := false
writer := bufio.NewWriter(conn) writer := bufio.NewWriter(conn)
defer func() { defer func() {
if errClose := conn.Close(); errClose != nil { if errClose := conn.Close(); errClose != nil {
@@ -44,432 +28,10 @@ func (s *Server) handleRedisConnection(conn net.Conn, reader *bufio.Reader) {
} }
}() }()
flush := func() bool { _ = writeRedisError(writer, "ERR RESP AUTH disabled; use mTLS")
if errFlush := writer.Flush(); errFlush != nil { if errFlush := writer.Flush(); errFlush != nil {
log.Errorf("redis protocol flush error: %v", errFlush) log.Errorf("redis protocol flush error: %v", errFlush)
return false
}
return true
} }
if s.cfg != nil && s.cfg.Home.Enabled {
_ = writeRedisError(writer, "ERR redis usage output disabled in home mode")
_ = writer.Flush()
return
}
for {
if !s.managementRoutesEnabled.Load() {
return
}
args, err := readRESPArray(reader)
if err != nil {
if !errors.Is(err, io.EOF) {
_ = writeRedisError(writer, "ERR "+err.Error())
_ = writer.Flush()
}
return
}
if len(args) == 0 {
_ = writeRedisError(writer, "ERR empty command")
if !flush() {
return
}
continue
}
cmd := strings.ToUpper(strings.TrimSpace(args[0]))
if cmd != "AUTH" && !authed {
if s.mgmt != nil {
_, statusCode, errMsg := s.mgmt.AuthenticateManagementKey(clientIP, localClient, "")
if statusCode == http.StatusForbidden && strings.HasPrefix(errMsg, "IP banned due to too many failed attempts") {
_ = writeRedisError(writer, "ERR "+errMsg)
} else {
_ = writeRedisError(writer, "NOAUTH Authentication required.")
}
} else {
_ = writeRedisError(writer, "NOAUTH Authentication required.")
}
if !flush() {
return
}
continue
}
switch cmd {
case "AUTH":
password, ok := parseAuthPassword(args)
if !ok {
if s.mgmt != nil {
_, statusCode, errMsg := s.mgmt.AuthenticateManagementKey(clientIP, localClient, "")
if statusCode == http.StatusForbidden && strings.HasPrefix(errMsg, "IP banned due to too many failed attempts") {
_ = writeRedisError(writer, "ERR "+errMsg)
if !flush() {
return
}
continue
}
}
_ = writeRedisError(writer, "ERR wrong number of arguments for 'auth' command")
if !flush() {
return
}
continue
}
if s.mgmt == nil {
_ = writeRedisError(writer, "ERR remote management disabled")
if !flush() {
return
}
continue
}
allowed, _, errMsg := s.mgmt.AuthenticateManagementKey(clientIP, localClient, password)
if !allowed {
_ = writeRedisError(writer, "ERR "+errMsg)
if !flush() {
return
}
continue
}
authed = true
_ = writeRedisSimpleString(writer, "OK")
if !flush() {
return
}
case "SUBSCRIBE":
if !authed {
_ = writeRedisError(writer, "NOAUTH Authentication required.")
if !flush() {
return
}
continue
}
channel, ok := parseSubscribeChannel(args)
if !ok {
_ = writeRedisError(writer, "ERR wrong number of arguments for 'subscribe' command")
if !flush() {
return
}
continue
}
if !strings.EqualFold(channel, redisUsageChannel) {
_ = writeRedisError(writer, fmt.Sprintf("ERR unsupported channel '%s'", channel))
if !flush() {
return
}
continue
}
messages, unsubscribe := redisqueue.SubscribeUsage()
if errWrite := writeRedisPubSubSubscribe(writer, redisUsageChannel, 1); errWrite != nil {
unsubscribe()
log.Errorf("redis protocol subscribe response error: %v", errWrite)
return
}
if !flush() {
unsubscribe()
return
}
s.streamRedisUsageSubscription(reader, writer, messages, unsubscribe)
return
case "LPOP", "RPOP":
if !authed {
_ = writeRedisError(writer, "NOAUTH Authentication required.")
if !flush() {
return
}
continue
}
count, hasCount, ok := parsePopCount(args)
if !ok {
_ = writeRedisError(writer, "ERR wrong number of arguments for '"+strings.ToLower(cmd)+"' command")
if !flush() {
return
}
continue
}
if count <= 0 {
_ = writeRedisError(writer, "ERR value is not an integer or out of range")
if !flush() {
return
}
continue
}
items := redisqueue.PopOldest(count)
if hasCount {
_ = writeRedisArrayOfBulkStrings(writer, items)
if !flush() {
return
}
continue
}
if len(items) == 0 {
_ = writeRedisNilBulkString(writer)
if !flush() {
return
}
continue
}
_ = writeRedisBulkString(writer, items[0])
if !flush() {
return
}
default:
_ = writeRedisError(writer, fmt.Sprintf("ERR unknown command '%s'", strings.ToLower(cmd)))
if !flush() {
return
}
}
}
}
func (s *Server) streamRedisUsageSubscription(reader *bufio.Reader, writer *bufio.Writer, messages <-chan []byte, unsubscribe func()) {
if unsubscribe == nil {
return
}
defer unsubscribe()
done := make(chan struct{})
defer close(done)
commands := make(chan redisSubscriptionCommand, 1)
go readRedisSubscriptionCommands(reader, commands, done)
for {
select {
case msg, ok := <-messages:
if !ok {
return
}
if errWrite := writeRedisPubSubMessage(writer, redisUsageChannel, msg); errWrite != nil {
log.Errorf("redis protocol publish message error: %v", errWrite)
return
}
if errFlush := writer.Flush(); errFlush != nil {
log.Errorf("redis protocol flush error: %v", errFlush)
return
}
case command, ok := <-commands:
if !ok {
return
}
keepOpen := handleRedisSubscriptionCommand(writer, command)
if errFlush := writer.Flush(); errFlush != nil {
log.Errorf("redis protocol flush error: %v", errFlush)
return
}
if !keepOpen {
return
}
}
}
}
func readRedisSubscriptionCommands(reader *bufio.Reader, commands chan<- redisSubscriptionCommand, done <-chan struct{}) {
defer close(commands)
for {
args, err := readRESPArray(reader)
if err != nil {
if !errors.Is(err, io.EOF) {
select {
case commands <- redisSubscriptionCommand{err: err}:
case <-done:
}
}
return
}
select {
case commands <- redisSubscriptionCommand{args: args}:
case <-done:
return
}
}
}
func handleRedisSubscriptionCommand(writer *bufio.Writer, command redisSubscriptionCommand) bool {
if command.err != nil {
_ = writeRedisError(writer, "ERR "+command.err.Error())
return false
}
if len(command.args) == 0 {
_ = writeRedisError(writer, "ERR empty command")
return true
}
cmd := strings.ToUpper(strings.TrimSpace(command.args[0]))
switch cmd {
case "PING":
payload := []byte(nil)
if len(command.args) > 1 {
payload = []byte(command.args[1])
}
_ = writeRedisPubSubPong(writer, payload)
return true
case "UNSUBSCRIBE":
_ = writeRedisPubSubUnsubscribe(writer, redisUsageChannel, 0)
return false
case "QUIT":
_ = writeRedisSimpleString(writer, "OK")
return false
default:
_ = writeRedisError(writer, fmt.Sprintf("ERR unknown command '%s'", strings.ToLower(cmd)))
return true
}
}
func resolveRemoteIP(addr net.Addr) (ip string, localClient bool) {
if addr == nil {
return "", false
}
var host string
switch a := addr.(type) {
case *net.TCPAddr:
if a != nil && a.IP != nil {
if ip4 := a.IP.To4(); ip4 != nil {
host = ip4.String()
} else {
host = a.IP.String()
}
}
default:
host = addr.String()
if h, _, err := net.SplitHostPort(host); err == nil {
host = h
}
host = strings.TrimSpace(host)
if raw, _, ok := strings.Cut(host, "%"); ok {
host = raw
}
if parsed := net.ParseIP(host); parsed != nil {
if ip4 := parsed.To4(); ip4 != nil {
host = ip4.String()
} else {
host = parsed.String()
}
}
}
host = strings.TrimSpace(host)
localClient = host == "127.0.0.1" || host == "::1"
return host, localClient
}
func parseAuthPassword(args []string) (string, bool) {
switch len(args) {
case 2:
return args[1], true
case 3:
// Support AUTH <username> <password> by ignoring username for compatibility.
return args[2], true
default:
return "", false
}
}
func parseSubscribeChannel(args []string) (string, bool) {
if len(args) != 2 {
return "", false
}
return strings.TrimSpace(args[1]), true
}
func parsePopCount(args []string) (count int, hasCount bool, ok bool) {
if len(args) != 2 && len(args) != 3 {
return 0, false, false
}
if len(args) == 2 {
return 1, false, true
}
parsed, err := strconv.Atoi(strings.TrimSpace(args[2]))
if err != nil {
return 0, true, true
}
return parsed, true, true
}
func readRESPArray(reader *bufio.Reader) ([]string, error) {
prefix, err := reader.ReadByte()
if err != nil {
return nil, err
}
if prefix != '*' {
return nil, fmt.Errorf("protocol error")
}
line, err := readRESPLine(reader)
if err != nil {
return nil, err
}
count, err := strconv.Atoi(line)
if err != nil || count < 0 {
return nil, fmt.Errorf("protocol error")
}
args := make([]string, 0, count)
for i := 0; i < count; i++ {
value, err := readRESPString(reader)
if err != nil {
return nil, err
}
args = append(args, value)
}
return args, nil
}
func readRESPString(reader *bufio.Reader) (string, error) {
prefix, err := reader.ReadByte()
if err != nil {
return "", err
}
switch prefix {
case '$':
return readRESPBulkString(reader)
case '+', ':':
return readRESPLine(reader)
default:
return "", fmt.Errorf("protocol error")
}
}
func readRESPBulkString(reader *bufio.Reader) (string, error) {
line, err := readRESPLine(reader)
if err != nil {
return "", err
}
length, err := strconv.Atoi(line)
if err != nil {
return "", fmt.Errorf("protocol error")
}
if length < 0 {
return "", nil
}
buf := make([]byte, length+2)
if _, err := io.ReadFull(reader, buf); err != nil {
return "", err
}
if length+2 < 2 || buf[length] != '\r' || buf[length+1] != '\n' {
return "", fmt.Errorf("protocol error")
}
return string(buf[:length]), nil
}
func readRESPLine(reader *bufio.Reader) (string, error) {
line, err := reader.ReadString('\n')
if err != nil {
return "", err
}
line = strings.TrimSuffix(line, "\n")
line = strings.TrimSuffix(line, "\r")
return line, nil
}
func writeRedisSimpleString(writer *bufio.Writer, value string) error {
if writer == nil {
return net.ErrClosed
}
_, err := writer.WriteString("+" + value + "\r\n")
return err
} }
func writeRedisError(writer *bufio.Writer, message string) error { func writeRedisError(writer *bufio.Writer, message string) error {
@@ -479,108 +41,3 @@ func writeRedisError(writer *bufio.Writer, message string) error {
_, err := writer.WriteString("-" + message + "\r\n") _, err := writer.WriteString("-" + message + "\r\n")
return err return err
} }
func writeRedisNilBulkString(writer *bufio.Writer) error {
if writer == nil {
return net.ErrClosed
}
_, err := writer.WriteString("$-1\r\n")
return err
}
func writeRedisBulkString(writer *bufio.Writer, payload []byte) error {
if writer == nil {
return net.ErrClosed
}
if payload == nil {
return writeRedisNilBulkString(writer)
}
if _, err := writer.WriteString("$" + strconv.Itoa(len(payload)) + "\r\n"); err != nil {
return err
}
if _, err := writer.Write(payload); err != nil {
return err
}
_, err := writer.WriteString("\r\n")
return err
}
func writeRedisArrayOfBulkStrings(writer *bufio.Writer, items [][]byte) error {
if writer == nil {
return net.ErrClosed
}
if _, err := writer.WriteString("*" + strconv.Itoa(len(items)) + "\r\n"); err != nil {
return err
}
for i := range items {
if err := writeRedisBulkString(writer, items[i]); err != nil {
return err
}
}
return nil
}
func writeRedisInteger(writer *bufio.Writer, value int) error {
if writer == nil {
return net.ErrClosed
}
_, err := writer.WriteString(":" + strconv.Itoa(value) + "\r\n")
return err
}
func writeRedisArrayHeader(writer *bufio.Writer, count int) error {
if writer == nil {
return net.ErrClosed
}
_, err := writer.WriteString("*" + strconv.Itoa(count) + "\r\n")
return err
}
func writeRedisPubSubSubscribe(writer *bufio.Writer, channel string, count int) error {
if err := writeRedisArrayHeader(writer, 3); err != nil {
return err
}
if err := writeRedisBulkString(writer, []byte("subscribe")); err != nil {
return err
}
if err := writeRedisBulkString(writer, []byte(channel)); err != nil {
return err
}
return writeRedisInteger(writer, count)
}
func writeRedisPubSubUnsubscribe(writer *bufio.Writer, channel string, count int) error {
if err := writeRedisArrayHeader(writer, 3); err != nil {
return err
}
if err := writeRedisBulkString(writer, []byte("unsubscribe")); err != nil {
return err
}
if err := writeRedisBulkString(writer, []byte(channel)); err != nil {
return err
}
return writeRedisInteger(writer, count)
}
func writeRedisPubSubMessage(writer *bufio.Writer, channel string, payload []byte) error {
if err := writeRedisArrayHeader(writer, 3); err != nil {
return err
}
if err := writeRedisBulkString(writer, []byte("message")); err != nil {
return err
}
if err := writeRedisBulkString(writer, []byte(channel)); err != nil {
return err
}
return writeRedisBulkString(writer, payload)
}
func writeRedisPubSubPong(writer *bufio.Writer, payload []byte) error {
if err := writeRedisArrayHeader(writer, 2); err != nil {
return err
}
if err := writeRedisBulkString(writer, []byte("pong")); err != nil {
return err
}
return writeRedisBulkString(writer, payload)
}
@@ -3,14 +3,9 @@ package api
import ( import (
"bufio" "bufio"
"bytes" "bytes"
"encoding/json"
"errors" "errors"
"fmt" "fmt"
"io"
"net" "net"
"net/http"
"net/http/httptest"
"strconv"
"strings" "strings"
"testing" "testing"
"time" "time"
@@ -18,18 +13,6 @@ import (
"github.com/router-for-me/CLIProxyAPI/v7/internal/redisqueue" "github.com/router-for-me/CLIProxyAPI/v7/internal/redisqueue"
) )
type remoteAddrConn struct {
net.Conn
remoteAddr net.Addr
}
func (c *remoteAddrConn) RemoteAddr() net.Addr {
if c == nil {
return nil
}
return c.remoteAddr
}
func startRedisMuxListener(t *testing.T, server *Server) (addr string, stop func()) { func startRedisMuxListener(t *testing.T, server *Server) (addr string, stop func()) {
t.Helper() t.Helper()
@@ -86,17 +69,6 @@ func readTestRESPLine(r *bufio.Reader) (string, error) {
return strings.TrimSuffix(line, "\r\n"), nil return strings.TrimSuffix(line, "\r\n"), nil
} }
func readTestRESPSimpleString(r *bufio.Reader) (string, error) {
prefix, err := r.ReadByte()
if err != nil {
return "", err
}
if prefix != '+' {
return "", fmt.Errorf("expected simple string prefix '+', got %q", prefix)
}
return readTestRESPLine(r)
}
func readTestRESPError(r *bufio.Reader) (string, error) { func readTestRESPError(r *bufio.Reader) (string, error) {
prefix, err := r.ReadByte() prefix, err := r.ReadByte()
if err != nil { if err != nil {
@@ -108,171 +80,6 @@ func readTestRESPError(r *bufio.Reader) (string, error) {
return readTestRESPLine(r) return readTestRESPLine(r)
} }
func readTestRESPBulkString(r *bufio.Reader) ([]byte, error) {
prefix, err := r.ReadByte()
if err != nil {
return nil, err
}
if prefix != '$' {
return nil, fmt.Errorf("expected bulk string prefix '$', got %q", prefix)
}
line, err := readTestRESPLine(r)
if err != nil {
return nil, err
}
length, err := strconv.Atoi(line)
if err != nil {
return nil, fmt.Errorf("invalid bulk string length %q: %v", line, err)
}
if length == -1 {
return nil, nil
}
if length < -1 {
return nil, fmt.Errorf("invalid bulk string length %d", length)
}
payload := make([]byte, length+2)
if _, err := io.ReadFull(r, payload); err != nil {
return nil, err
}
if payload[length] != '\r' || payload[length+1] != '\n' {
return nil, fmt.Errorf("invalid bulk string terminator")
}
return payload[:length], nil
}
func readRESPArrayOfBulkStrings(r *bufio.Reader) ([][]byte, error) {
prefix, err := r.ReadByte()
if err != nil {
return nil, err
}
if prefix != '*' {
return nil, fmt.Errorf("expected array prefix '*', got %q", prefix)
}
line, err := readTestRESPLine(r)
if err != nil {
return nil, err
}
count, err := strconv.Atoi(line)
if err != nil {
return nil, fmt.Errorf("invalid array length %q: %v", line, err)
}
if count < 0 {
return nil, fmt.Errorf("invalid array length %d", count)
}
out := make([][]byte, 0, count)
for i := 0; i < count; i++ {
item, err := readTestRESPBulkString(r)
if err != nil {
return nil, err
}
out = append(out, item)
}
return out, nil
}
func readTestRESPInteger(r *bufio.Reader) (int, error) {
prefix, err := r.ReadByte()
if err != nil {
return 0, err
}
if prefix != ':' {
return 0, fmt.Errorf("expected integer prefix ':', got %q", prefix)
}
line, err := readTestRESPLine(r)
if err != nil {
return 0, err
}
value, err := strconv.Atoi(line)
if err != nil {
return 0, fmt.Errorf("invalid integer %q: %v", line, err)
}
return value, nil
}
func readTestRESPArrayHeader(r *bufio.Reader) (int, error) {
prefix, err := r.ReadByte()
if err != nil {
return 0, err
}
if prefix != '*' {
return 0, fmt.Errorf("expected array prefix '*', got %q", prefix)
}
line, err := readTestRESPLine(r)
if err != nil {
return 0, err
}
count, err := strconv.Atoi(line)
if err != nil {
return 0, fmt.Errorf("invalid array length %q: %v", line, err)
}
if count < 0 {
return 0, fmt.Errorf("invalid array length %d", count)
}
return count, nil
}
func readTestRESPPubSubSubscribe(r *bufio.Reader) (string, int, error) {
count, err := readTestRESPArrayHeader(r)
if err != nil {
return "", 0, err
}
if count != 3 {
return "", 0, fmt.Errorf("subscribe array length = %d, want 3", count)
}
kind, err := readTestRESPBulkString(r)
if err != nil {
return "", 0, err
}
if string(kind) != "subscribe" {
return "", 0, fmt.Errorf("pubsub kind = %q, want subscribe", string(kind))
}
channel, err := readTestRESPBulkString(r)
if err != nil {
return "", 0, err
}
subscriptions, err := readTestRESPInteger(r)
if err != nil {
return "", 0, err
}
return string(channel), subscriptions, nil
}
func readTestRESPPubSubMessage(r *bufio.Reader) (string, []byte, error) {
count, err := readTestRESPArrayHeader(r)
if err != nil {
return "", nil, err
}
if count != 3 {
return "", nil, fmt.Errorf("message array length = %d, want 3", count)
}
kind, err := readTestRESPBulkString(r)
if err != nil {
return "", nil, err
}
if string(kind) != "message" {
return "", nil, fmt.Errorf("pubsub kind = %q, want message", string(kind))
}
channel, err := readTestRESPBulkString(r)
if err != nil {
return "", nil, err
}
payload, err := readTestRESPBulkString(r)
if err != nil {
return "", nil, err
}
return string(channel), payload, nil
}
func TestRedisProtocol_ManagementDisabled_RejectsConnection(t *testing.T) { func TestRedisProtocol_ManagementDisabled_RejectsConnection(t *testing.T) {
t.Setenv("MANAGEMENT_PASSWORD", "") t.Setenv("MANAGEMENT_PASSWORD", "")
redisqueue.SetEnabled(false) redisqueue.SetEnabled(false)
@@ -296,13 +103,19 @@ func TestRedisProtocol_ManagementDisabled_RejectsConnection(t *testing.T) {
t.Fatalf("failed to write RESP command: %v", errWrite) t.Fatalf("failed to write RESP command: %v", errWrite)
} }
if msg, err := readTestRESPError(bufio.NewReader(conn)); err != nil {
t.Fatalf("failed to read disabled RESP error: %v", err)
} else if msg != "ERR RESP AUTH disabled; use mTLS" {
t.Fatalf("unexpected disabled RESP error: %q", msg)
}
buf := make([]byte, 1) buf := make([]byte, 1)
_, errRead := conn.Read(buf) _, errRead := conn.Read(buf)
if errRead == nil { if errRead == nil {
t.Fatalf("expected connection to be closed when management is disabled") t.Fatalf("expected connection to be closed after disabled RESP error")
} }
if ne, ok := errRead.(net.Error); ok && ne.Timeout() { if ne, ok := errRead.(net.Error); ok && ne.Timeout() {
t.Fatalf("expected connection to be closed when management is disabled, got timeout: %v", errRead) t.Fatalf("expected connection to be closed after disabled RESP error, got timeout: %v", errRead)
} }
} }
@@ -333,17 +146,23 @@ func TestRedisProtocol_HomeEnabled_DisablesConnection(t *testing.T) {
_ = conn.SetDeadline(time.Now().Add(2 * time.Second)) _ = conn.SetDeadline(time.Now().Add(2 * time.Second))
_ = writeTestRESPCommand(conn, "PING") _ = writeTestRESPCommand(conn, "PING")
if msg, err := readTestRESPError(bufio.NewReader(conn)); err != nil {
t.Fatalf("failed to read disabled RESP error: %v", err)
} else if msg != "ERR RESP AUTH disabled; use mTLS" {
t.Fatalf("unexpected disabled RESP error: %q", msg)
}
buf := make([]byte, 1) buf := make([]byte, 1)
_, errRead := conn.Read(buf) _, errRead := conn.Read(buf)
if errRead == nil { if errRead == nil {
t.Fatalf("expected connection to be closed when home mode is enabled") t.Fatalf("expected connection to be closed after disabled RESP error")
} }
if ne, ok := errRead.(net.Error); ok && ne.Timeout() { if ne, ok := errRead.(net.Error); ok && ne.Timeout() {
t.Fatalf("expected connection to be closed when home mode is enabled, got timeout: %v", errRead) t.Fatalf("expected connection to be closed after disabled RESP error, got timeout: %v", errRead)
} }
} }
func TestRedisProtocol_AUTH_And_PopContracts(t *testing.T) { func TestRedisProtocol_AUTH_DisabledAndClosesConnection(t *testing.T) {
const managementPassword = "test-management-password" const managementPassword = "test-management-password"
t.Setenv("MANAGEMENT_PASSWORD", managementPassword) t.Setenv("MANAGEMENT_PASSWORD", managementPassword)
@@ -368,369 +187,21 @@ func TestRedisProtocol_AUTH_And_PopContracts(t *testing.T) {
_ = conn.SetDeadline(time.Now().Add(5 * time.Second)) _ = conn.SetDeadline(time.Now().Add(5 * time.Second))
if errWrite := writeTestRESPCommand(conn, "AUTH", "test-key"); errWrite != nil {
t.Fatalf("failed to write AUTH command: %v", errWrite)
}
if msg, err := readTestRESPError(reader); err != nil {
t.Fatalf("failed to read AUTH error: %v", err)
} else if msg != "ERR invalid management key" {
t.Fatalf("unexpected AUTH error: %q", msg)
}
if errWrite := writeTestRESPCommand(conn, "LPOP", "queue"); errWrite != nil {
t.Fatalf("failed to write LPOP command: %v", errWrite)
}
if msg, err := readTestRESPError(reader); err != nil {
t.Fatalf("failed to read LPOP NOAUTH error: %v", err)
} else if msg != "NOAUTH Authentication required." {
t.Fatalf("unexpected LPOP NOAUTH error: %q", msg)
}
if errWrite := writeTestRESPCommand(conn, "AUTH", managementPassword); errWrite != nil { if errWrite := writeTestRESPCommand(conn, "AUTH", managementPassword); errWrite != nil {
t.Fatalf("failed to write AUTH command: %v", errWrite) t.Fatalf("failed to write AUTH command: %v", errWrite)
} }
if msg, err := readTestRESPSimpleString(reader); err != nil { if msg, err := readTestRESPError(reader); err != nil {
t.Fatalf("failed to read AUTH response: %v", err) t.Fatalf("failed to read disabled AUTH error: %v", err)
} else if msg != "OK" { } else if msg != "ERR RESP AUTH disabled; use mTLS" {
t.Fatalf("unexpected AUTH response: %q", msg) t.Fatalf("unexpected disabled AUTH error: %q", msg)
} }
if !redisqueue.Enabled() { buf := make([]byte, 1)
t.Fatalf("expected redisqueue to be enabled") _, errRead := conn.Read(buf)
if errRead == nil {
t.Fatalf("expected connection to be closed after disabled AUTH error")
} }
redisqueue.Enqueue([]byte("a")) if ne, ok := errRead.(net.Error); ok && ne.Timeout() {
redisqueue.Enqueue([]byte("b")) t.Fatalf("expected connection to be closed after disabled AUTH error, got timeout: %v", errRead)
redisqueue.Enqueue([]byte("c"))
if errWrite := writeTestRESPCommand(conn, "RPOP", "queue"); errWrite != nil {
t.Fatalf("failed to write RPOP command: %v", errWrite)
}
if item, err := readTestRESPBulkString(reader); err != nil {
t.Fatalf("failed to read RPOP response: %v", err)
} else if string(item) != "a" {
t.Fatalf("unexpected RPOP item: %q", string(item))
}
if errWrite := writeTestRESPCommand(conn, "LPOP", "queue"); errWrite != nil {
t.Fatalf("failed to write LPOP command: %v", errWrite)
}
if item, err := readTestRESPBulkString(reader); err != nil {
t.Fatalf("failed to read LPOP response: %v", err)
} else if string(item) != "b" {
t.Fatalf("unexpected LPOP item: %q", string(item))
}
if errWrite := writeTestRESPCommand(conn, "RPOP", "queue", "10"); errWrite != nil {
t.Fatalf("failed to write RPOP count command: %v", errWrite)
}
items, errItems := readRESPArrayOfBulkStrings(reader)
if errItems != nil {
t.Fatalf("failed to read RPOP count response: %v", errItems)
}
if len(items) != 1 || string(items[0]) != "c" {
t.Fatalf("unexpected RPOP count items: %#v", items)
}
if errWrite := writeTestRESPCommand(conn, "LPOP", "queue"); errWrite != nil {
t.Fatalf("failed to write LPOP empty command: %v", errWrite)
}
item, errItem := readTestRESPBulkString(reader)
if errItem != nil {
t.Fatalf("failed to read LPOP empty response: %v", errItem)
}
if item != nil {
t.Fatalf("expected nil bulk string for empty queue, got %q", string(item))
}
if errWrite := writeTestRESPCommand(conn, "RPOP", "queue", "2"); errWrite != nil {
t.Fatalf("failed to write RPOP empty count command: %v", errWrite)
}
emptyItems, errEmpty := readRESPArrayOfBulkStrings(reader)
if errEmpty != nil {
t.Fatalf("failed to read RPOP empty count response: %v", errEmpty)
}
if len(emptyItems) != 0 {
t.Fatalf("expected empty array for empty queue with count, got %#v", emptyItems)
}
}
func TestRedisProtocol_SubscribeUsageBroadcastsAndSkipsQueue(t *testing.T) {
const managementPassword = "test-management-password"
t.Setenv("MANAGEMENT_PASSWORD", managementPassword)
redisqueue.SetEnabled(false)
t.Cleanup(func() { redisqueue.SetEnabled(false) })
server := newTestServer(t)
if !server.managementRoutesEnabled.Load() {
t.Fatalf("expected managementRoutesEnabled to be true")
}
addr, stop := startRedisMuxListener(t, server)
t.Cleanup(stop)
firstConn, errDialFirst := net.DialTimeout("tcp", addr, time.Second)
if errDialFirst != nil {
t.Fatalf("failed to dial first redis listener: %v", errDialFirst)
}
t.Cleanup(func() { _ = firstConn.Close() })
firstReader := bufio.NewReader(firstConn)
_ = firstConn.SetDeadline(time.Now().Add(5 * time.Second))
if errWrite := writeTestRESPCommand(firstConn, "AUTH", managementPassword); errWrite != nil {
t.Fatalf("failed to write first AUTH command: %v", errWrite)
}
if msg, err := readTestRESPSimpleString(firstReader); err != nil {
t.Fatalf("failed to read first AUTH response: %v", err)
} else if msg != "OK" {
t.Fatalf("unexpected first AUTH response: %q", msg)
}
if errWrite := writeTestRESPCommand(firstConn, "SUBSCRIBE", "usage"); errWrite != nil {
t.Fatalf("failed to write first SUBSCRIBE command: %v", errWrite)
}
if channel, count, err := readTestRESPPubSubSubscribe(firstReader); err != nil {
t.Fatalf("failed to read first SUBSCRIBE response: %v", err)
} else if channel != "usage" || count != 1 {
t.Fatalf("unexpected first SUBSCRIBE response channel=%q count=%d", channel, count)
}
secondConn, errDialSecond := net.DialTimeout("tcp", addr, time.Second)
if errDialSecond != nil {
t.Fatalf("failed to dial second redis listener: %v", errDialSecond)
}
t.Cleanup(func() { _ = secondConn.Close() })
secondReader := bufio.NewReader(secondConn)
_ = secondConn.SetDeadline(time.Now().Add(5 * time.Second))
if errWrite := writeTestRESPCommand(secondConn, "AUTH", managementPassword); errWrite != nil {
t.Fatalf("failed to write second AUTH command: %v", errWrite)
}
if msg, err := readTestRESPSimpleString(secondReader); err != nil {
t.Fatalf("failed to read second AUTH response: %v", err)
} else if msg != "OK" {
t.Fatalf("unexpected second AUTH response: %q", msg)
}
if errWrite := writeTestRESPCommand(secondConn, "SUBSCRIBE", "usage"); errWrite != nil {
t.Fatalf("failed to write second SUBSCRIBE command: %v", errWrite)
}
if channel, count, err := readTestRESPPubSubSubscribe(secondReader); err != nil {
t.Fatalf("failed to read second SUBSCRIBE response: %v", err)
} else if channel != "usage" || count != 1 {
t.Fatalf("unexpected second SUBSCRIBE response channel=%q count=%d", channel, count)
}
redisqueue.Enqueue([]byte(`{"id":1}`))
if channel, payload, err := readTestRESPPubSubMessage(firstReader); err != nil {
t.Fatalf("failed to read first pubsub message: %v", err)
} else if channel != "usage" || string(payload) != `{"id":1}` {
t.Fatalf("unexpected first pubsub message channel=%q payload=%q", channel, string(payload))
}
if channel, payload, err := readTestRESPPubSubMessage(secondReader); err != nil {
t.Fatalf("failed to read second pubsub message: %v", err)
} else if channel != "usage" || string(payload) != `{"id":1}` {
t.Fatalf("unexpected second pubsub message channel=%q payload=%q", channel, string(payload))
}
popConn, errDialPop := net.DialTimeout("tcp", addr, time.Second)
if errDialPop != nil {
t.Fatalf("failed to dial pop redis listener: %v", errDialPop)
}
t.Cleanup(func() { _ = popConn.Close() })
popReader := bufio.NewReader(popConn)
_ = popConn.SetDeadline(time.Now().Add(5 * time.Second))
if errWrite := writeTestRESPCommand(popConn, "AUTH", managementPassword); errWrite != nil {
t.Fatalf("failed to write pop AUTH command: %v", errWrite)
}
if msg, err := readTestRESPSimpleString(popReader); err != nil {
t.Fatalf("failed to read pop AUTH response: %v", err)
} else if msg != "OK" {
t.Fatalf("unexpected pop AUTH response: %q", msg)
}
if errWrite := writeTestRESPCommand(popConn, "LPOP", "usage"); errWrite != nil {
t.Fatalf("failed to write pop LPOP command: %v", errWrite)
}
item, errItem := readTestRESPBulkString(popReader)
if errItem != nil {
t.Fatalf("failed to read pop LPOP response: %v", errItem)
}
if item != nil {
t.Fatalf("expected subscribed usage to skip queue, got %q", string(item))
}
managementReq := httptest.NewRequest(http.MethodGet, "/v0/management/usage-queue?count=1", nil)
managementReq.Header.Set("Authorization", "Bearer "+managementPassword)
managementRR := httptest.NewRecorder()
server.engine.ServeHTTP(managementRR, managementReq)
if managementRR.Code != http.StatusOK {
t.Fatalf("management usage status = %d, want %d body=%s", managementRR.Code, http.StatusOK, managementRR.Body.String())
}
var managementPayload []json.RawMessage
if errUnmarshal := json.Unmarshal(managementRR.Body.Bytes(), &managementPayload); errUnmarshal != nil {
t.Fatalf("unmarshal management usage response: %v", errUnmarshal)
}
if len(managementPayload) != 0 {
t.Fatalf("expected management usage queue to be empty, got %s", managementRR.Body.String())
}
}
func TestRedisProtocol_IPBan_MirrorsManagementPolicy(t *testing.T) {
const managementPassword = "test-management-password"
t.Setenv("MANAGEMENT_PASSWORD", managementPassword)
redisqueue.SetEnabled(false)
t.Cleanup(func() { redisqueue.SetEnabled(false) })
server := newTestServer(t)
if !server.managementRoutesEnabled.Load() {
t.Fatalf("expected managementRoutesEnabled to be true")
}
clientConn, serverConn := net.Pipe()
t.Cleanup(func() { _ = clientConn.Close() })
t.Cleanup(func() { _ = serverConn.Close() })
fakeRemote := &net.TCPAddr{
IP: net.ParseIP("1.2.3.4"),
Port: 1234,
}
wrappedConn := &remoteAddrConn{Conn: serverConn, remoteAddr: fakeRemote}
go server.handleRedisConnection(wrappedConn, bufio.NewReader(wrappedConn))
reader := bufio.NewReader(clientConn)
_ = clientConn.SetDeadline(time.Now().Add(5 * time.Second))
for i := 0; i < 5; i++ {
if errWrite := writeTestRESPCommand(clientConn, "LPOP", "queue"); errWrite != nil {
t.Fatalf("failed to write LPOP command: %v", errWrite)
}
if msg, err := readTestRESPError(reader); err != nil {
t.Fatalf("failed to read LPOP NOAUTH error: %v", err)
} else if msg != "NOAUTH Authentication required." {
t.Fatalf("unexpected LPOP NOAUTH error at attempt %d: %q", i+1, msg)
}
}
if errWrite := writeTestRESPCommand(clientConn, "LPOP", "queue"); errWrite != nil {
t.Fatalf("failed to write LPOP command after failures: %v", errWrite)
}
msg, err := readTestRESPError(reader)
if err != nil {
t.Fatalf("failed to read LPOP banned error: %v", err)
}
if !strings.HasPrefix(msg, "ERR IP banned due to too many failed attempts. Try again in") {
t.Fatalf("unexpected LPOP banned error: %q", msg)
}
}
func TestRedisProtocol_AUTH_IPBan_BlocksCorrectPasswordDuringBan(t *testing.T) {
const managementPassword = "test-management-password"
t.Setenv("MANAGEMENT_PASSWORD", managementPassword)
redisqueue.SetEnabled(false)
t.Cleanup(func() { redisqueue.SetEnabled(false) })
server := newTestServer(t)
if !server.managementRoutesEnabled.Load() {
t.Fatalf("expected managementRoutesEnabled to be true")
}
clientConn, serverConn := net.Pipe()
t.Cleanup(func() { _ = clientConn.Close() })
t.Cleanup(func() { _ = serverConn.Close() })
fakeRemote := &net.TCPAddr{
IP: net.ParseIP("1.2.3.4"),
Port: 1234,
}
wrappedConn := &remoteAddrConn{Conn: serverConn, remoteAddr: fakeRemote}
go server.handleRedisConnection(wrappedConn, bufio.NewReader(wrappedConn))
reader := bufio.NewReader(clientConn)
_ = clientConn.SetDeadline(time.Now().Add(5 * time.Second))
for i := 0; i < 5; i++ {
if errWrite := writeTestRESPCommand(clientConn, "AUTH", "wrong-password"); errWrite != nil {
t.Fatalf("failed to write AUTH command: %v", errWrite)
}
if msg, err := readTestRESPError(reader); err != nil {
t.Fatalf("failed to read AUTH error: %v", err)
} else if msg != "ERR invalid management key" {
t.Fatalf("unexpected AUTH error at attempt %d: %q", i+1, msg)
}
}
for i := 0; i < 2; i++ {
if errWrite := writeTestRESPCommand(clientConn, "AUTH", "wrong-password"); errWrite != nil {
t.Fatalf("failed to write AUTH command after failures: %v", errWrite)
}
msg, err := readTestRESPError(reader)
if err != nil {
t.Fatalf("failed to read AUTH banned error: %v", err)
}
if !strings.HasPrefix(msg, "ERR IP banned due to too many failed attempts. Try again in") {
t.Fatalf("unexpected AUTH banned error at attempt %d: %q", i+6, msg)
}
}
if errWrite := writeTestRESPCommand(clientConn, "AUTH", managementPassword); errWrite != nil {
t.Fatalf("failed to write AUTH command with correct password: %v", errWrite)
}
msg, err := readTestRESPError(reader)
if err != nil {
t.Fatalf("failed to read AUTH banned error for correct password: %v", err)
}
if !strings.HasPrefix(msg, "ERR IP banned due to too many failed attempts. Try again in") {
t.Fatalf("unexpected AUTH banned error for correct password: %q", msg)
}
}
func TestRedisProtocol_LOCALHOST_AUTH_IPBan_BlocksCorrectPasswordDuringBan(t *testing.T) {
const managementPassword = "test-management-password"
t.Setenv("MANAGEMENT_PASSWORD", managementPassword)
redisqueue.SetEnabled(false)
t.Cleanup(func() { redisqueue.SetEnabled(false) })
server := newTestServer(t)
if !server.managementRoutesEnabled.Load() {
t.Fatalf("expected managementRoutesEnabled to be true")
}
addr, stop := startRedisMuxListener(t, server)
t.Cleanup(stop)
conn, errDial := net.DialTimeout("tcp", addr, time.Second)
if errDial != nil {
t.Fatalf("failed to dial redis listener: %v", errDial)
}
t.Cleanup(func() { _ = conn.Close() })
reader := bufio.NewReader(conn)
_ = conn.SetDeadline(time.Now().Add(5 * time.Second))
for i := 0; i < 5; i++ {
if errWrite := writeTestRESPCommand(conn, "AUTH", "wrong-password"); errWrite != nil {
t.Fatalf("failed to write AUTH command: %v", errWrite)
}
if msg, err := readTestRESPError(reader); err != nil {
t.Fatalf("failed to read AUTH error: %v", err)
} else if msg != "ERR invalid management key" {
t.Fatalf("unexpected AUTH error at attempt %d: %q", i+1, msg)
}
}
if errWrite := writeTestRESPCommand(conn, "AUTH", managementPassword); errWrite != nil {
t.Fatalf("failed to write AUTH command with correct password: %v", errWrite)
}
msg, err := readTestRESPError(reader)
if err != nil {
t.Fatalf("failed to read AUTH banned error for correct password: %v", err)
}
if !strings.HasPrefix(msg, "ERR IP banned due to too many failed attempts. Try again in") {
t.Fatalf("unexpected AUTH banned error for correct password: %q", msg)
} }
} }
+1
View File
@@ -6,6 +6,7 @@ import (
"net/http/httptest" "net/http/httptest"
"os" "os"
"path/filepath" "path/filepath"
"strings"
"testing" "testing"
"time" "time"
+4 -4
View File
@@ -37,8 +37,8 @@ type Config struct {
// TLS config controls HTTPS server settings. // TLS config controls HTTPS server settings.
TLS TLSConfig `yaml:"tls" json:"tls"` TLS TLSConfig `yaml:"tls" json:"tls"`
// Home config enables the Redis-based control plane integration. // Home config is runtime-only and is populated from -home-jwt.
Home HomeConfig `yaml:"home" json:"-"` Home HomeConfig `yaml:"-" json:"-"`
// RemoteManagement nests management-related options under 'remote-management'. // RemoteManagement nests management-related options under 'remote-management'.
RemoteManagement RemoteManagement `yaml:"remote-management" json:"-"` RemoteManagement RemoteManagement `yaml:"remote-management" json:"-"`
@@ -69,8 +69,8 @@ type Config struct {
// UsageStatisticsEnabled toggles in-memory usage aggregation; when false, usage data is discarded. // UsageStatisticsEnabled toggles in-memory usage aggregation; when false, usage data is discarded.
UsageStatisticsEnabled bool `yaml:"usage-statistics-enabled" json:"usage-statistics-enabled"` UsageStatisticsEnabled bool `yaml:"usage-statistics-enabled" json:"usage-statistics-enabled"`
// RedisUsageQueueRetentionSeconds controls how long (in seconds) usage queue items // RedisUsageQueueRetentionSeconds controls how long usage queue items are retained
// are retained in memory for the Redis RESP interface (LPOP/RPOP). // in memory for Management API consumers.
// Default: 60. Max: 3600. // Default: 60. Max: 3600.
RedisUsageQueueRetentionSeconds int `yaml:"redis-usage-queue-retention-seconds" json:"redis-usage-queue-retention-seconds"` RedisUsageQueueRetentionSeconds int `yaml:"redis-usage-queue-retention-seconds" json:"redis-usage-queue-retention-seconds"`
+1 -2
View File
@@ -1,11 +1,10 @@
package config package config
// HomeConfig configures the optional "home" control plane integration over Redis protocol. // HomeConfig stores runtime-only Home control plane settings from -home-jwt.
type HomeConfig struct { type HomeConfig struct {
Enabled bool `yaml:"enabled" json:"enabled"` Enabled bool `yaml:"enabled" json:"enabled"`
Host string `yaml:"host" json:"-"` Host string `yaml:"host" json:"-"`
Port int `yaml:"port" json:"-"` Port int `yaml:"port" json:"-"`
Password string `yaml:"password" json:"-"`
DisableClusterDiscovery bool `yaml:"disable-cluster-discovery" json:"-"` DisableClusterDiscovery bool `yaml:"disable-cluster-discovery" json:"-"`
TLS HomeTLSConfig `yaml:"tls" json:"-"` TLS HomeTLSConfig `yaml:"tls" json:"-"`
} }
+17 -21
View File
@@ -2,13 +2,12 @@ package config
import "testing" import "testing"
func TestParseConfigBytesHomeTLS(t *testing.T) { func TestParseConfigBytesIgnoresHomeConfig(t *testing.T) {
cfg, err := ParseConfigBytes([]byte(` cfg, err := ParseConfigBytes([]byte(`
home: home:
enabled: true enabled: true
host: home.example.com host: home.example.com
port: 444 port: 444
password: secret
disable-cluster-discovery: true disable-cluster-discovery: true
tls: tls:
enable: true enable: true
@@ -20,31 +19,28 @@ home:
t.Fatalf("ParseConfigBytes() error = %v", err) t.Fatalf("ParseConfigBytes() error = %v", err)
} }
if !cfg.Home.Enabled { if cfg.Home.Enabled {
t.Fatal("Home.Enabled = false, want true") t.Fatal("Home.Enabled = true, want false")
} }
if cfg.Home.Host != "home.example.com" { if cfg.Home.Host != "" {
t.Fatalf("Home.Host = %q, want home.example.com", cfg.Home.Host) t.Fatalf("Home.Host = %q, want empty", cfg.Home.Host)
} }
if cfg.Home.Port != 444 { if cfg.Home.Port != 0 {
t.Fatalf("Home.Port = %d, want 444", cfg.Home.Port) t.Fatalf("Home.Port = %d, want 0", cfg.Home.Port)
} }
if cfg.Home.Password != "secret" { if cfg.Home.DisableClusterDiscovery {
t.Fatalf("Home.Password = %q, want secret", cfg.Home.Password) t.Fatal("Home.DisableClusterDiscovery = true, want false")
} }
if !cfg.Home.DisableClusterDiscovery { if cfg.Home.TLS.Enable {
t.Fatal("Home.DisableClusterDiscovery = false, want true") t.Fatal("Home.TLS.Enable = true, want false")
} }
if !cfg.Home.TLS.Enable { if cfg.Home.TLS.ServerName != "" {
t.Fatal("Home.TLS.Enable = false, want true") t.Fatalf("Home.TLS.ServerName = %q, want empty", cfg.Home.TLS.ServerName)
} }
if cfg.Home.TLS.ServerName != "home.example.com" { if cfg.Home.TLS.CACert != "" {
t.Fatalf("Home.TLS.ServerName = %q, want home.example.com", cfg.Home.TLS.ServerName) t.Fatalf("Home.TLS.CACert = %q, want empty", cfg.Home.TLS.CACert)
} }
if cfg.Home.TLS.CACert != "C:/certs/ca.pem" { if cfg.Home.TLS.InsecureSkipVerify {
t.Fatalf("Home.TLS.CACert = %q, want C:/certs/ca.pem", cfg.Home.TLS.CACert) t.Fatal("Home.TLS.InsecureSkipVerify = true, want false")
}
if !cfg.Home.TLS.InsecureSkipVerify {
t.Fatal("Home.TLS.InsecureSkipVerify = false, want true")
} }
} }
-1
View File
@@ -180,7 +180,6 @@ func (c *Client) redisOptionsLocked(addr string) (*redis.Options, error) {
} }
return &redis.Options{ return &redis.Options{
Addr: addr, Addr: addr,
Password: c.homeCfg.Password,
TLSConfig: tlsConfig, TLSConfig: tlsConfig,
DialTimeout: homeRedisOperationTimeout, DialTimeout: homeRedisOperationTimeout,
ReadTimeout: homeRedisOperationTimeout, ReadTimeout: homeRedisOperationTimeout,
+5 -6
View File
@@ -37,10 +37,9 @@ func TestAuthDispatchRequestDefaultsCountToOne(t *testing.T) {
func TestRedisOptionsHomeTLSDisabled(t *testing.T) { func TestRedisOptionsHomeTLSDisabled(t *testing.T) {
client := New(config.HomeConfig{ client := New(config.HomeConfig{
Enabled: true, Enabled: true,
Host: "127.0.0.1", Host: "127.0.0.1",
Port: 6379, Port: 6379,
Password: "secret",
}) })
client.mu.Lock() client.mu.Lock()
@@ -53,8 +52,8 @@ func TestRedisOptionsHomeTLSDisabled(t *testing.T) {
if options.TLSConfig != nil { if options.TLSConfig != nil {
t.Fatalf("TLSConfig = %#v, want nil", options.TLSConfig) t.Fatalf("TLSConfig = %#v, want nil", options.TLSConfig)
} }
if options.Password != "secret" { if options.Password != "" {
t.Fatalf("Password = %q, want secret", options.Password) t.Fatalf("Password = %q, want empty", options.Password)
} }
} }