diff --git a/model/permission/permissions.go b/model/permission/permissions.go index 5a5b1e324a8..86301ddaa97 100644 --- a/model/permission/permissions.go +++ b/model/permission/permissions.go @@ -3,19 +3,24 @@ package permission import ( + "context" "encoding/json" "fmt" "net/http" + "slices" "strings" "time" build "github.com/cozy/cozy-stack/pkg/config" + "github.com/cozy/cozy-stack/pkg/config/config" "github.com/cozy/cozy-stack/pkg/consts" "github.com/cozy/cozy-stack/pkg/couchdb" "github.com/cozy/cozy-stack/pkg/couchdb/mango" "github.com/cozy/cozy-stack/pkg/crypto" + "github.com/cozy/cozy-stack/pkg/logger" "github.com/cozy/cozy-stack/pkg/metadata" "github.com/cozy/cozy-stack/pkg/prefixer" + "github.com/cozy/cozy-stack/pkg/utils" "github.com/labstack/echo/v4" ) @@ -242,9 +247,206 @@ func GetForSharePreview(db prefixer.Prefixer, sharingID string) (*Permission, er } // GetForShareInteract retrieves the Permission doc for a given sharing to -// read/write a note +// read/write a note. It may repair legacy duplicate share-interact permission +// docs as part of the read by creating/updating the canonical permission doc +// and deleting duplicate legacy docs. func GetForShareInteract(db prefixer.Prefixer, sharingID string) (*Permission, error) { - return getFromSource(db, TypeShareInteract, consts.Sharings, sharingID) + return getOrRepairShareInteractPermissions(db, sharingID) +} + +// ShareInteractPermissionID returns the canonical permission document ID for a +// share-interact permission set. +func ShareInteractPermissionID(sharingID string) string { + return TypeShareInteract + "-" + sharingID +} + +func shareInteractLockName(sharingID string) string { + return "permissions/share-interact/" + sharingID +} + +func getShareInteractPermissions(db prefixer.Prefixer, sharingID string) ([]Permission, error) { + var res []Permission + req := couchdb.FindRequest{ + UseIndex: "by-source-and-type", + Selector: mango.And( + mango.Equal("type", TypeShareInteract), + mango.Equal("source_id", consts.Sharings+"/"+sharingID), + ), + Limit: 1000, + } + err := couchdb.FindDocs(db, consts.Permissions, &req, &res) + if err != nil { + // With a cluster of couchdb, we can have a race condition where we + // query an index before it has been updated for a doc that has just + // been created. Keep the same fallback as getFromSource. + time.Sleep(1 * time.Second) + err = couchdb.FindDocs(db, consts.Permissions, &req, &res) + if err != nil { + return nil, err + } + } + return res, nil +} + +func getOrRepairShareInteractPermissions(db prefixer.Prefixer, sharingID string) (*Permission, error) { + perms, err := getShareInteractPermissions(db, sharingID) + if err != nil { + return nil, err + } + if canonical, needsRepair, err := shareInteractRepairState(perms, sharingID); err != nil || !needsRepair { + return canonical, err + } + + return repairShareInteractPermissionsWithLock(db, sharingID) +} + +func repairShareInteractPermissionsWithLock(db prefixer.Prefixer, sharingID string) (*Permission, error) { + mu := config.Lock().ReadWrite(db, shareInteractLockName(sharingID)) + if err := mu.Lock(); err != nil { + return nil, err + } + defer mu.Unlock() + + return utils.RetryWithBackoffValue(context.Background(), shareInteractRetryOptions(), func() (*Permission, error) { + return getOrRepairShareInteractPermissionsOnce(db, sharingID) + }) +} + +// getOrRepairShareInteractPermissionsOnce performs a single read+repair pass. +// The caller must hold the share-interact write lock for sharingID. +func getOrRepairShareInteractPermissionsOnce(db prefixer.Prefixer, sharingID string) (*Permission, error) { + perms, err := getShareInteractPermissions(db, sharingID) + if err != nil { + return nil, err + } + if canonical, needsRepair, err := shareInteractRepairState(perms, sharingID); err != nil || !needsRepair { + return canonical, err + } + + return repairShareInteractPermissions(db, sharingID, perms) +} + +func shareInteractRetryOptions() utils.RetryOptions { + return utils.RetryOptions{ + Attempts: 5, + Delay: 10 * time.Millisecond, + MaxDelay: 100 * time.Millisecond, + JitterFactor: 0.25, + ShouldRetry: func(err error) bool { + return couchdb.IsConflictError(err) || couchdb.IsFileExists(err) + }, + } +} + +func shareInteractRepairState(perms []Permission, sharingID string) (*Permission, bool, error) { + if len(perms) == 0 { + return nil, false, &couchdb.Error{ + StatusCode: http.StatusNotFound, + Name: "not_found", + Reason: fmt.Sprintf("no permission doc for %v", sharingID), + } + } + + canonicalID := ShareInteractPermissionID(sharingID) + var canonical *Permission + usable := 0 + hasDuplicate := false + for i := range perms { + p := &perms[i] + isCanonical := p.ID() == canonicalID + if !isCanonical { + hasDuplicate = true + } + if p.Expired() { + continue + } + usable++ + if isCanonical { + canonical = p + } + } + if usable == 0 { + return nil, false, ErrExpiredToken + } + if usable == 1 && canonical != nil && !hasDuplicate { + return canonical, false, nil + } + return nil, true, nil +} + +// repairShareInteractPermissions merges usable share-interact permission docs +// into the canonical doc and deletes non-canonical duplicates. The caller must +// hold the share-interact write lock for sharingID. +func repairShareInteractPermissions(db prefixer.Prefixer, sharingID string, perms []Permission) (*Permission, error) { + canonicalID := ShareInteractPermissionID(sharingID) + merged := &Permission{ + PID: canonicalID, + Type: TypeShareInteract, + SourceID: consts.Sharings + "/" + sharingID, + Codes: make(map[string]string), + } + hasUsablePermission := false + hasCanonical := false + duplicates := make([]*Permission, 0, len(perms)) + for i := range perms { + p := &perms[i] + if p.ID() == canonicalID { + hasCanonical = true + // Keep the canonical doc revision and then overwrite the document with content rebuilt from non-expired docs. + merged.SetRev(p.Rev()) + } else { + duplicates = append(duplicates, p) + } + if p.Expired() { + continue + } + + hasUsablePermission = true + + if len(merged.Permissions) == 0 && len(p.Permissions) > 0 { + merged.Permissions = slices.Clone(p.Permissions) + } + if merged.Metadata == nil && p.Metadata != nil { + merged.Metadata = p.Metadata.Clone() + } + if merged.ExpiresAt == nil && p.ExpiresAt != nil { + merged.ExpiresAt = p.ExpiresAt + } + + for key, code := range p.Codes { + if key == "" { + continue + } + if existing, ok := merged.Codes[key]; ok && existing != code { + logger.WithDomain(db.DomainName()).WithNamespace("permissions"). + Warnf("conflicting share-interact code for %s in sharing %s", key, sharingID) + continue + } + merged.Codes[key] = code + } + } + if !hasUsablePermission { + return nil, ErrExpiredToken + } + + if !hasCanonical { + if err := couchdb.CreateNamedDoc(db, merged); err != nil { + return nil, err + } + } else if err := couchdb.UpdateDoc(db, merged); err != nil { + return nil, err + } + + for _, p := range duplicates { + if err := couchdb.DeleteDoc(db, p); err != nil { + if couchdb.IsNotFoundError(err) { + continue + } + return nil, err + } + } + + return merged, nil } func getFromSource(db prefixer.Prefixer, permType, docType, slug string) (*Permission, error) { @@ -605,21 +807,72 @@ func CreateSharePreviewSet(db prefixer.Prefixer, sharingID string, codes, shortc return doc, nil } -// CreateShareInteractSet creates a Permission doc for reading/writing a note -// inside a sharing +// CreateShareInteractSet creates or updates the Permission doc for reading and +// writing a note inside a sharing. When subdoc.Permissions is not empty, it +// replaces the existing permission rules with that full set. func CreateShareInteractSet(db prefixer.Prefixer, sharingID string, codes map[string]string, subdoc Permission) (*Permission, error) { doc := &Permission{ + PID: ShareInteractPermissionID(sharingID), Type: TypeShareInteract, Permissions: subdoc.Permissions, Codes: codes, SourceID: consts.Sharings + "/" + sharingID, Metadata: subdoc.Metadata, } - err := couchdb.CreateDoc(db, doc) - if err != nil { + + mu := config.Lock().ReadWrite(db, shareInteractLockName(sharingID)) + if err := mu.Lock(); err != nil { return nil, err } - return doc, nil + defer mu.Unlock() + + return utils.RetryWithBackoffValue(context.Background(), shareInteractRetryOptions(), func() (*Permission, error) { + existing, err := GetPermissionByIDIncludingExpired(db, ShareInteractPermissionID(sharingID)) + if err != nil { + if !couchdb.IsNotFoundError(err) { + return nil, err + } + err = couchdb.CreateNamedDoc(db, doc) + if err == nil { + return doc, nil + } + return nil, err + } + + merged := existing.Clone().(*Permission) + merged.Type = TypeShareInteract + merged.SourceID = consts.Sharings + "/" + sharingID + merged.SetID(ShareInteractPermissionID(sharingID)) + merged.ExpiresAt = nil + if len(subdoc.Permissions) > 0 { + // Callers pass the complete interact rule set, so replace instead of merging rules. + merged.Permissions = subdoc.Permissions + } + if merged.Metadata == nil && subdoc.Metadata != nil { + merged.Metadata = subdoc.Metadata.Clone() + } + if merged.Metadata != nil { + merged.Metadata.ChangeUpdatedAt() + } + if merged.Codes == nil { + merged.Codes = make(map[string]string) + } + for key, code := range codes { + if key == "" { + continue + } + if existingCode, ok := merged.Codes[key]; ok && existingCode != code { + logger.WithDomain(db.DomainName()).WithNamespace("permissions"). + Warnf("keeping existing share-interact code for %s in sharing %s", key, sharingID) + continue + } + merged.Codes[key] = code + } + if err := couchdb.UpdateDoc(db, merged); err != nil { + return nil, err + } + return merged, nil + }) } // ForceWebapp creates or updates a Permission doc for a given webapp diff --git a/model/permission/permissions_test.go b/model/permission/permissions_test.go index 6a8a17f3e8c..24974148672 100644 --- a/model/permission/permissions_test.go +++ b/model/permission/permissions_test.go @@ -5,11 +5,23 @@ import ( "encoding/json" "strings" "testing" + "time" + "github.com/cozy/cozy-stack/pkg/config/config" + "github.com/cozy/cozy-stack/pkg/consts" + "github.com/cozy/cozy-stack/pkg/couchdb" + "github.com/cozy/cozy-stack/pkg/metadata" + "github.com/cozy/cozy-stack/pkg/prefixer" "github.com/labstack/echo/v4" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) +func testPrefix(t *testing.T) prefixer.Prefixer { + t.Helper() + return prefixer.NewPrefixer(0, "test", t.Name()) +} + func TestCheckDoctypeName(t *testing.T) { assert.NoError(t, CheckDoctypeName("io.cozy.files", false)) assert.NoError(t, CheckDoctypeName("io.cozy.account_types", false)) @@ -423,6 +435,254 @@ func TestShareSetPermissions(t *testing.T) { assert.Error(t, err) } +func TestGetForShareInteractRepairsDuplicateDocs(t *testing.T) { + if testing.Short() { + t.Skip("an instance is required for this test: test skipped due to the use of --short flag") + } + + config.UseTestFile(t) + db := testPrefix(t) + require.NoError(t, couchdb.ResetDB(db, consts.Permissions)) + + const sharingID = "sharing-duplicate-interact-permissions" + perms := Permission{ + Permissions: Set{{ + Title: "Shared drive", + Type: consts.Files, + Values: []string{"shared-drive-root"}, + Verbs: ALL, + }}, + } + + err := couchdb.CreateDoc(db, &Permission{ + Type: TypeShareInteract, + Permissions: perms.Permissions, + Codes: map[string]string{ + "alice@example.test": "alice-token", + }, + SourceID: consts.Sharings + "/" + sharingID, + }) + require.NoError(t, err) + err = couchdb.CreateDoc(db, &Permission{ + Type: TypeShareInteract, + Permissions: perms.Permissions, + Codes: map[string]string{ + "bob@example.test": "bob-token", + }, + SourceID: consts.Sharings + "/" + sharingID, + }) + require.NoError(t, err) + + repaired, err := GetForShareInteract(db, sharingID) + require.NoError(t, err) + require.Equal(t, ShareInteractPermissionID(sharingID), repaired.ID()) + require.Equal(t, map[string]string{ + "alice@example.test": "alice-token", + "bob@example.test": "bob-token", + }, repaired.Codes) + + all, err := getShareInteractPermissions(db, sharingID) + require.NoError(t, err) + require.Len(t, all, 1) + require.Equal(t, ShareInteractPermissionID(sharingID), all[0].ID()) + + repaired, err = GetForShareInteract(db, sharingID) + require.NoError(t, err) + require.Equal(t, map[string]string{ + "alice@example.test": "alice-token", + "bob@example.test": "bob-token", + }, repaired.Codes) + + all, err = getShareInteractPermissions(db, sharingID) + require.NoError(t, err) + require.Len(t, all, 1) +} + +func TestGetForShareInteractRepairsExpiredDuplicateDocs(t *testing.T) { + if testing.Short() { + t.Skip("an instance is required for this test: test skipped due to the use of --short flag") + } + + config.UseTestFile(t) + db := testPrefix(t) + require.NoError(t, couchdb.ResetDB(db, consts.Permissions)) + + const sharingID = "sharing-expired-duplicate-interact-permissions" + rules := Set{{ + Title: "Shared drive", + Type: consts.Files, + Values: []string{"shared-drive-root"}, + Verbs: ALL, + }} + expiredAt := time.Now().Add(-time.Hour).Format(time.RFC3339) + + err := couchdb.CreateNamedDoc(db, &Permission{ + PID: ShareInteractPermissionID(sharingID), + Type: TypeShareInteract, + Permissions: rules, + Codes: map[string]string{ + "alice@example.test": "alice-token", + }, + SourceID: consts.Sharings + "/" + sharingID, + }) + require.NoError(t, err) + err = couchdb.CreateDoc(db, &Permission{ + Type: TypeShareInteract, + Permissions: rules, + Codes: map[string]string{ + "expired@example.test": "expired-token", + }, + ExpiresAt: expiredAt, + SourceID: consts.Sharings + "/" + sharingID, + }) + require.NoError(t, err) + + repaired, err := GetForShareInteract(db, sharingID) + require.NoError(t, err) + require.Equal(t, ShareInteractPermissionID(sharingID), repaired.ID()) + require.Equal(t, map[string]string{ + "alice@example.test": "alice-token", + }, repaired.Codes) + + all, err := getShareInteractPermissions(db, sharingID) + require.NoError(t, err) + require.Len(t, all, 1) + require.Equal(t, ShareInteractPermissionID(sharingID), all[0].ID()) +} + +func TestCreateShareInteractSetUsesCanonicalDoc(t *testing.T) { + if testing.Short() { + t.Skip("an instance is required for this test: test skipped due to the use of --short flag") + } + + config.UseTestFile(t) + db := testPrefix(t) + require.NoError(t, couchdb.ResetDB(db, consts.Permissions)) + + const sharingID = "sharing-canonical-interact-permissions" + md := metadata.New() + md.UpdatedAt = time.Now().Add(-time.Hour) + perms := Permission{ + Permissions: Set{{ + Title: "Shared drive", + Type: consts.Files, + Values: []string{"shared-drive-root"}, + Verbs: ALL, + }}, + Metadata: md, + } + + first, err := CreateShareInteractSet(db, sharingID, map[string]string{ + "alice@example.test": "alice-token", + }, perms) + require.NoError(t, err) + require.Equal(t, ShareInteractPermissionID(sharingID), first.ID()) + + second, err := CreateShareInteractSet(db, sharingID, map[string]string{ + "bob@example.test": "bob-token", + }, perms) + require.NoError(t, err) + require.Equal(t, ShareInteractPermissionID(sharingID), second.ID()) + require.Equal(t, map[string]string{ + "alice@example.test": "alice-token", + "bob@example.test": "bob-token", + }, second.Codes) + require.NotNil(t, first.Metadata) + require.NotNil(t, second.Metadata) + require.True(t, second.Metadata.UpdatedAt.After(first.Metadata.UpdatedAt)) + + all, err := getShareInteractPermissions(db, sharingID) + require.NoError(t, err) + require.Len(t, all, 1) + require.Equal(t, ShareInteractPermissionID(sharingID), all[0].ID()) +} + +func TestCreateShareInteractSetReactivatesExpiredCanonicalDoc(t *testing.T) { + if testing.Short() { + t.Skip("an instance is required for this test: test skipped due to the use of --short flag") + } + + config.UseTestFile(t) + db := testPrefix(t) + require.NoError(t, couchdb.ResetDB(db, consts.Permissions)) + + const sharingID = "sharing-expired-canonical-interact-permissions" + rules := Set{{ + Title: "Shared drive", + Type: consts.Files, + Values: []string{"shared-drive-root"}, + Verbs: ALL, + }} + expiredAt := time.Now().Add(-time.Hour).Format(time.RFC3339) + + err := couchdb.CreateNamedDoc(db, &Permission{ + PID: ShareInteractPermissionID(sharingID), + Type: TypeShareInteract, + Permissions: rules, + Codes: map[string]string{ + "alice@example.test": "alice-token", + }, + ExpiresAt: expiredAt, + SourceID: consts.Sharings + "/" + sharingID, + }) + require.NoError(t, err) + + updated, err := CreateShareInteractSet(db, sharingID, map[string]string{ + "bob@example.test": "bob-token", + }, Permission{Permissions: rules}) + require.NoError(t, err) + require.Nil(t, updated.ExpiresAt) + require.Equal(t, map[string]string{ + "alice@example.test": "alice-token", + "bob@example.test": "bob-token", + }, updated.Codes) + + stored, err := GetForShareInteract(db, sharingID) + require.NoError(t, err) + require.Nil(t, stored.ExpiresAt) +} + +func TestCreateShareInteractSetDoesNotRepairLegacyDuplicates(t *testing.T) { + if testing.Short() { + t.Skip("an instance is required for this test: test skipped due to the use of --short flag") + } + + config.UseTestFile(t) + db := testPrefix(t) + require.NoError(t, couchdb.ResetDB(db, consts.Permissions)) + + const sharingID = "sharing-create-with-legacy-duplicate" + rules := Set{{ + Title: "Shared drive", + Type: consts.Files, + Values: []string{"shared-drive-root"}, + Verbs: ALL, + }} + + err := couchdb.CreateDoc(db, &Permission{ + Type: TypeShareInteract, + Permissions: rules, + Codes: map[string]string{ + "alice@example.test": "alice-token", + }, + SourceID: consts.Sharings + "/" + sharingID, + }) + require.NoError(t, err) + + created, err := CreateShareInteractSet(db, sharingID, map[string]string{ + "bob@example.test": "bob-token", + }, Permission{Permissions: rules}) + require.NoError(t, err) + require.Equal(t, ShareInteractPermissionID(sharingID), created.ID()) + require.Equal(t, map[string]string{ + "bob@example.test": "bob-token", + }, created.Codes) + + all, err := getShareInteractPermissions(db, sharingID) + require.NoError(t, err) + require.Len(t, all, 2) +} + func TestCreateShareSetBlocklist(t *testing.T) { s := Set{Rule{Type: "io.cozy.notifications"}} subdoc := Permission{ diff --git a/model/sharing/member_test.go b/model/sharing/member_test.go new file mode 100644 index 00000000000..b1d15b3ced1 --- /dev/null +++ b/model/sharing/member_test.go @@ -0,0 +1,153 @@ +package sharing + +import ( + "fmt" + "os" + "sync" + "testing" + + "github.com/cozy/cozy-stack/model/permission" + "github.com/cozy/cozy-stack/pkg/config/config" + "github.com/cozy/cozy-stack/pkg/consts" + "github.com/cozy/cozy-stack/pkg/couchdb" + "github.com/cozy/cozy-stack/pkg/couchdb/mango" + "github.com/cozy/cozy-stack/tests/testutils" + "github.com/stretchr/testify/require" +) + +func TestFindMemberByInteractCodeWithDuplicatePermissions(t *testing.T) { + if testing.Short() { + t.Skip("an instance is required for this test: test skipped due to the use of --short flag") + } + + config.UseTestFile(t) + testutils.NeedCouchdb(t) + setup := testutils.NewSetup(t, t.Name()) + inst := setup.GetTestInstance() + require.NoError(t, couchdb.ResetDB(inst, consts.Permissions)) + + const sharingID = "sharing-duplicate-interact-permissions" + const aliceEmail = "alice@example.test" + const bobEmail = "bob@example.test" + + aliceToken := "alice-interact-token" + bobToken := "bob-interact-token" + perms := permission.Permission{ + Permissions: permission.Set{{ + Title: "Shared drive", + Type: consts.Files, + Values: []string{"shared-drive-root"}, + Verbs: permission.ALL, + }}, + } + + err := couchdb.CreateDoc(inst, &permission.Permission{ + Type: permission.TypeShareInteract, + Permissions: perms.Permissions, + Codes: map[string]string{ + aliceEmail: aliceToken, + }, + SourceID: consts.Sharings + "/" + sharingID, + }) + require.NoError(t, err) + err = couchdb.CreateDoc(inst, &permission.Permission{ + Type: permission.TypeShareInteract, + Permissions: perms.Permissions, + Codes: map[string]string{ + bobEmail: bobToken, + }, + SourceID: consts.Sharings + "/" + sharingID, + }) + require.NoError(t, err) + + targetEmail := bobEmail + targetToken := bobToken + + s := Sharing{ + SID: sharingID, + Members: []Member{ + {Email: "owner@example.test", Status: MemberStatusOwner}, + {Email: aliceEmail, Status: MemberStatusReady}, + {Email: bobEmail, Status: MemberStatusReady}, + }, + } + member, err := s.FindMemberByInteractCode(inst, targetToken) + require.NoError(t, err) + require.Equal(t, targetEmail, member.Email) +} + +func TestGetInteractCodeConcurrentCalls(t *testing.T) { + if testing.Short() { + t.Skip("an instance is required for this test: test skipped due to the use of --short flag") + } + + config.UseTestFile(t) + testutils.NeedCouchdb(t) + setup := testutils.NewSetup(t, t.Name()) + inst := setup.GetTestInstance() + require.NoError(t, couchdb.ResetDB(inst, consts.Permissions)) + + calls := 100 + if os.Getenv("COZY_STRESS_TESTS") == "1" { + calls = 1000 + } + sharingID := "sharing-concurrent-interact-permissions" + s := Sharing{ + SID: sharingID, + AppSlug: "drive", + Rules: []Rule{{ + Title: "Shared drive", + DocType: consts.Files, + Values: []string{"shared-drive-root"}, + }}, + } + + members := make([]Member, calls) + codes := make([]string, calls) + errs := make(chan error, calls) + start := make(chan struct{}) + var wg sync.WaitGroup + wg.Add(calls) + + for i := 0; i < calls; i++ { + members[i] = Member{Email: fmt.Sprintf("member-%04d@example.test", i)} + go func(i int) { + defer wg.Done() + <-start + code, err := s.GetInteractCode(inst, &members[i], i+1) + if err != nil { + errs <- err + return + } + codes[i] = code + }(i) + } + + close(start) + wg.Wait() + close(errs) + for err := range errs { + require.NoError(t, err) + } + + interact, err := permission.GetForShareInteract(inst, sharingID) + require.NoError(t, err) + require.Len(t, interact.Codes, calls) + for i, member := range members { + require.NotEmpty(t, codes[i]) + require.Equal(t, codes[i], interact.Codes[member.Email]) + } + + var perms []permission.Permission + req := couchdb.FindRequest{ + UseIndex: "by-source-and-type", + Selector: mango.And( + mango.Equal("type", permission.TypeShareInteract), + mango.Equal("source_id", consts.Sharings+"/"+sharingID), + ), + Limit: calls, + } + require.NoError(t, couchdb.FindDocs(inst, consts.Permissions, &req, &perms)) + require.Len(t, perms, 1) + require.Equal(t, permission.ShareInteractPermissionID(sharingID), perms[0].ID()) +} diff --git a/model/sharing/sharing.go b/model/sharing/sharing.go index 845d969352c..626b2971ce2 100644 --- a/model/sharing/sharing.go +++ b/model/sharing/sharing.go @@ -342,7 +342,7 @@ func (s *Sharing) GetInteractCode(inst *instance.Instance, member *Member, membe interact, err := permission.GetForShareInteract(inst, s.ID()) if err != nil { if couchdb.IsNotFoundError(err) { - return s.CreateInteractPermissions(inst, member) + return s.createInteractPermissions(inst, member) } return "", err } @@ -364,7 +364,11 @@ func (s *Sharing) GetInteractCode(inst *instance.Instance, member *Member, membe } if key == member.Instance || key == member.Email || key == indexKey { if needUpdate { - if err := couchdb.UpdateDoc(inst, interact); err != nil { + _, err := permission.CreateShareInteractSet(inst, s.SID, nil, permission.Permission{ + Permissions: set, + Metadata: interact.Metadata, + }) + if err != nil { return "", err } } @@ -385,19 +389,21 @@ func (s *Sharing) GetInteractCode(inst *instance.Instance, member *Member, membe if err != nil { return "", err } - if interact.Codes == nil { - interact.Codes = make(map[string]string) - } - interact.Codes[key] = code - if err := couchdb.UpdateDoc(inst, interact); err != nil { + + updated, err := permission.CreateShareInteractSet(inst, s.SID, map[string]string{key: code}, permission.Permission{ + Permissions: set, + Metadata: interact.Metadata, + }) + if err != nil { return "", err } - return code, nil + if stored, ok := updated.Codes[key]; ok { + return stored, nil + } + return "", fmt.Errorf("share-interact code for %s was not stored in sharing %s", key, s.SID) } -// CreateInteractPermissions creates the permissions doc for reading and -// writing a note inside this sharing. -func (s *Sharing) CreateInteractPermissions(inst *instance.Instance, m *Member) (string, error) { +func (s *Sharing) createInteractPermissions(inst *instance.Instance, m *Member) (string, error) { key := m.Email if key == "" { key = m.Instance @@ -417,10 +423,13 @@ func (s *Sharing) CreateInteractPermissions(inst *instance.Instance, m *Member) Metadata: md, } - _, err = permission.CreateShareInteractSet(inst, s.SID, codes, doc) + interact, err := permission.CreateShareInteractSet(inst, s.SID, codes, doc) if err != nil { return "", err } + if stored, ok := interact.Codes[key]; ok { + return stored, nil + } return code, nil } diff --git a/pkg/utils/retry.go b/pkg/utils/retry.go index 85218b6bcff..f7502c66610 100644 --- a/pkg/utils/retry.go +++ b/pkg/utils/retry.go @@ -1,23 +1,171 @@ package utils -import "time" +import ( + "context" + "math/rand/v2" + "time" +) + +const maxRetryDelay = time.Duration(1<<63 - 1) + +// RetryOptions configures RetryWithBackoff. +type RetryOptions struct { + // Attempts is the maximum number of calls to fn. Values lower than 1 are + // treated as 1. + Attempts int + // Delay is the wait duration before the first retry. It doubles after each + // failed attempt. + Delay time.Duration + // MaxDelay caps the exponential backoff delay before jitter is applied. + MaxDelay time.Duration + // JitterFactor adds a random delay between 0 and the current backoff delay + // multiplied by JitterFactor. This jitter is one-sided: it can only extend + // the delay, never shorten it. For example, 0.25 adds up to 25% jitter. + // Values lower than or equal to 0 disable jitter. + JitterFactor float64 + // ShouldRetry can be used to retry only some errors. When nil, every + // non-nil error is retried until Attempts is exhausted. + ShouldRetry func(error) bool + // OnRetry is called after a failed attempt and before sleeping. The attempt + // argument is the number of attempts already made, starting at 1. + OnRetry func(attempt int, err error, delay time.Duration) +} + +// RetryWithBackoff calls fn until it succeeds, ShouldRetry rejects the returned +// error, the context is done, or the maximum number of attempts is reached. +func RetryWithBackoff(ctx context.Context, opts RetryOptions, fn func() error) error { + _, err := RetryWithBackoffValue(ctx, opts, func() (struct{}, error) { + return struct{}{}, fn() + }) + return err +} + +// RetryWithBackoffValue calls fn until it succeeds, ShouldRetry rejects the +// returned error, the context is done, or the maximum number of attempts is +// reached. It returns the value returned by the successful call. +func RetryWithBackoffValue[T any](ctx context.Context, opts RetryOptions, fn func() (T, error)) (T, error) { + var zero T + if ctx == nil { + ctx = context.Background() + } + + attempts := opts.Attempts + if attempts < 1 { + attempts = 1 + } + + var err error + var value T + for attempt := 0; attempt < attempts; attempt++ { + select { + case <-ctx.Done(): + return zero, ctx.Err() + default: + } + + value, err = fn() + if err == nil { + return value, nil + } + if attempt == attempts-1 { + return zero, err + } + if opts.ShouldRetry != nil && !opts.ShouldRetry(err) { + return zero, err + } + + delay := retryDelay(opts.Delay, attempt, opts.MaxDelay, opts.JitterFactor) + if opts.OnRetry != nil { + opts.OnRetry(attempt+1, err, delay) + } + if err := sleepWithContext(ctx, delay); err != nil { + return zero, err + } + } + + return zero, err +} // RetryWithExpBackoff can be used to call several times a function until it // returns no error or the maximum count of calls has been reached. Between two // calls, it will wait, first by the given delay, and after that, the delay // will double after each failure. func RetryWithExpBackoff(count int, delay time.Duration, fn func() error) error { - err := fn() - if err == nil { - return nil + return RetryWithBackoff(context.Background(), RetryOptions{ + Attempts: count, + Delay: delay, + }, fn) +} + +func retryDelay(initial time.Duration, attempt int, maxDelay time.Duration, jitterFactor float64) time.Duration { + delay := backoffDelay(initial, attempt, maxDelay) + return addJitter(delay, jitterFactor) +} + +func backoffDelay(initial time.Duration, attempt int, maxDelay time.Duration) time.Duration { + if initial <= 0 { + return 0 } - for i := 1; i < count; i++ { - time.Sleep(delay) + + delay := initial + for i := 0; i < attempt; i++ { + if delay > maxRetryDelay/2 { + return cappedDelay(maxRetryDelay, maxDelay) + } delay *= 2 - err = fn() - if err == nil { + } + return cappedDelay(delay, maxDelay) +} + +func cappedDelay(delay time.Duration, maxDelay time.Duration) time.Duration { + if maxDelay > 0 && delay > maxDelay { + return maxDelay + } + return delay +} + +func addJitter(delay time.Duration, jitterFactor float64) time.Duration { + if delay <= 0 || jitterFactor <= 0 { + return delay + } + + maxJitterFloat := float64(delay) * jitterFactor + if maxJitterFloat < 1 { + return delay + } + var maxJitter time.Duration + if maxJitterFloat >= float64(maxRetryDelay) { + maxJitter = maxRetryDelay + } else { + maxJitter = time.Duration(maxJitterFloat) + } + if maxJitter <= 0 { + return delay + } + jitter := time.Duration(rand.Int64N(int64(maxJitter))) + if maxRetryDelay-delay < jitter { + return maxRetryDelay + } + return delay + jitter +} + +func sleepWithContext(ctx context.Context, delay time.Duration) error { + if delay <= 0 { + select { + case <-ctx.Done(): + return ctx.Err() + default: return nil } } - return err + + timer := time.NewTimer(delay) + defer timer.Stop() + + select { + case <-ctx.Done(): + return ctx.Err() + case <-timer.C: + return nil + } } diff --git a/pkg/utils/retry_test.go b/pkg/utils/retry_test.go new file mode 100644 index 00000000000..d3d29f846df --- /dev/null +++ b/pkg/utils/retry_test.go @@ -0,0 +1,145 @@ +package utils + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestRetryWithBackoffSucceedsAfterRetry(t *testing.T) { + errTemporary := errors.New("temporary") + calls := 0 + + err := RetryWithBackoff(context.Background(), RetryOptions{ + Attempts: 3, + }, func() error { + calls++ + if calls < 3 { + return errTemporary + } + return nil + }) + + require.NoError(t, err) + assert.Equal(t, 3, calls) +} + +func TestRetryWithBackoffValueReturnsSuccessfulValue(t *testing.T) { + errTemporary := errors.New("temporary") + calls := 0 + + value, err := RetryWithBackoffValue(context.Background(), RetryOptions{ + Attempts: 3, + }, func() (string, error) { + calls++ + if calls < 2 { + return "", errTemporary + } + return "done", nil + }) + + require.NoError(t, err) + assert.Equal(t, "done", value) + assert.Equal(t, 2, calls) +} + +func TestRetryWithBackoffStopsOnNonRetryableError(t *testing.T) { + errRetryable := errors.New("retryable") + errFatal := errors.New("fatal") + calls := 0 + + err := RetryWithBackoff(context.Background(), RetryOptions{ + Attempts: 5, + ShouldRetry: func(err error) bool { + return errors.Is(err, errRetryable) + }, + }, func() error { + calls++ + if calls == 1 { + return errRetryable + } + return errFatal + }) + + require.ErrorIs(t, err, errFatal) + assert.Equal(t, 2, calls) +} + +func TestRetryWithBackoffCapsDelay(t *testing.T) { + errTemporary := errors.New("temporary") + var delays []time.Duration + + err := RetryWithBackoff(context.Background(), RetryOptions{ + Attempts: 4, + Delay: time.Nanosecond, + MaxDelay: 2 * time.Nanosecond, + OnRetry: func(_ int, _ error, delay time.Duration) { + delays = append(delays, delay) + }, + }, func() error { + return errTemporary + }) + + require.ErrorIs(t, err, errTemporary) + assert.Equal(t, []time.Duration{ + time.Nanosecond, + 2 * time.Nanosecond, + 2 * time.Nanosecond, + }, delays) +} + +func TestRetryDelaySaturatesOnOverflow(t *testing.T) { + assert.Equal(t, maxRetryDelay, backoffDelay(maxRetryDelay, 1, 0)) + assert.Equal(t, time.Hour, backoffDelay(maxRetryDelay, 1, time.Hour)) +} + +func TestRetryWithBackoffStopsWhenContextIsCanceled(t *testing.T) { + errTemporary := errors.New("temporary") + ctx, cancel := context.WithCancel(context.Background()) + calls := 0 + + err := RetryWithBackoff(ctx, RetryOptions{ + Attempts: 3, + Delay: time.Hour, + OnRetry: func(_ int, _ error, _ time.Duration) { + cancel() + }, + }, func() error { + calls++ + return errTemporary + }) + + require.ErrorIs(t, err, context.Canceled) + assert.Equal(t, 1, calls) +} + +func TestRetryDelayWithJitter(t *testing.T) { + base := 4 * time.Second + hasJitter := false + + for i := 0; i < 100; i++ { + delay := retryDelay(base, 0, 0, 0.25) + + assert.GreaterOrEqual(t, delay, base) + assert.Less(t, delay, base+time.Second) + hasJitter = hasJitter || delay > base + } + assert.True(t, hasJitter) +} + +func TestRetryWithExpBackoffRunsAtLeastOnce(t *testing.T) { + errTemporary := errors.New("temporary") + calls := 0 + + err := RetryWithExpBackoff(0, time.Nanosecond, func() error { + calls++ + return errTemporary + }) + + require.ErrorIs(t, err, errTemporary) + assert.Equal(t, 1, calls) +}