Skip to content

Commit dd522bd

Browse files
committed
Improve S3 remote URL handling
1 parent 240ab84 commit dd522bd

3 files changed

Lines changed: 43 additions & 7 deletions

File tree

pkg/artifact/kitfile.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,9 @@ func (kf *KitFile) LoadModel(kitfileContent io.ReadCloser) error {
137137
if err := decoder.Decode(kf); err != nil {
138138
return err
139139
}
140+
if err := kf.Validate(); err != nil {
141+
return err
142+
}
140143
return nil
141144
}
142145

@@ -210,6 +213,10 @@ func (kf *KitFile) Validate() error {
210213
if dataset.RemoteHash == "" {
211214
return fmt.Errorf("remoteHash is required when remote dataset paths are used (%s)", dataset.RemotePath)
212215
}
216+
} else {
217+
if dataset.RemoteHash != "" {
218+
return fmt.Errorf("remote hash is only applicable when remotePath is set")
219+
}
213220
}
214221
}
215222
for _, doc := range kf.Docs {

pkg/lib/external/s3api/api.go

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,13 @@ func ParseS3ObjectReference(ref string, hash string) (*S3ObjectReference, error)
6565
bucketKey := strings.TrimPrefix(s3url.Path, "/")
6666
version := s3url.Query().Get("versionId")
6767

68+
if bucketName == "" {
69+
return nil, fmt.Errorf("bucket name is required for S3 references")
70+
}
71+
if bucketKey == "" {
72+
return nil, fmt.Errorf("bucket key is required for S3 references")
73+
}
74+
6875
return &S3ObjectReference{
6976
Bucket: bucketName,
7077
Key: bucketKey,
@@ -104,6 +111,7 @@ func DownloadObject(ctx context.Context, client S3ClientAPI, ref *S3ObjectRefere
104111
if err != nil {
105112
return fmt.Errorf("failed to get object from S3 bucket: %w", err)
106113
}
114+
defer obj.Body.Close()
107115

108116
outfile, err := os.Create(outputPath)
109117
if err != nil {
@@ -143,6 +151,9 @@ func SetUpClient(ctx context.Context) (S3ClientAPI, error) {
143151
// For now, use path style URLs (endpoint.com/bucket/key) instead of virtual-hosted style (bucket.endpoint.com/key) when
144152
// an alternate endpoint is specified; this is used/supported by most S3-compatible APIs
145153
clientOpts = append(clientOpts, func(o *s3.Options) { o.UsePathStyle = true })
154+
} else {
155+
// If we're using the default endpoint, enable UseARNRegion to allow finding buckets across regions
156+
clientOpts = append(clientOpts, func(o *s3.Options) { o.UseARNRegion = true })
146157
}
147158

148159
s3cfg, err := s3config.LoadDefaultConfig(ctx, cfgOpts...)

pkg/lib/filesystem/unpack/core.go

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -176,17 +176,20 @@ func unpackRecursive(ctx context.Context, opts *UnpackOptions, visitedRefs []str
176176

177177
case mediatype.DatasetBaseType:
178178
// Since some datasets may be remote, we need to search the Kitfile for the next non-remote dataset
179-
var entry artifact.DataSet
179+
var entry *artifact.DataSet
180180
for idx := datasetIdx; idx < len(config.DataSets); idx++ {
181181
dataset := config.DataSets[idx]
182182
if dataset.RemotePath != "" {
183183
continue
184184
}
185-
entry = dataset
185+
entry = &dataset
186186
datasetIdx = idx + 1
187187
break
188188
}
189-
if !shouldUnpackLayer(entry, opts.FilterConfs) {
189+
if entry == nil {
190+
continue
191+
}
192+
if !shouldUnpackLayer(*entry, opts.FilterConfs) {
190193
continue
191194
}
192195
layerInfo, layerPath = entry.LayerInfo, entry.Path
@@ -249,15 +252,30 @@ func unpackRecursive(ctx context.Context, opts *UnpackOptions, visitedRefs []str
249252
return err
250253
}
251254
for path, s3Ref := range remoteFiles {
255+
_, relPath, err := filesystem.VerifySubpath(opts.UnpackDir, path)
256+
if err != nil {
257+
return fmt.Errorf("error verifying path %s for remote reference: %w", path, err)
258+
}
259+
252260
output.Debugf("Downloading remote dataset: Bucket: %s, Key: %s", s3Ref.Bucket, s3Ref.Key)
253-
if _, err := os.Stat(path); !errors.Is(err, os.ErrNotExist) && !opts.Overwrite {
254-
return fmt.Errorf("failed to unpack remote dataset: path '%s' already exists", path)
261+
if fi, exists := filesystem.PathExists(relPath); exists {
262+
if opts.IgnoreExisting {
263+
output.Debugf("File %s already exists; skipping", path)
264+
continue
265+
}
266+
if !opts.Overwrite {
267+
return fmt.Errorf("failed to unpack remote dataset: path '%s' already exists", path)
268+
}
269+
if !fi.Mode().IsRegular() {
270+
return fmt.Errorf("failed to unpack remote dataset: path '%s' already exists and is not a regular file", path)
271+
}
255272
}
256-
pathDir := filepath.Dir(path)
273+
274+
pathDir := filepath.Dir(relPath)
257275
if err := os.MkdirAll(pathDir, 0755); err != nil {
258276
return fmt.Errorf("failed to create directory %s: %w", pathDir, err)
259277
}
260-
if err := s3api.DownloadObject(ctx, client, &s3Ref, path); err != nil {
278+
if err := s3api.DownloadObject(ctx, client, &s3Ref, relPath); err != nil {
261279
return fmt.Errorf("failed to download remote dataset for path %s: %w", path, err)
262280
}
263281
output.Infof("Downloaded remote S3 dataset for path %s", path)

0 commit comments

Comments
 (0)