Skip to content

Commit ec40335

Browse files
committed
Minor.
1 parent 7044bd4 commit ec40335

4 files changed

Lines changed: 94 additions & 2 deletions

File tree

keystore/atomicfile/write.go

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
package atomicfile
2+
3+
import (
4+
"fmt"
5+
"io"
6+
"os"
7+
"path/filepath"
8+
)
9+
10+
// WriteFile atomically writes the contents of r to the specified filepath with the default permission 0600.
11+
// This is a copy of https://github.com/natefinch/atomic/blob/master/atomic.go with minor modifications allowing
12+
// to set mode of the written file. If the file already exists, its mode is preserved.
13+
func WriteFile(filename string, r io.Reader, mode os.FileMode) (err error) {
14+
// write to a temp file first, then we'll atomically replace the target file
15+
// with the temp file.
16+
dir, file := filepath.Split(filename)
17+
if dir == "" {
18+
dir = "."
19+
}
20+
21+
f, err := os.CreateTemp(dir, file)
22+
if err != nil {
23+
return fmt.Errorf("cannot create temp file: %v", err)
24+
}
25+
defer func() {
26+
if err != nil {
27+
// Don't leave the temp file lying around on error.
28+
_ = os.Remove(f.Name()) // yes, ignore the error, not much we can do about it.
29+
}
30+
}()
31+
// ensure we always close f. Note that this does not conflict with the close below, as close is idempotent.
32+
defer f.Close()
33+
name := f.Name()
34+
if _, err := io.Copy(f, r); err != nil {
35+
return fmt.Errorf("cannot write data to tempfile %q: %v", name, err)
36+
}
37+
// fsync is important, otherwise os.Rename could rename a zero-length file
38+
if err := f.Sync(); err != nil {
39+
return fmt.Errorf("can't flush tempfile %q: %v", name, err)
40+
}
41+
if err := f.Close(); err != nil {
42+
return fmt.Errorf("can't close tempfile %q: %v", name, err)
43+
}
44+
45+
// get the file mode from the original file and use that for the replacement file, too.
46+
destInfo, err := os.Stat(filename)
47+
if os.IsNotExist(err) {
48+
// no original file
49+
if err := os.Chmod(name, mode); err != nil {
50+
return fmt.Errorf("can't set filemode on tempfile %q: %v", name, err)
51+
}
52+
} else if err != nil {
53+
return err
54+
} else {
55+
sourceInfo, err := os.Stat(name)
56+
if err != nil {
57+
return err
58+
}
59+
60+
if sourceInfo.Mode() != destInfo.Mode() {
61+
if err := os.Chmod(name, destInfo.Mode()); err != nil {
62+
return fmt.Errorf("can't set filemode on tempfile %q: %v", name, err)
63+
}
64+
}
65+
}
66+
if err := os.Rename(name, filename); err != nil {
67+
return fmt.Errorf("cannot replace %q with tempfile %q: %v", filename, name, err)
68+
}
69+
return nil
70+
}

keystore/atomicfile/write_test.go

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
package atomicfile
2+
3+
import (
4+
"bytes"
5+
"os"
6+
"path/filepath"
7+
"testing"
8+
9+
"github.com/stretchr/testify/require"
10+
)
11+
12+
func TestWriteFile_WriteAndRead(t *testing.T) {
13+
path := filepath.Join(t.TempDir(), "out.txt")
14+
data := []byte("test")
15+
err := WriteFile(path, bytes.NewReader(data), 0600)
16+
require.NoError(t, err)
17+
readData, err := os.ReadFile(path)
18+
require.NoError(t, err)
19+
require.Equal(t, readData, data)
20+
}

keystore/file.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ import (
55
"context"
66
"os"
77

8-
"github.com/natefinch/atomic"
8+
"github.com/smartcontractkit/chainlink-common/keystore/atomicfile"
99
)
1010

1111
var _ Storage = &FileStorage{}
@@ -26,5 +26,5 @@ func (f *FileStorage) GetEncryptedKeystore(ctx context.Context) ([]byte, error)
2626
}
2727

2828
func (f *FileStorage) PutEncryptedKeystore(ctx context.Context, encryptedKeystore []byte) error {
29-
return atomic.WriteFile(f.name, bytes.NewReader(encryptedKeystore))
29+
return atomicfile.WriteFile(f.name, bytes.NewReader(encryptedKeystore), 0600)
3030
}

keystore/internal/raw_test.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ func TestRaw_nonprintable(t *testing.T) {
2323

2424
assert.Equal(t, exp, fmt.Sprintf("%#v", r))
2525

26+
assert.Equal(t, exp, fmt.Sprintf("%x", r))
27+
2628
assert.Equal(t, exp, fmt.Sprintf("%s", r)) //nolint:gosimple // S1025 deliberately testing formatting verbs
2729

2830
got, err := json.Marshal(r) //nolint:staticcheck // SA9005 deliberately testing marshalling

0 commit comments

Comments
 (0)