Skip to content

Commit 8957607

Browse files
committed
fix(ocireg): preserve default transport settings for HTTPS registries
1 parent b776b68 commit 8957607

2 files changed

Lines changed: 51 additions & 3 deletions

File tree

api/oci/extensions/repositories/ocireg/repository.go

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,18 @@ type RepositoryImpl struct {
5353
info *RepositoryInfo
5454
}
5555

56+
func newHTTPTransport(conf *tls.Config) *http.Transport {
57+
base, ok := http.DefaultTransport.(*http.Transport)
58+
if !ok {
59+
return &http.Transport{TLSClientConfig: conf}
60+
}
61+
62+
transport := base.Clone()
63+
transport.TLSClientConfig = conf
64+
65+
return transport
66+
}
67+
5668
var (
5769
_ cpi.RepositoryImpl = (*RepositoryImpl)(nil)
5870
_ credentials.ConsumerIdentityProvider = &RepositoryImpl{}
@@ -172,9 +184,7 @@ func (r *RepositoryImpl) getResolver(comp string) (oras.Resolver, error) {
172184
return rootCAs
173185
}(),
174186
}
175-
client.Transport = ocmlog.NewRoundTripper(retry.NewTransport(&http.Transport{
176-
TLSClientConfig: conf,
177-
}), logger)
187+
client.Transport = ocmlog.NewRoundTripper(retry.NewTransport(newHTTPTransport(conf)), logger)
178188
}
179189

180190
authClient := &auth.Client{
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
package ocireg
2+
3+
import (
4+
"crypto/tls"
5+
"net/http"
6+
"reflect"
7+
"testing"
8+
9+
"github.com/stretchr/testify/assert"
10+
"github.com/stretchr/testify/require"
11+
)
12+
13+
func TestNewHTTPTransportClonesDefaultTransport(t *testing.T) {
14+
t.Parallel()
15+
16+
base, ok := http.DefaultTransport.(*http.Transport)
17+
require.True(t, ok, "default transport must be *http.Transport")
18+
19+
conf := &tls.Config{MinVersion: tls.VersionTLS12}
20+
21+
got := newHTTPTransport(conf)
22+
require.NotNil(t, got)
23+
24+
assert.NotSame(t, base, got)
25+
assert.Same(t, conf, got.TLSClientConfig)
26+
assert.Equal(t, base.ForceAttemptHTTP2, got.ForceAttemptHTTP2)
27+
assert.Equal(t, base.MaxIdleConns, got.MaxIdleConns)
28+
assert.Equal(t, base.MaxIdleConnsPerHost, got.MaxIdleConnsPerHost)
29+
assert.Equal(t, base.MaxConnsPerHost, got.MaxConnsPerHost)
30+
assert.Equal(t, base.IdleConnTimeout, got.IdleConnTimeout)
31+
assert.Equal(t, base.TLSHandshakeTimeout, got.TLSHandshakeTimeout)
32+
assert.Equal(t, base.ExpectContinueTimeout, got.ExpectContinueTimeout)
33+
34+
if base.Proxy != nil {
35+
require.NotNil(t, got.Proxy)
36+
assert.Equal(t, reflect.ValueOf(base.Proxy).Pointer(), reflect.ValueOf(got.Proxy).Pointer())
37+
}
38+
}

0 commit comments

Comments
 (0)