Skip to content

Commit 137f3ec

Browse files
authored
Merge pull request #169 from concourse/aws-sdk-v2
Upgrade to v2 of the aws-sdk
2 parents 506b2e4 + 33f96c9 commit 137f3ec

11 files changed

Lines changed: 213 additions & 416 deletions

File tree

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ the bucket.
6565

6666
* `region_name`: *Optional. Default `us-east-1`.* The region the bucket is in.
6767

68-
* `endpoint`: *Optional.* Custom endpoint for using S3 compatible provider.
68+
* `endpoint`: *Optional.* Custom endpoint for using S3 compatible provider. Can be a hostname or URL.
6969

7070
* `disable_ssl`: *Optional.* Disable SSL for the endpoint, useful for S3 compatible providers without SSL.
7171

@@ -74,7 +74,7 @@ the bucket.
7474
* `server_side_encryption`: *Optional.* The server-side encryption algorithm
7575
used when storing the version object (e.g. `AES256`, `aws:kms`).
7676

77-
* `use_v2_signing`: *Optional.* Use v2 signing, default false.
77+
* `use_v2_signing`: *Deprecated.* No longer used after upgrading to v2 of the AWS Go SDK.
7878

7979
### `swift` Driver
8080

check/check_test.go

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,17 @@ package main_test
22

33
import (
44
"bytes"
5+
"context"
56
"encoding/json"
67
"os"
78
"os/exec"
89
"path"
910
"time"
1011

11-
"github.com/aws/aws-sdk-go/aws"
12-
"github.com/aws/aws-sdk-go/aws/credentials"
13-
"github.com/aws/aws-sdk-go/aws/session"
14-
"github.com/aws/aws-sdk-go/service/s3"
12+
"github.com/aws/aws-sdk-go-v2/aws"
13+
"github.com/aws/aws-sdk-go-v2/config"
14+
"github.com/aws/aws-sdk-go-v2/credentials"
15+
"github.com/aws/aws-sdk-go-v2/service/s3"
1516
"github.com/concourse/semver-resource/models"
1617
"github.com/google/uuid"
1718
. "github.com/onsi/ginkgo/v2"
@@ -45,24 +46,24 @@ var _ = Describe("Check", func() {
4546
Context("when executed", func() {
4647
var request models.CheckRequest
4748
var response models.CheckResponse
48-
var svc *s3.S3
49+
var svc *s3.Client
4950

5051
BeforeEach(func() {
5152
guid, err := uuid.NewRandom()
5253
Expect(err).NotTo(HaveOccurred())
5354

5455
key = guid.String()
5556

56-
creds := credentials.NewStaticCredentials(accessKeyID, secretAccessKey, "")
57-
awsConfig := &aws.Config{
58-
Region: aws.String(regionName),
59-
Credentials: creds,
60-
S3ForcePathStyle: aws.Bool(true),
61-
MaxRetries: aws.Int(12),
62-
}
63-
sess, err := session.NewSession(awsConfig)
57+
creds := credentials.NewStaticCredentialsProvider(accessKeyID, secretAccessKey, "")
58+
cfg, err := config.LoadDefaultConfig(context.TODO(),
59+
config.WithRegion(regionName),
60+
config.WithRetryMaxAttempts(12),
61+
config.WithCredentialsProvider(creds),
62+
)
6463
Expect(err).NotTo(HaveOccurred())
65-
svc = s3.New(sess)
64+
svc = s3.NewFromConfig(cfg, func(o *s3.Options) {
65+
o.UsePathStyle = true
66+
})
6667

6768
request = models.CheckRequest{
6869
Version: models.Version{},
@@ -80,7 +81,7 @@ var _ = Describe("Check", func() {
8081
})
8182

8283
AfterEach(func() {
83-
_, err := svc.DeleteObject(&s3.DeleteObjectInput{
84+
_, err := svc.DeleteObject(context.TODO(), &s3.DeleteObjectInput{
8485
Bucket: aws.String(bucketName),
8586
Key: aws.String(key),
8687
})
@@ -105,12 +106,11 @@ var _ = Describe("Check", func() {
105106
})
106107

107108
putVersion := func(version string) {
108-
_, err := svc.PutObject(&s3.PutObjectInput{
109+
_, err := svc.PutObject(context.TODO(), &s3.PutObjectInput{
109110
Bucket: aws.String(bucketName),
110111
Key: aws.String(key),
111112
ContentType: aws.String("text/plain"),
112113
Body: bytes.NewReader([]byte(version)),
113-
ACL: aws.String(s3.ObjectCannedACLPrivate),
114114
})
115115
Expect(err).NotTo(HaveOccurred())
116116
}

driver/driver.go

Lines changed: 60 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,18 @@
11
package driver
22

33
import (
4+
"context"
45
"crypto/tls"
56
"fmt"
67
"net/http"
7-
8-
"github.com/aws/aws-sdk-go/aws"
9-
"github.com/aws/aws-sdk-go/aws/credentials"
10-
"github.com/aws/aws-sdk-go/aws/credentials/stscreds"
11-
"github.com/aws/aws-sdk-go/aws/session"
12-
"github.com/aws/aws-sdk-go/service/s3"
8+
"net/url"
9+
10+
"github.com/aws/aws-sdk-go-v2/aws"
11+
"github.com/aws/aws-sdk-go-v2/config"
12+
"github.com/aws/aws-sdk-go-v2/credentials"
13+
"github.com/aws/aws-sdk-go-v2/credentials/stscreds"
14+
"github.com/aws/aws-sdk-go-v2/service/s3"
15+
"github.com/aws/aws-sdk-go-v2/service/sts"
1316
"github.com/blang/semver"
1417
"github.com/concourse/semver-resource/models"
1518
"github.com/concourse/semver-resource/version"
@@ -38,16 +41,14 @@ func FromSource(source models.Source) (Driver, error) {
3841

3942
switch source.Driver {
4043
case models.DriverUnspecified, models.DriverS3:
41-
var creds *credentials.Credentials
44+
var credsProvider aws.CredentialsProvider
4245

43-
if source.AccessKeyID == "" && source.SecretAccessKey == "" {
44-
creds = credentials.AnonymousCredentials
45-
} else {
46-
creds = credentials.NewStaticCredentials(source.AccessKeyID, source.SecretAccessKey, source.SessionToken)
46+
if source.AccessKeyID != "" && source.SecretAccessKey != "" {
47+
credsProvider = credentials.NewStaticCredentialsProvider(source.AccessKeyID, source.SecretAccessKey, source.SessionToken)
4748
}
4849

4950
regionName := source.RegionName
50-
if len(regionName) == 0 {
51+
if regionName == "" {
5152
regionName = "us-east-1"
5253
}
5354

@@ -60,39 +61,67 @@ func FromSource(source models.Source) (Driver, error) {
6061
httpClient = http.DefaultClient
6162
}
6263

63-
awsConfig := &aws.Config{
64-
Region: aws.String(regionName),
65-
Credentials: creds,
66-
S3ForcePathStyle: aws.Bool(true),
67-
MaxRetries: aws.Int(maxRetries),
68-
DisableSSL: aws.Bool(source.DisableSSL),
69-
HTTPClient: httpClient,
64+
cfg, err := config.LoadDefaultConfig(context.TODO(),
65+
config.WithRegion(regionName),
66+
config.WithHTTPClient(httpClient),
67+
config.WithRetryMaxAttempts(maxRetries),
68+
config.WithCredentialsProvider(credsProvider),
69+
)
70+
if err != nil {
71+
return nil, fmt.Errorf("error loading default aws config: %w", err)
7072
}
7173

72-
if len(source.Endpoint) != 0 {
73-
awsConfig.Endpoint = aws.String(source.Endpoint)
74+
if source.AssumeRoleArn != "" {
75+
stsClient := sts.NewFromConfig(cfg)
76+
roleCreds := stscreds.NewAssumeRoleProvider(stsClient, source.AssumeRoleArn)
77+
creds, err := roleCreds.Retrieve(context.TODO())
78+
if err != nil {
79+
return nil, fmt.Errorf("error assuming role: %w", err)
80+
}
81+
82+
cfg.Credentials = aws.NewCredentialsCache(credentials.NewStaticCredentialsProvider(
83+
creds.AccessKeyID,
84+
creds.SecretAccessKey,
85+
creds.SessionToken,
86+
))
7487
}
7588

76-
s3Session := session.New(awsConfig)
89+
s3Opts := []func(*s3.Options){
90+
func(o *s3.Options) {
91+
o.UsePathStyle = true
92+
},
93+
}
7794

78-
var s3Client *s3.S3
79-
if source.AssumeRoleArn != "" {
80-
creds := stscreds.NewCredentials(s3Session, source.AssumeRoleArn)
81-
s3Client = s3.New(s3Session, &aws.Config{Credentials: creds})
82-
} else {
83-
s3Client = s3.New(s3Session)
95+
if source.Endpoint != "" {
96+
endpoint := source.Endpoint
97+
u, err := url.Parse(source.Endpoint)
98+
if err != nil {
99+
return nil, fmt.Errorf("error parsing given endpoint: %w", err)
100+
}
101+
if u.Scheme == "" {
102+
// source.Endpoint is a hostname
103+
scheme := "https://"
104+
if source.DisableSSL {
105+
scheme = "http://"
106+
}
107+
endpoint = scheme + source.Endpoint
108+
}
109+
110+
s3Opts = append(s3Opts, func(o *s3.Options) {
111+
o.BaseEndpoint = &endpoint
112+
})
84113
}
85114

86-
svc := s3Client
115+
s3Client := s3.NewFromConfig(cfg, s3Opts...)
87116

88117
if source.UseV2Signing {
89-
setv2Handlers(svc)
118+
//TODO: warn this setting is deprecated. The SDK only has v4 signing
90119
}
91120

92121
return &S3Driver{
93122
InitialVersion: initialVersion,
94123

95-
Svc: svc,
124+
Svc: s3Client,
96125
BucketName: source.Bucket,
97126
Key: source.Key,
98127
ServerSideEncryption: source.ServerSideEncryption,

driver/driver_test.go

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ package driver_test
33
import (
44
"net/http"
55

6-
"github.com/aws/aws-sdk-go/service/s3"
6+
"github.com/aws/aws-sdk-go-v2/service/s3"
77
"github.com/concourse/semver-resource/driver"
88
"github.com/concourse/semver-resource/models"
99
. "github.com/onsi/ginkgo/v2"
@@ -25,9 +25,9 @@ var _ = Describe("Driver", func() {
2525
s3Driver, ok := aDriver.(*driver.S3Driver)
2626
Expect(ok).To(BeTrue())
2727
Expect(s3Driver.Svc).To(Not(BeNil()))
28-
svc, ok := s3Driver.Svc.(*s3.S3)
28+
svc, ok := s3Driver.Svc.(*s3.Client)
2929
Expect(ok).To(BeTrue())
30-
Expect(svc.Client.Config.HTTPClient).Should(BeEquivalentTo(http.DefaultClient))
30+
Expect(svc.Options().HTTPClient).Should(BeEquivalentTo(http.DefaultClient))
3131
})
3232
It("returns a s3 driver with a transport that ignores ssl verification", func() {
3333
src.SkipSSLVerification = true
@@ -37,10 +37,11 @@ var _ = Describe("Driver", func() {
3737
s3Driver, ok := aDriver.(*driver.S3Driver)
3838
Expect(ok).To(BeTrue())
3939
Expect(s3Driver.Svc).To(Not(BeNil()))
40-
svc, ok := s3Driver.Svc.(*s3.S3)
40+
svc, ok := s3Driver.Svc.(*s3.Client)
4141
Expect(ok).To(BeTrue())
42-
Expect(svc.Client.Config.HTTPClient.Transport).ToNot(BeNil())
43-
transport, ok := svc.Client.Config.HTTPClient.Transport.(*http.Transport)
42+
httpClient, ok := svc.Options().HTTPClient.(*http.Client)
43+
Expect(httpClient.Transport).ToNot(BeNil())
44+
transport, ok := httpClient.Transport.(*http.Transport)
4445
Expect(ok).To(BeTrue())
4546
Expect(transport.TLSClientConfig.InsecureSkipVerify).Should(BeTrue())
4647
})

driver/s3.go

Lines changed: 30 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -2,20 +2,22 @@ package driver
22

33
import (
44
"bytes"
5+
"context"
6+
"errors"
57
"fmt"
68
"io"
79
"strings"
810

9-
"github.com/aws/aws-sdk-go/aws"
10-
"github.com/aws/aws-sdk-go/aws/awserr"
11-
"github.com/aws/aws-sdk-go/service/s3"
11+
"github.com/aws/aws-sdk-go-v2/aws"
12+
"github.com/aws/aws-sdk-go-v2/service/s3"
13+
"github.com/aws/aws-sdk-go-v2/service/s3/types"
1214
"github.com/blang/semver"
1315
"github.com/concourse/semver-resource/version"
1416
)
1517

1618
type Servicer interface {
17-
GetObject(*s3.GetObjectInput) (*s3.GetObjectOutput, error)
18-
PutObject(*s3.PutObjectInput) (*s3.PutObjectOutput, error)
19+
GetObject(context.Context, *s3.GetObjectInput, ...func(*s3.Options)) (*s3.GetObjectOutput, error)
20+
PutObject(context.Context, *s3.PutObjectInput, ...func(*s3.Options)) (*s3.PutObjectOutput, error)
1921
}
2022

2123
type S3Driver struct {
@@ -30,10 +32,11 @@ type S3Driver struct {
3032
func (driver *S3Driver) Bump(bump version.Bump) (semver.Version, error) {
3133
var currentVersion semver.Version
3234

33-
resp, err := driver.Svc.GetObject(&s3.GetObjectInput{
34-
Bucket: aws.String(driver.BucketName),
35-
Key: aws.String(driver.Key),
36-
})
35+
resp, err := driver.Svc.GetObject(context.TODO(),
36+
&s3.GetObjectInput{
37+
Bucket: aws.String(driver.BucketName),
38+
Key: aws.String(driver.Key),
39+
})
3740
if err == nil {
3841
bucketNumberPayload, err := io.ReadAll(resp.Body)
3942
if err != nil {
@@ -46,10 +49,13 @@ func (driver *S3Driver) Bump(bump version.Bump) (semver.Version, error) {
4649
if err != nil {
4750
return semver.Version{}, err
4851
}
49-
} else if s3err, ok := err.(awserr.RequestFailure); ok && s3err.StatusCode() == 404 {
50-
currentVersion = driver.InitialVersion
5152
} else {
52-
return semver.Version{}, err
53+
var noSuchKey *types.NoSuchKey
54+
if errors.As(err, &noSuchKey) {
55+
currentVersion = driver.InitialVersion
56+
} else {
57+
return semver.Version{}, err
58+
}
5359
}
5460

5561
newVersion := bump.Apply(currentVersion)
@@ -68,21 +74,20 @@ func (driver *S3Driver) Set(newVersion semver.Version) error {
6874
Key: aws.String(driver.Key),
6975
ContentType: aws.String("text/plain"),
7076
Body: bytes.NewReader([]byte(newVersion.String())),
71-
ACL: aws.String(s3.ObjectCannedACLPrivate),
7277
}
7378

7479
if len(driver.ServerSideEncryption) > 0 {
75-
params.ServerSideEncryption = aws.String(driver.ServerSideEncryption)
80+
params.ServerSideEncryption = types.ServerSideEncryption(driver.ServerSideEncryption)
7681
}
7782

78-
_, err := driver.Svc.PutObject(params)
83+
_, err := driver.Svc.PutObject(context.TODO(), params)
7984
return err
8085
}
8186

8287
func (driver *S3Driver) Check(cursor *semver.Version) ([]semver.Version, error) {
8388
var bucketNumber string
8489

85-
resp, err := driver.Svc.GetObject(&s3.GetObjectInput{
90+
resp, err := driver.Svc.GetObject(context.TODO(), &s3.GetObjectInput{
8691
Bucket: aws.String(driver.BucketName),
8792
Key: aws.String(driver.Key),
8893
})
@@ -94,14 +99,17 @@ func (driver *S3Driver) Check(cursor *semver.Version) ([]semver.Version, error)
9499
defer resp.Body.Close()
95100

96101
bucketNumber = string(bucketNumberPayload)
97-
} else if s3err, ok := err.(awserr.RequestFailure); ok && s3err.StatusCode() == 404 {
98-
if cursor == nil {
99-
return []semver.Version{driver.InitialVersion}, nil
102+
} else {
103+
var noSuchKey *types.NoSuchKey
104+
if errors.As(err, &noSuchKey) {
105+
if cursor == nil {
106+
return []semver.Version{driver.InitialVersion}, nil
107+
} else {
108+
return []semver.Version{}, nil
109+
}
100110
} else {
101-
return []semver.Version{}, nil
111+
return nil, err
102112
}
103-
} else {
104-
return nil, err
105113
}
106114

107115
bucketVersion, err := semver.Parse(bucketNumber)

0 commit comments

Comments
 (0)