Skip to content

Commit 511c5dc

Browse files
committed
simplify MkdirTmp usage
Signed-off-by: egibs <20933572+egibs@users.noreply.github.com>
1 parent 10d040d commit 511c5dc

3 files changed

Lines changed: 97 additions & 10 deletions

File tree

pkg/archive/archive.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -164,13 +164,13 @@ func extractNestedArchive(ctx context.Context, c malcontent.Config, d string, f
164164
// Some packages may have archives and files with colliding names
165165
// e.g., demo_page.css and demo_page.css.gz
166166
// the former is the uncompressed version of the latter
167-
// if we encounter this, replace the name with something that won't collide
167+
// if we encounter this, use os.MkdirTemp to create a unique directory
168168
if _, err := os.Stat(archivePath); err == nil {
169169
logger.Debugf("duplicate file name already exists, modifying directory name for %s", archivePath)
170-
var err error
171-
archivePath, err = os.MkdirTemp(d, strings.TrimSuffix(f, programkind.GetExt(f))+"_*")
172-
if err != nil {
173-
return fmt.Errorf("failed to create unique extraction directory: %w", err)
170+
var mkErr error
171+
archivePath, mkErr = os.MkdirTemp(filepath.Dir(archivePath), filepath.Base(archivePath)+"_*")
172+
if mkErr != nil {
173+
return fmt.Errorf("failed to create unique extraction directory: %w", mkErr)
174174
}
175175
} else if err := os.MkdirAll(archivePath, 0o700); err != nil {
176176
return fmt.Errorf("failed to create extraction directory: %w", err)

pkg/archive/symlink_test.go

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,11 @@ import (
77
"context"
88
"os"
99
"path/filepath"
10+
"sync"
1011
"testing"
12+
13+
"github.com/chainguard-dev/clog"
14+
"github.com/chainguard-dev/malcontent/pkg/malcontent"
1115
)
1216

1317
func TestSymlinkExtraction(t *testing.T) {
@@ -109,6 +113,88 @@ func TestValidateResolvedPath(t *testing.T) {
109113
}
110114
}
111115

116+
// TestExtractNestedArchiveWithSubdirectory verifies that extractNestedArchive
117+
// handles archives located in subdirectories (where the relative path contains
118+
// path separators). This is a regression test for a bug where os.MkdirTemp
119+
// was called with a pattern containing path separators, which is not allowed.
120+
func TestExtractNestedArchiveWithSubdirectory(t *testing.T) {
121+
t.Parallel()
122+
123+
tmpDir, err := os.MkdirTemp("", "nested-archive-test-*")
124+
if err != nil {
125+
t.Fatalf("failed to create temp dir: %v", err)
126+
}
127+
defer os.RemoveAll(tmpDir)
128+
129+
// Create a nested archive inside a subdirectory, simulating what happens
130+
// when an archive contains another archive at a path like "subdir/inner.tar.gz"
131+
subDir := filepath.Join(tmpDir, "subdir")
132+
if err := os.MkdirAll(subDir, 0o700); err != nil {
133+
t.Fatalf("failed to create subdir: %v", err)
134+
}
135+
136+
// Copy a real .gz test file into the subdirectory
137+
srcData, err := os.ReadFile("../../pkg/action/testdata/apko.gz")
138+
if err != nil {
139+
t.Fatalf("failed to read test archive: %v", err)
140+
}
141+
nestedArchive := filepath.Join(subDir, "apko.gz")
142+
if err := os.WriteFile(nestedArchive, srcData, 0o600); err != nil {
143+
t.Fatalf("failed to write nested archive: %v", err)
144+
}
145+
146+
ctx := context.Background()
147+
logger := clog.FromContext(ctx)
148+
cfg := malcontent.Config{}
149+
var extracted sync.Map
150+
151+
// This is the call that previously failed with "pattern contains path separator"
152+
err = extractNestedArchive(ctx, cfg, tmpDir, "subdir/apko.gz", &extracted, logger, 1)
153+
if err != nil {
154+
t.Fatalf("extractNestedArchive failed: %v", err)
155+
}
156+
}
157+
158+
// TestExtractNestedArchiveCollision verifies that extractNestedArchive handles
159+
// name collisions by falling back to os.MkdirTemp when the deterministic path
160+
// already exists.
161+
func TestExtractNestedArchiveCollision(t *testing.T) {
162+
t.Parallel()
163+
164+
tmpDir, err := os.MkdirTemp("", "nested-collision-test-*")
165+
if err != nil {
166+
t.Fatalf("failed to create temp dir: %v", err)
167+
}
168+
defer os.RemoveAll(tmpDir)
169+
170+
// Create a file that will collide with the extraction directory name.
171+
// When extracting "apko.gz", the extraction dir would be "apko" — create
172+
// that as a file first to force the collision path.
173+
collisionPath := filepath.Join(tmpDir, "apko")
174+
if err := os.WriteFile(collisionPath, []byte("existing"), 0o600); err != nil {
175+
t.Fatalf("failed to create collision file: %v", err)
176+
}
177+
178+
srcData, err := os.ReadFile("../../pkg/action/testdata/apko.gz")
179+
if err != nil {
180+
t.Fatalf("failed to read test archive: %v", err)
181+
}
182+
archivePath := filepath.Join(tmpDir, "apko.gz")
183+
if err := os.WriteFile(archivePath, srcData, 0o600); err != nil {
184+
t.Fatalf("failed to write archive: %v", err)
185+
}
186+
187+
ctx := context.Background()
188+
logger := clog.FromContext(ctx)
189+
cfg := malcontent.Config{}
190+
var extracted sync.Map
191+
192+
err = extractNestedArchive(ctx, cfg, tmpDir, "apko.gz", &extracted, logger, 1)
193+
if err != nil {
194+
t.Fatalf("extractNestedArchive with collision failed: %v", err)
195+
}
196+
}
197+
112198
func TestHandleSymlink(t *testing.T) {
113199
t.Parallel()
114200
tmpDir, err := os.MkdirTemp("", "symlink-test-*")

pkg/archive/zip.go

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -114,11 +114,12 @@ func extractFile(ctx context.Context, zf *zip.File, destDir string, logger *clog
114114
// this case insensitivity will break scans, so rename files that collide with existing directories
115115
if runtime.GOOS == "darwin" {
116116
if _, err := os.Stat(filepath.Join(destDir, zf.Name)); err == nil {
117-
tmpFile, err := os.CreateTemp(destDir, filepath.Base(zf.Name)+"_*")
118-
if err == nil {
119-
zf.Name = filepath.Base(tmpFile.Name())
120-
tmpFile.Close()
121-
os.Remove(tmpFile.Name())
117+
uniqueDir, mkErr := os.MkdirTemp(filepath.Join(destDir, filepath.Dir(zf.Name)), filepath.Base(zf.Name)+"_*")
118+
if mkErr == nil {
119+
rel, relErr := filepath.Rel(destDir, uniqueDir)
120+
if relErr == nil {
121+
zf.Name = rel
122+
}
122123
}
123124
}
124125
}

0 commit comments

Comments
 (0)