Merge pull request #3430 from router-for-me/home

Implement Redis integration with TLS support and cluster discovery option
This commit is contained in:
Luis Pater
2026-05-16 20:44:20 +08:00
committed by GitHub
8 changed files with 511 additions and 34 deletions
+125
View File
@@ -0,0 +1,125 @@
package main
import (
"fmt"
"net"
"net/url"
"strconv"
"strings"
"github.com/router-for-me/CLIProxyAPI/v7/internal/config"
)
func parseHomeFlagConfig(rawAddr string, password string) (config.HomeConfig, error) {
rawAddr = strings.TrimSpace(rawAddr)
if rawAddr == "" {
return config.HomeConfig{}, fmt.Errorf("address is empty")
}
if strings.Contains(rawAddr, "://") {
return parseHomeURLConfig(rawAddr, password)
}
host, portStr, errSplit := net.SplitHostPort(rawAddr)
if errSplit != nil {
return config.HomeConfig{}, fmt.Errorf("expected host:port, redis://host:port, or rediss://host:port: %w", errSplit)
}
host = strings.TrimSpace(host)
if host == "" {
return config.HomeConfig{}, fmt.Errorf("host is empty")
}
port, errPort := parseHomePort(portStr)
if errPort != nil {
return config.HomeConfig{}, errPort
}
return config.HomeConfig{
Enabled: true,
Host: host,
Port: port,
Password: password,
}, nil
}
func parseHomeURLConfig(rawAddr string, password string) (config.HomeConfig, error) {
parsed, errParse := url.Parse(rawAddr)
if errParse != nil {
return config.HomeConfig{}, fmt.Errorf("parse URL: %w", errParse)
}
scheme := strings.ToLower(strings.TrimSpace(parsed.Scheme))
if scheme != "redis" && scheme != "rediss" {
return config.HomeConfig{}, fmt.Errorf("unsupported URL scheme %q", parsed.Scheme)
}
host := strings.TrimSpace(parsed.Hostname())
if host == "" {
return config.HomeConfig{}, fmt.Errorf("host is empty")
}
port, errPort := parseHomePort(parsed.Port())
if errPort != nil {
return config.HomeConfig{}, errPort
}
if password == "" && parsed.User != nil {
if urlPassword, ok := parsed.User.Password(); ok {
password = urlPassword
}
}
homeCfg := config.HomeConfig{
Enabled: true,
Host: host,
Port: port,
Password: password,
}
query := parsed.Query()
homeCfg.DisableClusterDiscovery = parseHomeBoolQuery(query, "disable-cluster-discovery", "disable_cluster_discovery")
if scheme == "rediss" {
homeCfg.TLS.Enable = true
homeCfg.TLS.ServerName = strings.TrimSpace(firstHomeQueryValue(query, "server-name", "server_name"))
homeCfg.TLS.InsecureSkipVerify = parseHomeBoolQuery(query, "insecure-skip-verify", "insecure_skip_verify", "skip_verify")
homeCfg.TLS.CACert = strings.TrimSpace(firstHomeQueryValue(query, "ca-cert", "ca_cert"))
}
return homeCfg, nil
}
func parseHomePort(rawPort string) (int, error) {
rawPort = strings.TrimSpace(rawPort)
if rawPort == "" {
return 0, fmt.Errorf("port is empty")
}
port, errPort := strconv.Atoi(rawPort)
if errPort != nil || port <= 0 || port > 65535 {
return 0, fmt.Errorf("invalid port %q", rawPort)
}
return port, nil
}
func firstHomeQueryValue(values url.Values, keys ...string) string {
for _, key := range keys {
if value := values.Get(key); value != "" {
return value
}
}
return ""
}
func parseHomeBoolQuery(values url.Values, keys ...string) bool {
for _, key := range keys {
value := strings.TrimSpace(values.Get(key))
if value == "" {
continue
}
parsed, errParse := strconv.ParseBool(value)
return errParse == nil && parsed
}
return false
}
+77
View File
@@ -0,0 +1,77 @@
package main
import "testing"
func TestParseHomeFlagConfigHostPort(t *testing.T) {
cfg, err := parseHomeFlagConfig("home.example.com:8327", "secret")
if err != nil {
t.Fatalf("parseHomeFlagConfig() error = %v", err)
}
if !cfg.Enabled {
t.Fatal("Enabled = false, want true")
}
if cfg.Host != "home.example.com" {
t.Fatalf("Host = %q, want home.example.com", cfg.Host)
}
if cfg.Port != 8327 {
t.Fatalf("Port = %d, want 8327", cfg.Port)
}
if cfg.Password != "secret" {
t.Fatalf("Password = %q, want secret", cfg.Password)
}
if cfg.TLS.Enable {
t.Fatal("TLS.Enable = true, want false")
}
}
func TestParseHomeFlagConfigRediss(t *testing.T) {
cfg, err := parseHomeFlagConfig("rediss://:url-secret@home.example.com:444?server-name=home.example.com&skip_verify=true&ca-cert=C%3A%2Fcerts%2Fca.pem", "")
if err != nil {
t.Fatalf("parseHomeFlagConfig() error = %v", err)
}
if cfg.Host != "home.example.com" {
t.Fatalf("Host = %q, want home.example.com", cfg.Host)
}
if cfg.Port != 444 {
t.Fatalf("Port = %d, want 444", cfg.Port)
}
if cfg.Password != "url-secret" {
t.Fatalf("Password = %q, want url-secret", cfg.Password)
}
if !cfg.TLS.Enable {
t.Fatal("TLS.Enable = false, want true")
}
if cfg.TLS.ServerName != "home.example.com" {
t.Fatalf("TLS.ServerName = %q, want home.example.com", cfg.TLS.ServerName)
}
if !cfg.TLS.InsecureSkipVerify {
t.Fatal("TLS.InsecureSkipVerify = false, want true")
}
if cfg.TLS.CACert != "C:/certs/ca.pem" {
t.Fatalf("TLS.CACert = %q, want C:/certs/ca.pem", cfg.TLS.CACert)
}
}
func TestParseHomeFlagConfigPasswordFlagOverridesURLPassword(t *testing.T) {
cfg, err := parseHomeFlagConfig("rediss://:url-secret@home.example.com:444", "flag-secret")
if err != nil {
t.Fatalf("parseHomeFlagConfig() error = %v", err)
}
if cfg.Password != "flag-secret" {
t.Fatalf("Password = %q, want flag-secret", cfg.Password)
}
}
func TestParseHomeFlagConfigDisableClusterDiscovery(t *testing.T) {
cfg, err := parseHomeFlagConfig("redis://home.example.com:8327?disable-cluster-discovery=true", "")
if err != nil {
t.Fatalf("parseHomeFlagConfig() error = %v", err)
}
if !cfg.DisableClusterDiscovery {
t.Fatal("DisableClusterDiscovery = false, want true")
}
}
+8 -22
View File
@@ -10,11 +10,9 @@ import (
"fmt" "fmt"
"io" "io"
"io/fs" "io/fs"
"net"
"net/url" "net/url"
"os" "os"
"path/filepath" "path/filepath"
"strconv"
"strings" "strings"
"time" "time"
@@ -75,6 +73,7 @@ func main() {
var password string var password string
var homeAddr string var homeAddr string
var homePassword string var homePassword string
var homeDisableClusterDiscovery bool
var tuiMode bool var tuiMode bool
var standalone bool var standalone bool
var localModel bool var localModel bool
@@ -93,8 +92,9 @@ func main() {
flag.StringVar(&vertexImport, "vertex-import", "", "Import Vertex service account key JSON file") flag.StringVar(&vertexImport, "vertex-import", "", "Import Vertex service account key JSON file")
flag.StringVar(&vertexImportPrefix, "vertex-import-prefix", "", "Prefix for Vertex model namespacing (use with -vertex-import)") flag.StringVar(&vertexImportPrefix, "vertex-import-prefix", "", "Prefix for Vertex model namespacing (use with -vertex-import)")
flag.StringVar(&password, "password", "", "") flag.StringVar(&password, "password", "", "")
flag.StringVar(&homeAddr, "home", "", "Home control plane address in host:port format (loads config from home and skips local config file)") flag.StringVar(&homeAddr, "home", "", "Home control plane address in host:port, redis://host:port, or rediss://host:port format (loads config from home and skips local config file)")
flag.StringVar(&homePassword, "home-password", "", "Home control plane password (Redis AUTH)") flag.StringVar(&homePassword, "home-password", "", "Home control plane password (Redis AUTH)")
flag.BoolVar(&homeDisableClusterDiscovery, "home-disable-cluster-discovery", false, "Disable Home CLUSTER NODES discovery and keep using the configured -home address")
flag.BoolVar(&tuiMode, "tui", false, "Start with terminal management UI") flag.BoolVar(&tuiMode, "tui", false, "Start with terminal management UI")
flag.BoolVar(&standalone, "standalone", false, "In TUI mode, start an embedded local server") flag.BoolVar(&standalone, "standalone", false, "In TUI mode, start an embedded local server")
flag.BoolVar(&localModel, "local-model", false, "Use embedded model catalog only, skip remote model fetching") flag.BoolVar(&localModel, "local-model", false, "Use embedded model catalog only, skip remote model fetching")
@@ -247,27 +247,13 @@ func main() {
if strings.TrimSpace(homeAddr) != "" { if strings.TrimSpace(homeAddr) != "" {
configLoadedFromHome = true configLoadedFromHome = true
trimmedHomePassword := strings.TrimSpace(homePassword) trimmedHomePassword := strings.TrimSpace(homePassword)
host, portStr, errSplit := net.SplitHostPort(strings.TrimSpace(homeAddr)) homeCfg, errHomeCfg := parseHomeFlagConfig(homeAddr, trimmedHomePassword)
if errSplit != nil { if errHomeCfg != nil {
log.Errorf("invalid -home address %q (expected host:port): %v", homeAddr, errSplit) log.Errorf("invalid -home address %q: %v", homeAddr, errHomeCfg)
return return
} }
host = strings.TrimSpace(host) if homeDisableClusterDiscovery {
if host == "" { homeCfg.DisableClusterDiscovery = true
log.Errorf("invalid -home address %q: host is empty", homeAddr)
return
}
port, errPort := strconv.Atoi(strings.TrimSpace(portStr))
if errPort != nil || port <= 0 {
log.Errorf("invalid -home address %q: invalid port %q", homeAddr, portStr)
return
}
homeCfg := config.HomeConfig{
Enabled: true,
Host: host,
Port: port,
Password: trimmedHomePassword,
} }
homeClient := home.New(homeCfg) homeClient := home.New(homeCfg)
defer homeClient.Close() defer homeClient.Close()
+13
View File
@@ -17,6 +17,19 @@ home:
host: "127.0.0.1" host: "127.0.0.1"
port: 6379 port: 6379
password: "" password: ""
# Keep CPA pinned to the configured home address instead of switching to CLUSTER NODES entries.
# Useful when Home is behind NAT, Docker networking, or a reverse proxy.
disable-cluster-discovery: false
# Optional TLS for the outbound Redis connection to the home control plane.
# Enable this when connecting through rediss:// or an SSL stream proxy.
tls:
enable: false
# Optional SNI/certificate name override. Leave empty to use the configured home host.
server-name: ""
# Trust a private CA bundle in addition to system roots.
ca-cert: ""
# Only for testing self-signed endpoints; disables certificate verification.
insecure-skip-verify: false
# Management API settings # Management API settings
remote-management: remote-management:
+14 -4
View File
@@ -2,8 +2,18 @@ package config
// HomeConfig configures the optional "home" control plane integration over Redis protocol. // HomeConfig configures the optional "home" control plane integration over Redis protocol.
type HomeConfig struct { type HomeConfig struct {
Enabled bool `yaml:"enabled" json:"enabled"` Enabled bool `yaml:"enabled" json:"enabled"`
Host string `yaml:"host" json:"-"` Host string `yaml:"host" json:"-"`
Port int `yaml:"port" json:"-"` Port int `yaml:"port" json:"-"`
Password string `yaml:"password" json:"-"` Password string `yaml:"password" json:"-"`
DisableClusterDiscovery bool `yaml:"disable-cluster-discovery" json:"-"`
TLS HomeTLSConfig `yaml:"tls" json:"-"`
}
// HomeTLSConfig configures client-side TLS for the home Redis connection.
type HomeTLSConfig struct {
Enable bool `yaml:"enable" json:"-"`
ServerName string `yaml:"server-name" json:"-"`
InsecureSkipVerify bool `yaml:"insecure-skip-verify" json:"-"`
CACert string `yaml:"ca-cert" json:"-"`
} }
+50
View File
@@ -0,0 +1,50 @@
package config
import "testing"
func TestParseConfigBytesHomeTLS(t *testing.T) {
cfg, err := ParseConfigBytes([]byte(`
home:
enabled: true
host: home.example.com
port: 444
password: secret
disable-cluster-discovery: true
tls:
enable: true
server-name: home.example.com
ca-cert: C:/certs/ca.pem
insecure-skip-verify: true
`))
if err != nil {
t.Fatalf("ParseConfigBytes() error = %v", err)
}
if !cfg.Home.Enabled {
t.Fatal("Home.Enabled = false, want true")
}
if cfg.Home.Host != "home.example.com" {
t.Fatalf("Home.Host = %q, want home.example.com", cfg.Home.Host)
}
if cfg.Home.Port != 444 {
t.Fatalf("Home.Port = %d, want 444", cfg.Home.Port)
}
if cfg.Home.Password != "secret" {
t.Fatalf("Home.Password = %q, want secret", cfg.Home.Password)
}
if !cfg.Home.DisableClusterDiscovery {
t.Fatal("Home.DisableClusterDiscovery = false, want true")
}
if !cfg.Home.TLS.Enable {
t.Fatal("Home.TLS.Enable = false, want true")
}
if cfg.Home.TLS.ServerName != "home.example.com" {
t.Fatalf("Home.TLS.ServerName = %q, want home.example.com", cfg.Home.TLS.ServerName)
}
if cfg.Home.TLS.CACert != "C:/certs/ca.pem" {
t.Fatalf("Home.TLS.CACert = %q, want C:/certs/ca.pem", cfg.Home.TLS.CACert)
}
if !cfg.Home.TLS.InsecureSkipVerify {
t.Fatal("Home.TLS.InsecureSkipVerify = false, want true")
}
}
+97 -8
View File
@@ -2,11 +2,14 @@ package home
import ( import (
"context" "context"
"crypto/tls"
"crypto/x509"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"net" "net"
"net/http" "net/http"
"os"
"sort" "sort"
"strconv" "strconv"
"strings" "strings"
@@ -151,20 +154,83 @@ func (c *Client) ensureClients() error {
} }
if c.cmd == nil { if c.cmd == nil {
c.cmd = redis.NewClient(&redis.Options{ options, errOptions := c.redisOptionsLocked(addr)
Addr: addr, if errOptions != nil {
Password: c.homeCfg.Password, return errOptions
}) }
c.cmd = redis.NewClient(options)
} }
if c.sub == nil { if c.sub == nil {
c.sub = redis.NewClient(&redis.Options{ options, errOptions := c.redisOptionsLocked(addr)
Addr: addr, if errOptions != nil {
Password: c.homeCfg.Password, return errOptions
}) }
c.sub = redis.NewClient(options)
} }
return nil return nil
} }
func (c *Client) redisOptionsLocked(addr string) (*redis.Options, error) {
tlsConfig, errTLS := c.homeTLSConfigLocked()
if errTLS != nil {
return nil, errTLS
}
return &redis.Options{
Addr: addr,
Password: c.homeCfg.Password,
TLSConfig: tlsConfig,
}, nil
}
func (c *Client) homeTLSConfigLocked() (*tls.Config, error) {
serverName := strings.TrimSpace(c.homeCfg.TLS.ServerName)
if serverName == "" {
serverName = strings.TrimSpace(c.seedHost)
}
if serverName == "" {
serverName = strings.TrimSpace(c.homeCfg.Host)
}
return newHomeTLSConfig(c.homeCfg.TLS, serverName)
}
func newHomeTLSConfig(cfg config.HomeTLSConfig, fallbackServerName string) (*tls.Config, error) {
if !cfg.Enable {
return nil, nil
}
serverName := strings.TrimSpace(cfg.ServerName)
if serverName == "" {
serverName = strings.TrimSpace(fallbackServerName)
}
tlsConfig := &tls.Config{
MinVersion: tls.VersionTLS12,
ServerName: serverName,
InsecureSkipVerify: cfg.InsecureSkipVerify,
}
caCertPath := strings.TrimSpace(cfg.CACert)
if caCertPath == "" {
return tlsConfig, nil
}
caCertPEM, errRead := os.ReadFile(caCertPath)
if errRead != nil {
return nil, fmt.Errorf("home tls: read ca-cert: %w", errRead)
}
certPool, errPool := x509.SystemCertPool()
if errPool != nil || certPool == nil {
certPool = x509.NewCertPool()
}
if !certPool.AppendCertsFromPEM(caCertPEM) {
return nil, fmt.Errorf("home tls: ca-cert contains no PEM certificates")
}
tlsConfig.RootCAs = certPool
return tlsConfig, nil
}
func (c *Client) commandClient() (*redis.Client, error) { func (c *Client) commandClient() (*redis.Client, error) {
if errEnsure := c.ensureClients(); errEnsure != nil { if errEnsure := c.ensureClients(); errEnsure != nil {
return nil, errEnsure return nil, errEnsure
@@ -199,7 +265,23 @@ func (c *Client) Ping(ctx context.Context) error {
return cmd.Ping(ctx).Err() return cmd.Ping(ctx).Err()
} }
func (c *Client) clusterDiscoveryEnabled() bool {
if c == nil {
return false
}
c.mu.Lock()
defer c.mu.Unlock()
return c.clusterDiscoveryEnabledLocked()
}
func (c *Client) clusterDiscoveryEnabledLocked() bool {
return !c.homeCfg.DisableClusterDiscovery
}
func (c *Client) refreshBestClusterNode(ctx context.Context) { func (c *Client) refreshBestClusterNode(ctx context.Context) {
if !c.clusterDiscoveryEnabled() {
return
}
switched, errRefresh := c.refreshClusterNodes(ctx) switched, errRefresh := c.refreshClusterNodes(ctx)
if errRefresh != nil { if errRefresh != nil {
log.Debugf("home cluster nodes unavailable: %v", errRefresh) log.Debugf("home cluster nodes unavailable: %v", errRefresh)
@@ -213,6 +295,9 @@ func (c *Client) refreshBestClusterNode(ctx context.Context) {
} }
func (c *Client) refreshClusterNodes(ctx context.Context) (bool, error) { func (c *Client) refreshClusterNodes(ctx context.Context) (bool, error) {
if !c.clusterDiscoveryEnabled() {
return false, nil
}
if ctx == nil { if ctx == nil {
ctx = context.Background() ctx = context.Background()
} }
@@ -287,6 +372,10 @@ func (c *Client) failoverAfterReconnectFailure() (bool, string) {
c.mu.Lock() c.mu.Lock()
defer c.mu.Unlock() defer c.mu.Unlock()
if !c.clusterDiscoveryEnabledLocked() {
c.reconnectFailures = 0
return false, ""
}
c.reconnectFailures++ c.reconnectFailures++
if c.reconnectFailures < homeReconnectFailoverThreshold { if c.reconnectFailures < homeReconnectFailoverThreshold {
return false, "" return false, ""
+127
View File
@@ -1,9 +1,13 @@
package home package home
import ( import (
"context"
"crypto/tls"
"encoding/json" "encoding/json"
"net/http" "net/http"
"testing" "testing"
"github.com/router-for-me/CLIProxyAPI/v7/internal/config"
) )
func TestAuthDispatchRequestIncludesCount(t *testing.T) { func TestAuthDispatchRequestIncludesCount(t *testing.T) {
@@ -30,3 +34,126 @@ func TestAuthDispatchRequestDefaultsCountToOne(t *testing.T) {
t.Fatalf("count = %d, want 1", req.Count) t.Fatalf("count = %d, want 1", req.Count)
} }
} }
func TestRedisOptionsHomeTLSDisabled(t *testing.T) {
client := New(config.HomeConfig{
Enabled: true,
Host: "127.0.0.1",
Port: 6379,
Password: "secret",
})
client.mu.Lock()
options, err := client.redisOptionsLocked("127.0.0.1:6379")
client.mu.Unlock()
if err != nil {
t.Fatalf("redisOptionsLocked() error = %v", err)
}
if options.TLSConfig != nil {
t.Fatalf("TLSConfig = %#v, want nil", options.TLSConfig)
}
if options.Password != "secret" {
t.Fatalf("Password = %q, want secret", options.Password)
}
}
func TestRedisOptionsHomeTLSEnabledUsesSeedHostAsServerName(t *testing.T) {
client := New(config.HomeConfig{
Enabled: true,
Host: "home.example.com",
Port: 444,
TLS: config.HomeTLSConfig{
Enable: true,
},
})
client.homeCfg.Host = "127.0.0.1"
client.mu.Lock()
options, err := client.redisOptionsLocked("127.0.0.1:444")
client.mu.Unlock()
if err != nil {
t.Fatalf("redisOptionsLocked() error = %v", err)
}
if options.TLSConfig == nil {
t.Fatal("TLSConfig is nil")
}
if options.TLSConfig.ServerName != "home.example.com" {
t.Fatalf("ServerName = %q, want home.example.com", options.TLSConfig.ServerName)
}
if options.TLSConfig.MinVersion != tls.VersionTLS12 {
t.Fatalf("MinVersion = %d, want TLS 1.2", options.TLSConfig.MinVersion)
}
}
func TestRedisOptionsHomeTLSEnabledUsesExplicitServerName(t *testing.T) {
client := New(config.HomeConfig{
Enabled: true,
Host: "127.0.0.1",
Port: 444,
TLS: config.HomeTLSConfig{
Enable: true,
ServerName: "home.example.com",
InsecureSkipVerify: true,
},
})
client.mu.Lock()
options, err := client.redisOptionsLocked("127.0.0.1:444")
client.mu.Unlock()
if err != nil {
t.Fatalf("redisOptionsLocked() error = %v", err)
}
if options.TLSConfig == nil {
t.Fatal("TLSConfig is nil")
}
if options.TLSConfig.ServerName != "home.example.com" {
t.Fatalf("ServerName = %q, want home.example.com", options.TLSConfig.ServerName)
}
if !options.TLSConfig.InsecureSkipVerify {
t.Fatal("InsecureSkipVerify = false, want true")
}
}
func TestRefreshClusterNodesDisabledSkipsRedisCommand(t *testing.T) {
client := New(config.HomeConfig{
Enabled: true,
Host: "127.0.0.1",
Port: 1,
DisableClusterDiscovery: true,
})
switched, err := client.refreshClusterNodes(context.Background())
if err != nil {
t.Fatalf("refreshClusterNodes() error = %v", err)
}
if switched {
t.Fatal("refreshClusterNodes() switched = true, want false")
}
if client.cmd != nil || client.sub != nil {
t.Fatalf("redis clients were initialized when cluster discovery was disabled")
}
}
func TestFailoverAfterReconnectFailureDisabledDoesNotSwitchToClusterNode(t *testing.T) {
client := New(config.HomeConfig{
Enabled: true,
Host: "seed.example.com",
Port: 8327,
DisableClusterDiscovery: true,
})
client.mu.Lock()
client.clusterNodes = []clusterNode{{IP: "other.example.com", Port: 8327}}
client.reconnectFailures = homeReconnectFailoverThreshold - 1
client.mu.Unlock()
switched, addr := client.failoverAfterReconnectFailure()
if switched {
t.Fatalf("failoverAfterReconnectFailure() switched to %s, want no switch", addr)
}
if got, _ := client.addr(); got != "seed.example.com:8327" {
t.Fatalf("addr() = %q, want seed.example.com:8327", got)
}
}