diff --git a/sdk/proxyutil/proxy.go b/sdk/proxyutil/proxy.go index 591ec9d9..029efeb7 100644 --- a/sdk/proxyutil/proxy.go +++ b/sdk/proxyutil/proxy.go @@ -68,14 +68,18 @@ func Parse(raw string) (Setting, error) { } } +func cloneDefaultTransport() *http.Transport { + if transport, ok := http.DefaultTransport.(*http.Transport); ok && transport != nil { + return transport.Clone() + } + return &http.Transport{} +} + // NewDirectTransport returns a transport that bypasses environment proxies. func NewDirectTransport() *http.Transport { - if transport, ok := http.DefaultTransport.(*http.Transport); ok && transport != nil { - clone := transport.Clone() - clone.Proxy = nil - return clone - } - return &http.Transport{Proxy: nil} + clone := cloneDefaultTransport() + clone.Proxy = nil + return clone } // BuildHTTPTransport constructs an HTTP transport for the provided proxy setting. @@ -102,14 +106,16 @@ func BuildHTTPTransport(raw string) (*http.Transport, Mode, error) { if errSOCKS5 != nil { return nil, setting.Mode, fmt.Errorf("create SOCKS5 dialer failed: %w", errSOCKS5) } - return &http.Transport{ - Proxy: nil, - DialContext: func(_ context.Context, network, addr string) (net.Conn, error) { - return dialer.Dial(network, addr) - }, - }, setting.Mode, nil + transport := cloneDefaultTransport() + transport.Proxy = nil + transport.DialContext = func(_ context.Context, network, addr string) (net.Conn, error) { + return dialer.Dial(network, addr) + } + return transport, setting.Mode, nil } - return &http.Transport{Proxy: http.ProxyURL(setting.URL)}, setting.Mode, nil + transport := cloneDefaultTransport() + transport.Proxy = http.ProxyURL(setting.URL) + return transport, setting.Mode, nil default: return nil, setting.Mode, nil } diff --git a/sdk/proxyutil/proxy_test.go b/sdk/proxyutil/proxy_test.go index bea413dc..5b250117 100644 --- a/sdk/proxyutil/proxy_test.go +++ b/sdk/proxyutil/proxy_test.go @@ -5,6 +5,16 @@ import ( "testing" ) +func mustDefaultTransport(t *testing.T) *http.Transport { + t.Helper() + + transport, ok := http.DefaultTransport.(*http.Transport) + if !ok || transport == nil { + t.Fatal("http.DefaultTransport is not an *http.Transport") + } + return transport +} + func TestParse(t *testing.T) { t.Parallel() @@ -86,4 +96,44 @@ func TestBuildHTTPTransportHTTPProxy(t *testing.T) { if proxyURL == nil || proxyURL.String() != "http://proxy.example.com:8080" { t.Fatalf("proxy URL = %v, want http://proxy.example.com:8080", proxyURL) } + + defaultTransport := mustDefaultTransport(t) + if transport.ForceAttemptHTTP2 != defaultTransport.ForceAttemptHTTP2 { + t.Fatalf("ForceAttemptHTTP2 = %v, want %v", transport.ForceAttemptHTTP2, defaultTransport.ForceAttemptHTTP2) + } + if transport.IdleConnTimeout != defaultTransport.IdleConnTimeout { + t.Fatalf("IdleConnTimeout = %v, want %v", transport.IdleConnTimeout, defaultTransport.IdleConnTimeout) + } + if transport.TLSHandshakeTimeout != defaultTransport.TLSHandshakeTimeout { + t.Fatalf("TLSHandshakeTimeout = %v, want %v", transport.TLSHandshakeTimeout, defaultTransport.TLSHandshakeTimeout) + } +} + +func TestBuildHTTPTransportSOCKS5ProxyInheritsDefaultTransportSettings(t *testing.T) { + t.Parallel() + + transport, mode, errBuild := BuildHTTPTransport("socks5://proxy.example.com:1080") + if errBuild != nil { + t.Fatalf("BuildHTTPTransport returned error: %v", errBuild) + } + if mode != ModeProxy { + t.Fatalf("mode = %d, want %d", mode, ModeProxy) + } + if transport == nil { + t.Fatal("expected transport, got nil") + } + if transport.Proxy != nil { + t.Fatal("expected SOCKS5 transport to bypass http proxy function") + } + + defaultTransport := mustDefaultTransport(t) + if transport.ForceAttemptHTTP2 != defaultTransport.ForceAttemptHTTP2 { + t.Fatalf("ForceAttemptHTTP2 = %v, want %v", transport.ForceAttemptHTTP2, defaultTransport.ForceAttemptHTTP2) + } + if transport.IdleConnTimeout != defaultTransport.IdleConnTimeout { + t.Fatalf("IdleConnTimeout = %v, want %v", transport.IdleConnTimeout, defaultTransport.IdleConnTimeout) + } + if transport.TLSHandshakeTimeout != defaultTransport.TLSHandshakeTimeout { + t.Fatalf("TLSHandshakeTimeout = %v, want %v", transport.TLSHandshakeTimeout, defaultTransport.TLSHandshakeTimeout) + } }