Skip to content

Commit 2517e24

Browse files
committed
Fix bugs related to overwriting in install
1 parent ae5d11f commit 2517e24

2 files changed

Lines changed: 119 additions & 4 deletions

File tree

internal/cli/commands/install.go

Lines changed: 83 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -179,13 +179,21 @@ func InstallPackages(packageFiles []string, packagesDir string, overwriteExistin
179179
return reporter.Error("IO_ERROR", err.Error(), nil, err)
180180
}
181181

182+
promotedDirs, err := promoteOverwrittenPackageDirs(packagesDir, packages)
183+
if err != nil {
184+
rollbackPromotedPackageDirs(promotedDirs)
185+
rollbackInstallDirs(createdDirs)
186+
return reporter.Error("IO_ERROR", err.Error(), nil, err)
187+
}
188+
182189
if err := db.Exec("COMMIT").Error; err != nil {
190+
rollbackPromotedPackageDirs(promotedDirs)
183191
rollbackInstallDirs(createdDirs)
184192
return reporter.Error("IO_ERROR", fmt.Sprintf("commit install transaction: %v", err), nil, err)
185193
}
186194
committed = true
187195

188-
removeOverwrittenPackageDirs(packagesDir, packages)
196+
cleanupPromotedPackageDirBackups(promotedDirs)
189197
installed, overwritten, skipped := 0, 0, 0
190198
for _, pkg := range packages {
191199
switch pkg.Action {
@@ -494,6 +502,15 @@ func ensureInstallDirectoriesAvailable(packagesDir string, packages []installPac
494502
}
495503
destination := installedPackageDir(packagesDir, pkg.Inspection.ID, pkg.Inspection.Version.String())
496504
if _, err := os.Stat(destination); err == nil {
505+
if pkg.Action == installActionOverwrite {
506+
staging := stagedInstallPackageDir(packagesDir, pkg)
507+
if _, err := os.Stat(staging); err == nil {
508+
return fmt.Errorf("package staging directory already exists: %s", staging)
509+
} else if !os.IsNotExist(err) {
510+
return fmt.Errorf("stat package staging directory %s: %w", staging, err)
511+
}
512+
continue
513+
}
497514
return fmt.Errorf("package directory already exists: %s", destination)
498515
} else if !os.IsNotExist(err) {
499516
return fmt.Errorf("stat package directory %s: %w", destination, err)
@@ -512,7 +529,7 @@ func extractInstallPackages(ctx context.Context, packagesDir string, packages []
512529
if pkg.Action == installActionSkip {
513530
continue
514531
}
515-
destination := installedPackageDir(packagesDir, pkg.Inspection.ID, pkg.Inspection.Version.String())
532+
destination := installExtractDestination(packagesDir, pkg)
516533
mu.Lock()
517534
createdDirs = append(createdDirs, destination)
518535
mu.Unlock()
@@ -537,6 +554,21 @@ func extractInstallPackages(ctx context.Context, packagesDir string, packages []
537554
return createdDirs, nil
538555
}
539556

557+
func installExtractDestination(packagesDir string, pkg installPackage) string {
558+
if pkg.Action == installActionOverwrite {
559+
return stagedInstallPackageDir(packagesDir, pkg)
560+
}
561+
return installedPackageDir(packagesDir, pkg.Inspection.ID, pkg.Inspection.Version.String())
562+
}
563+
564+
func stagedInstallPackageDir(packagesDir string, pkg installPackage) string {
565+
hashPrefix := pkg.Hash
566+
if len(hashPrefix) > 16 {
567+
hashPrefix = hashPrefix[:16]
568+
}
569+
return filepath.Join(packagesDir, "."+installedPackageDirName(pkg.Inspection.ID, pkg.Inspection.Version.String())+".new-"+hashPrefix)
570+
}
571+
540572
func writeInstalledPackages(db *gorm.DB, packages []installPackage) error {
541573
now := time.Now().UnixMilli()
542574
for _, pkg := range packages {
@@ -795,13 +827,60 @@ func rollbackInstallDirs(dirs []string) {
795827
}
796828
}
797829

798-
func removeOverwrittenPackageDirs(packagesDir string, packages []installPackage) {
830+
type promotedPackageDir struct {
831+
Final string
832+
Backup string
833+
}
834+
835+
func promoteOverwrittenPackageDirs(packagesDir string, packages []installPackage) ([]promotedPackageDir, error) {
836+
promoted := make([]promotedPackageDir, 0, len(packages))
799837
for _, pkg := range packages {
800838
if pkg.Action != installActionOverwrite || pkg.Existing == nil || pkg.Existing.Hash == "" || pkg.Existing.Hash == pkg.Hash {
801839
continue
802840
}
803-
_ = os.RemoveAll(installedPackageDir(packagesDir, pkg.Existing.ID, pkg.Existing.Version))
841+
final := installedPackageDir(packagesDir, pkg.Existing.ID, pkg.Existing.Version)
842+
staging := stagedInstallPackageDir(packagesDir, pkg)
843+
backup := backupInstallPackageDir(packagesDir, pkg)
844+
if _, err := os.Stat(backup); err == nil {
845+
return promoted, fmt.Errorf("package backup directory already exists: %s", backup)
846+
} else if !os.IsNotExist(err) {
847+
return promoted, fmt.Errorf("stat package backup directory %s: %w", backup, err)
848+
}
849+
if err := os.Rename(final, backup); err != nil {
850+
return promoted, fmt.Errorf("backup existing package directory %s: %w", final, err)
851+
}
852+
swap := promotedPackageDir{Final: final, Backup: backup}
853+
promoted = append(promoted, swap)
854+
if err := os.Rename(staging, final); err != nil {
855+
if restoreErr := os.Rename(backup, final); restoreErr != nil {
856+
return promoted, fmt.Errorf("promote package directory %s: %w; restore backup: %v", final, err, restoreErr)
857+
}
858+
promoted = promoted[:len(promoted)-1]
859+
return promoted, fmt.Errorf("promote package directory %s: %w", final, err)
860+
}
861+
}
862+
return promoted, nil
863+
}
864+
865+
func rollbackPromotedPackageDirs(promoted []promotedPackageDir) {
866+
for i := len(promoted) - 1; i >= 0; i-- {
867+
_ = os.RemoveAll(promoted[i].Final)
868+
_ = os.Rename(promoted[i].Backup, promoted[i].Final)
869+
}
870+
}
871+
872+
func cleanupPromotedPackageDirBackups(promoted []promotedPackageDir) {
873+
for _, dir := range promoted {
874+
_ = os.RemoveAll(dir.Backup)
875+
}
876+
}
877+
878+
func backupInstallPackageDir(packagesDir string, pkg installPackage) string {
879+
hashPrefix := pkg.Hash
880+
if len(hashPrefix) > 16 {
881+
hashPrefix = hashPrefix[:16]
804882
}
883+
return filepath.Join(packagesDir, "."+installedPackageDirName(pkg.Inspection.ID, pkg.Inspection.Version.String())+".old-"+hashPrefix)
805884
}
806885

807886
func installPackageKey(id string, version string) string {

internal/cli/commands/install_test.go

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,42 @@ func TestInstallPackagesDryRunDoesNotExtractOrWriteDatabase(t *testing.T) {
7878
}
7979
}
8080

81+
func TestInstallPackagesOverwriteExistingReplacesDirectory(t *testing.T) {
82+
packagesDir := t.TempDir()
83+
tempDir := t.TempDir()
84+
first := filepath.Join(tempDir, "first.dspk")
85+
second := filepath.Join(tempDir, "second.dspk")
86+
if err := os.WriteFile(first, makeInstallTestArchive(t, "vendor/simple", "1.0", map[string]string{"old.txt": "old"}), 0o644); err != nil {
87+
t.Fatalf("write first package file: %v", err)
88+
}
89+
if err := os.WriteFile(second, makeInstallTestArchive(t, "vendor/simple", "1.0", map[string]string{"new.txt": "new"}), 0o644); err != nil {
90+
t.Fatalf("write second package file: %v", err)
91+
}
92+
93+
var output bytes.Buffer
94+
if err := InstallPackages([]string{first}, packagesDir, false, false, false, &output); err != nil {
95+
t.Fatalf("InstallPackages(first) error = %v\n%s", err, output.String())
96+
}
97+
output.Reset()
98+
if err := InstallPackages([]string{second}, packagesDir, true, false, false, &output); err != nil {
99+
t.Fatalf("InstallPackages(second) error = %v\n%s", err, output.String())
100+
}
101+
102+
packageDir := installedPackageDir(packagesDir, "vendor/simple", "1.0.0.0")
103+
if _, err := os.Stat(filepath.Join(packageDir, "new.txt")); err != nil {
104+
t.Fatalf("new package file was not installed: %v", err)
105+
}
106+
if _, err := os.Stat(filepath.Join(packageDir, "old.txt")); !os.IsNotExist(err) {
107+
t.Fatalf("old package file still exists, err = %v", err)
108+
}
109+
if strings.Contains(output.String(), "package directory already exists") {
110+
t.Fatalf("overwrite output reported existing directory:\n%s", output.String())
111+
}
112+
if !strings.Contains(output.String(), "overwritten") {
113+
t.Fatalf("overwrite output missing overwritten result:\n%s", output.String())
114+
}
115+
}
116+
81117
func TestInstallPackagesRejectsDuplicatePackageIdentityInSameCommand(t *testing.T) {
82118
packagesDir := t.TempDir()
83119
tempDir := t.TempDir()

0 commit comments

Comments
 (0)