Skip to content

Commit 561c502

Browse files
authored
fix(cli): harden okdev cp downloads (#87)
1 parent 0b1ed63 commit 561c502

6 files changed

Lines changed: 339 additions & 21 deletions

File tree

internal/cli/cp.go

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,20 @@ func multiPodDownloadPath(localPath, shortName, remotePath string, remoteIsDir b
185185
return filepath.Join(podDir, filepath.Base(remotePath))
186186
}
187187

188+
func downloadTargetPath(localPath, shortName, remotePath string, remoteIsDir bool, podCount int) string {
189+
if podCount == 1 {
190+
return localPath
191+
}
192+
return multiPodDownloadPath(localPath, shortName, remotePath, remoteIsDir)
193+
}
194+
195+
func downloadSuccessDestination(localPath, shortName string, podCount int) string {
196+
if podCount == 1 {
197+
return localPath
198+
}
199+
return filepath.Join(localPath, shortName) + string(os.PathSeparator)
200+
}
201+
188202
type cpResult struct {
189203
pod string
190204
err error
@@ -263,6 +277,7 @@ func runMultiPodCp(cmd *cobra.Command, cc *commandContext, localPath, remotePath
263277
out := cmd.OutOrStdout()
264278
results := make(chan cpResult, len(pods))
265279
sem := make(chan struct{}, effectiveFanout)
280+
podCount := len(pods)
266281

267282
var wg sync.WaitGroup
268283
for _, pod := range pods {
@@ -277,17 +292,12 @@ func runMultiPodCp(cmd *cobra.Command, cc *commandContext, localPath, remotePath
277292
if upload {
278293
cpErr = runSinglePodCp(ctx, cc.kube, cc.namespace, pod.Name, targetContainer, localPath, remotePath, true, io.Discard)
279294
} else {
280-
podDir := filepath.Join(localPath, short)
281-
if err := os.MkdirAll(podDir, 0o755); err != nil {
282-
results <- cpResult{pod: pod.Name, err: err}
283-
return
284-
}
285295
isRemoteDir, err := cc.kube.IsRemoteDir(ctx, cc.namespace, pod.Name, targetContainer, remotePath)
286296
if err != nil {
287297
results <- cpResult{pod: pod.Name, err: err}
288298
return
289299
}
290-
podLocalPath := multiPodDownloadPath(localPath, short, remotePath, isRemoteDir)
300+
podLocalPath := downloadTargetPath(localPath, short, remotePath, isRemoteDir, podCount)
291301
cpErr = runSinglePodCp(ctx, cc.kube, cc.namespace, pod.Name, targetContainer, podLocalPath, remotePath, false, io.Discard)
292302
}
293303
results <- cpResult{pod: pod.Name, err: cpErr}
@@ -309,7 +319,7 @@ func runMultiPodCp(cmd *cobra.Command, cc *commandContext, localPath, remotePath
309319
if upload {
310320
fmt.Fprintf(out, "%s: copied %s -> :%s\n", short, localPath, remotePath)
311321
} else {
312-
fmt.Fprintf(out, "%s: copied :%s -> %s/%s/\n", short, remotePath, localPath, short)
322+
fmt.Fprintf(out, "%s: copied :%s -> %s\n", short, remotePath, downloadSuccessDestination(localPath, short, podCount))
313323
}
314324
}
315325
}

internal/cli/cp_test.go

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package cli
22

33
import (
44
"fmt"
5+
"path/filepath"
56
"strings"
67
"testing"
78

@@ -99,6 +100,70 @@ func TestMultiPodDownloadPath(t *testing.T) {
99100
}
100101
}
101102

103+
func TestDownloadTargetPath(t *testing.T) {
104+
tests := []struct {
105+
name string
106+
localPath string
107+
shortName string
108+
remotePath string
109+
remoteIsDir bool
110+
podCount int
111+
want string
112+
}{
113+
{
114+
name: "single pod file stays flat",
115+
localPath: "/tmp/out.txt",
116+
shortName: "worker-0",
117+
remotePath: "/workspace/result.txt",
118+
podCount: 1,
119+
want: "/tmp/out.txt",
120+
},
121+
{
122+
name: "single pod directory stays flat",
123+
localPath: "/tmp/out",
124+
shortName: "worker-0",
125+
remotePath: "/workspace/results",
126+
remoteIsDir: true,
127+
podCount: 1,
128+
want: "/tmp/out",
129+
},
130+
{
131+
name: "multi pod file nests by pod",
132+
localPath: "/tmp/out",
133+
shortName: "worker-0",
134+
remotePath: "/workspace/result.txt",
135+
podCount: 2,
136+
want: filepath.Join("/tmp/out", "worker-0", "result.txt"),
137+
},
138+
{
139+
name: "multi pod directory nests by pod",
140+
localPath: "/tmp/out",
141+
shortName: "worker-0",
142+
remotePath: "/workspace/results",
143+
remoteIsDir: true,
144+
podCount: 2,
145+
want: filepath.Join("/tmp/out", "worker-0"),
146+
},
147+
}
148+
149+
for _, tt := range tests {
150+
t.Run(tt.name, func(t *testing.T) {
151+
if got := downloadTargetPath(tt.localPath, tt.shortName, tt.remotePath, tt.remoteIsDir, tt.podCount); got != tt.want {
152+
t.Fatalf("download target path = %q, want %q", got, tt.want)
153+
}
154+
})
155+
}
156+
}
157+
158+
func TestDownloadSuccessDestination(t *testing.T) {
159+
if got := downloadSuccessDestination("/tmp/out.txt", "worker-0", 1); got != "/tmp/out.txt" {
160+
t.Fatalf("single pod success destination = %q", got)
161+
}
162+
if got := downloadSuccessDestination("/tmp/out", "worker-0", 2); got != filepath.Join("/tmp/out", "worker-0")+string(filepath.Separator) {
163+
t.Fatalf("multi pod success destination = %q", got)
164+
}
165+
}
166+
102167
func TestCpReadinessCheckErrorsWhenPodsNotRunning(t *testing.T) {
103168
allPods := []kube.PodSummary{
104169
{Name: "sess-master-0", Phase: "Running", Labels: map[string]string{"okdev.io/workload-role": "Master"}},

internal/kube/client.go

Lines changed: 129 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1034,10 +1034,57 @@ func (c *Client) CopyFromPod(ctx context.Context, namespace, podName, remotePath
10341034
}
10351035

10361036
func (c *Client) CopyFromPodInContainer(ctx context.Context, namespace, podName, container, remotePath, localPath string) error {
1037+
var lastErr error
1038+
for attempt := 0; attempt < 3; attempt++ {
1039+
lastErr = c.copyFromPodInContainerOnce(ctx, namespace, podName, container, remotePath, localPath)
1040+
if lastErr == nil {
1041+
return nil
1042+
}
1043+
if !isRetryableCopyStreamError(lastErr) || attempt == 2 {
1044+
return lastErr
1045+
}
1046+
time.Sleep(time.Duration(attempt+1) * 500 * time.Millisecond)
1047+
}
1048+
return lastErr
1049+
}
1050+
1051+
func (c *Client) copyFromPodInContainerOnce(ctx context.Context, namespace, podName, container, remotePath, localPath string) error {
10371052
cs, cfg, err := c.clientset()
10381053
if err != nil {
10391054
return err
10401055
}
1056+
parent := filepath.Dir(remotePath)
1057+
base := filepath.Base(remotePath)
1058+
cmd := []string{"sh", "-lc", fmt.Sprintf("tar cf - -C %s %s", shellQuote(parent), shellQuote(base))}
1059+
tarFile, err := os.CreateTemp("", "okdev-copy-*.tar")
1060+
if err != nil {
1061+
return err
1062+
}
1063+
tarPath := tarFile.Name()
1064+
defer os.Remove(tarPath)
1065+
var errBuf bytes.Buffer
1066+
if err := c.execStream(ctx, cs, cfg, namespace, podName, container, cmd, nil, tarFile, &errBuf, false); err != nil {
1067+
_ = tarFile.Close()
1068+
return err
1069+
}
1070+
if _, err := tarFile.Seek(0, io.SeekStart); err != nil {
1071+
_ = tarFile.Close()
1072+
return err
1073+
}
1074+
if err := extractSingleFileFromTar(tarFile, localPath); err != nil {
1075+
_ = tarFile.Close()
1076+
return err
1077+
}
1078+
return tarFile.Close()
1079+
}
1080+
1081+
func createTempDownloadFile(localPath string) (*os.File, error) {
1082+
dir := filepath.Dir(localPath)
1083+
base := filepath.Base(localPath)
1084+
return os.CreateTemp(dir, base+".tmp-*")
1085+
}
1086+
1087+
func extractSingleFileFromTar(r io.Reader, localPath string) error {
10411088
if err := os.MkdirAll(filepath.Dir(localPath), 0o755); err != nil {
10421089
return err
10431090
}
@@ -1053,25 +1100,45 @@ func (c *Client) CopyFromPodInContainer(ctx context.Context, namespace, podName,
10531100
}
10541101
_ = os.Remove(tempPath)
10551102
}()
1056-
var errBuf bytes.Buffer
1057-
cmd := []string{"sh", "-lc", fmt.Sprintf("cat %s", shellQuote(remotePath))}
1058-
if err := c.execStream(ctx, cs, cfg, namespace, podName, container, cmd, nil, tempFile, &errBuf, false); err != nil {
1103+
1104+
tr := tar.NewReader(r)
1105+
foundFile := false
1106+
var fileMode os.FileMode = 0o644
1107+
for {
1108+
hdr, err := tr.Next()
1109+
if err == io.EOF {
1110+
break
1111+
}
1112+
if err != nil {
1113+
return err
1114+
}
1115+
switch hdr.Typeflag {
1116+
case tar.TypeDir, tar.TypeXHeader, tar.TypeXGlobalHeader, tar.TypeGNULongName, tar.TypeGNULongLink:
1117+
continue
1118+
case tar.TypeReg, tar.TypeRegA:
1119+
if foundFile {
1120+
return fmt.Errorf("tar archive contains multiple files")
1121+
}
1122+
if _, err := io.Copy(tempFile, tr); err != nil {
1123+
return err
1124+
}
1125+
fileMode = os.FileMode(hdr.Mode)
1126+
foundFile = true
1127+
default:
1128+
return fmt.Errorf("tar archive entry %q has unsupported type", hdr.Name)
1129+
}
1130+
}
1131+
if !foundFile {
1132+
return fmt.Errorf("tar archive contained no file")
1133+
}
1134+
if err := tempFile.Chmod(fileMode); err != nil {
10591135
return err
10601136
}
10611137
if err := tempFile.Close(); err != nil {
10621138
return err
10631139
}
10641140
closed = true
1065-
if err := os.Rename(tempPath, localPath); err != nil {
1066-
return err
1067-
}
1068-
return nil
1069-
}
1070-
1071-
func createTempDownloadFile(localPath string) (*os.File, error) {
1072-
dir := filepath.Dir(localPath)
1073-
base := filepath.Base(localPath)
1074-
return os.CreateTemp(dir, base+".tmp-*")
1141+
return os.Rename(tempPath, localPath)
10751142
}
10761143

10771144
func (c *Client) StreamFromPod(ctx context.Context, namespace, podName, script string, stdout io.Writer) error {
@@ -1185,6 +1252,21 @@ func (c *Client) CopyDirToPod(ctx context.Context, namespace, pod, container, lo
11851252

11861253
// CopyDirFromPod streams a remote directory as a tar archive and extracts it locally.
11871254
func (c *Client) CopyDirFromPod(ctx context.Context, namespace, pod, container, remoteDir, localDir string) error {
1255+
var lastErr error
1256+
for attempt := 0; attempt < 3; attempt++ {
1257+
lastErr = c.copyDirFromPodOnce(ctx, namespace, pod, container, remoteDir, localDir)
1258+
if lastErr == nil {
1259+
return nil
1260+
}
1261+
if !isRetryableCopyStreamError(lastErr) || attempt == 2 {
1262+
return lastErr
1263+
}
1264+
time.Sleep(time.Duration(attempt+1) * 500 * time.Millisecond)
1265+
}
1266+
return lastErr
1267+
}
1268+
1269+
func (c *Client) copyDirFromPodOnce(ctx context.Context, namespace, pod, container, remoteDir, localDir string) error {
11881270
cs, cfg, err := c.clientset()
11891271
if err != nil {
11901272
return err
@@ -1215,6 +1297,40 @@ func (c *Client) CopyDirFromPod(ctx context.Context, namespace, pod, container,
12151297
return tarFile.Close()
12161298
}
12171299

1300+
func isRetryableCopyStreamError(err error) bool {
1301+
if err == nil {
1302+
return false
1303+
}
1304+
msg := strings.ToLower(err.Error())
1305+
nonRetryable := []string{
1306+
"not found",
1307+
"no such file",
1308+
"permission denied",
1309+
"is a directory",
1310+
"tar archive contains multiple files",
1311+
"tar archive contained no file",
1312+
"unsupported type",
1313+
}
1314+
for _, s := range nonRetryable {
1315+
if strings.Contains(msg, s) {
1316+
return false
1317+
}
1318+
}
1319+
retryable := []string{
1320+
"unexpected eof",
1321+
"context deadline exceeded",
1322+
"unexpected error when reading response body",
1323+
"connection reset by peer",
1324+
"timeout",
1325+
}
1326+
for _, s := range retryable {
1327+
if strings.Contains(msg, s) {
1328+
return true
1329+
}
1330+
}
1331+
return false
1332+
}
1333+
12181334
// IsRemoteDir probes whether remotePath is a directory on the pod.
12191335
func (c *Client) IsRemoteDir(ctx context.Context, namespace, pod, container, remotePath string) (bool, error) {
12201336
cs, cfg, err := c.clientset()

internal/kube/client_copy_test.go

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,3 +76,79 @@ func TestExtractTarToDir(t *testing.T) {
7676
t.Fatalf("d/g.txt: %v %q", err, data)
7777
}
7878
}
79+
80+
func TestExtractSingleFileFromTar(t *testing.T) {
81+
var buf bytes.Buffer
82+
tw := tar.NewWriter(&buf)
83+
if err := tw.WriteHeader(&tar.Header{Name: "f.txt", Mode: 0o640, Size: 3, Typeflag: tar.TypeReg}); err != nil {
84+
t.Fatal(err)
85+
}
86+
if _, err := tw.Write([]byte("abc")); err != nil {
87+
t.Fatal(err)
88+
}
89+
if err := tw.Close(); err != nil {
90+
t.Fatal(err)
91+
}
92+
93+
outPath := filepath.Join(t.TempDir(), "out.txt")
94+
if err := extractSingleFileFromTar(bytes.NewReader(buf.Bytes()), outPath); err != nil {
95+
t.Fatal(err)
96+
}
97+
98+
data, err := os.ReadFile(outPath)
99+
if err != nil {
100+
t.Fatal(err)
101+
}
102+
if string(data) != "abc" {
103+
t.Fatalf("unexpected file content %q", data)
104+
}
105+
info, err := os.Stat(outPath)
106+
if err != nil {
107+
t.Fatal(err)
108+
}
109+
if info.Mode().Perm() != 0o640 {
110+
t.Fatalf("unexpected file mode %o", info.Mode().Perm())
111+
}
112+
}
113+
114+
func TestExtractSingleFileFromTarRejectsMultipleFiles(t *testing.T) {
115+
var buf bytes.Buffer
116+
tw := tar.NewWriter(&buf)
117+
for _, name := range []string{"a.txt", "b.txt"} {
118+
if err := tw.WriteHeader(&tar.Header{Name: name, Mode: 0o644, Size: 1, Typeflag: tar.TypeReg}); err != nil {
119+
t.Fatal(err)
120+
}
121+
if _, err := tw.Write([]byte("x")); err != nil {
122+
t.Fatal(err)
123+
}
124+
}
125+
if err := tw.Close(); err != nil {
126+
t.Fatal(err)
127+
}
128+
129+
err := extractSingleFileFromTar(bytes.NewReader(buf.Bytes()), filepath.Join(t.TempDir(), "out.txt"))
130+
if err == nil || err.Error() != "tar archive contains multiple files" {
131+
t.Fatalf("expected multiple files error, got %v", err)
132+
}
133+
}
134+
135+
func TestExtractSingleFileFromTarRejectsTruncatedArchive(t *testing.T) {
136+
payload := bytes.Repeat([]byte("a"), 2048)
137+
var buf bytes.Buffer
138+
tw := tar.NewWriter(&buf)
139+
if err := tw.WriteHeader(&tar.Header{Name: "f.txt", Mode: 0o644, Size: int64(len(payload)), Typeflag: tar.TypeReg}); err != nil {
140+
t.Fatal(err)
141+
}
142+
if _, err := tw.Write(payload); err != nil {
143+
t.Fatal(err)
144+
}
145+
if err := tw.Close(); err != nil {
146+
t.Fatal(err)
147+
}
148+
149+
truncated := buf.Bytes()[:len(buf.Bytes())-1100]
150+
err := extractSingleFileFromTar(bytes.NewReader(truncated), filepath.Join(t.TempDir(), "out.txt"))
151+
if err == nil {
152+
t.Fatal("expected truncated tar extraction to fail")
153+
}
154+
}

0 commit comments

Comments
 (0)