feat(api): implement protocol multiplexer and Redis queue for usage integration
- Added `protocol_multiplexer.go`, enabling support for both HTTP and Redis protocols on a single listener. - Introduced `redis_queue_protocol.go` to handle Redis-compatible RESP commands for queue management. - Integrated `redisqueue` package, supporting in-memory queuing with expiration pruning. - Updated server initialization to manage a shared listener and multiplex connections. - Adjusted `Handler` to adopt `AuthenticateManagementKey` for modular key validation, supporting both HTTP and Redis flows.
This commit is contained in:
@@ -0,0 +1,32 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"crypto/tls"
|
||||
"net"
|
||||
)
|
||||
|
||||
type bufferedConn struct {
|
||||
net.Conn
|
||||
reader *bufio.Reader
|
||||
}
|
||||
|
||||
func (c *bufferedConn) Read(p []byte) (int, error) {
|
||||
if c == nil {
|
||||
return 0, net.ErrClosed
|
||||
}
|
||||
if c.reader == nil {
|
||||
return c.Conn.Read(p)
|
||||
}
|
||||
return c.reader.Read(p)
|
||||
}
|
||||
|
||||
func (c *bufferedConn) ConnectionState() tls.ConnectionState {
|
||||
if c == nil || c.Conn == nil {
|
||||
return tls.ConnectionState{}
|
||||
}
|
||||
if stater, ok := c.Conn.(interface{ ConnectionState() tls.ConnectionState }); ok {
|
||||
return stater.ConnectionState()
|
||||
}
|
||||
return tls.ConnectionState{}
|
||||
}
|
||||
@@ -152,9 +152,6 @@ func (h *Handler) SetPostAuthHook(hook coreauth.PostAuthHook) {
|
||||
// All requests (local and remote) require a valid management key.
|
||||
// Additionally, remote access requires allow-remote-management=true.
|
||||
func (h *Handler) Middleware() gin.HandlerFunc {
|
||||
const maxFailures = 5
|
||||
const banDuration = 30 * time.Minute
|
||||
|
||||
return func(c *gin.Context) {
|
||||
c.Header("X-CPA-VERSION", buildinfo.Version)
|
||||
c.Header("X-CPA-COMMIT", buildinfo.Commit)
|
||||
@@ -162,64 +159,6 @@ func (h *Handler) Middleware() gin.HandlerFunc {
|
||||
|
||||
clientIP := c.ClientIP()
|
||||
localClient := clientIP == "127.0.0.1" || clientIP == "::1"
|
||||
cfg := h.cfg
|
||||
var (
|
||||
allowRemote bool
|
||||
secretHash string
|
||||
)
|
||||
if cfg != nil {
|
||||
allowRemote = cfg.RemoteManagement.AllowRemote
|
||||
secretHash = cfg.RemoteManagement.SecretKey
|
||||
}
|
||||
if h.allowRemoteOverride {
|
||||
allowRemote = true
|
||||
}
|
||||
envSecret := h.envSecret
|
||||
|
||||
fail := func() {}
|
||||
if !localClient {
|
||||
h.attemptsMu.Lock()
|
||||
ai := h.failedAttempts[clientIP]
|
||||
if ai != nil {
|
||||
if !ai.blockedUntil.IsZero() {
|
||||
if time.Now().Before(ai.blockedUntil) {
|
||||
remaining := time.Until(ai.blockedUntil).Round(time.Second)
|
||||
h.attemptsMu.Unlock()
|
||||
c.AbortWithStatusJSON(http.StatusForbidden, gin.H{"error": fmt.Sprintf("IP banned due to too many failed attempts. Try again in %s", remaining)})
|
||||
return
|
||||
}
|
||||
// Ban expired, reset state
|
||||
ai.blockedUntil = time.Time{}
|
||||
ai.count = 0
|
||||
}
|
||||
}
|
||||
h.attemptsMu.Unlock()
|
||||
|
||||
if !allowRemote {
|
||||
c.AbortWithStatusJSON(http.StatusForbidden, gin.H{"error": "remote management disabled"})
|
||||
return
|
||||
}
|
||||
|
||||
fail = func() {
|
||||
h.attemptsMu.Lock()
|
||||
aip := h.failedAttempts[clientIP]
|
||||
if aip == nil {
|
||||
aip = &attemptInfo{}
|
||||
h.failedAttempts[clientIP] = aip
|
||||
}
|
||||
aip.count++
|
||||
aip.lastActivity = time.Now()
|
||||
if aip.count >= maxFailures {
|
||||
aip.blockedUntil = time.Now().Add(banDuration)
|
||||
aip.count = 0
|
||||
}
|
||||
h.attemptsMu.Unlock()
|
||||
}
|
||||
}
|
||||
if secretHash == "" && envSecret == "" {
|
||||
c.AbortWithStatusJSON(http.StatusForbidden, gin.H{"error": "remote management key not set"})
|
||||
return
|
||||
}
|
||||
|
||||
// Accept either Authorization: Bearer <key> or X-Management-Key
|
||||
var provided string
|
||||
@@ -235,44 +174,98 @@ func (h *Handler) Middleware() gin.HandlerFunc {
|
||||
provided = c.GetHeader("X-Management-Key")
|
||||
}
|
||||
|
||||
if provided == "" {
|
||||
if !localClient {
|
||||
fail()
|
||||
}
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "missing management key"})
|
||||
allowed, statusCode, errMsg := h.AuthenticateManagementKey(clientIP, localClient, provided)
|
||||
if !allowed {
|
||||
c.AbortWithStatusJSON(statusCode, gin.H{"error": errMsg})
|
||||
return
|
||||
}
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
if localClient {
|
||||
if lp := h.localPassword; lp != "" {
|
||||
if subtle.ConstantTimeCompare([]byte(provided), []byte(lp)) == 1 {
|
||||
c.Next()
|
||||
return
|
||||
// AuthenticateManagementKey verifies the provided management key for the given client.
|
||||
// It mirrors the behaviour of Middleware() so non-HTTP callers can reuse the same logic.
|
||||
func (h *Handler) AuthenticateManagementKey(clientIP string, localClient bool, provided string) (bool, int, string) {
|
||||
const maxFailures = 5
|
||||
const banDuration = 30 * time.Minute
|
||||
|
||||
if h == nil {
|
||||
return false, http.StatusForbidden, "remote management disabled"
|
||||
}
|
||||
|
||||
cfg := h.cfg
|
||||
var (
|
||||
allowRemote bool
|
||||
secretHash string
|
||||
)
|
||||
if cfg != nil {
|
||||
allowRemote = cfg.RemoteManagement.AllowRemote
|
||||
secretHash = cfg.RemoteManagement.SecretKey
|
||||
}
|
||||
if h.allowRemoteOverride {
|
||||
allowRemote = true
|
||||
}
|
||||
envSecret := h.envSecret
|
||||
|
||||
fail := func() {}
|
||||
if !localClient {
|
||||
h.attemptsMu.Lock()
|
||||
ai := h.failedAttempts[clientIP]
|
||||
if ai != nil {
|
||||
if !ai.blockedUntil.IsZero() {
|
||||
if time.Now().Before(ai.blockedUntil) {
|
||||
remaining := time.Until(ai.blockedUntil).Round(time.Second)
|
||||
h.attemptsMu.Unlock()
|
||||
return false, http.StatusForbidden, fmt.Sprintf("IP banned due to too many failed attempts. Try again in %s", remaining)
|
||||
}
|
||||
// Ban expired, reset state
|
||||
ai.blockedUntil = time.Time{}
|
||||
ai.count = 0
|
||||
}
|
||||
}
|
||||
h.attemptsMu.Unlock()
|
||||
|
||||
if envSecret != "" && subtle.ConstantTimeCompare([]byte(provided), []byte(envSecret)) == 1 {
|
||||
if !localClient {
|
||||
h.attemptsMu.Lock()
|
||||
if ai := h.failedAttempts[clientIP]; ai != nil {
|
||||
ai.count = 0
|
||||
ai.blockedUntil = time.Time{}
|
||||
}
|
||||
h.attemptsMu.Unlock()
|
||||
}
|
||||
c.Next()
|
||||
return
|
||||
if !allowRemote {
|
||||
return false, http.StatusForbidden, "remote management disabled"
|
||||
}
|
||||
|
||||
if secretHash == "" || bcrypt.CompareHashAndPassword([]byte(secretHash), []byte(provided)) != nil {
|
||||
if !localClient {
|
||||
fail()
|
||||
fail = func() {
|
||||
h.attemptsMu.Lock()
|
||||
aip := h.failedAttempts[clientIP]
|
||||
if aip == nil {
|
||||
aip = &attemptInfo{}
|
||||
h.failedAttempts[clientIP] = aip
|
||||
}
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "invalid management key"})
|
||||
return
|
||||
aip.count++
|
||||
aip.lastActivity = time.Now()
|
||||
if aip.count >= maxFailures {
|
||||
aip.blockedUntil = time.Now().Add(banDuration)
|
||||
aip.count = 0
|
||||
}
|
||||
h.attemptsMu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
if secretHash == "" && envSecret == "" {
|
||||
return false, http.StatusForbidden, "remote management key not set"
|
||||
}
|
||||
|
||||
if provided == "" {
|
||||
if !localClient {
|
||||
fail()
|
||||
}
|
||||
return false, http.StatusUnauthorized, "missing management key"
|
||||
}
|
||||
|
||||
if localClient {
|
||||
if lp := h.localPassword; lp != "" {
|
||||
if subtle.ConstantTimeCompare([]byte(provided), []byte(lp)) == 1 {
|
||||
return true, 0, ""
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if envSecret != "" && subtle.ConstantTimeCompare([]byte(provided), []byte(envSecret)) == 1 {
|
||||
if !localClient {
|
||||
h.attemptsMu.Lock()
|
||||
if ai := h.failedAttempts[clientIP]; ai != nil {
|
||||
@@ -281,9 +274,26 @@ func (h *Handler) Middleware() gin.HandlerFunc {
|
||||
}
|
||||
h.attemptsMu.Unlock()
|
||||
}
|
||||
|
||||
c.Next()
|
||||
return true, 0, ""
|
||||
}
|
||||
|
||||
if secretHash == "" || bcrypt.CompareHashAndPassword([]byte(secretHash), []byte(provided)) != nil {
|
||||
if !localClient {
|
||||
fail()
|
||||
}
|
||||
return false, http.StatusUnauthorized, "invalid management key"
|
||||
}
|
||||
|
||||
if !localClient {
|
||||
h.attemptsMu.Lock()
|
||||
if ai := h.failedAttempts[clientIP]; ai != nil {
|
||||
ai.count = 0
|
||||
ai.blockedUntil = time.Time{}
|
||||
}
|
||||
h.attemptsMu.Unlock()
|
||||
}
|
||||
|
||||
return true, 0, ""
|
||||
}
|
||||
|
||||
// persist saves the current in-memory config to disk.
|
||||
|
||||
@@ -0,0 +1,68 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"net"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type muxListener struct {
|
||||
addr net.Addr
|
||||
connCh chan net.Conn
|
||||
closeCh chan struct{}
|
||||
once sync.Once
|
||||
}
|
||||
|
||||
func newMuxListener(addr net.Addr, buffer int) *muxListener {
|
||||
if buffer <= 0 {
|
||||
buffer = 1
|
||||
}
|
||||
return &muxListener{
|
||||
addr: addr,
|
||||
connCh: make(chan net.Conn, buffer),
|
||||
closeCh: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
func (l *muxListener) Put(conn net.Conn) error {
|
||||
if conn == nil {
|
||||
return nil
|
||||
}
|
||||
select {
|
||||
case <-l.closeCh:
|
||||
return net.ErrClosed
|
||||
case l.connCh <- conn:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (l *muxListener) Accept() (net.Conn, error) {
|
||||
select {
|
||||
case <-l.closeCh:
|
||||
return nil, net.ErrClosed
|
||||
case conn := <-l.connCh:
|
||||
if conn == nil {
|
||||
return nil, net.ErrClosed
|
||||
}
|
||||
return conn, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (l *muxListener) Close() error {
|
||||
if l == nil {
|
||||
return nil
|
||||
}
|
||||
l.once.Do(func() {
|
||||
close(l.closeCh)
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *muxListener) Addr() net.Addr {
|
||||
if l == nil {
|
||||
return &net.TCPAddr{}
|
||||
}
|
||||
if l.addr == nil {
|
||||
return &net.TCPAddr{}
|
||||
}
|
||||
return l.addr
|
||||
}
|
||||
@@ -0,0 +1,109 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func normalizeHTTPServeError(err error) error {
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
if errors.Is(err, net.ErrClosed) {
|
||||
return nil
|
||||
}
|
||||
if errors.Is(err, http.ErrServerClosed) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func normalizeListenerError(err error) error {
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
if errors.Is(err, net.ErrClosed) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *Server) acceptMuxConnections(listener net.Listener, httpListener *muxListener) error {
|
||||
if s == nil || listener == nil {
|
||||
return net.ErrClosed
|
||||
}
|
||||
|
||||
for {
|
||||
conn, errAccept := listener.Accept()
|
||||
if errAccept != nil {
|
||||
return errAccept
|
||||
}
|
||||
if conn == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
tlsConn, ok := conn.(*tls.Conn)
|
||||
if ok {
|
||||
if errHandshake := tlsConn.Handshake(); errHandshake != nil {
|
||||
if errClose := conn.Close(); errClose != nil {
|
||||
log.Errorf("failed to close connection after TLS handshake error: %v", errClose)
|
||||
}
|
||||
continue
|
||||
}
|
||||
proto := strings.TrimSpace(tlsConn.ConnectionState().NegotiatedProtocol)
|
||||
if proto == "h2" || proto == "http/1.1" {
|
||||
if httpListener == nil {
|
||||
if errClose := conn.Close(); errClose != nil {
|
||||
log.Errorf("failed to close connection: %v", errClose)
|
||||
}
|
||||
continue
|
||||
}
|
||||
if errPut := httpListener.Put(tlsConn); errPut != nil {
|
||||
if errClose := conn.Close(); errClose != nil {
|
||||
log.Errorf("failed to close connection after HTTP routing failure: %v", errClose)
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
reader := bufio.NewReader(conn)
|
||||
prefix, errPeek := reader.Peek(1)
|
||||
if errPeek != nil {
|
||||
if errClose := conn.Close(); errClose != nil {
|
||||
log.Errorf("failed to close connection after protocol peek failure: %v", errClose)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if isRedisRESPPrefix(prefix[0]) {
|
||||
if !s.managementRoutesEnabled.Load() {
|
||||
if errClose := conn.Close(); errClose != nil {
|
||||
log.Errorf("failed to close redis connection while management is disabled: %v", errClose)
|
||||
}
|
||||
continue
|
||||
}
|
||||
go s.handleRedisConnection(conn, reader)
|
||||
continue
|
||||
}
|
||||
|
||||
if httpListener == nil {
|
||||
if errClose := conn.Close(); errClose != nil {
|
||||
log.Errorf("failed to close connection without HTTP listener: %v", errClose)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if errPut := httpListener.Put(&bufferedConn{Conn: conn, reader: reader}); errPut != nil {
|
||||
if errClose := conn.Close(); errClose != nil {
|
||||
log.Errorf("failed to close connection after HTTP routing failure: %v", errClose)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,317 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/redisqueue"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func isRedisRESPPrefix(prefix byte) bool {
|
||||
switch prefix {
|
||||
case '*', '$', '+', '-', ':':
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) handleRedisConnection(conn net.Conn, reader *bufio.Reader) {
|
||||
if s == nil || conn == nil || reader == nil {
|
||||
return
|
||||
}
|
||||
|
||||
clientIP, localClient := resolveRemoteIP(conn.RemoteAddr())
|
||||
authed := false
|
||||
writer := bufio.NewWriter(conn)
|
||||
defer func() {
|
||||
if errClose := conn.Close(); errClose != nil {
|
||||
log.Errorf("redis connection close error: %v", errClose)
|
||||
}
|
||||
}()
|
||||
|
||||
flush := func() bool {
|
||||
if errFlush := writer.Flush(); errFlush != nil {
|
||||
log.Errorf("redis protocol flush error: %v", errFlush)
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
for {
|
||||
if !s.managementRoutesEnabled.Load() {
|
||||
return
|
||||
}
|
||||
|
||||
args, err := readRESPArray(reader)
|
||||
if err != nil {
|
||||
if !errors.Is(err, io.EOF) {
|
||||
_ = writeRedisError(writer, "ERR "+err.Error())
|
||||
_ = writer.Flush()
|
||||
}
|
||||
return
|
||||
}
|
||||
if len(args) == 0 {
|
||||
_ = writeRedisError(writer, "ERR empty command")
|
||||
if !flush() {
|
||||
return
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
cmd := strings.ToUpper(strings.TrimSpace(args[0]))
|
||||
switch cmd {
|
||||
case "AUTH":
|
||||
password, ok := parseAuthPassword(args)
|
||||
if !ok {
|
||||
_ = writeRedisError(writer, "ERR wrong number of arguments for 'auth' command")
|
||||
if !flush() {
|
||||
return
|
||||
}
|
||||
continue
|
||||
}
|
||||
if s.mgmt == nil {
|
||||
_ = writeRedisError(writer, "ERR remote management disabled")
|
||||
if !flush() {
|
||||
return
|
||||
}
|
||||
continue
|
||||
}
|
||||
allowed, _, errMsg := s.mgmt.AuthenticateManagementKey(clientIP, localClient, password)
|
||||
if !allowed {
|
||||
_ = writeRedisError(writer, "ERR "+errMsg)
|
||||
if !flush() {
|
||||
return
|
||||
}
|
||||
continue
|
||||
}
|
||||
authed = true
|
||||
_ = writeRedisSimpleString(writer, "OK")
|
||||
if !flush() {
|
||||
return
|
||||
}
|
||||
case "LPOP", "RPOP":
|
||||
if !authed {
|
||||
_ = writeRedisError(writer, "NOAUTH Authentication required.")
|
||||
if !flush() {
|
||||
return
|
||||
}
|
||||
continue
|
||||
}
|
||||
count, hasCount, ok := parsePopCount(args)
|
||||
if !ok {
|
||||
_ = writeRedisError(writer, "ERR wrong number of arguments for '"+strings.ToLower(cmd)+"' command")
|
||||
if !flush() {
|
||||
return
|
||||
}
|
||||
continue
|
||||
}
|
||||
if count <= 0 {
|
||||
_ = writeRedisError(writer, "ERR value is not an integer or out of range")
|
||||
if !flush() {
|
||||
return
|
||||
}
|
||||
continue
|
||||
}
|
||||
items := redisqueue.PopOldest(count)
|
||||
if hasCount {
|
||||
_ = writeRedisArrayOfBulkStrings(writer, items)
|
||||
if !flush() {
|
||||
return
|
||||
}
|
||||
continue
|
||||
}
|
||||
if len(items) == 0 {
|
||||
_ = writeRedisNilBulkString(writer)
|
||||
if !flush() {
|
||||
return
|
||||
}
|
||||
continue
|
||||
}
|
||||
_ = writeRedisBulkString(writer, items[0])
|
||||
if !flush() {
|
||||
return
|
||||
}
|
||||
default:
|
||||
_ = writeRedisError(writer, fmt.Sprintf("ERR unknown command '%s'", strings.ToLower(cmd)))
|
||||
if !flush() {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func resolveRemoteIP(addr net.Addr) (ip string, localClient bool) {
|
||||
if addr == nil {
|
||||
return "", false
|
||||
}
|
||||
host := addr.String()
|
||||
if h, _, err := net.SplitHostPort(host); err == nil {
|
||||
host = h
|
||||
}
|
||||
host = strings.TrimSpace(host)
|
||||
localClient = host == "127.0.0.1" || host == "::1"
|
||||
return host, localClient
|
||||
}
|
||||
|
||||
func parseAuthPassword(args []string) (string, bool) {
|
||||
switch len(args) {
|
||||
case 2:
|
||||
return args[1], true
|
||||
case 3:
|
||||
// Support AUTH <username> <password> by ignoring username for compatibility.
|
||||
return args[2], true
|
||||
default:
|
||||
return "", false
|
||||
}
|
||||
}
|
||||
|
||||
func parsePopCount(args []string) (count int, hasCount bool, ok bool) {
|
||||
if len(args) != 2 && len(args) != 3 {
|
||||
return 0, false, false
|
||||
}
|
||||
if len(args) == 2 {
|
||||
return 1, false, true
|
||||
}
|
||||
parsed, err := strconv.Atoi(strings.TrimSpace(args[2]))
|
||||
if err != nil {
|
||||
return 0, true, true
|
||||
}
|
||||
return parsed, true, true
|
||||
}
|
||||
|
||||
func readRESPArray(reader *bufio.Reader) ([]string, error) {
|
||||
prefix, err := reader.ReadByte()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if prefix != '*' {
|
||||
return nil, fmt.Errorf("protocol error")
|
||||
}
|
||||
line, err := readRESPLine(reader)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
count, err := strconv.Atoi(line)
|
||||
if err != nil || count < 0 {
|
||||
return nil, fmt.Errorf("protocol error")
|
||||
}
|
||||
args := make([]string, 0, count)
|
||||
for i := 0; i < count; i++ {
|
||||
value, err := readRESPString(reader)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
args = append(args, value)
|
||||
}
|
||||
return args, nil
|
||||
}
|
||||
|
||||
func readRESPString(reader *bufio.Reader) (string, error) {
|
||||
prefix, err := reader.ReadByte()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
switch prefix {
|
||||
case '$':
|
||||
return readRESPBulkString(reader)
|
||||
case '+', ':':
|
||||
return readRESPLine(reader)
|
||||
default:
|
||||
return "", fmt.Errorf("protocol error")
|
||||
}
|
||||
}
|
||||
|
||||
func readRESPBulkString(reader *bufio.Reader) (string, error) {
|
||||
line, err := readRESPLine(reader)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
length, err := strconv.Atoi(line)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("protocol error")
|
||||
}
|
||||
if length < 0 {
|
||||
return "", nil
|
||||
}
|
||||
buf := make([]byte, length+2)
|
||||
if _, err := io.ReadFull(reader, buf); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if length+2 < 2 || buf[length] != '\r' || buf[length+1] != '\n' {
|
||||
return "", fmt.Errorf("protocol error")
|
||||
}
|
||||
return string(buf[:length]), nil
|
||||
}
|
||||
|
||||
func readRESPLine(reader *bufio.Reader) (string, error) {
|
||||
line, err := reader.ReadString('\n')
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
line = strings.TrimSuffix(line, "\n")
|
||||
line = strings.TrimSuffix(line, "\r")
|
||||
return line, nil
|
||||
}
|
||||
|
||||
func writeRedisSimpleString(writer *bufio.Writer, value string) error {
|
||||
if writer == nil {
|
||||
return net.ErrClosed
|
||||
}
|
||||
_, err := writer.WriteString("+" + value + "\r\n")
|
||||
return err
|
||||
}
|
||||
|
||||
func writeRedisError(writer *bufio.Writer, message string) error {
|
||||
if writer == nil {
|
||||
return net.ErrClosed
|
||||
}
|
||||
_, err := writer.WriteString("-" + message + "\r\n")
|
||||
return err
|
||||
}
|
||||
|
||||
func writeRedisNilBulkString(writer *bufio.Writer) error {
|
||||
if writer == nil {
|
||||
return net.ErrClosed
|
||||
}
|
||||
_, err := writer.WriteString("$-1\r\n")
|
||||
return err
|
||||
}
|
||||
|
||||
func writeRedisBulkString(writer *bufio.Writer, payload []byte) error {
|
||||
if writer == nil {
|
||||
return net.ErrClosed
|
||||
}
|
||||
if payload == nil {
|
||||
return writeRedisNilBulkString(writer)
|
||||
}
|
||||
if _, err := writer.WriteString("$" + strconv.Itoa(len(payload)) + "\r\n"); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := writer.Write(payload); err != nil {
|
||||
return err
|
||||
}
|
||||
_, err := writer.WriteString("\r\n")
|
||||
return err
|
||||
}
|
||||
|
||||
func writeRedisArrayOfBulkStrings(writer *bufio.Writer, items [][]byte) error {
|
||||
if writer == nil {
|
||||
return net.ErrClosed
|
||||
}
|
||||
if _, err := writer.WriteString("*" + strconv.Itoa(len(items)) + "\r\n"); err != nil {
|
||||
return err
|
||||
}
|
||||
for i := range items {
|
||||
if err := writeRedisBulkString(writer, items[i]); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,304 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/redisqueue"
|
||||
)
|
||||
|
||||
func startRedisMuxListener(t *testing.T, server *Server) (addr string, stop func()) {
|
||||
t.Helper()
|
||||
|
||||
listener, errListen := net.Listen("tcp", "127.0.0.1:0")
|
||||
if errListen != nil {
|
||||
t.Fatalf("failed to listen: %v", errListen)
|
||||
}
|
||||
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
errCh <- server.acceptMuxConnections(listener, nil)
|
||||
}()
|
||||
|
||||
stop = func() {
|
||||
_ = listener.Close()
|
||||
select {
|
||||
case err := <-errCh:
|
||||
if err != nil && !errors.Is(err, net.ErrClosed) {
|
||||
t.Errorf("accept loop returned unexpected error: %v", err)
|
||||
}
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Errorf("timeout waiting for accept loop to exit")
|
||||
}
|
||||
}
|
||||
|
||||
return listener.Addr().String(), stop
|
||||
}
|
||||
|
||||
func writeTestRESPCommand(conn net.Conn, args ...string) error {
|
||||
if conn == nil {
|
||||
return net.ErrClosed
|
||||
}
|
||||
if len(args) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
fmt.Fprintf(&buf, "*%d\r\n", len(args))
|
||||
for _, arg := range args {
|
||||
fmt.Fprintf(&buf, "$%d\r\n%s\r\n", len(arg), arg)
|
||||
}
|
||||
_, err := conn.Write(buf.Bytes())
|
||||
return err
|
||||
}
|
||||
|
||||
func readTestRESPLine(r *bufio.Reader) (string, error) {
|
||||
line, err := r.ReadString('\n')
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if !strings.HasSuffix(line, "\r\n") {
|
||||
return "", fmt.Errorf("invalid RESP line terminator: %q", line)
|
||||
}
|
||||
return strings.TrimSuffix(line, "\r\n"), nil
|
||||
}
|
||||
|
||||
func readTestRESPSimpleString(r *bufio.Reader) (string, error) {
|
||||
prefix, err := r.ReadByte()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if prefix != '+' {
|
||||
return "", fmt.Errorf("expected simple string prefix '+', got %q", prefix)
|
||||
}
|
||||
return readTestRESPLine(r)
|
||||
}
|
||||
|
||||
func readTestRESPError(r *bufio.Reader) (string, error) {
|
||||
prefix, err := r.ReadByte()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if prefix != '-' {
|
||||
return "", fmt.Errorf("expected error prefix '-', got %q", prefix)
|
||||
}
|
||||
return readTestRESPLine(r)
|
||||
}
|
||||
|
||||
func readTestRESPBulkString(r *bufio.Reader) ([]byte, error) {
|
||||
prefix, err := r.ReadByte()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if prefix != '$' {
|
||||
return nil, fmt.Errorf("expected bulk string prefix '$', got %q", prefix)
|
||||
}
|
||||
|
||||
line, err := readTestRESPLine(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
length, err := strconv.Atoi(line)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid bulk string length %q: %v", line, err)
|
||||
}
|
||||
if length == -1 {
|
||||
return nil, nil
|
||||
}
|
||||
if length < -1 {
|
||||
return nil, fmt.Errorf("invalid bulk string length %d", length)
|
||||
}
|
||||
|
||||
payload := make([]byte, length+2)
|
||||
if _, err := io.ReadFull(r, payload); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if payload[length] != '\r' || payload[length+1] != '\n' {
|
||||
return nil, fmt.Errorf("invalid bulk string terminator")
|
||||
}
|
||||
return payload[:length], nil
|
||||
}
|
||||
|
||||
func readRESPArrayOfBulkStrings(r *bufio.Reader) ([][]byte, error) {
|
||||
prefix, err := r.ReadByte()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if prefix != '*' {
|
||||
return nil, fmt.Errorf("expected array prefix '*', got %q", prefix)
|
||||
}
|
||||
|
||||
line, err := readTestRESPLine(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
count, err := strconv.Atoi(line)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid array length %q: %v", line, err)
|
||||
}
|
||||
if count < 0 {
|
||||
return nil, fmt.Errorf("invalid array length %d", count)
|
||||
}
|
||||
|
||||
out := make([][]byte, 0, count)
|
||||
for i := 0; i < count; i++ {
|
||||
item, err := readTestRESPBulkString(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out = append(out, item)
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func TestRedisProtocol_ManagementDisabled_RejectsConnection(t *testing.T) {
|
||||
t.Setenv("MANAGEMENT_PASSWORD", "")
|
||||
redisqueue.SetEnabled(false)
|
||||
|
||||
server := newTestServer(t)
|
||||
if server.managementRoutesEnabled.Load() {
|
||||
t.Fatalf("expected managementRoutesEnabled to be false")
|
||||
}
|
||||
|
||||
addr, stop := startRedisMuxListener(t, server)
|
||||
t.Cleanup(stop)
|
||||
|
||||
conn, errDial := net.DialTimeout("tcp", addr, time.Second)
|
||||
if errDial != nil {
|
||||
t.Fatalf("failed to dial redis listener: %v", errDial)
|
||||
}
|
||||
t.Cleanup(func() { _ = conn.Close() })
|
||||
|
||||
_ = conn.SetDeadline(time.Now().Add(2 * time.Second))
|
||||
if errWrite := writeTestRESPCommand(conn, "PING"); errWrite != nil {
|
||||
t.Fatalf("failed to write RESP command: %v", errWrite)
|
||||
}
|
||||
|
||||
buf := make([]byte, 1)
|
||||
_, errRead := conn.Read(buf)
|
||||
if errRead == nil {
|
||||
t.Fatalf("expected connection to be closed when management is disabled")
|
||||
}
|
||||
if ne, ok := errRead.(net.Error); ok && ne.Timeout() {
|
||||
t.Fatalf("expected connection to be closed when management is disabled, got timeout: %v", errRead)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRedisProtocol_AUTH_And_PopContracts(t *testing.T) {
|
||||
const managementPassword = "test-management-password"
|
||||
|
||||
t.Setenv("MANAGEMENT_PASSWORD", managementPassword)
|
||||
redisqueue.SetEnabled(false)
|
||||
t.Cleanup(func() { redisqueue.SetEnabled(false) })
|
||||
|
||||
server := newTestServer(t)
|
||||
if !server.managementRoutesEnabled.Load() {
|
||||
t.Fatalf("expected managementRoutesEnabled to be true")
|
||||
}
|
||||
|
||||
addr, stop := startRedisMuxListener(t, server)
|
||||
t.Cleanup(stop)
|
||||
|
||||
conn, errDial := net.DialTimeout("tcp", addr, time.Second)
|
||||
if errDial != nil {
|
||||
t.Fatalf("failed to dial redis listener: %v", errDial)
|
||||
}
|
||||
t.Cleanup(func() { _ = conn.Close() })
|
||||
|
||||
reader := bufio.NewReader(conn)
|
||||
|
||||
_ = conn.SetDeadline(time.Now().Add(5 * time.Second))
|
||||
|
||||
if errWrite := writeTestRESPCommand(conn, "AUTH", "test-key"); errWrite != nil {
|
||||
t.Fatalf("failed to write AUTH command: %v", errWrite)
|
||||
}
|
||||
if msg, err := readTestRESPError(reader); err != nil {
|
||||
t.Fatalf("failed to read AUTH error: %v", err)
|
||||
} else if msg != "ERR invalid management key" {
|
||||
t.Fatalf("unexpected AUTH error: %q", msg)
|
||||
}
|
||||
|
||||
if errWrite := writeTestRESPCommand(conn, "LPOP", "queue"); errWrite != nil {
|
||||
t.Fatalf("failed to write LPOP command: %v", errWrite)
|
||||
}
|
||||
if msg, err := readTestRESPError(reader); err != nil {
|
||||
t.Fatalf("failed to read LPOP NOAUTH error: %v", err)
|
||||
} else if msg != "NOAUTH Authentication required." {
|
||||
t.Fatalf("unexpected LPOP NOAUTH error: %q", msg)
|
||||
}
|
||||
|
||||
if errWrite := writeTestRESPCommand(conn, "AUTH", managementPassword); errWrite != nil {
|
||||
t.Fatalf("failed to write AUTH command: %v", errWrite)
|
||||
}
|
||||
if msg, err := readTestRESPSimpleString(reader); err != nil {
|
||||
t.Fatalf("failed to read AUTH response: %v", err)
|
||||
} else if msg != "OK" {
|
||||
t.Fatalf("unexpected AUTH response: %q", msg)
|
||||
}
|
||||
|
||||
if !redisqueue.Enabled() {
|
||||
t.Fatalf("expected redisqueue to be enabled")
|
||||
}
|
||||
redisqueue.Enqueue([]byte("a"))
|
||||
redisqueue.Enqueue([]byte("b"))
|
||||
redisqueue.Enqueue([]byte("c"))
|
||||
|
||||
if errWrite := writeTestRESPCommand(conn, "RPOP", "queue"); errWrite != nil {
|
||||
t.Fatalf("failed to write RPOP command: %v", errWrite)
|
||||
}
|
||||
if item, err := readTestRESPBulkString(reader); err != nil {
|
||||
t.Fatalf("failed to read RPOP response: %v", err)
|
||||
} else if string(item) != "a" {
|
||||
t.Fatalf("unexpected RPOP item: %q", string(item))
|
||||
}
|
||||
|
||||
if errWrite := writeTestRESPCommand(conn, "LPOP", "queue"); errWrite != nil {
|
||||
t.Fatalf("failed to write LPOP command: %v", errWrite)
|
||||
}
|
||||
if item, err := readTestRESPBulkString(reader); err != nil {
|
||||
t.Fatalf("failed to read LPOP response: %v", err)
|
||||
} else if string(item) != "b" {
|
||||
t.Fatalf("unexpected LPOP item: %q", string(item))
|
||||
}
|
||||
|
||||
if errWrite := writeTestRESPCommand(conn, "RPOP", "queue", "10"); errWrite != nil {
|
||||
t.Fatalf("failed to write RPOP count command: %v", errWrite)
|
||||
}
|
||||
items, errItems := readRESPArrayOfBulkStrings(reader)
|
||||
if errItems != nil {
|
||||
t.Fatalf("failed to read RPOP count response: %v", errItems)
|
||||
}
|
||||
if len(items) != 1 || string(items[0]) != "c" {
|
||||
t.Fatalf("unexpected RPOP count items: %#v", items)
|
||||
}
|
||||
|
||||
if errWrite := writeTestRESPCommand(conn, "LPOP", "queue"); errWrite != nil {
|
||||
t.Fatalf("failed to write LPOP empty command: %v", errWrite)
|
||||
}
|
||||
item, errItem := readTestRESPBulkString(reader)
|
||||
if errItem != nil {
|
||||
t.Fatalf("failed to read LPOP empty response: %v", errItem)
|
||||
}
|
||||
if item != nil {
|
||||
t.Fatalf("expected nil bulk string for empty queue, got %q", string(item))
|
||||
}
|
||||
|
||||
if errWrite := writeTestRESPCommand(conn, "RPOP", "queue", "2"); errWrite != nil {
|
||||
t.Fatalf("failed to write RPOP empty count command: %v", errWrite)
|
||||
}
|
||||
emptyItems, errEmpty := readRESPArrayOfBulkStrings(reader)
|
||||
if errEmpty != nil {
|
||||
t.Fatalf("failed to read RPOP empty count response: %v", errEmpty)
|
||||
}
|
||||
if len(emptyItems) != 0 {
|
||||
t.Fatalf("expected empty array for empty queue with count, got %#v", emptyItems)
|
||||
}
|
||||
}
|
||||
+106
-13
@@ -7,8 +7,10 @@ package api
|
||||
import (
|
||||
"context"
|
||||
"crypto/subtle"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
@@ -28,6 +30,7 @@ import (
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/managementasset"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/redisqueue"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/usage"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
sdkaccess "github.com/router-for-me/CLIProxyAPI/v6/sdk/access"
|
||||
@@ -38,6 +41,7 @@ import (
|
||||
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/net/http2"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
@@ -127,6 +131,12 @@ type Server struct {
|
||||
// server is the underlying HTTP server.
|
||||
server *http.Server
|
||||
|
||||
// muxBaseListener is the shared TCP listener used to serve both HTTP and Redis protocol traffic.
|
||||
muxBaseListener net.Listener
|
||||
|
||||
// muxHTTPListener receives HTTP connections selected by the multiplexer.
|
||||
muxHTTPListener *muxListener
|
||||
|
||||
// handlers contains the API handlers for processing requests.
|
||||
handlers *handlers.BaseAPIHandler
|
||||
|
||||
@@ -299,6 +309,7 @@ func NewServer(cfg *config.Config, authManager *auth.Manager, accessManager *sdk
|
||||
// or when a local management password is provided (e.g. TUI mode).
|
||||
hasManagementSecret := cfg.RemoteManagement.SecretKey != "" || envManagementSecret || s.localPassword != ""
|
||||
s.managementRoutesEnabled.Store(hasManagementSecret)
|
||||
redisqueue.SetEnabled(hasManagementSecret)
|
||||
if hasManagementSecret {
|
||||
s.registerManagementRoutes()
|
||||
}
|
||||
@@ -797,26 +808,98 @@ func (s *Server) Start() error {
|
||||
return fmt.Errorf("failed to start HTTP server: server not initialized")
|
||||
}
|
||||
|
||||
addr := s.server.Addr
|
||||
listener, errListen := net.Listen("tcp", addr)
|
||||
if errListen != nil {
|
||||
return fmt.Errorf("failed to start HTTP server: %v", errListen)
|
||||
}
|
||||
|
||||
useTLS := s.cfg != nil && s.cfg.TLS.Enable
|
||||
if useTLS {
|
||||
cert := strings.TrimSpace(s.cfg.TLS.Cert)
|
||||
key := strings.TrimSpace(s.cfg.TLS.Key)
|
||||
if cert == "" || key == "" {
|
||||
certPath := strings.TrimSpace(s.cfg.TLS.Cert)
|
||||
keyPath := strings.TrimSpace(s.cfg.TLS.Key)
|
||||
if certPath == "" || keyPath == "" {
|
||||
if errClose := listener.Close(); errClose != nil {
|
||||
log.Errorf("failed to close listener after TLS validation failure: %v", errClose)
|
||||
}
|
||||
return fmt.Errorf("failed to start HTTPS server: tls.cert or tls.key is empty")
|
||||
}
|
||||
log.Debugf("Starting API server on %s with TLS", s.server.Addr)
|
||||
if errServeTLS := s.server.ListenAndServeTLS(cert, key); errServeTLS != nil && !errors.Is(errServeTLS, http.ErrServerClosed) {
|
||||
return fmt.Errorf("failed to start HTTPS server: %v", errServeTLS)
|
||||
certPair, errLoad := tls.LoadX509KeyPair(certPath, keyPath)
|
||||
if errLoad != nil {
|
||||
if errClose := listener.Close(); errClose != nil {
|
||||
log.Errorf("failed to close listener after TLS key pair load failure: %v", errClose)
|
||||
}
|
||||
return fmt.Errorf("failed to start HTTPS server: %v", errLoad)
|
||||
}
|
||||
|
||||
tlsConfig := &tls.Config{
|
||||
Certificates: []tls.Certificate{certPair},
|
||||
NextProtos: []string{"h2", "http/1.1"},
|
||||
}
|
||||
s.server.TLSConfig = tlsConfig
|
||||
if errHTTP2 := http2.ConfigureServer(s.server, &http2.Server{}); errHTTP2 != nil {
|
||||
log.Warnf("failed to configure HTTP/2: %v", errHTTP2)
|
||||
}
|
||||
listener = tls.NewListener(listener, tlsConfig)
|
||||
log.Debugf("Starting API server on %s with TLS", addr)
|
||||
} else {
|
||||
log.Debugf("Starting API server on %s", addr)
|
||||
}
|
||||
|
||||
httpListener := newMuxListener(listener.Addr(), 1024)
|
||||
s.muxBaseListener = listener
|
||||
s.muxHTTPListener = httpListener
|
||||
|
||||
httpErrCh := make(chan error, 1)
|
||||
acceptErrCh := make(chan error, 1)
|
||||
|
||||
go func() {
|
||||
httpErrCh <- s.server.Serve(httpListener)
|
||||
}()
|
||||
go func() {
|
||||
acceptErrCh <- s.acceptMuxConnections(listener, httpListener)
|
||||
}()
|
||||
|
||||
select {
|
||||
case errServe := <-httpErrCh:
|
||||
if s.muxBaseListener != nil {
|
||||
if errClose := s.muxBaseListener.Close(); errClose != nil && !errors.Is(errClose, net.ErrClosed) {
|
||||
log.Debugf("failed to close shared listener after HTTP serve exit: %v", errClose)
|
||||
}
|
||||
}
|
||||
if s.muxHTTPListener != nil {
|
||||
_ = s.muxHTTPListener.Close()
|
||||
}
|
||||
errAccept := <-acceptErrCh
|
||||
errServe = normalizeHTTPServeError(errServe)
|
||||
errAccept = normalizeListenerError(errAccept)
|
||||
if errServe != nil {
|
||||
return fmt.Errorf("failed to start HTTP server: %v", errServe)
|
||||
}
|
||||
if errAccept != nil {
|
||||
return fmt.Errorf("failed to start HTTP server: %v", errAccept)
|
||||
}
|
||||
return nil
|
||||
case errAccept := <-acceptErrCh:
|
||||
if s.muxHTTPListener != nil {
|
||||
_ = s.muxHTTPListener.Close()
|
||||
}
|
||||
if s.muxBaseListener != nil {
|
||||
if errClose := s.muxBaseListener.Close(); errClose != nil && !errors.Is(errClose, net.ErrClosed) {
|
||||
log.Debugf("failed to close shared listener after accept loop exit: %v", errClose)
|
||||
}
|
||||
}
|
||||
errServe := <-httpErrCh
|
||||
errServe = normalizeHTTPServeError(errServe)
|
||||
errAccept = normalizeListenerError(errAccept)
|
||||
if errAccept != nil {
|
||||
return fmt.Errorf("failed to start HTTP server: %v", errAccept)
|
||||
}
|
||||
if errServe != nil {
|
||||
return fmt.Errorf("failed to start HTTP server: %v", errServe)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
log.Debugf("Starting API server on %s", s.server.Addr)
|
||||
if errServe := s.server.ListenAndServe(); errServe != nil && !errors.Is(errServe, http.ErrServerClosed) {
|
||||
return fmt.Errorf("failed to start HTTP server: %v", errServe)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop gracefully shuts down the API server without interrupting any
|
||||
@@ -837,6 +920,15 @@ func (s *Server) Stop(ctx context.Context) error {
|
||||
}
|
||||
}
|
||||
|
||||
if s.muxHTTPListener != nil {
|
||||
_ = s.muxHTTPListener.Close()
|
||||
}
|
||||
if s.muxBaseListener != nil {
|
||||
if errClose := s.muxBaseListener.Close(); errClose != nil && !errors.Is(errClose, net.ErrClosed) {
|
||||
log.Debugf("failed to close shared listener: %v", errClose)
|
||||
}
|
||||
}
|
||||
|
||||
// Shutdown the HTTP server.
|
||||
if err := s.server.Shutdown(ctx); err != nil {
|
||||
return fmt.Errorf("failed to shutdown HTTP server: %v", err)
|
||||
@@ -963,6 +1055,7 @@ func (s *Server) UpdateClients(cfg *config.Config) {
|
||||
s.managementRoutesEnabled.Store(!newSecretEmpty)
|
||||
}
|
||||
}
|
||||
redisqueue.SetEnabled(s.managementRoutesEnabled.Load())
|
||||
|
||||
s.applyAccessConfig(oldCfg, cfg)
|
||||
s.cfg = cfg
|
||||
|
||||
Reference in New Issue
Block a user