Preserve default transport settings for proxy clients

This commit is contained in:
kwz
2026-03-25 15:33:09 +08:00
parent 76c064c729
commit c89d19b300
2 changed files with 69 additions and 13 deletions

View File

@@ -68,14 +68,18 @@ func Parse(raw string) (Setting, error) {
} }
} }
func cloneDefaultTransport() *http.Transport {
if transport, ok := http.DefaultTransport.(*http.Transport); ok && transport != nil {
return transport.Clone()
}
return &http.Transport{}
}
// NewDirectTransport returns a transport that bypasses environment proxies. // NewDirectTransport returns a transport that bypasses environment proxies.
func NewDirectTransport() *http.Transport { func NewDirectTransport() *http.Transport {
if transport, ok := http.DefaultTransport.(*http.Transport); ok && transport != nil { clone := cloneDefaultTransport()
clone := transport.Clone() clone.Proxy = nil
clone.Proxy = nil return clone
return clone
}
return &http.Transport{Proxy: nil}
} }
// BuildHTTPTransport constructs an HTTP transport for the provided proxy setting. // BuildHTTPTransport constructs an HTTP transport for the provided proxy setting.
@@ -102,14 +106,16 @@ func BuildHTTPTransport(raw string) (*http.Transport, Mode, error) {
if errSOCKS5 != nil { if errSOCKS5 != nil {
return nil, setting.Mode, fmt.Errorf("create SOCKS5 dialer failed: %w", errSOCKS5) return nil, setting.Mode, fmt.Errorf("create SOCKS5 dialer failed: %w", errSOCKS5)
} }
return &http.Transport{ transport := cloneDefaultTransport()
Proxy: nil, transport.Proxy = nil
DialContext: func(_ context.Context, network, addr string) (net.Conn, error) { transport.DialContext = func(_ context.Context, network, addr string) (net.Conn, error) {
return dialer.Dial(network, addr) return dialer.Dial(network, addr)
}, }
}, setting.Mode, nil return transport, setting.Mode, nil
} }
return &http.Transport{Proxy: http.ProxyURL(setting.URL)}, setting.Mode, nil transport := cloneDefaultTransport()
transport.Proxy = http.ProxyURL(setting.URL)
return transport, setting.Mode, nil
default: default:
return nil, setting.Mode, nil return nil, setting.Mode, nil
} }

View File

@@ -5,6 +5,16 @@ import (
"testing" "testing"
) )
func mustDefaultTransport(t *testing.T) *http.Transport {
t.Helper()
transport, ok := http.DefaultTransport.(*http.Transport)
if !ok || transport == nil {
t.Fatal("http.DefaultTransport is not an *http.Transport")
}
return transport
}
func TestParse(t *testing.T) { func TestParse(t *testing.T) {
t.Parallel() t.Parallel()
@@ -86,4 +96,44 @@ func TestBuildHTTPTransportHTTPProxy(t *testing.T) {
if proxyURL == nil || proxyURL.String() != "http://proxy.example.com:8080" { if proxyURL == nil || proxyURL.String() != "http://proxy.example.com:8080" {
t.Fatalf("proxy URL = %v, want http://proxy.example.com:8080", proxyURL) t.Fatalf("proxy URL = %v, want http://proxy.example.com:8080", proxyURL)
} }
defaultTransport := mustDefaultTransport(t)
if transport.ForceAttemptHTTP2 != defaultTransport.ForceAttemptHTTP2 {
t.Fatalf("ForceAttemptHTTP2 = %v, want %v", transport.ForceAttemptHTTP2, defaultTransport.ForceAttemptHTTP2)
}
if transport.IdleConnTimeout != defaultTransport.IdleConnTimeout {
t.Fatalf("IdleConnTimeout = %v, want %v", transport.IdleConnTimeout, defaultTransport.IdleConnTimeout)
}
if transport.TLSHandshakeTimeout != defaultTransport.TLSHandshakeTimeout {
t.Fatalf("TLSHandshakeTimeout = %v, want %v", transport.TLSHandshakeTimeout, defaultTransport.TLSHandshakeTimeout)
}
}
func TestBuildHTTPTransportSOCKS5ProxyInheritsDefaultTransportSettings(t *testing.T) {
t.Parallel()
transport, mode, errBuild := BuildHTTPTransport("socks5://proxy.example.com:1080")
if errBuild != nil {
t.Fatalf("BuildHTTPTransport returned error: %v", errBuild)
}
if mode != ModeProxy {
t.Fatalf("mode = %d, want %d", mode, ModeProxy)
}
if transport == nil {
t.Fatal("expected transport, got nil")
}
if transport.Proxy != nil {
t.Fatal("expected SOCKS5 transport to bypass http proxy function")
}
defaultTransport := mustDefaultTransport(t)
if transport.ForceAttemptHTTP2 != defaultTransport.ForceAttemptHTTP2 {
t.Fatalf("ForceAttemptHTTP2 = %v, want %v", transport.ForceAttemptHTTP2, defaultTransport.ForceAttemptHTTP2)
}
if transport.IdleConnTimeout != defaultTransport.IdleConnTimeout {
t.Fatalf("IdleConnTimeout = %v, want %v", transport.IdleConnTimeout, defaultTransport.IdleConnTimeout)
}
if transport.TLSHandshakeTimeout != defaultTransport.TLSHandshakeTimeout {
t.Fatalf("TLSHandshakeTimeout = %v, want %v", transport.TLSHandshakeTimeout, defaultTransport.TLSHandshakeTimeout)
}
} }