feat(server): add mTLS certificate bootstrap via JWT for Home connections
- Introduced `-home-jwt` flag and `HOME_JWT` environment variable to provide JWT for mTLS certificate generation. - Added new APIs to handle certificate requests, validate JWT claims, and manage local certificate files. - Updated Home TLS configuration to support client certificates, keys, and dynamic server name resolution.
This commit is contained in:
+56
-1
@@ -190,6 +190,7 @@ func main() {
|
||||
var password string
|
||||
var homeAddr string
|
||||
var homePassword string
|
||||
var homeJWT string
|
||||
var homeDisableClusterDiscovery bool
|
||||
var tuiMode bool
|
||||
var standalone bool
|
||||
@@ -212,6 +213,7 @@ func main() {
|
||||
flag.StringVar(&password, "password", "", "")
|
||||
flag.StringVar(&homeAddr, "home", "", "Home control plane address in host:port, redis://host:port, or rediss://host:port format (loads config from home and skips local config file)")
|
||||
flag.StringVar(&homePassword, "home-password", "", "Home control plane password (Redis AUTH)")
|
||||
flag.StringVar(&homeJWT, "home-jwt", "", "Home control plane JWT for mTLS certificate bootstrap and connection")
|
||||
flag.BoolVar(&homeDisableClusterDiscovery, "home-disable-cluster-discovery", false, "Disable Home CLUSTER NODES discovery and keep using the configured -home address")
|
||||
flag.BoolVar(&tuiMode, "tui", false, "Start with terminal management UI")
|
||||
flag.BoolVar(&standalone, "standalone", false, "In TUI mode, start an embedded local server")
|
||||
@@ -311,6 +313,11 @@ func main() {
|
||||
homePassword = v
|
||||
}
|
||||
}
|
||||
if strings.TrimSpace(homeJWT) == "" {
|
||||
if v, ok := lookupEnv("HOME_JWT", "home_jwt"); ok {
|
||||
homeJWT = v
|
||||
}
|
||||
}
|
||||
|
||||
if value, ok := lookupEnv("PGSTORE_DSN", "pgstore_dsn"); ok {
|
||||
usePostgresStore = true
|
||||
@@ -375,7 +382,55 @@ func main() {
|
||||
// Determine and load the configuration file.
|
||||
// Prefer the Postgres store when configured, otherwise fallback to git or local files.
|
||||
var configFilePath string
|
||||
if strings.TrimSpace(homeAddr) != "" {
|
||||
if strings.TrimSpace(homeJWT) != "" {
|
||||
configLoadedFromHome = true
|
||||
ctxHome, cancelHome := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
homeCfg, errHomeCfg := home.ConfigFromJWT(ctxHome, homeJWT)
|
||||
cancelHome()
|
||||
if errHomeCfg != nil {
|
||||
log.Errorf("invalid -home-jwt: %v", errHomeCfg)
|
||||
return
|
||||
}
|
||||
if homeDisableClusterDiscovery {
|
||||
homeCfg.DisableClusterDiscovery = true
|
||||
}
|
||||
homeClient := home.New(homeCfg)
|
||||
defer homeClient.Close()
|
||||
|
||||
ctxHomeConfig, cancelHomeConfig := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
raw, errGetConfig := homeClient.GetConfig(ctxHomeConfig)
|
||||
cancelHomeConfig()
|
||||
if errGetConfig != nil {
|
||||
log.Errorf("failed to fetch config from home: %v", errGetConfig)
|
||||
return
|
||||
}
|
||||
|
||||
parsed, errParseConfig := config.ParseConfigBytes(raw)
|
||||
if errParseConfig != nil {
|
||||
log.Errorf("failed to parse config payload from home: %v", errParseConfig)
|
||||
return
|
||||
}
|
||||
if parsed == nil {
|
||||
parsed = &config.Config{}
|
||||
}
|
||||
parsed.Home = homeCfg
|
||||
parsed.Port = 8317 // Default to 8317 for home mode, can be overridden by home config
|
||||
parsed.UsageStatisticsEnabled = true
|
||||
cfg = parsed
|
||||
|
||||
// Keep a non-empty config path for downstream components (log paths, management assets, etc),
|
||||
// but do not require the file to exist when loading config from home.
|
||||
if strings.TrimSpace(configPath) != "" {
|
||||
configFilePath = configPath
|
||||
} else {
|
||||
configFilePath = filepath.Join(wd, "config.yaml")
|
||||
}
|
||||
|
||||
// Local stores are intentionally disabled when config is loaded from home.
|
||||
usePostgresStore = false
|
||||
useObjectStore = false
|
||||
useGitStore = false
|
||||
} else if strings.TrimSpace(homeAddr) != "" {
|
||||
configLoadedFromHome = true
|
||||
trimmedHomePassword := strings.TrimSpace(homePassword)
|
||||
homeCfg, errHomeCfg := parseHomeFlagConfig(homeAddr, trimmedHomePassword)
|
||||
|
||||
@@ -16,4 +16,7 @@ type HomeTLSConfig struct {
|
||||
ServerName string `yaml:"server-name" json:"-"`
|
||||
InsecureSkipVerify bool `yaml:"insecure-skip-verify" json:"-"`
|
||||
CACert string `yaml:"ca-cert" json:"-"`
|
||||
ClientCert string `yaml:"-" json:"-"`
|
||||
ClientKey string `yaml:"-" json:"-"`
|
||||
UseTargetServerName bool `yaml:"-" json:"-"`
|
||||
}
|
||||
|
||||
@@ -0,0 +1,323 @@
|
||||
package home
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v7/internal/config"
|
||||
)
|
||||
|
||||
const homeCertificateRequestTimeout = 30 * time.Second
|
||||
|
||||
type homeJWTClaims struct {
|
||||
CertificateID string `json:"certificate_id"`
|
||||
IP string `json:"ip"`
|
||||
Port int `json:"port"`
|
||||
IssuedAt int64 `json:"iat"`
|
||||
}
|
||||
|
||||
type certificateRequestResponse struct {
|
||||
OK bool `json:"ok"`
|
||||
Certificate string `json:"certificate"`
|
||||
CA string `json:"ca"`
|
||||
}
|
||||
|
||||
type certificatePaths struct {
|
||||
Dir string
|
||||
ClientCert string
|
||||
ClientKey string
|
||||
CACert string
|
||||
}
|
||||
|
||||
// ConfigFromJWT prepares a Home config from the JWT and ensures local mTLS files exist.
|
||||
func ConfigFromJWT(ctx context.Context, rawJWT string) (config.HomeConfig, error) {
|
||||
claims, errClaims := parseHomeJWTClaims(rawJWT)
|
||||
if errClaims != nil {
|
||||
return config.HomeConfig{}, errClaims
|
||||
}
|
||||
paths, errPaths := defaultCertificatePaths()
|
||||
if errPaths != nil {
|
||||
return config.HomeConfig{}, errPaths
|
||||
}
|
||||
if errEnsure := ensureHomeCertificateFiles(ctx, claims, paths); errEnsure != nil {
|
||||
return config.HomeConfig{}, errEnsure
|
||||
}
|
||||
return config.HomeConfig{
|
||||
Enabled: true,
|
||||
Host: strings.TrimSpace(claims.IP),
|
||||
Port: claims.Port,
|
||||
TLS: config.HomeTLSConfig{
|
||||
Enable: true,
|
||||
CACert: paths.CACert,
|
||||
ClientCert: paths.ClientCert,
|
||||
ClientKey: paths.ClientKey,
|
||||
UseTargetServerName: true,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func parseHomeJWTClaims(rawJWT string) (homeJWTClaims, error) {
|
||||
var claims homeJWTClaims
|
||||
parts := strings.Split(strings.TrimSpace(rawJWT), ".")
|
||||
if len(parts) != 3 {
|
||||
return claims, fmt.Errorf("home jwt is invalid")
|
||||
}
|
||||
payload, errDecode := decodeJWTPart(parts[1])
|
||||
if errDecode != nil {
|
||||
return claims, errDecode
|
||||
}
|
||||
if errUnmarshal := json.Unmarshal(payload, &claims); errUnmarshal != nil {
|
||||
return claims, errUnmarshal
|
||||
}
|
||||
if strings.TrimSpace(claims.CertificateID) == "" {
|
||||
return claims, fmt.Errorf("home jwt certificate_id is required")
|
||||
}
|
||||
if strings.TrimSpace(claims.IP) == "" || claims.Port <= 0 {
|
||||
return claims, fmt.Errorf("home jwt target address is invalid")
|
||||
}
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
func decodeJWTPart(part string) ([]byte, error) {
|
||||
if decoded, errDecode := base64.RawURLEncoding.DecodeString(part); errDecode == nil {
|
||||
return decoded, nil
|
||||
}
|
||||
return base64.URLEncoding.DecodeString(part)
|
||||
}
|
||||
|
||||
func defaultCertificatePaths() (certificatePaths, error) {
|
||||
homeDir, errHome := os.UserHomeDir()
|
||||
if errHome != nil {
|
||||
return certificatePaths{}, errHome
|
||||
}
|
||||
dir := filepath.Join(homeDir, ".cli-proxy-api")
|
||||
return certificatePaths{
|
||||
Dir: dir,
|
||||
ClientCert: filepath.Join(dir, "client-crt.pem"),
|
||||
ClientKey: filepath.Join(dir, "client-key.pem"),
|
||||
CACert: filepath.Join(dir, "home-ca-crt.pem"),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func ensureHomeCertificateFiles(ctx context.Context, claims homeJWTClaims, paths certificatePaths) error {
|
||||
if fileExists(paths.ClientCert) && fileExists(paths.ClientKey) {
|
||||
if !fileExists(paths.CACert) {
|
||||
return fmt.Errorf("home ca certificate file is missing")
|
||||
}
|
||||
if errChmod := chmodCertificateFiles(paths); errChmod != nil {
|
||||
return errChmod
|
||||
}
|
||||
return nil
|
||||
}
|
||||
if errMkdir := os.MkdirAll(paths.Dir, 0o700); errMkdir != nil {
|
||||
return errMkdir
|
||||
}
|
||||
key, errKey := loadOrCreateClientKey(paths.ClientKey)
|
||||
if errKey != nil {
|
||||
return errKey
|
||||
}
|
||||
csrPEM, errCSR := createClientCSR(claims.CertificateID, key)
|
||||
if errCSR != nil {
|
||||
return errCSR
|
||||
}
|
||||
response, errRequest := requestClientCertificate(ctx, claims, csrPEM)
|
||||
if errRequest != nil {
|
||||
return errRequest
|
||||
}
|
||||
if strings.TrimSpace(response.Certificate) == "" || strings.TrimSpace(response.CA) == "" {
|
||||
return fmt.Errorf("home certificate response is incomplete")
|
||||
}
|
||||
if errWrite := writeFile0600(paths.ClientCert, []byte(response.Certificate)); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
if errWrite := writeFile0600(paths.CACert, []byte(response.CA)); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func loadOrCreateClientKey(path string) (*rsa.PrivateKey, error) {
|
||||
if fileExists(path) {
|
||||
raw, errRead := os.ReadFile(path)
|
||||
if errRead != nil {
|
||||
return nil, errRead
|
||||
}
|
||||
key, errParse := parseRSAPrivateKeyPEM(raw)
|
||||
if errParse != nil {
|
||||
return nil, errParse
|
||||
}
|
||||
if errChmod := os.Chmod(path, 0o600); errChmod != nil {
|
||||
return nil, errChmod
|
||||
}
|
||||
return key, nil
|
||||
}
|
||||
key, errKey := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if errKey != nil {
|
||||
return nil, errKey
|
||||
}
|
||||
raw := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(key)})
|
||||
if errWrite := writeFile0600(path, raw); errWrite != nil {
|
||||
return nil, errWrite
|
||||
}
|
||||
return key, nil
|
||||
}
|
||||
|
||||
func writeFile0600(path string, raw []byte) error {
|
||||
if errWrite := os.WriteFile(path, raw, 0o600); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
return os.Chmod(path, 0o600)
|
||||
}
|
||||
|
||||
func chmodCertificateFiles(paths certificatePaths) error {
|
||||
for _, path := range []string{paths.ClientCert, paths.ClientKey, paths.CACert} {
|
||||
if errChmod := os.Chmod(path, 0o600); errChmod != nil {
|
||||
return errChmod
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func parseRSAPrivateKeyPEM(raw []byte) (*rsa.PrivateKey, error) {
|
||||
block, _ := pem.Decode(raw)
|
||||
if block == nil {
|
||||
return nil, fmt.Errorf("client key pem is invalid")
|
||||
}
|
||||
switch block.Type {
|
||||
case "RSA PRIVATE KEY":
|
||||
return x509.ParsePKCS1PrivateKey(block.Bytes)
|
||||
case "PRIVATE KEY":
|
||||
key, errParse := x509.ParsePKCS8PrivateKey(block.Bytes)
|
||||
if errParse != nil {
|
||||
return nil, errParse
|
||||
}
|
||||
rsaKey, ok := key.(*rsa.PrivateKey)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("client key is not rsa")
|
||||
}
|
||||
return rsaKey, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("client key pem type %q is unsupported", block.Type)
|
||||
}
|
||||
}
|
||||
|
||||
func createClientCSR(certificateID string, key *rsa.PrivateKey) ([]byte, error) {
|
||||
certificateID = strings.TrimSpace(certificateID)
|
||||
if certificateID == "" {
|
||||
return nil, fmt.Errorf("certificate id is required")
|
||||
}
|
||||
template := &x509.CertificateRequest{
|
||||
Subject: pkix.Name{
|
||||
CommonName: certificateID,
|
||||
},
|
||||
}
|
||||
der, errCreate := x509.CreateCertificateRequest(rand.Reader, template, key)
|
||||
if errCreate != nil {
|
||||
return nil, errCreate
|
||||
}
|
||||
return pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE REQUEST", Bytes: der}), nil
|
||||
}
|
||||
|
||||
func requestClientCertificate(ctx context.Context, claims homeJWTClaims, csrPEM []byte) (certificateRequestResponse, error) {
|
||||
var response certificateRequestResponse
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
dialCtx, cancel := context.WithTimeout(ctx, homeCertificateRequestTimeout)
|
||||
defer cancel()
|
||||
addr := net.JoinHostPort(strings.TrimSpace(claims.IP), strconv.Itoa(claims.Port))
|
||||
conn, errDial := (&net.Dialer{}).DialContext(dialCtx, "tcp", addr)
|
||||
if errDial != nil {
|
||||
return response, errDial
|
||||
}
|
||||
defer func() {
|
||||
_ = conn.Close()
|
||||
}()
|
||||
if deadline, ok := dialCtx.Deadline(); ok {
|
||||
_ = conn.SetDeadline(deadline)
|
||||
}
|
||||
if _, errWrite := conn.Write(encodeRESPArray("CERTIFICATE", "REQUEST", claims.CertificateID, string(csrPEM))); errWrite != nil {
|
||||
return response, errWrite
|
||||
}
|
||||
raw, errRead := readRESPBulk(bufio.NewReader(conn))
|
||||
if errRead != nil {
|
||||
return response, errRead
|
||||
}
|
||||
if errUnmarshal := json.Unmarshal(raw, &response); errUnmarshal != nil {
|
||||
return response, errUnmarshal
|
||||
}
|
||||
if !response.OK {
|
||||
return response, fmt.Errorf("home certificate request failed")
|
||||
}
|
||||
return response, nil
|
||||
}
|
||||
|
||||
func encodeRESPArray(args ...string) []byte {
|
||||
var buf bytes.Buffer
|
||||
buf.WriteString("*")
|
||||
buf.WriteString(strconv.Itoa(len(args)))
|
||||
buf.WriteString("\r\n")
|
||||
for _, arg := range args {
|
||||
buf.WriteString("$")
|
||||
buf.WriteString(strconv.Itoa(len(arg)))
|
||||
buf.WriteString("\r\n")
|
||||
buf.WriteString(arg)
|
||||
buf.WriteString("\r\n")
|
||||
}
|
||||
return buf.Bytes()
|
||||
}
|
||||
|
||||
func readRESPBulk(reader *bufio.Reader) ([]byte, error) {
|
||||
prefix, errRead := reader.ReadByte()
|
||||
if errRead != nil {
|
||||
return nil, errRead
|
||||
}
|
||||
switch prefix {
|
||||
case '$':
|
||||
line, errLine := reader.ReadString('\n')
|
||||
if errLine != nil {
|
||||
return nil, errLine
|
||||
}
|
||||
size, errSize := strconv.Atoi(strings.TrimSpace(line))
|
||||
if errSize != nil {
|
||||
return nil, errSize
|
||||
}
|
||||
if size < 0 {
|
||||
return nil, fmt.Errorf("home certificate request returned nil")
|
||||
}
|
||||
payload := make([]byte, size+2)
|
||||
if _, errFull := io.ReadFull(reader, payload); errFull != nil {
|
||||
return nil, errFull
|
||||
}
|
||||
return payload[:size], nil
|
||||
case '-':
|
||||
line, errLine := reader.ReadString('\n')
|
||||
if errLine != nil {
|
||||
return nil, errLine
|
||||
}
|
||||
return nil, fmt.Errorf("%s", strings.TrimSpace(line))
|
||||
default:
|
||||
return nil, fmt.Errorf("home certificate request returned unsupported resp prefix %q", prefix)
|
||||
}
|
||||
}
|
||||
|
||||
func fileExists(path string) bool {
|
||||
info, errStat := os.Stat(path)
|
||||
return errStat == nil && !info.IsDir()
|
||||
}
|
||||
+27
-2
@@ -172,7 +172,7 @@ func (c *Client) ensureClients() error {
|
||||
}
|
||||
|
||||
func (c *Client) redisOptionsLocked(addr string) (*redis.Options, error) {
|
||||
tlsConfig, errTLS := c.homeTLSConfigLocked()
|
||||
tlsConfig, errTLS := c.homeTLSConfigLocked(addr)
|
||||
if errTLS != nil {
|
||||
return nil, errTLS
|
||||
}
|
||||
@@ -183,17 +183,29 @@ func (c *Client) redisOptionsLocked(addr string) (*redis.Options, error) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *Client) homeTLSConfigLocked() (*tls.Config, error) {
|
||||
func (c *Client) homeTLSConfigLocked(addr string) (*tls.Config, error) {
|
||||
serverName := strings.TrimSpace(c.homeCfg.TLS.ServerName)
|
||||
if serverName == "" {
|
||||
if c.homeCfg.TLS.UseTargetServerName {
|
||||
serverName = hostFromAddress(addr)
|
||||
} else {
|
||||
serverName = strings.TrimSpace(c.seedHost)
|
||||
}
|
||||
}
|
||||
if serverName == "" {
|
||||
serverName = strings.TrimSpace(c.homeCfg.Host)
|
||||
}
|
||||
return newHomeTLSConfig(c.homeCfg.TLS, serverName)
|
||||
}
|
||||
|
||||
func hostFromAddress(addr string) string {
|
||||
host, _, errSplit := net.SplitHostPort(strings.TrimSpace(addr))
|
||||
if errSplit == nil {
|
||||
return strings.TrimSpace(host)
|
||||
}
|
||||
return strings.TrimSpace(addr)
|
||||
}
|
||||
|
||||
func newHomeTLSConfig(cfg config.HomeTLSConfig, fallbackServerName string) (*tls.Config, error) {
|
||||
if !cfg.Enable {
|
||||
return nil, nil
|
||||
@@ -210,6 +222,19 @@ func newHomeTLSConfig(cfg config.HomeTLSConfig, fallbackServerName string) (*tls
|
||||
InsecureSkipVerify: cfg.InsecureSkipVerify,
|
||||
}
|
||||
|
||||
clientCertPath := strings.TrimSpace(cfg.ClientCert)
|
||||
clientKeyPath := strings.TrimSpace(cfg.ClientKey)
|
||||
if clientCertPath != "" || clientKeyPath != "" {
|
||||
if clientCertPath == "" || clientKeyPath == "" {
|
||||
return nil, fmt.Errorf("home tls: client certificate and key must be set together")
|
||||
}
|
||||
certPair, errLoad := tls.LoadX509KeyPair(clientCertPath, clientKeyPath)
|
||||
if errLoad != nil {
|
||||
return nil, fmt.Errorf("home tls: load client certificate: %w", errLoad)
|
||||
}
|
||||
tlsConfig.Certificates = []tls.Certificate{certPair}
|
||||
}
|
||||
|
||||
caCertPath := strings.TrimSpace(cfg.CACert)
|
||||
if caCertPath == "" {
|
||||
return tlsConfig, nil
|
||||
|
||||
Reference in New Issue
Block a user