Skip to content

Commit 0e003a8

Browse files
authored
Improve export download file handling (#350)
1 parent 2d2090d commit 0e003a8

2 files changed

Lines changed: 239 additions & 4 deletions

File tree

internal/util/download.go

Lines changed: 64 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@ import (
2121
"net/http"
2222
"os"
2323
"path/filepath"
24+
"runtime"
25+
"strings"
26+
"unicode"
2427

2528
"github.com/go-resty/resty/v2"
2629
)
@@ -60,17 +63,74 @@ func GetResponse(url string, debug bool) (*http.Response, error) {
6063

6164
// CreateFile creates a file if it does not exist
6265
func CreateFile(path, fileName string) (*os.File, error) {
63-
filePath := filepath.Join(path, fileName)
64-
if _, err := os.Stat(filePath); err == nil {
65-
return nil, fmt.Errorf("file already exists")
66+
filePath, err := safeDownloadPath(fileName)
67+
if err != nil {
68+
return nil, err
6669
}
67-
file, err := os.Create(filePath)
70+
71+
if path == "" {
72+
path = "."
73+
}
74+
root, err := os.OpenRoot(path)
6875
if err != nil {
6976
return nil, err
7077
}
78+
defer root.Close()
79+
80+
file, err := root.OpenFile(filePath, os.O_WRONLY|os.O_CREATE|os.O_EXCL, 0644)
81+
if err != nil {
82+
if os.IsExist(err) {
83+
return nil, fmt.Errorf("file already exists")
84+
}
85+
return nil, fmt.Errorf("download destination %q cannot be created: %w", fileName, err)
86+
}
7187
return file, nil
7288
}
7389

90+
func safeDownloadPath(fileName string) (string, error) {
91+
if fileName == "" {
92+
return "", downloadDestinationError(fileName, "cannot be empty")
93+
}
94+
if fileName == "." || fileName == ".." {
95+
return "", downloadDestinationError(fileName, "must refer to a file")
96+
}
97+
if filepath.IsAbs(fileName) || isWindowsAbs(fileName) {
98+
return "", downloadDestinationError(fileName, "must be relative to the output path")
99+
}
100+
if containsControlCharacter(fileName) {
101+
return "", downloadDestinationError(fileName, "contains unsupported characters")
102+
}
103+
104+
clean := filepath.Clean(fileName)
105+
if clean == "." || clean == ".." {
106+
return "", downloadDestinationError(fileName, "must refer to a file")
107+
}
108+
if strings.HasPrefix(clean, ".."+string(os.PathSeparator)) || filepath.IsAbs(clean) {
109+
return "", downloadDestinationError(fileName, "is outside the output path")
110+
}
111+
112+
return clean, nil
113+
}
114+
115+
func downloadDestinationError(fileName, reason string) error {
116+
return fmt.Errorf("download destination %q %s", fileName, reason)
117+
}
118+
119+
func containsControlCharacter(s string) bool {
120+
return strings.ContainsRune(s, 0) || strings.IndexFunc(s, unicode.IsControl) >= 0
121+
}
122+
123+
func isWindowsAbs(path string) bool {
124+
if runtime.GOOS == "windows" {
125+
return false
126+
}
127+
if len(path) >= 3 && path[1] == ':' && (path[2] == '\\' || path[2] == '/') {
128+
c := path[0]
129+
return ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z')
130+
}
131+
return strings.HasPrefix(path, `\\`)
132+
}
133+
74134
// CreateFolder creates a folder if it does not exist
75135
func CreateFolder(path string) error {
76136
if path == "" {

internal/util/download_test.go

Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
// Copyright 2026 PingCAP, Inc.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package util
16+
17+
import (
18+
"os"
19+
"path/filepath"
20+
"runtime"
21+
"testing"
22+
)
23+
24+
func TestCreateFileAllowsRelativeDestinations(t *testing.T) {
25+
tests := []struct {
26+
name string
27+
fileName string
28+
}{
29+
{
30+
name: "normal file name",
31+
fileName: "a.sql.gz",
32+
},
33+
{
34+
name: "subdirectory",
35+
fileName: filepath.Join("folder", "a.sql.gz"),
36+
},
37+
{
38+
name: "nested subdirectory",
39+
fileName: filepath.Join("a", "b", "c.sql.gz"),
40+
},
41+
}
42+
43+
for _, tt := range tests {
44+
t.Run(tt.name, func(t *testing.T) {
45+
base := t.TempDir()
46+
parent := filepath.Dir(filepath.Join(base, tt.fileName))
47+
if err := os.MkdirAll(parent, 0755); err != nil {
48+
t.Fatalf("MkdirAll() error = %v", err)
49+
}
50+
51+
file, err := CreateFile(base, tt.fileName)
52+
if err != nil {
53+
t.Fatalf("CreateFile() error = %v", err)
54+
}
55+
file.Close()
56+
57+
if _, err := os.Stat(filepath.Join(base, tt.fileName)); err != nil {
58+
t.Fatalf("expected file inside base: %v", err)
59+
}
60+
})
61+
}
62+
}
63+
64+
func TestCreateFileRejectsUnsupportedDestinations(t *testing.T) {
65+
base := t.TempDir()
66+
67+
tests := []struct {
68+
name string
69+
fileName string
70+
}{
71+
{
72+
name: "path outside output path",
73+
fileName: filepath.Join("..", "..", "tmp", "pwned"),
74+
},
75+
{
76+
name: "absolute path",
77+
fileName: filepath.Join(string(os.PathSeparator), "tmp", "pwned"),
78+
},
79+
{
80+
name: "empty file name",
81+
fileName: "",
82+
},
83+
{
84+
name: "current directory",
85+
fileName: ".",
86+
},
87+
{
88+
name: "parent directory",
89+
fileName: "..",
90+
},
91+
{
92+
name: "cleans to current directory",
93+
fileName: filepath.Join("a", ".."),
94+
},
95+
}
96+
97+
for _, tt := range tests {
98+
t.Run(tt.name, func(t *testing.T) {
99+
file, err := CreateFile(base, tt.fileName)
100+
if err == nil {
101+
file.Close()
102+
t.Fatalf("CreateFile() succeeded for %q", tt.fileName)
103+
}
104+
})
105+
}
106+
}
107+
108+
func TestCreateFileDoesNotOverwriteExistingDestination(t *testing.T) {
109+
base := t.TempDir()
110+
path := filepath.Join(base, "a.sql.gz")
111+
if err := os.WriteFile(path, []byte("existing"), 0644); err != nil {
112+
t.Fatalf("WriteFile() error = %v", err)
113+
}
114+
115+
file, err := CreateFile(base, "a.sql.gz")
116+
if err == nil {
117+
file.Close()
118+
t.Fatalf("CreateFile() succeeded for existing destination")
119+
}
120+
121+
content, err := os.ReadFile(path)
122+
if err != nil {
123+
t.Fatalf("ReadFile() error = %v", err)
124+
}
125+
if string(content) != "existing" {
126+
t.Fatalf("existing destination was overwritten: %q", string(content))
127+
}
128+
}
129+
130+
func TestCreateFileUsesCurrentDirectoryWhenBaseIsEmpty(t *testing.T) {
131+
wd, err := os.Getwd()
132+
if err != nil {
133+
t.Fatalf("Getwd() error = %v", err)
134+
}
135+
base := t.TempDir()
136+
if err := os.Chdir(base); err != nil {
137+
t.Fatalf("Chdir() error = %v", err)
138+
}
139+
t.Cleanup(func() {
140+
if err := os.Chdir(wd); err != nil {
141+
t.Fatalf("restore working directory: %v", err)
142+
}
143+
})
144+
145+
file, err := CreateFile("", "a.sql.gz")
146+
if err != nil {
147+
t.Fatalf("CreateFile() error = %v", err)
148+
}
149+
file.Close()
150+
151+
if _, err := os.Stat(filepath.Join(base, "a.sql.gz")); err != nil {
152+
t.Fatalf("expected file in current directory: %v", err)
153+
}
154+
}
155+
156+
func TestCreateFileRejectsSymlinkedParentOutsideBase(t *testing.T) {
157+
if runtime.GOOS == "windows" {
158+
t.Skip("symlink behavior requires additional privileges on Windows")
159+
}
160+
161+
base := t.TempDir()
162+
outside := t.TempDir()
163+
if err := os.Symlink(outside, filepath.Join(base, "link")); err != nil {
164+
t.Fatalf("Symlink() error = %v", err)
165+
}
166+
167+
file, err := CreateFile(base, filepath.Join("link", "export.sql.gz"))
168+
if err == nil {
169+
file.Close()
170+
t.Fatalf("CreateFile() succeeded through symlinked parent")
171+
}
172+
if _, statErr := os.Stat(filepath.Join(outside, "export.sql.gz")); !os.IsNotExist(statErr) {
173+
t.Fatalf("expected no file outside base, stat error = %v", statErr)
174+
}
175+
}

0 commit comments

Comments
 (0)