Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 23 additions & 22 deletions pkg/assembler/backends/ent/backend/package.go
Original file line number Diff line number Diff line change
Expand Up @@ -249,16 +249,17 @@ func upsertBulkPackage(ctx context.Context, tx *ent.Tx, pkgInputs []*model.IDorP
return nil, errors.Wrap(err, "bulk upsert pkgName node")
}

if err := tx.PackageVersion.CreateBulk(pkgVersionCreates...).
OnConflict(
sql.ConflictColumns(
packageversion.FieldHash,
packageversion.FieldNameID,
),
).
DoNothing().
Exec(ctx); err != nil && err != stdsql.ErrNoRows {

if err := retryOnFKViolation(ctx, func() error {
return tx.PackageVersion.CreateBulk(pkgVersionCreates...).
OnConflict(
sql.ConflictColumns(
packageversion.FieldHash,
packageversion.FieldNameID,
),
).
DoNothing().
Exec(ctx)
}); err != nil && err != stdsql.ErrNoRows {
return nil, errors.Wrap(err, "bulk upsert pkgVersion node")
}
}
Expand Down Expand Up @@ -293,18 +294,18 @@ func upsertPackage(ctx context.Context, tx *ent.Tx, pkg model.IDorPkgInput) (*mo

pkgVersionCreate := generatePackageVersionCreate(tx, &pkgVersionID, &pkgNameID, &pkg)

if err := pkgVersionCreate.
OnConflict(
sql.ConflictColumns(
packageversion.FieldHash,
packageversion.FieldNameID,
),
).
DoNothing().
Exec(ctx); err != nil {
if err != stdsql.ErrNoRows {
return nil, errors.Wrap(err, "upsert package version")
}
if err := retryOnFKViolation(ctx, func() error {
return pkgVersionCreate.
OnConflict(
sql.ConflictColumns(
packageversion.FieldHash,
packageversion.FieldNameID,
),
).
DoNothing().
Exec(ctx)
}); err != nil && err != stdsql.ErrNoRows {
return nil, errors.Wrap(err, "upsert package version")
}

return &model.PackageIDs{
Expand Down
73 changes: 73 additions & 0 deletions pkg/assembler/backends/ent/backend/retry.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
//
// Copyright 2026 The GUAC Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package backend

import (
"context"
"errors"
"time"

"github.com/lib/pq"
)

// pgForeignKeyViolationCode is PostgreSQL SQLSTATE 23503 (foreign_key_violation).
// See https://www.postgresql.org/docs/current/errcodes-appendix.html.
const pgForeignKeyViolationCode = "23503"

// isPGForeignKeyViolation reports whether err (possibly wrapped) is a PostgreSQL
// foreign-key-violation error. It intentionally does not match other constraint
// classes (unique, check, not-null) because those are not expected to become
// valid on retry.
func isPGForeignKeyViolation(err error) bool {
if err == nil {
return false
}
var pqErr *pq.Error
if !errors.As(err, &pqErr) {
return false
}
return string(pqErr.Code) == pgForeignKeyViolationCode
}

// fkRetryBackoffs controls sleep duration before each retry. Length determines
// max retries. Chosen to cover the ~1–2 s window in which rows committed by a
// sibling transaction typically become visible under production load.
var fkRetryBackoffs = []time.Duration{
500 * time.Millisecond,
1 * time.Second,
}

// retryOnFKViolation invokes fn and, on a PostgreSQL foreign-key violation,
// retries with bounded backoff. Non-FK errors propagate immediately.
// Honors ctx cancellation between attempts.
func retryOnFKViolation(ctx context.Context, fn func() error) error {
err := fn()
if !isPGForeignKeyViolation(err) {
return err
}
for _, backoff := range fkRetryBackoffs {
select {
case <-ctx.Done():
return ctx.Err()
case <-time.After(backoff):
}
err = fn()
if !isPGForeignKeyViolation(err) {
return err
}
}
return err
}
177 changes: 177 additions & 0 deletions pkg/assembler/backends/ent/backend/retry_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
//
// Copyright 2026 The GUAC Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package backend

import (
"context"
"errors"
"fmt"
"testing"
"time"

"github.com/lib/pq"
)

func TestIsPGForeignKeyViolation(t *testing.T) {
tests := []struct {
name string
err error
want bool
}{
{
name: "nil error",
err: nil,
want: false,
},
{
name: "plain error",
err: errors.New("some other failure"),
want: false,
},
{
name: "pq foreign_key_violation by value",
err: &pq.Error{Code: "23503"},
want: true,
},
{
name: "pq unique_violation is not retryable here",
err: &pq.Error{Code: "23505"},
want: false,
},
{
name: "pq check_violation is not retryable here",
err: &pq.Error{Code: "23514"},
want: false,
},
{
name: "wrapped pq foreign_key_violation",
err: fmt.Errorf("bulk upsert pkgVersion node: %w", &pq.Error{Code: "23503"}),
want: true,
},
{
name: "doubly-wrapped pq foreign_key_violation",
err: fmt.Errorf("outer: %w", fmt.Errorf("inner: %w", &pq.Error{Code: "23503"})),
want: true,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
if got := isPGForeignKeyViolation(tc.err); got != tc.want {
t.Fatalf("isPGForeignKeyViolation(%v) = %v, want %v", tc.err, got, tc.want)
}
})
}
}

func TestRetryOnFKViolation_SuccessFirstAttempt(t *testing.T) {
calls := 0
err := retryOnFKViolation(context.Background(), func() error {
calls++
return nil
})
if err != nil {
t.Fatalf("want nil err, got %v", err)
}
if calls != 1 {
t.Fatalf("want 1 call, got %d", calls)
}
}

func TestRetryOnFKViolation_NonRetryableErrorPropagates(t *testing.T) {
calls := 0
sentinel := errors.New("not a pq error")
err := retryOnFKViolation(context.Background(), func() error {
calls++
return sentinel
})
if !errors.Is(err, sentinel) {
t.Fatalf("want sentinel err, got %v", err)
}
if calls != 1 {
t.Fatalf("want 1 call (no retry on non-FK), got %d", calls)
}
}

func TestRetryOnFKViolation_UniqueViolationNotRetried(t *testing.T) {
calls := 0
err := retryOnFKViolation(context.Background(), func() error {
calls++
return &pq.Error{Code: "23505"}
})
if err == nil {
t.Fatalf("want err, got nil")
}
if calls != 1 {
t.Fatalf("want 1 call (unique violation not retried), got %d", calls)
}
}

func TestRetryOnFKViolation_RecoversAfterTransientFK(t *testing.T) {
calls := 0
err := retryOnFKViolation(context.Background(), func() error {
calls++
if calls < 2 {
return &pq.Error{Code: "23503"}
}
return nil
})
if err != nil {
t.Fatalf("want nil err after recovery, got %v", err)
}
if calls != 2 {
t.Fatalf("want 2 calls (one retry), got %d", calls)
}
}

func TestRetryOnFKViolation_GivesUpAfterMaxAttempts(t *testing.T) {
calls := 0
err := retryOnFKViolation(context.Background(), func() error {
calls++
return &pq.Error{Code: "23503"}
})
if err == nil {
t.Fatalf("want err after exhausting retries")
}
if !isPGForeignKeyViolation(err) {
t.Fatalf("want final err to be FK violation, got %v", err)
}
if calls != 3 {
t.Fatalf("want 3 calls (initial + 2 retries), got %d", calls)
}
}

func TestRetryOnFKViolation_HonorsContextCancel(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
cancel() // pre-cancelled

calls := 0
start := time.Now()
err := retryOnFKViolation(ctx, func() error {
calls++
return &pq.Error{Code: "23503"}
})
elapsed := time.Since(start)

if err == nil {
t.Fatalf("want err from cancelled context")
}
if elapsed > 200*time.Millisecond {
t.Fatalf("cancellation should abort fast, elapsed=%v", elapsed)
}
if calls < 1 {
t.Fatalf("want at least one attempt, got %d", calls)
}
}
40 changes: 24 additions & 16 deletions pkg/assembler/backends/ent/backend/sbom.go
Original file line number Diff line number Diff line change
Expand Up @@ -896,10 +896,12 @@ func updateHasSBOMWithIncludePackageIDs(ctx context.Context, client *ent.Client,
batches := chunk(sortedPkgUUIDs, 10000)

for _, batchedPkgUUIDs := range batches {
err := client.BillOfMaterials.
UpdateOneID(hasSBOMID).
AddIncludedSoftwarePackageIDs(batchedPkgUUIDs...).
Exec(ctx)
err := retryOnFKViolation(ctx, func() error {
return client.BillOfMaterials.
UpdateOneID(hasSBOMID).
AddIncludedSoftwarePackageIDs(batchedPkgUUIDs...).
Exec(ctx)
})
if err != nil {
return fmt.Errorf("update for IncludedSoftwarePackageIDs hasSBOM node failed with error: %w", err)
}
Expand All @@ -911,10 +913,12 @@ func updateHasSBOMWithIncludeArtifacts(ctx context.Context, client *ent.Client,
batches := chunk(sortedArtUUIDs, 10000)

for _, batchedArtUUIDs := range batches {
err := client.BillOfMaterials.
UpdateOneID(hasSBOMID).
AddIncludedSoftwareArtifactIDs(batchedArtUUIDs...).
Exec(ctx)
err := retryOnFKViolation(ctx, func() error {
return client.BillOfMaterials.
UpdateOneID(hasSBOMID).
AddIncludedSoftwareArtifactIDs(batchedArtUUIDs...).
Exec(ctx)
})
if err != nil {
return fmt.Errorf("update for IncludedSoftwareArtifactIDs hasSBOM node failed with error: %w", err)
}
Expand All @@ -926,10 +930,12 @@ func updateHasSBOMWithIncludeDependencies(ctx context.Context, client *ent.Clien
batches := chunk(sortedIsDepUUIDs, 10000)

for _, batchedIsDepUUIDs := range batches {
err := client.BillOfMaterials.
UpdateOneID(hasSBOMID).
AddIncludedDependencyIDs(batchedIsDepUUIDs...).
Exec(ctx)
err := retryOnFKViolation(ctx, func() error {
return client.BillOfMaterials.
UpdateOneID(hasSBOMID).
AddIncludedDependencyIDs(batchedIsDepUUIDs...).
Exec(ctx)
})
if err != nil {
return fmt.Errorf("update for IncludedDependencyIDs hasSBOM node failed with error: %w", err)
}
Expand All @@ -941,10 +947,12 @@ func updateHasSBOMWithIncludeOccurrences(ctx context.Context, client *ent.Client
batches := chunk(sortedIsOccurrenceUUIDs, 10000)

for _, batchedIsOccurUUIDs := range batches {
err := client.BillOfMaterials.
UpdateOneID(hasSBOMID).
AddIncludedOccurrenceIDs(batchedIsOccurUUIDs...).
Exec(ctx)
err := retryOnFKViolation(ctx, func() error {
return client.BillOfMaterials.
UpdateOneID(hasSBOMID).
AddIncludedOccurrenceIDs(batchedIsOccurUUIDs...).
Exec(ctx)
})
if err != nil {
return fmt.Errorf("update for IncludedOccurrenceIDs hasSBOM node failed with error: %w", err)
}
Expand Down
Loading