diff --git a/ctpolicy/ctpolicy.go b/ctpolicy/ctpolicy.go index 1ccf41c69b5..5a92c2d179e 100644 --- a/ctpolicy/ctpolicy.go +++ b/ctpolicy/ctpolicy.go @@ -111,7 +111,7 @@ func (ctp *CTPolicy) GetSCTs(ctx context.Context, cert core.CertDER, expiration // Identify the set of candidate logs whose temporal interval includes this // cert's expiry. Randomize the order of the logs so that we're not always // trying to submit to the same two. - logs := ctp.sctLogs.ForTime(expiration).Permute() + logs := ctp.sctLogs.ForTime(expiration).Shuffle() if len(logs) < 2 { return nil, berrors.MissingSCTsError("Insufficient CT logs available (%d)", len(logs)) } diff --git a/ctpolicy/loglist/loglist.go b/ctpolicy/loglist/loglist.go index be6f97bb49a..e4218ff3014 100644 --- a/ctpolicy/loglist/loglist.go +++ b/ctpolicy/loglist/loglist.go @@ -226,13 +226,25 @@ func (ll List) ForTime(expiry time.Time) List { return res } -// Permute returns a new log list containing the exact same logs, but in a -// randomly-shuffled order. -func (ll List) Permute() List { +// Shuffle returns a new log list containing the exact same logs, but in a +// nearly-randomly-shuffled order. If possible, it ensures that the first two +// logs consist of one tiled log and one non-tiled log, to boost the percentage +// of certs which include an SCT from a static log. +func (ll List) Shuffle() List { res := slices.Clone(ll) rand.Shuffle(len(res), func(i int, j int) { res[i], res[j] = res[j], res[i] }) + if len(res) > 2 && res[0].Tiled == res[1].Tiled { + // If we have more than two logs, and both are either tiled or not, + // try to bring another log to the front of the list. + for i := 2; i < len(res); i++ { + if res[0].Tiled != res[i].Tiled { + res[0], res[i] = res[i], res[0] + break + } + } + } return res } diff --git a/ctpolicy/loglist/loglist_test.go b/ctpolicy/loglist/loglist_test.go index 9eb1e6fa2b4..32c2938a48a 100644 --- a/ctpolicy/loglist/loglist_test.go +++ b/ctpolicy/loglist/loglist_test.go @@ -160,7 +160,7 @@ func TestForTime(t *testing.T) { test.AssertDeepEquals(t, actual, expected) } -func TestPermute(t *testing.T) { +func TestShuffle(t *testing.T) { input := List{ Log{Name: "Log A1"}, Log{Name: "Log A2"}, @@ -176,7 +176,7 @@ func TestPermute(t *testing.T) { } for range 100 { - actual := input.Permute() + actual := input.Shuffle() for index, log := range actual { foundIndices[log.Name][index]++ } @@ -191,6 +191,30 @@ func TestPermute(t *testing.T) { } } +func TestShufflePrefersTiled(t *testing.T) { + input := List{ + Log{Name: "Log A1"}, + Log{Name: "Log A2"}, + Log{Name: "Log T1", Tiled: true}, + } + + foundIndices := make(map[string]map[int]int) + for _, log := range input { + foundIndices[log.Name] = make(map[int]int) + } + + for range 100 { + actual := input.Shuffle() + for index, log := range actual { + foundIndices[log.Name][index]++ + } + } + + if foundIndices["Log T1"][2] != 0 { + t.Errorf("Tiled log should have always been pulled into first two indices") + } +} + func TestGetByID(t *testing.T) { input := List{ Log{Name: "Log A1", Id: "ID A1"},