Merge pull request #3430 from router-for-me/home
Implement Redis integration with TLS support and cluster discovery option
This commit is contained in:
@@ -0,0 +1,125 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/url"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v7/internal/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
func parseHomeFlagConfig(rawAddr string, password string) (config.HomeConfig, error) {
|
||||||
|
rawAddr = strings.TrimSpace(rawAddr)
|
||||||
|
if rawAddr == "" {
|
||||||
|
return config.HomeConfig{}, fmt.Errorf("address is empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.Contains(rawAddr, "://") {
|
||||||
|
return parseHomeURLConfig(rawAddr, password)
|
||||||
|
}
|
||||||
|
|
||||||
|
host, portStr, errSplit := net.SplitHostPort(rawAddr)
|
||||||
|
if errSplit != nil {
|
||||||
|
return config.HomeConfig{}, fmt.Errorf("expected host:port, redis://host:port, or rediss://host:port: %w", errSplit)
|
||||||
|
}
|
||||||
|
|
||||||
|
host = strings.TrimSpace(host)
|
||||||
|
if host == "" {
|
||||||
|
return config.HomeConfig{}, fmt.Errorf("host is empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
port, errPort := parseHomePort(portStr)
|
||||||
|
if errPort != nil {
|
||||||
|
return config.HomeConfig{}, errPort
|
||||||
|
}
|
||||||
|
|
||||||
|
return config.HomeConfig{
|
||||||
|
Enabled: true,
|
||||||
|
Host: host,
|
||||||
|
Port: port,
|
||||||
|
Password: password,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseHomeURLConfig(rawAddr string, password string) (config.HomeConfig, error) {
|
||||||
|
parsed, errParse := url.Parse(rawAddr)
|
||||||
|
if errParse != nil {
|
||||||
|
return config.HomeConfig{}, fmt.Errorf("parse URL: %w", errParse)
|
||||||
|
}
|
||||||
|
|
||||||
|
scheme := strings.ToLower(strings.TrimSpace(parsed.Scheme))
|
||||||
|
if scheme != "redis" && scheme != "rediss" {
|
||||||
|
return config.HomeConfig{}, fmt.Errorf("unsupported URL scheme %q", parsed.Scheme)
|
||||||
|
}
|
||||||
|
|
||||||
|
host := strings.TrimSpace(parsed.Hostname())
|
||||||
|
if host == "" {
|
||||||
|
return config.HomeConfig{}, fmt.Errorf("host is empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
port, errPort := parseHomePort(parsed.Port())
|
||||||
|
if errPort != nil {
|
||||||
|
return config.HomeConfig{}, errPort
|
||||||
|
}
|
||||||
|
|
||||||
|
if password == "" && parsed.User != nil {
|
||||||
|
if urlPassword, ok := parsed.User.Password(); ok {
|
||||||
|
password = urlPassword
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
homeCfg := config.HomeConfig{
|
||||||
|
Enabled: true,
|
||||||
|
Host: host,
|
||||||
|
Port: port,
|
||||||
|
Password: password,
|
||||||
|
}
|
||||||
|
query := parsed.Query()
|
||||||
|
homeCfg.DisableClusterDiscovery = parseHomeBoolQuery(query, "disable-cluster-discovery", "disable_cluster_discovery")
|
||||||
|
|
||||||
|
if scheme == "rediss" {
|
||||||
|
homeCfg.TLS.Enable = true
|
||||||
|
homeCfg.TLS.ServerName = strings.TrimSpace(firstHomeQueryValue(query, "server-name", "server_name"))
|
||||||
|
homeCfg.TLS.InsecureSkipVerify = parseHomeBoolQuery(query, "insecure-skip-verify", "insecure_skip_verify", "skip_verify")
|
||||||
|
homeCfg.TLS.CACert = strings.TrimSpace(firstHomeQueryValue(query, "ca-cert", "ca_cert"))
|
||||||
|
}
|
||||||
|
|
||||||
|
return homeCfg, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseHomePort(rawPort string) (int, error) {
|
||||||
|
rawPort = strings.TrimSpace(rawPort)
|
||||||
|
if rawPort == "" {
|
||||||
|
return 0, fmt.Errorf("port is empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
port, errPort := strconv.Atoi(rawPort)
|
||||||
|
if errPort != nil || port <= 0 || port > 65535 {
|
||||||
|
return 0, fmt.Errorf("invalid port %q", rawPort)
|
||||||
|
}
|
||||||
|
|
||||||
|
return port, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func firstHomeQueryValue(values url.Values, keys ...string) string {
|
||||||
|
for _, key := range keys {
|
||||||
|
if value := values.Get(key); value != "" {
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseHomeBoolQuery(values url.Values, keys ...string) bool {
|
||||||
|
for _, key := range keys {
|
||||||
|
value := strings.TrimSpace(values.Get(key))
|
||||||
|
if value == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
parsed, errParse := strconv.ParseBool(value)
|
||||||
|
return errParse == nil && parsed
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
@@ -0,0 +1,77 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
func TestParseHomeFlagConfigHostPort(t *testing.T) {
|
||||||
|
cfg, err := parseHomeFlagConfig("home.example.com:8327", "secret")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parseHomeFlagConfig() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !cfg.Enabled {
|
||||||
|
t.Fatal("Enabled = false, want true")
|
||||||
|
}
|
||||||
|
if cfg.Host != "home.example.com" {
|
||||||
|
t.Fatalf("Host = %q, want home.example.com", cfg.Host)
|
||||||
|
}
|
||||||
|
if cfg.Port != 8327 {
|
||||||
|
t.Fatalf("Port = %d, want 8327", cfg.Port)
|
||||||
|
}
|
||||||
|
if cfg.Password != "secret" {
|
||||||
|
t.Fatalf("Password = %q, want secret", cfg.Password)
|
||||||
|
}
|
||||||
|
if cfg.TLS.Enable {
|
||||||
|
t.Fatal("TLS.Enable = true, want false")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseHomeFlagConfigRediss(t *testing.T) {
|
||||||
|
cfg, err := parseHomeFlagConfig("rediss://:url-secret@home.example.com:444?server-name=home.example.com&skip_verify=true&ca-cert=C%3A%2Fcerts%2Fca.pem", "")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parseHomeFlagConfig() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if cfg.Host != "home.example.com" {
|
||||||
|
t.Fatalf("Host = %q, want home.example.com", cfg.Host)
|
||||||
|
}
|
||||||
|
if cfg.Port != 444 {
|
||||||
|
t.Fatalf("Port = %d, want 444", cfg.Port)
|
||||||
|
}
|
||||||
|
if cfg.Password != "url-secret" {
|
||||||
|
t.Fatalf("Password = %q, want url-secret", cfg.Password)
|
||||||
|
}
|
||||||
|
if !cfg.TLS.Enable {
|
||||||
|
t.Fatal("TLS.Enable = false, want true")
|
||||||
|
}
|
||||||
|
if cfg.TLS.ServerName != "home.example.com" {
|
||||||
|
t.Fatalf("TLS.ServerName = %q, want home.example.com", cfg.TLS.ServerName)
|
||||||
|
}
|
||||||
|
if !cfg.TLS.InsecureSkipVerify {
|
||||||
|
t.Fatal("TLS.InsecureSkipVerify = false, want true")
|
||||||
|
}
|
||||||
|
if cfg.TLS.CACert != "C:/certs/ca.pem" {
|
||||||
|
t.Fatalf("TLS.CACert = %q, want C:/certs/ca.pem", cfg.TLS.CACert)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseHomeFlagConfigPasswordFlagOverridesURLPassword(t *testing.T) {
|
||||||
|
cfg, err := parseHomeFlagConfig("rediss://:url-secret@home.example.com:444", "flag-secret")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parseHomeFlagConfig() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if cfg.Password != "flag-secret" {
|
||||||
|
t.Fatalf("Password = %q, want flag-secret", cfg.Password)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseHomeFlagConfigDisableClusterDiscovery(t *testing.T) {
|
||||||
|
cfg, err := parseHomeFlagConfig("redis://home.example.com:8327?disable-cluster-discovery=true", "")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parseHomeFlagConfig() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !cfg.DisableClusterDiscovery {
|
||||||
|
t.Fatal("DisableClusterDiscovery = false, want true")
|
||||||
|
}
|
||||||
|
}
|
||||||
+8
-22
@@ -10,11 +10,9 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"io/fs"
|
"io/fs"
|
||||||
"net"
|
|
||||||
"net/url"
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strconv"
|
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -75,6 +73,7 @@ func main() {
|
|||||||
var password string
|
var password string
|
||||||
var homeAddr string
|
var homeAddr string
|
||||||
var homePassword string
|
var homePassword string
|
||||||
|
var homeDisableClusterDiscovery bool
|
||||||
var tuiMode bool
|
var tuiMode bool
|
||||||
var standalone bool
|
var standalone bool
|
||||||
var localModel bool
|
var localModel bool
|
||||||
@@ -93,8 +92,9 @@ func main() {
|
|||||||
flag.StringVar(&vertexImport, "vertex-import", "", "Import Vertex service account key JSON file")
|
flag.StringVar(&vertexImport, "vertex-import", "", "Import Vertex service account key JSON file")
|
||||||
flag.StringVar(&vertexImportPrefix, "vertex-import-prefix", "", "Prefix for Vertex model namespacing (use with -vertex-import)")
|
flag.StringVar(&vertexImportPrefix, "vertex-import-prefix", "", "Prefix for Vertex model namespacing (use with -vertex-import)")
|
||||||
flag.StringVar(&password, "password", "", "")
|
flag.StringVar(&password, "password", "", "")
|
||||||
flag.StringVar(&homeAddr, "home", "", "Home control plane address in host:port format (loads config from home and skips local config file)")
|
flag.StringVar(&homeAddr, "home", "", "Home control plane address in host:port, redis://host:port, or rediss://host:port format (loads config from home and skips local config file)")
|
||||||
flag.StringVar(&homePassword, "home-password", "", "Home control plane password (Redis AUTH)")
|
flag.StringVar(&homePassword, "home-password", "", "Home control plane password (Redis AUTH)")
|
||||||
|
flag.BoolVar(&homeDisableClusterDiscovery, "home-disable-cluster-discovery", false, "Disable Home CLUSTER NODES discovery and keep using the configured -home address")
|
||||||
flag.BoolVar(&tuiMode, "tui", false, "Start with terminal management UI")
|
flag.BoolVar(&tuiMode, "tui", false, "Start with terminal management UI")
|
||||||
flag.BoolVar(&standalone, "standalone", false, "In TUI mode, start an embedded local server")
|
flag.BoolVar(&standalone, "standalone", false, "In TUI mode, start an embedded local server")
|
||||||
flag.BoolVar(&localModel, "local-model", false, "Use embedded model catalog only, skip remote model fetching")
|
flag.BoolVar(&localModel, "local-model", false, "Use embedded model catalog only, skip remote model fetching")
|
||||||
@@ -247,27 +247,13 @@ func main() {
|
|||||||
if strings.TrimSpace(homeAddr) != "" {
|
if strings.TrimSpace(homeAddr) != "" {
|
||||||
configLoadedFromHome = true
|
configLoadedFromHome = true
|
||||||
trimmedHomePassword := strings.TrimSpace(homePassword)
|
trimmedHomePassword := strings.TrimSpace(homePassword)
|
||||||
host, portStr, errSplit := net.SplitHostPort(strings.TrimSpace(homeAddr))
|
homeCfg, errHomeCfg := parseHomeFlagConfig(homeAddr, trimmedHomePassword)
|
||||||
if errSplit != nil {
|
if errHomeCfg != nil {
|
||||||
log.Errorf("invalid -home address %q (expected host:port): %v", homeAddr, errSplit)
|
log.Errorf("invalid -home address %q: %v", homeAddr, errHomeCfg)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
host = strings.TrimSpace(host)
|
if homeDisableClusterDiscovery {
|
||||||
if host == "" {
|
homeCfg.DisableClusterDiscovery = true
|
||||||
log.Errorf("invalid -home address %q: host is empty", homeAddr)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
port, errPort := strconv.Atoi(strings.TrimSpace(portStr))
|
|
||||||
if errPort != nil || port <= 0 {
|
|
||||||
log.Errorf("invalid -home address %q: invalid port %q", homeAddr, portStr)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
homeCfg := config.HomeConfig{
|
|
||||||
Enabled: true,
|
|
||||||
Host: host,
|
|
||||||
Port: port,
|
|
||||||
Password: trimmedHomePassword,
|
|
||||||
}
|
}
|
||||||
homeClient := home.New(homeCfg)
|
homeClient := home.New(homeCfg)
|
||||||
defer homeClient.Close()
|
defer homeClient.Close()
|
||||||
|
|||||||
@@ -17,6 +17,19 @@ home:
|
|||||||
host: "127.0.0.1"
|
host: "127.0.0.1"
|
||||||
port: 6379
|
port: 6379
|
||||||
password: ""
|
password: ""
|
||||||
|
# Keep CPA pinned to the configured home address instead of switching to CLUSTER NODES entries.
|
||||||
|
# Useful when Home is behind NAT, Docker networking, or a reverse proxy.
|
||||||
|
disable-cluster-discovery: false
|
||||||
|
# Optional TLS for the outbound Redis connection to the home control plane.
|
||||||
|
# Enable this when connecting through rediss:// or an SSL stream proxy.
|
||||||
|
tls:
|
||||||
|
enable: false
|
||||||
|
# Optional SNI/certificate name override. Leave empty to use the configured home host.
|
||||||
|
server-name: ""
|
||||||
|
# Trust a private CA bundle in addition to system roots.
|
||||||
|
ca-cert: ""
|
||||||
|
# Only for testing self-signed endpoints; disables certificate verification.
|
||||||
|
insecure-skip-verify: false
|
||||||
|
|
||||||
# Management API settings
|
# Management API settings
|
||||||
remote-management:
|
remote-management:
|
||||||
|
|||||||
+14
-4
@@ -2,8 +2,18 @@ package config
|
|||||||
|
|
||||||
// HomeConfig configures the optional "home" control plane integration over Redis protocol.
|
// HomeConfig configures the optional "home" control plane integration over Redis protocol.
|
||||||
type HomeConfig struct {
|
type HomeConfig struct {
|
||||||
Enabled bool `yaml:"enabled" json:"enabled"`
|
Enabled bool `yaml:"enabled" json:"enabled"`
|
||||||
Host string `yaml:"host" json:"-"`
|
Host string `yaml:"host" json:"-"`
|
||||||
Port int `yaml:"port" json:"-"`
|
Port int `yaml:"port" json:"-"`
|
||||||
Password string `yaml:"password" json:"-"`
|
Password string `yaml:"password" json:"-"`
|
||||||
|
DisableClusterDiscovery bool `yaml:"disable-cluster-discovery" json:"-"`
|
||||||
|
TLS HomeTLSConfig `yaml:"tls" json:"-"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// HomeTLSConfig configures client-side TLS for the home Redis connection.
|
||||||
|
type HomeTLSConfig struct {
|
||||||
|
Enable bool `yaml:"enable" json:"-"`
|
||||||
|
ServerName string `yaml:"server-name" json:"-"`
|
||||||
|
InsecureSkipVerify bool `yaml:"insecure-skip-verify" json:"-"`
|
||||||
|
CACert string `yaml:"ca-cert" json:"-"`
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,50 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
func TestParseConfigBytesHomeTLS(t *testing.T) {
|
||||||
|
cfg, err := ParseConfigBytes([]byte(`
|
||||||
|
home:
|
||||||
|
enabled: true
|
||||||
|
host: home.example.com
|
||||||
|
port: 444
|
||||||
|
password: secret
|
||||||
|
disable-cluster-discovery: true
|
||||||
|
tls:
|
||||||
|
enable: true
|
||||||
|
server-name: home.example.com
|
||||||
|
ca-cert: C:/certs/ca.pem
|
||||||
|
insecure-skip-verify: true
|
||||||
|
`))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ParseConfigBytes() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !cfg.Home.Enabled {
|
||||||
|
t.Fatal("Home.Enabled = false, want true")
|
||||||
|
}
|
||||||
|
if cfg.Home.Host != "home.example.com" {
|
||||||
|
t.Fatalf("Home.Host = %q, want home.example.com", cfg.Home.Host)
|
||||||
|
}
|
||||||
|
if cfg.Home.Port != 444 {
|
||||||
|
t.Fatalf("Home.Port = %d, want 444", cfg.Home.Port)
|
||||||
|
}
|
||||||
|
if cfg.Home.Password != "secret" {
|
||||||
|
t.Fatalf("Home.Password = %q, want secret", cfg.Home.Password)
|
||||||
|
}
|
||||||
|
if !cfg.Home.DisableClusterDiscovery {
|
||||||
|
t.Fatal("Home.DisableClusterDiscovery = false, want true")
|
||||||
|
}
|
||||||
|
if !cfg.Home.TLS.Enable {
|
||||||
|
t.Fatal("Home.TLS.Enable = false, want true")
|
||||||
|
}
|
||||||
|
if cfg.Home.TLS.ServerName != "home.example.com" {
|
||||||
|
t.Fatalf("Home.TLS.ServerName = %q, want home.example.com", cfg.Home.TLS.ServerName)
|
||||||
|
}
|
||||||
|
if cfg.Home.TLS.CACert != "C:/certs/ca.pem" {
|
||||||
|
t.Fatalf("Home.TLS.CACert = %q, want C:/certs/ca.pem", cfg.Home.TLS.CACert)
|
||||||
|
}
|
||||||
|
if !cfg.Home.TLS.InsecureSkipVerify {
|
||||||
|
t.Fatal("Home.TLS.InsecureSkipVerify = false, want true")
|
||||||
|
}
|
||||||
|
}
|
||||||
+97
-8
@@ -2,11 +2,14 @@ package home
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
|
"crypto/x509"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"os"
|
||||||
"sort"
|
"sort"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -151,20 +154,83 @@ func (c *Client) ensureClients() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if c.cmd == nil {
|
if c.cmd == nil {
|
||||||
c.cmd = redis.NewClient(&redis.Options{
|
options, errOptions := c.redisOptionsLocked(addr)
|
||||||
Addr: addr,
|
if errOptions != nil {
|
||||||
Password: c.homeCfg.Password,
|
return errOptions
|
||||||
})
|
}
|
||||||
|
c.cmd = redis.NewClient(options)
|
||||||
}
|
}
|
||||||
if c.sub == nil {
|
if c.sub == nil {
|
||||||
c.sub = redis.NewClient(&redis.Options{
|
options, errOptions := c.redisOptionsLocked(addr)
|
||||||
Addr: addr,
|
if errOptions != nil {
|
||||||
Password: c.homeCfg.Password,
|
return errOptions
|
||||||
})
|
}
|
||||||
|
c.sub = redis.NewClient(options)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *Client) redisOptionsLocked(addr string) (*redis.Options, error) {
|
||||||
|
tlsConfig, errTLS := c.homeTLSConfigLocked()
|
||||||
|
if errTLS != nil {
|
||||||
|
return nil, errTLS
|
||||||
|
}
|
||||||
|
return &redis.Options{
|
||||||
|
Addr: addr,
|
||||||
|
Password: c.homeCfg.Password,
|
||||||
|
TLSConfig: tlsConfig,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) homeTLSConfigLocked() (*tls.Config, error) {
|
||||||
|
serverName := strings.TrimSpace(c.homeCfg.TLS.ServerName)
|
||||||
|
if serverName == "" {
|
||||||
|
serverName = strings.TrimSpace(c.seedHost)
|
||||||
|
}
|
||||||
|
if serverName == "" {
|
||||||
|
serverName = strings.TrimSpace(c.homeCfg.Host)
|
||||||
|
}
|
||||||
|
return newHomeTLSConfig(c.homeCfg.TLS, serverName)
|
||||||
|
}
|
||||||
|
|
||||||
|
func newHomeTLSConfig(cfg config.HomeTLSConfig, fallbackServerName string) (*tls.Config, error) {
|
||||||
|
if !cfg.Enable {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
serverName := strings.TrimSpace(cfg.ServerName)
|
||||||
|
if serverName == "" {
|
||||||
|
serverName = strings.TrimSpace(fallbackServerName)
|
||||||
|
}
|
||||||
|
|
||||||
|
tlsConfig := &tls.Config{
|
||||||
|
MinVersion: tls.VersionTLS12,
|
||||||
|
ServerName: serverName,
|
||||||
|
InsecureSkipVerify: cfg.InsecureSkipVerify,
|
||||||
|
}
|
||||||
|
|
||||||
|
caCertPath := strings.TrimSpace(cfg.CACert)
|
||||||
|
if caCertPath == "" {
|
||||||
|
return tlsConfig, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
caCertPEM, errRead := os.ReadFile(caCertPath)
|
||||||
|
if errRead != nil {
|
||||||
|
return nil, fmt.Errorf("home tls: read ca-cert: %w", errRead)
|
||||||
|
}
|
||||||
|
|
||||||
|
certPool, errPool := x509.SystemCertPool()
|
||||||
|
if errPool != nil || certPool == nil {
|
||||||
|
certPool = x509.NewCertPool()
|
||||||
|
}
|
||||||
|
if !certPool.AppendCertsFromPEM(caCertPEM) {
|
||||||
|
return nil, fmt.Errorf("home tls: ca-cert contains no PEM certificates")
|
||||||
|
}
|
||||||
|
tlsConfig.RootCAs = certPool
|
||||||
|
|
||||||
|
return tlsConfig, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (c *Client) commandClient() (*redis.Client, error) {
|
func (c *Client) commandClient() (*redis.Client, error) {
|
||||||
if errEnsure := c.ensureClients(); errEnsure != nil {
|
if errEnsure := c.ensureClients(); errEnsure != nil {
|
||||||
return nil, errEnsure
|
return nil, errEnsure
|
||||||
@@ -199,7 +265,23 @@ func (c *Client) Ping(ctx context.Context) error {
|
|||||||
return cmd.Ping(ctx).Err()
|
return cmd.Ping(ctx).Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *Client) clusterDiscoveryEnabled() bool {
|
||||||
|
if c == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
return c.clusterDiscoveryEnabledLocked()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) clusterDiscoveryEnabledLocked() bool {
|
||||||
|
return !c.homeCfg.DisableClusterDiscovery
|
||||||
|
}
|
||||||
|
|
||||||
func (c *Client) refreshBestClusterNode(ctx context.Context) {
|
func (c *Client) refreshBestClusterNode(ctx context.Context) {
|
||||||
|
if !c.clusterDiscoveryEnabled() {
|
||||||
|
return
|
||||||
|
}
|
||||||
switched, errRefresh := c.refreshClusterNodes(ctx)
|
switched, errRefresh := c.refreshClusterNodes(ctx)
|
||||||
if errRefresh != nil {
|
if errRefresh != nil {
|
||||||
log.Debugf("home cluster nodes unavailable: %v", errRefresh)
|
log.Debugf("home cluster nodes unavailable: %v", errRefresh)
|
||||||
@@ -213,6 +295,9 @@ func (c *Client) refreshBestClusterNode(ctx context.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) refreshClusterNodes(ctx context.Context) (bool, error) {
|
func (c *Client) refreshClusterNodes(ctx context.Context) (bool, error) {
|
||||||
|
if !c.clusterDiscoveryEnabled() {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
if ctx == nil {
|
if ctx == nil {
|
||||||
ctx = context.Background()
|
ctx = context.Background()
|
||||||
}
|
}
|
||||||
@@ -287,6 +372,10 @@ func (c *Client) failoverAfterReconnectFailure() (bool, string) {
|
|||||||
c.mu.Lock()
|
c.mu.Lock()
|
||||||
defer c.mu.Unlock()
|
defer c.mu.Unlock()
|
||||||
|
|
||||||
|
if !c.clusterDiscoveryEnabledLocked() {
|
||||||
|
c.reconnectFailures = 0
|
||||||
|
return false, ""
|
||||||
|
}
|
||||||
c.reconnectFailures++
|
c.reconnectFailures++
|
||||||
if c.reconnectFailures < homeReconnectFailoverThreshold {
|
if c.reconnectFailures < homeReconnectFailoverThreshold {
|
||||||
return false, ""
|
return false, ""
|
||||||
|
|||||||
@@ -1,9 +1,13 @@
|
|||||||
package home
|
package home
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"net/http"
|
"net/http"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v7/internal/config"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestAuthDispatchRequestIncludesCount(t *testing.T) {
|
func TestAuthDispatchRequestIncludesCount(t *testing.T) {
|
||||||
@@ -30,3 +34,126 @@ func TestAuthDispatchRequestDefaultsCountToOne(t *testing.T) {
|
|||||||
t.Fatalf("count = %d, want 1", req.Count)
|
t.Fatalf("count = %d, want 1", req.Count)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestRedisOptionsHomeTLSDisabled(t *testing.T) {
|
||||||
|
client := New(config.HomeConfig{
|
||||||
|
Enabled: true,
|
||||||
|
Host: "127.0.0.1",
|
||||||
|
Port: 6379,
|
||||||
|
Password: "secret",
|
||||||
|
})
|
||||||
|
|
||||||
|
client.mu.Lock()
|
||||||
|
options, err := client.redisOptionsLocked("127.0.0.1:6379")
|
||||||
|
client.mu.Unlock()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("redisOptionsLocked() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if options.TLSConfig != nil {
|
||||||
|
t.Fatalf("TLSConfig = %#v, want nil", options.TLSConfig)
|
||||||
|
}
|
||||||
|
if options.Password != "secret" {
|
||||||
|
t.Fatalf("Password = %q, want secret", options.Password)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRedisOptionsHomeTLSEnabledUsesSeedHostAsServerName(t *testing.T) {
|
||||||
|
client := New(config.HomeConfig{
|
||||||
|
Enabled: true,
|
||||||
|
Host: "home.example.com",
|
||||||
|
Port: 444,
|
||||||
|
TLS: config.HomeTLSConfig{
|
||||||
|
Enable: true,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
client.homeCfg.Host = "127.0.0.1"
|
||||||
|
|
||||||
|
client.mu.Lock()
|
||||||
|
options, err := client.redisOptionsLocked("127.0.0.1:444")
|
||||||
|
client.mu.Unlock()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("redisOptionsLocked() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if options.TLSConfig == nil {
|
||||||
|
t.Fatal("TLSConfig is nil")
|
||||||
|
}
|
||||||
|
if options.TLSConfig.ServerName != "home.example.com" {
|
||||||
|
t.Fatalf("ServerName = %q, want home.example.com", options.TLSConfig.ServerName)
|
||||||
|
}
|
||||||
|
if options.TLSConfig.MinVersion != tls.VersionTLS12 {
|
||||||
|
t.Fatalf("MinVersion = %d, want TLS 1.2", options.TLSConfig.MinVersion)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRedisOptionsHomeTLSEnabledUsesExplicitServerName(t *testing.T) {
|
||||||
|
client := New(config.HomeConfig{
|
||||||
|
Enabled: true,
|
||||||
|
Host: "127.0.0.1",
|
||||||
|
Port: 444,
|
||||||
|
TLS: config.HomeTLSConfig{
|
||||||
|
Enable: true,
|
||||||
|
ServerName: "home.example.com",
|
||||||
|
InsecureSkipVerify: true,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
client.mu.Lock()
|
||||||
|
options, err := client.redisOptionsLocked("127.0.0.1:444")
|
||||||
|
client.mu.Unlock()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("redisOptionsLocked() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if options.TLSConfig == nil {
|
||||||
|
t.Fatal("TLSConfig is nil")
|
||||||
|
}
|
||||||
|
if options.TLSConfig.ServerName != "home.example.com" {
|
||||||
|
t.Fatalf("ServerName = %q, want home.example.com", options.TLSConfig.ServerName)
|
||||||
|
}
|
||||||
|
if !options.TLSConfig.InsecureSkipVerify {
|
||||||
|
t.Fatal("InsecureSkipVerify = false, want true")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRefreshClusterNodesDisabledSkipsRedisCommand(t *testing.T) {
|
||||||
|
client := New(config.HomeConfig{
|
||||||
|
Enabled: true,
|
||||||
|
Host: "127.0.0.1",
|
||||||
|
Port: 1,
|
||||||
|
DisableClusterDiscovery: true,
|
||||||
|
})
|
||||||
|
|
||||||
|
switched, err := client.refreshClusterNodes(context.Background())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("refreshClusterNodes() error = %v", err)
|
||||||
|
}
|
||||||
|
if switched {
|
||||||
|
t.Fatal("refreshClusterNodes() switched = true, want false")
|
||||||
|
}
|
||||||
|
if client.cmd != nil || client.sub != nil {
|
||||||
|
t.Fatalf("redis clients were initialized when cluster discovery was disabled")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFailoverAfterReconnectFailureDisabledDoesNotSwitchToClusterNode(t *testing.T) {
|
||||||
|
client := New(config.HomeConfig{
|
||||||
|
Enabled: true,
|
||||||
|
Host: "seed.example.com",
|
||||||
|
Port: 8327,
|
||||||
|
DisableClusterDiscovery: true,
|
||||||
|
})
|
||||||
|
client.mu.Lock()
|
||||||
|
client.clusterNodes = []clusterNode{{IP: "other.example.com", Port: 8327}}
|
||||||
|
client.reconnectFailures = homeReconnectFailoverThreshold - 1
|
||||||
|
client.mu.Unlock()
|
||||||
|
|
||||||
|
switched, addr := client.failoverAfterReconnectFailure()
|
||||||
|
if switched {
|
||||||
|
t.Fatalf("failoverAfterReconnectFailure() switched to %s, want no switch", addr)
|
||||||
|
}
|
||||||
|
if got, _ := client.addr(); got != "seed.example.com:8327" {
|
||||||
|
t.Fatalf("addr() = %q, want seed.example.com:8327", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user