From 77ba15f71b61d25653465d9fba3417cae1ec7055 Mon Sep 17 00:00:00 2001 From: Luis Pater Date: Tue, 19 May 2026 00:53:40 +0800 Subject: [PATCH] feat(server): add mTLS certificate bootstrap via JWT for Home connections - Introduced `-home-jwt` flag and `HOME_JWT` environment variable to provide JWT for mTLS certificate generation. - Added new APIs to handle certificate requests, validate JWT claims, and manage local certificate files. - Updated Home TLS configuration to support client certificates, keys, and dynamic server name resolution. --- cmd/server/main.go | 57 ++++++- internal/config/home.go | 11 +- internal/home/certificate.go | 323 +++++++++++++++++++++++++++++++++++ internal/home/client.go | 31 +++- 4 files changed, 414 insertions(+), 8 deletions(-) create mode 100644 internal/home/certificate.go diff --git a/cmd/server/main.go b/cmd/server/main.go index 99d8780a..a42a7324 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -190,6 +190,7 @@ func main() { var password string var homeAddr string var homePassword string + var homeJWT string var homeDisableClusterDiscovery bool var tuiMode bool var standalone bool @@ -212,6 +213,7 @@ func main() { flag.StringVar(&password, "password", "", "") flag.StringVar(&homeAddr, "home", "", "Home control plane address in host:port, redis://host:port, or rediss://host:port format (loads config from home and skips local config file)") flag.StringVar(&homePassword, "home-password", "", "Home control plane password (Redis AUTH)") + flag.StringVar(&homeJWT, "home-jwt", "", "Home control plane JWT for mTLS certificate bootstrap and connection") flag.BoolVar(&homeDisableClusterDiscovery, "home-disable-cluster-discovery", false, "Disable Home CLUSTER NODES discovery and keep using the configured -home address") flag.BoolVar(&tuiMode, "tui", false, "Start with terminal management UI") flag.BoolVar(&standalone, "standalone", false, "In TUI mode, start an embedded local server") @@ -311,6 +313,11 @@ func main() { homePassword = v } } + if strings.TrimSpace(homeJWT) == "" { + if v, ok := lookupEnv("HOME_JWT", "home_jwt"); ok { + homeJWT = v + } + } if value, ok := lookupEnv("PGSTORE_DSN", "pgstore_dsn"); ok { usePostgresStore = true @@ -375,7 +382,55 @@ func main() { // Determine and load the configuration file. // Prefer the Postgres store when configured, otherwise fallback to git or local files. var configFilePath string - if strings.TrimSpace(homeAddr) != "" { + if strings.TrimSpace(homeJWT) != "" { + configLoadedFromHome = true + ctxHome, cancelHome := context.WithTimeout(context.Background(), 30*time.Second) + homeCfg, errHomeCfg := home.ConfigFromJWT(ctxHome, homeJWT) + cancelHome() + if errHomeCfg != nil { + log.Errorf("invalid -home-jwt: %v", errHomeCfg) + return + } + if homeDisableClusterDiscovery { + homeCfg.DisableClusterDiscovery = true + } + homeClient := home.New(homeCfg) + defer homeClient.Close() + + ctxHomeConfig, cancelHomeConfig := context.WithTimeout(context.Background(), 30*time.Second) + raw, errGetConfig := homeClient.GetConfig(ctxHomeConfig) + cancelHomeConfig() + if errGetConfig != nil { + log.Errorf("failed to fetch config from home: %v", errGetConfig) + return + } + + parsed, errParseConfig := config.ParseConfigBytes(raw) + if errParseConfig != nil { + log.Errorf("failed to parse config payload from home: %v", errParseConfig) + return + } + if parsed == nil { + parsed = &config.Config{} + } + parsed.Home = homeCfg + parsed.Port = 8317 // Default to 8317 for home mode, can be overridden by home config + parsed.UsageStatisticsEnabled = true + cfg = parsed + + // Keep a non-empty config path for downstream components (log paths, management assets, etc), + // but do not require the file to exist when loading config from home. + if strings.TrimSpace(configPath) != "" { + configFilePath = configPath + } else { + configFilePath = filepath.Join(wd, "config.yaml") + } + + // Local stores are intentionally disabled when config is loaded from home. + usePostgresStore = false + useObjectStore = false + useGitStore = false + } else if strings.TrimSpace(homeAddr) != "" { configLoadedFromHome = true trimmedHomePassword := strings.TrimSpace(homePassword) homeCfg, errHomeCfg := parseHomeFlagConfig(homeAddr, trimmedHomePassword) diff --git a/internal/config/home.go b/internal/config/home.go index 8e7945b4..8cf323b6 100644 --- a/internal/config/home.go +++ b/internal/config/home.go @@ -12,8 +12,11 @@ type HomeConfig struct { // HomeTLSConfig configures client-side TLS for the home Redis connection. type HomeTLSConfig struct { - Enable bool `yaml:"enable" json:"-"` - ServerName string `yaml:"server-name" json:"-"` - InsecureSkipVerify bool `yaml:"insecure-skip-verify" json:"-"` - CACert string `yaml:"ca-cert" json:"-"` + Enable bool `yaml:"enable" json:"-"` + ServerName string `yaml:"server-name" json:"-"` + InsecureSkipVerify bool `yaml:"insecure-skip-verify" json:"-"` + CACert string `yaml:"ca-cert" json:"-"` + ClientCert string `yaml:"-" json:"-"` + ClientKey string `yaml:"-" json:"-"` + UseTargetServerName bool `yaml:"-" json:"-"` } diff --git a/internal/home/certificate.go b/internal/home/certificate.go new file mode 100644 index 00000000..bb0902f8 --- /dev/null +++ b/internal/home/certificate.go @@ -0,0 +1,323 @@ +package home + +import ( + "bufio" + "bytes" + "context" + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "crypto/x509/pkix" + "encoding/base64" + "encoding/json" + "encoding/pem" + "fmt" + "io" + "net" + "os" + "path/filepath" + "strconv" + "strings" + "time" + + "github.com/router-for-me/CLIProxyAPI/v7/internal/config" +) + +const homeCertificateRequestTimeout = 30 * time.Second + +type homeJWTClaims struct { + CertificateID string `json:"certificate_id"` + IP string `json:"ip"` + Port int `json:"port"` + IssuedAt int64 `json:"iat"` +} + +type certificateRequestResponse struct { + OK bool `json:"ok"` + Certificate string `json:"certificate"` + CA string `json:"ca"` +} + +type certificatePaths struct { + Dir string + ClientCert string + ClientKey string + CACert string +} + +// ConfigFromJWT prepares a Home config from the JWT and ensures local mTLS files exist. +func ConfigFromJWT(ctx context.Context, rawJWT string) (config.HomeConfig, error) { + claims, errClaims := parseHomeJWTClaims(rawJWT) + if errClaims != nil { + return config.HomeConfig{}, errClaims + } + paths, errPaths := defaultCertificatePaths() + if errPaths != nil { + return config.HomeConfig{}, errPaths + } + if errEnsure := ensureHomeCertificateFiles(ctx, claims, paths); errEnsure != nil { + return config.HomeConfig{}, errEnsure + } + return config.HomeConfig{ + Enabled: true, + Host: strings.TrimSpace(claims.IP), + Port: claims.Port, + TLS: config.HomeTLSConfig{ + Enable: true, + CACert: paths.CACert, + ClientCert: paths.ClientCert, + ClientKey: paths.ClientKey, + UseTargetServerName: true, + }, + }, nil +} + +func parseHomeJWTClaims(rawJWT string) (homeJWTClaims, error) { + var claims homeJWTClaims + parts := strings.Split(strings.TrimSpace(rawJWT), ".") + if len(parts) != 3 { + return claims, fmt.Errorf("home jwt is invalid") + } + payload, errDecode := decodeJWTPart(parts[1]) + if errDecode != nil { + return claims, errDecode + } + if errUnmarshal := json.Unmarshal(payload, &claims); errUnmarshal != nil { + return claims, errUnmarshal + } + if strings.TrimSpace(claims.CertificateID) == "" { + return claims, fmt.Errorf("home jwt certificate_id is required") + } + if strings.TrimSpace(claims.IP) == "" || claims.Port <= 0 { + return claims, fmt.Errorf("home jwt target address is invalid") + } + return claims, nil +} + +func decodeJWTPart(part string) ([]byte, error) { + if decoded, errDecode := base64.RawURLEncoding.DecodeString(part); errDecode == nil { + return decoded, nil + } + return base64.URLEncoding.DecodeString(part) +} + +func defaultCertificatePaths() (certificatePaths, error) { + homeDir, errHome := os.UserHomeDir() + if errHome != nil { + return certificatePaths{}, errHome + } + dir := filepath.Join(homeDir, ".cli-proxy-api") + return certificatePaths{ + Dir: dir, + ClientCert: filepath.Join(dir, "client-crt.pem"), + ClientKey: filepath.Join(dir, "client-key.pem"), + CACert: filepath.Join(dir, "home-ca-crt.pem"), + }, nil +} + +func ensureHomeCertificateFiles(ctx context.Context, claims homeJWTClaims, paths certificatePaths) error { + if fileExists(paths.ClientCert) && fileExists(paths.ClientKey) { + if !fileExists(paths.CACert) { + return fmt.Errorf("home ca certificate file is missing") + } + if errChmod := chmodCertificateFiles(paths); errChmod != nil { + return errChmod + } + return nil + } + if errMkdir := os.MkdirAll(paths.Dir, 0o700); errMkdir != nil { + return errMkdir + } + key, errKey := loadOrCreateClientKey(paths.ClientKey) + if errKey != nil { + return errKey + } + csrPEM, errCSR := createClientCSR(claims.CertificateID, key) + if errCSR != nil { + return errCSR + } + response, errRequest := requestClientCertificate(ctx, claims, csrPEM) + if errRequest != nil { + return errRequest + } + if strings.TrimSpace(response.Certificate) == "" || strings.TrimSpace(response.CA) == "" { + return fmt.Errorf("home certificate response is incomplete") + } + if errWrite := writeFile0600(paths.ClientCert, []byte(response.Certificate)); errWrite != nil { + return errWrite + } + if errWrite := writeFile0600(paths.CACert, []byte(response.CA)); errWrite != nil { + return errWrite + } + return nil +} + +func loadOrCreateClientKey(path string) (*rsa.PrivateKey, error) { + if fileExists(path) { + raw, errRead := os.ReadFile(path) + if errRead != nil { + return nil, errRead + } + key, errParse := parseRSAPrivateKeyPEM(raw) + if errParse != nil { + return nil, errParse + } + if errChmod := os.Chmod(path, 0o600); errChmod != nil { + return nil, errChmod + } + return key, nil + } + key, errKey := rsa.GenerateKey(rand.Reader, 2048) + if errKey != nil { + return nil, errKey + } + raw := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(key)}) + if errWrite := writeFile0600(path, raw); errWrite != nil { + return nil, errWrite + } + return key, nil +} + +func writeFile0600(path string, raw []byte) error { + if errWrite := os.WriteFile(path, raw, 0o600); errWrite != nil { + return errWrite + } + return os.Chmod(path, 0o600) +} + +func chmodCertificateFiles(paths certificatePaths) error { + for _, path := range []string{paths.ClientCert, paths.ClientKey, paths.CACert} { + if errChmod := os.Chmod(path, 0o600); errChmod != nil { + return errChmod + } + } + return nil +} + +func parseRSAPrivateKeyPEM(raw []byte) (*rsa.PrivateKey, error) { + block, _ := pem.Decode(raw) + if block == nil { + return nil, fmt.Errorf("client key pem is invalid") + } + switch block.Type { + case "RSA PRIVATE KEY": + return x509.ParsePKCS1PrivateKey(block.Bytes) + case "PRIVATE KEY": + key, errParse := x509.ParsePKCS8PrivateKey(block.Bytes) + if errParse != nil { + return nil, errParse + } + rsaKey, ok := key.(*rsa.PrivateKey) + if !ok { + return nil, fmt.Errorf("client key is not rsa") + } + return rsaKey, nil + default: + return nil, fmt.Errorf("client key pem type %q is unsupported", block.Type) + } +} + +func createClientCSR(certificateID string, key *rsa.PrivateKey) ([]byte, error) { + certificateID = strings.TrimSpace(certificateID) + if certificateID == "" { + return nil, fmt.Errorf("certificate id is required") + } + template := &x509.CertificateRequest{ + Subject: pkix.Name{ + CommonName: certificateID, + }, + } + der, errCreate := x509.CreateCertificateRequest(rand.Reader, template, key) + if errCreate != nil { + return nil, errCreate + } + return pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE REQUEST", Bytes: der}), nil +} + +func requestClientCertificate(ctx context.Context, claims homeJWTClaims, csrPEM []byte) (certificateRequestResponse, error) { + var response certificateRequestResponse + if ctx == nil { + ctx = context.Background() + } + dialCtx, cancel := context.WithTimeout(ctx, homeCertificateRequestTimeout) + defer cancel() + addr := net.JoinHostPort(strings.TrimSpace(claims.IP), strconv.Itoa(claims.Port)) + conn, errDial := (&net.Dialer{}).DialContext(dialCtx, "tcp", addr) + if errDial != nil { + return response, errDial + } + defer func() { + _ = conn.Close() + }() + if deadline, ok := dialCtx.Deadline(); ok { + _ = conn.SetDeadline(deadline) + } + if _, errWrite := conn.Write(encodeRESPArray("CERTIFICATE", "REQUEST", claims.CertificateID, string(csrPEM))); errWrite != nil { + return response, errWrite + } + raw, errRead := readRESPBulk(bufio.NewReader(conn)) + if errRead != nil { + return response, errRead + } + if errUnmarshal := json.Unmarshal(raw, &response); errUnmarshal != nil { + return response, errUnmarshal + } + if !response.OK { + return response, fmt.Errorf("home certificate request failed") + } + return response, nil +} + +func encodeRESPArray(args ...string) []byte { + var buf bytes.Buffer + buf.WriteString("*") + buf.WriteString(strconv.Itoa(len(args))) + buf.WriteString("\r\n") + for _, arg := range args { + buf.WriteString("$") + buf.WriteString(strconv.Itoa(len(arg))) + buf.WriteString("\r\n") + buf.WriteString(arg) + buf.WriteString("\r\n") + } + return buf.Bytes() +} + +func readRESPBulk(reader *bufio.Reader) ([]byte, error) { + prefix, errRead := reader.ReadByte() + if errRead != nil { + return nil, errRead + } + switch prefix { + case '$': + line, errLine := reader.ReadString('\n') + if errLine != nil { + return nil, errLine + } + size, errSize := strconv.Atoi(strings.TrimSpace(line)) + if errSize != nil { + return nil, errSize + } + if size < 0 { + return nil, fmt.Errorf("home certificate request returned nil") + } + payload := make([]byte, size+2) + if _, errFull := io.ReadFull(reader, payload); errFull != nil { + return nil, errFull + } + return payload[:size], nil + case '-': + line, errLine := reader.ReadString('\n') + if errLine != nil { + return nil, errLine + } + return nil, fmt.Errorf("%s", strings.TrimSpace(line)) + default: + return nil, fmt.Errorf("home certificate request returned unsupported resp prefix %q", prefix) + } +} + +func fileExists(path string) bool { + info, errStat := os.Stat(path) + return errStat == nil && !info.IsDir() +} diff --git a/internal/home/client.go b/internal/home/client.go index 2652bc1c..cb0850e4 100644 --- a/internal/home/client.go +++ b/internal/home/client.go @@ -172,7 +172,7 @@ func (c *Client) ensureClients() error { } func (c *Client) redisOptionsLocked(addr string) (*redis.Options, error) { - tlsConfig, errTLS := c.homeTLSConfigLocked() + tlsConfig, errTLS := c.homeTLSConfigLocked(addr) if errTLS != nil { return nil, errTLS } @@ -183,10 +183,14 @@ func (c *Client) redisOptionsLocked(addr string) (*redis.Options, error) { }, nil } -func (c *Client) homeTLSConfigLocked() (*tls.Config, error) { +func (c *Client) homeTLSConfigLocked(addr string) (*tls.Config, error) { serverName := strings.TrimSpace(c.homeCfg.TLS.ServerName) if serverName == "" { - serverName = strings.TrimSpace(c.seedHost) + if c.homeCfg.TLS.UseTargetServerName { + serverName = hostFromAddress(addr) + } else { + serverName = strings.TrimSpace(c.seedHost) + } } if serverName == "" { serverName = strings.TrimSpace(c.homeCfg.Host) @@ -194,6 +198,14 @@ func (c *Client) homeTLSConfigLocked() (*tls.Config, error) { return newHomeTLSConfig(c.homeCfg.TLS, serverName) } +func hostFromAddress(addr string) string { + host, _, errSplit := net.SplitHostPort(strings.TrimSpace(addr)) + if errSplit == nil { + return strings.TrimSpace(host) + } + return strings.TrimSpace(addr) +} + func newHomeTLSConfig(cfg config.HomeTLSConfig, fallbackServerName string) (*tls.Config, error) { if !cfg.Enable { return nil, nil @@ -210,6 +222,19 @@ func newHomeTLSConfig(cfg config.HomeTLSConfig, fallbackServerName string) (*tls InsecureSkipVerify: cfg.InsecureSkipVerify, } + clientCertPath := strings.TrimSpace(cfg.ClientCert) + clientKeyPath := strings.TrimSpace(cfg.ClientKey) + if clientCertPath != "" || clientKeyPath != "" { + if clientCertPath == "" || clientKeyPath == "" { + return nil, fmt.Errorf("home tls: client certificate and key must be set together") + } + certPair, errLoad := tls.LoadX509KeyPair(clientCertPath, clientKeyPath) + if errLoad != nil { + return nil, fmt.Errorf("home tls: load client certificate: %w", errLoad) + } + tlsConfig.Certificates = []tls.Certificate{certPair} + } + caCertPath := strings.TrimSpace(cfg.CACert) if caCertPath == "" { return tlsConfig, nil