feat(server): add mTLS certificate bootstrap via JWT for Home connections

- Introduced `-home-jwt` flag and `HOME_JWT` environment variable to provide JWT for mTLS certificate generation.
- Added new APIs to handle certificate requests, validate JWT claims, and manage local certificate files.
- Updated Home TLS configuration to support client certificates, keys, and dynamic server name resolution.
This commit is contained in:
Luis Pater
2026-05-19 00:53:40 +08:00
parent cc0cb057b3
commit 77ba15f71b
4 changed files with 414 additions and 8 deletions
+56 -1
View File
@@ -190,6 +190,7 @@ func main() {
var password string var password string
var homeAddr string var homeAddr string
var homePassword string var homePassword string
var homeJWT string
var homeDisableClusterDiscovery bool var homeDisableClusterDiscovery bool
var tuiMode bool var tuiMode bool
var standalone bool var standalone bool
@@ -212,6 +213,7 @@ func main() {
flag.StringVar(&password, "password", "", "") flag.StringVar(&password, "password", "", "")
flag.StringVar(&homeAddr, "home", "", "Home control plane address in host:port, redis://host:port, or rediss://host:port format (loads config from home and skips local config file)") flag.StringVar(&homeAddr, "home", "", "Home control plane address in host:port, redis://host:port, or rediss://host:port format (loads config from home and skips local config file)")
flag.StringVar(&homePassword, "home-password", "", "Home control plane password (Redis AUTH)") flag.StringVar(&homePassword, "home-password", "", "Home control plane password (Redis AUTH)")
flag.StringVar(&homeJWT, "home-jwt", "", "Home control plane JWT for mTLS certificate bootstrap and connection")
flag.BoolVar(&homeDisableClusterDiscovery, "home-disable-cluster-discovery", false, "Disable Home CLUSTER NODES discovery and keep using the configured -home address") flag.BoolVar(&homeDisableClusterDiscovery, "home-disable-cluster-discovery", false, "Disable Home CLUSTER NODES discovery and keep using the configured -home address")
flag.BoolVar(&tuiMode, "tui", false, "Start with terminal management UI") flag.BoolVar(&tuiMode, "tui", false, "Start with terminal management UI")
flag.BoolVar(&standalone, "standalone", false, "In TUI mode, start an embedded local server") flag.BoolVar(&standalone, "standalone", false, "In TUI mode, start an embedded local server")
@@ -311,6 +313,11 @@ func main() {
homePassword = v homePassword = v
} }
} }
if strings.TrimSpace(homeJWT) == "" {
if v, ok := lookupEnv("HOME_JWT", "home_jwt"); ok {
homeJWT = v
}
}
if value, ok := lookupEnv("PGSTORE_DSN", "pgstore_dsn"); ok { if value, ok := lookupEnv("PGSTORE_DSN", "pgstore_dsn"); ok {
usePostgresStore = true usePostgresStore = true
@@ -375,7 +382,55 @@ func main() {
// Determine and load the configuration file. // Determine and load the configuration file.
// Prefer the Postgres store when configured, otherwise fallback to git or local files. // Prefer the Postgres store when configured, otherwise fallback to git or local files.
var configFilePath string var configFilePath string
if strings.TrimSpace(homeAddr) != "" { if strings.TrimSpace(homeJWT) != "" {
configLoadedFromHome = true
ctxHome, cancelHome := context.WithTimeout(context.Background(), 30*time.Second)
homeCfg, errHomeCfg := home.ConfigFromJWT(ctxHome, homeJWT)
cancelHome()
if errHomeCfg != nil {
log.Errorf("invalid -home-jwt: %v", errHomeCfg)
return
}
if homeDisableClusterDiscovery {
homeCfg.DisableClusterDiscovery = true
}
homeClient := home.New(homeCfg)
defer homeClient.Close()
ctxHomeConfig, cancelHomeConfig := context.WithTimeout(context.Background(), 30*time.Second)
raw, errGetConfig := homeClient.GetConfig(ctxHomeConfig)
cancelHomeConfig()
if errGetConfig != nil {
log.Errorf("failed to fetch config from home: %v", errGetConfig)
return
}
parsed, errParseConfig := config.ParseConfigBytes(raw)
if errParseConfig != nil {
log.Errorf("failed to parse config payload from home: %v", errParseConfig)
return
}
if parsed == nil {
parsed = &config.Config{}
}
parsed.Home = homeCfg
parsed.Port = 8317 // Default to 8317 for home mode, can be overridden by home config
parsed.UsageStatisticsEnabled = true
cfg = parsed
// Keep a non-empty config path for downstream components (log paths, management assets, etc),
// but do not require the file to exist when loading config from home.
if strings.TrimSpace(configPath) != "" {
configFilePath = configPath
} else {
configFilePath = filepath.Join(wd, "config.yaml")
}
// Local stores are intentionally disabled when config is loaded from home.
usePostgresStore = false
useObjectStore = false
useGitStore = false
} else if strings.TrimSpace(homeAddr) != "" {
configLoadedFromHome = true configLoadedFromHome = true
trimmedHomePassword := strings.TrimSpace(homePassword) trimmedHomePassword := strings.TrimSpace(homePassword)
homeCfg, errHomeCfg := parseHomeFlagConfig(homeAddr, trimmedHomePassword) homeCfg, errHomeCfg := parseHomeFlagConfig(homeAddr, trimmedHomePassword)
+7 -4
View File
@@ -12,8 +12,11 @@ type HomeConfig struct {
// HomeTLSConfig configures client-side TLS for the home Redis connection. // HomeTLSConfig configures client-side TLS for the home Redis connection.
type HomeTLSConfig struct { type HomeTLSConfig struct {
Enable bool `yaml:"enable" json:"-"` Enable bool `yaml:"enable" json:"-"`
ServerName string `yaml:"server-name" json:"-"` ServerName string `yaml:"server-name" json:"-"`
InsecureSkipVerify bool `yaml:"insecure-skip-verify" json:"-"` InsecureSkipVerify bool `yaml:"insecure-skip-verify" json:"-"`
CACert string `yaml:"ca-cert" json:"-"` CACert string `yaml:"ca-cert" json:"-"`
ClientCert string `yaml:"-" json:"-"`
ClientKey string `yaml:"-" json:"-"`
UseTargetServerName bool `yaml:"-" json:"-"`
} }
+323
View File
@@ -0,0 +1,323 @@
package home
import (
"bufio"
"bytes"
"context"
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"crypto/x509/pkix"
"encoding/base64"
"encoding/json"
"encoding/pem"
"fmt"
"io"
"net"
"os"
"path/filepath"
"strconv"
"strings"
"time"
"github.com/router-for-me/CLIProxyAPI/v7/internal/config"
)
const homeCertificateRequestTimeout = 30 * time.Second
type homeJWTClaims struct {
CertificateID string `json:"certificate_id"`
IP string `json:"ip"`
Port int `json:"port"`
IssuedAt int64 `json:"iat"`
}
type certificateRequestResponse struct {
OK bool `json:"ok"`
Certificate string `json:"certificate"`
CA string `json:"ca"`
}
type certificatePaths struct {
Dir string
ClientCert string
ClientKey string
CACert string
}
// ConfigFromJWT prepares a Home config from the JWT and ensures local mTLS files exist.
func ConfigFromJWT(ctx context.Context, rawJWT string) (config.HomeConfig, error) {
claims, errClaims := parseHomeJWTClaims(rawJWT)
if errClaims != nil {
return config.HomeConfig{}, errClaims
}
paths, errPaths := defaultCertificatePaths()
if errPaths != nil {
return config.HomeConfig{}, errPaths
}
if errEnsure := ensureHomeCertificateFiles(ctx, claims, paths); errEnsure != nil {
return config.HomeConfig{}, errEnsure
}
return config.HomeConfig{
Enabled: true,
Host: strings.TrimSpace(claims.IP),
Port: claims.Port,
TLS: config.HomeTLSConfig{
Enable: true,
CACert: paths.CACert,
ClientCert: paths.ClientCert,
ClientKey: paths.ClientKey,
UseTargetServerName: true,
},
}, nil
}
func parseHomeJWTClaims(rawJWT string) (homeJWTClaims, error) {
var claims homeJWTClaims
parts := strings.Split(strings.TrimSpace(rawJWT), ".")
if len(parts) != 3 {
return claims, fmt.Errorf("home jwt is invalid")
}
payload, errDecode := decodeJWTPart(parts[1])
if errDecode != nil {
return claims, errDecode
}
if errUnmarshal := json.Unmarshal(payload, &claims); errUnmarshal != nil {
return claims, errUnmarshal
}
if strings.TrimSpace(claims.CertificateID) == "" {
return claims, fmt.Errorf("home jwt certificate_id is required")
}
if strings.TrimSpace(claims.IP) == "" || claims.Port <= 0 {
return claims, fmt.Errorf("home jwt target address is invalid")
}
return claims, nil
}
func decodeJWTPart(part string) ([]byte, error) {
if decoded, errDecode := base64.RawURLEncoding.DecodeString(part); errDecode == nil {
return decoded, nil
}
return base64.URLEncoding.DecodeString(part)
}
func defaultCertificatePaths() (certificatePaths, error) {
homeDir, errHome := os.UserHomeDir()
if errHome != nil {
return certificatePaths{}, errHome
}
dir := filepath.Join(homeDir, ".cli-proxy-api")
return certificatePaths{
Dir: dir,
ClientCert: filepath.Join(dir, "client-crt.pem"),
ClientKey: filepath.Join(dir, "client-key.pem"),
CACert: filepath.Join(dir, "home-ca-crt.pem"),
}, nil
}
func ensureHomeCertificateFiles(ctx context.Context, claims homeJWTClaims, paths certificatePaths) error {
if fileExists(paths.ClientCert) && fileExists(paths.ClientKey) {
if !fileExists(paths.CACert) {
return fmt.Errorf("home ca certificate file is missing")
}
if errChmod := chmodCertificateFiles(paths); errChmod != nil {
return errChmod
}
return nil
}
if errMkdir := os.MkdirAll(paths.Dir, 0o700); errMkdir != nil {
return errMkdir
}
key, errKey := loadOrCreateClientKey(paths.ClientKey)
if errKey != nil {
return errKey
}
csrPEM, errCSR := createClientCSR(claims.CertificateID, key)
if errCSR != nil {
return errCSR
}
response, errRequest := requestClientCertificate(ctx, claims, csrPEM)
if errRequest != nil {
return errRequest
}
if strings.TrimSpace(response.Certificate) == "" || strings.TrimSpace(response.CA) == "" {
return fmt.Errorf("home certificate response is incomplete")
}
if errWrite := writeFile0600(paths.ClientCert, []byte(response.Certificate)); errWrite != nil {
return errWrite
}
if errWrite := writeFile0600(paths.CACert, []byte(response.CA)); errWrite != nil {
return errWrite
}
return nil
}
func loadOrCreateClientKey(path string) (*rsa.PrivateKey, error) {
if fileExists(path) {
raw, errRead := os.ReadFile(path)
if errRead != nil {
return nil, errRead
}
key, errParse := parseRSAPrivateKeyPEM(raw)
if errParse != nil {
return nil, errParse
}
if errChmod := os.Chmod(path, 0o600); errChmod != nil {
return nil, errChmod
}
return key, nil
}
key, errKey := rsa.GenerateKey(rand.Reader, 2048)
if errKey != nil {
return nil, errKey
}
raw := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(key)})
if errWrite := writeFile0600(path, raw); errWrite != nil {
return nil, errWrite
}
return key, nil
}
func writeFile0600(path string, raw []byte) error {
if errWrite := os.WriteFile(path, raw, 0o600); errWrite != nil {
return errWrite
}
return os.Chmod(path, 0o600)
}
func chmodCertificateFiles(paths certificatePaths) error {
for _, path := range []string{paths.ClientCert, paths.ClientKey, paths.CACert} {
if errChmod := os.Chmod(path, 0o600); errChmod != nil {
return errChmod
}
}
return nil
}
func parseRSAPrivateKeyPEM(raw []byte) (*rsa.PrivateKey, error) {
block, _ := pem.Decode(raw)
if block == nil {
return nil, fmt.Errorf("client key pem is invalid")
}
switch block.Type {
case "RSA PRIVATE KEY":
return x509.ParsePKCS1PrivateKey(block.Bytes)
case "PRIVATE KEY":
key, errParse := x509.ParsePKCS8PrivateKey(block.Bytes)
if errParse != nil {
return nil, errParse
}
rsaKey, ok := key.(*rsa.PrivateKey)
if !ok {
return nil, fmt.Errorf("client key is not rsa")
}
return rsaKey, nil
default:
return nil, fmt.Errorf("client key pem type %q is unsupported", block.Type)
}
}
func createClientCSR(certificateID string, key *rsa.PrivateKey) ([]byte, error) {
certificateID = strings.TrimSpace(certificateID)
if certificateID == "" {
return nil, fmt.Errorf("certificate id is required")
}
template := &x509.CertificateRequest{
Subject: pkix.Name{
CommonName: certificateID,
},
}
der, errCreate := x509.CreateCertificateRequest(rand.Reader, template, key)
if errCreate != nil {
return nil, errCreate
}
return pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE REQUEST", Bytes: der}), nil
}
func requestClientCertificate(ctx context.Context, claims homeJWTClaims, csrPEM []byte) (certificateRequestResponse, error) {
var response certificateRequestResponse
if ctx == nil {
ctx = context.Background()
}
dialCtx, cancel := context.WithTimeout(ctx, homeCertificateRequestTimeout)
defer cancel()
addr := net.JoinHostPort(strings.TrimSpace(claims.IP), strconv.Itoa(claims.Port))
conn, errDial := (&net.Dialer{}).DialContext(dialCtx, "tcp", addr)
if errDial != nil {
return response, errDial
}
defer func() {
_ = conn.Close()
}()
if deadline, ok := dialCtx.Deadline(); ok {
_ = conn.SetDeadline(deadline)
}
if _, errWrite := conn.Write(encodeRESPArray("CERTIFICATE", "REQUEST", claims.CertificateID, string(csrPEM))); errWrite != nil {
return response, errWrite
}
raw, errRead := readRESPBulk(bufio.NewReader(conn))
if errRead != nil {
return response, errRead
}
if errUnmarshal := json.Unmarshal(raw, &response); errUnmarshal != nil {
return response, errUnmarshal
}
if !response.OK {
return response, fmt.Errorf("home certificate request failed")
}
return response, nil
}
func encodeRESPArray(args ...string) []byte {
var buf bytes.Buffer
buf.WriteString("*")
buf.WriteString(strconv.Itoa(len(args)))
buf.WriteString("\r\n")
for _, arg := range args {
buf.WriteString("$")
buf.WriteString(strconv.Itoa(len(arg)))
buf.WriteString("\r\n")
buf.WriteString(arg)
buf.WriteString("\r\n")
}
return buf.Bytes()
}
func readRESPBulk(reader *bufio.Reader) ([]byte, error) {
prefix, errRead := reader.ReadByte()
if errRead != nil {
return nil, errRead
}
switch prefix {
case '$':
line, errLine := reader.ReadString('\n')
if errLine != nil {
return nil, errLine
}
size, errSize := strconv.Atoi(strings.TrimSpace(line))
if errSize != nil {
return nil, errSize
}
if size < 0 {
return nil, fmt.Errorf("home certificate request returned nil")
}
payload := make([]byte, size+2)
if _, errFull := io.ReadFull(reader, payload); errFull != nil {
return nil, errFull
}
return payload[:size], nil
case '-':
line, errLine := reader.ReadString('\n')
if errLine != nil {
return nil, errLine
}
return nil, fmt.Errorf("%s", strings.TrimSpace(line))
default:
return nil, fmt.Errorf("home certificate request returned unsupported resp prefix %q", prefix)
}
}
func fileExists(path string) bool {
info, errStat := os.Stat(path)
return errStat == nil && !info.IsDir()
}
+28 -3
View File
@@ -172,7 +172,7 @@ func (c *Client) ensureClients() error {
} }
func (c *Client) redisOptionsLocked(addr string) (*redis.Options, error) { func (c *Client) redisOptionsLocked(addr string) (*redis.Options, error) {
tlsConfig, errTLS := c.homeTLSConfigLocked() tlsConfig, errTLS := c.homeTLSConfigLocked(addr)
if errTLS != nil { if errTLS != nil {
return nil, errTLS return nil, errTLS
} }
@@ -183,10 +183,14 @@ func (c *Client) redisOptionsLocked(addr string) (*redis.Options, error) {
}, nil }, nil
} }
func (c *Client) homeTLSConfigLocked() (*tls.Config, error) { func (c *Client) homeTLSConfigLocked(addr string) (*tls.Config, error) {
serverName := strings.TrimSpace(c.homeCfg.TLS.ServerName) serverName := strings.TrimSpace(c.homeCfg.TLS.ServerName)
if serverName == "" { if serverName == "" {
serverName = strings.TrimSpace(c.seedHost) if c.homeCfg.TLS.UseTargetServerName {
serverName = hostFromAddress(addr)
} else {
serverName = strings.TrimSpace(c.seedHost)
}
} }
if serverName == "" { if serverName == "" {
serverName = strings.TrimSpace(c.homeCfg.Host) serverName = strings.TrimSpace(c.homeCfg.Host)
@@ -194,6 +198,14 @@ func (c *Client) homeTLSConfigLocked() (*tls.Config, error) {
return newHomeTLSConfig(c.homeCfg.TLS, serverName) return newHomeTLSConfig(c.homeCfg.TLS, serverName)
} }
func hostFromAddress(addr string) string {
host, _, errSplit := net.SplitHostPort(strings.TrimSpace(addr))
if errSplit == nil {
return strings.TrimSpace(host)
}
return strings.TrimSpace(addr)
}
func newHomeTLSConfig(cfg config.HomeTLSConfig, fallbackServerName string) (*tls.Config, error) { func newHomeTLSConfig(cfg config.HomeTLSConfig, fallbackServerName string) (*tls.Config, error) {
if !cfg.Enable { if !cfg.Enable {
return nil, nil return nil, nil
@@ -210,6 +222,19 @@ func newHomeTLSConfig(cfg config.HomeTLSConfig, fallbackServerName string) (*tls
InsecureSkipVerify: cfg.InsecureSkipVerify, InsecureSkipVerify: cfg.InsecureSkipVerify,
} }
clientCertPath := strings.TrimSpace(cfg.ClientCert)
clientKeyPath := strings.TrimSpace(cfg.ClientKey)
if clientCertPath != "" || clientKeyPath != "" {
if clientCertPath == "" || clientKeyPath == "" {
return nil, fmt.Errorf("home tls: client certificate and key must be set together")
}
certPair, errLoad := tls.LoadX509KeyPair(clientCertPath, clientKeyPath)
if errLoad != nil {
return nil, fmt.Errorf("home tls: load client certificate: %w", errLoad)
}
tlsConfig.Certificates = []tls.Certificate{certPair}
}
caCertPath := strings.TrimSpace(cfg.CACert) caCertPath := strings.TrimSpace(cfg.CACert)
if caCertPath == "" { if caCertPath == "" {
return tlsConfig, nil return tlsConfig, nil