Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 66 additions & 27 deletions core/services/ocr2/plugins/vault/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -470,6 +470,7 @@ func (r *ReportingPlugin) Observation(ctx context.Context, seqNr uint64, aq type
}

blobPayloads := make([][]byte, 0, len(localQueueItems))
blobPayloadIDs := make([]string, 0, len(localQueueItems))
maxObservedLocalQueueItems := 0
for _, item := range localQueueItems {
// The item is already in the pending queue. We'll be processing it
Expand Down Expand Up @@ -502,6 +503,7 @@ func (r *ReportingPlugin) Observation(ctx context.Context, seqNr uint64, aq type
}

blobPayloads = append(blobPayloads, itemb)
blobPayloadIDs = append(blobPayloadIDs, item.Id)

if len(blobPayloads) >= maxObservedLocalQueueItems {
r.lggr.Warnw("Observed local queue exceeds batch size limit, truncating",
Expand All @@ -512,35 +514,11 @@ func (r *ReportingPlugin) Observation(ctx context.Context, seqNr uint64, aq type
}
}

observedLocalQueue := make([][]byte, len(blobPayloads))
// Broadcast pending-queue blobs in parallel to reduce Observation() latency.
// Shortening this phase helps the OCR round finish within DeltaProgress.
blobBroadcastStart := time.Now()
defer func() {
r.lggr.Debugw("observation blob broadcast finished", "seqNr", seqNr, "blobCount", len(blobPayloads), "elapsed", time.Since(blobBroadcastStart))
}()
g, broadcastCtx := errgroup.WithContext(ctx)
for i, payload := range blobPayloads {
g.Go(func() error {
blobHandle, ierr2 := blobBroadcastFetcher.BroadcastBlob(broadcastCtx, payload, ocr3_1types.BlobExpirationHintSequenceNumber{SeqNr: seqNr + 2})
if ierr2 != nil {
return fmt.Errorf("could not broadcast pending queue item as blob: %w", ierr2)
}

blobHandleBytes, ierr2 := r.marshalBlob(blobHandle)
if ierr2 != nil {
return fmt.Errorf("could not marshal blob handle to bytes: %w", ierr2)
}

observedLocalQueue[i] = blobHandleBytes
return nil
})
}
if err = g.Wait(); err != nil {
pendingQueueItems, err := r.broadcastBlobPayloads(ctx, blobBroadcastFetcher, seqNr, blobPayloads, blobPayloadIDs)
if err != nil {
return nil, err
}

obspb.PendingQueueItems = observedLocalQueue
obspb.PendingQueueItems = pendingQueueItems

// Second, generate a random nonce that we'll use to sort the observations.
// Each node generates a nonce idepedently, to be concatenated later on.
Expand All @@ -563,6 +541,67 @@ func (r *ReportingPlugin) Observation(ctx context.Context, seqNr uint64, aq type
return types.Observation(obsb), nil
}

// broadcastBlobPayloads broadcasts each payload as a blob in parallel to reduce
// Observation() latency (shortening this phase helps the OCR round finish within
// DeltaProgress). Individual broadcast failures are logged and skipped rather than
// aborting the entire observation, so that one problematic payload does not prevent
// the remaining items from being observed. Context cancellation/deadline errors are
// propagated immediately so that expired rounds fail fast.
func (r *ReportingPlugin) broadcastBlobPayloads(
ctx context.Context,
fetcher ocr3_1types.BlobBroadcastFetcher,
seqNr uint64,
payloads [][]byte,
requestIDs []string,
) ([][]byte, error) {
results := make([][]byte, len(payloads))

start := time.Now()
defer func() {
r.lggr.Debugw("observation blob broadcast finished", "seqNr", seqNr, "blobCount", len(payloads), "elapsed", time.Since(start))
}()

var g errgroup.Group
for i, payload := range payloads {
g.Go(func() error {
blobHandle, err := fetcher.BroadcastBlob(ctx, payload, ocr3_1types.BlobExpirationHintSequenceNumber{SeqNr: seqNr + 2})
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@prashantkumar1982 The way I read this a single request that takes a long time will delay the whole batch, and could even cause it to fail since there's no actual timeout associated with the request (ctx will only be cancelled when the epoch changes)

Is it worth adding an explicit timeout for these requests?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, yes if there's a reason to believe these calls can be stuck for a long time.
My understanding was that these were local calls, and unlikely to stall the whole observation phase for a long time.

if err != nil {
if ctx.Err() != nil {
return ctx.Err()
}
r.lggr.Warnw("failed to broadcast pending queue item as blob, skipping",
"seqNr", seqNr,
"requestID", requestIDs[i],
"err", err)
return nil
Comment thread
prashantkumar1982 marked this conversation as resolved.
}

blobHandleBytes, err := r.marshalBlob(blobHandle)
if err != nil {
r.lggr.Warnw("failed to marshal blob handle, skipping",
"seqNr", seqNr,
"requestID", requestIDs[i],
"err", err)
return nil
}

results[i] = blobHandleBytes
return nil
})
}
if err := g.Wait(); err != nil {
return nil, err
}

filtered := make([][]byte, 0, len(results))
for _, item := range results {
if item != nil {
filtered = append(filtered, item)
}
}
return filtered, nil
}

func (r *ReportingPlugin) observeGetSecrets(ctx context.Context, reader ReadKVStore, req proto.Message, o *vaultcommon.Observation) {
tp := req.(*vaultcommon.GetSecretsRequest)
o.RequestType = vaultcommon.RequestType_GET_SECRETS
Expand Down
231 changes: 228 additions & 3 deletions core/services/ocr2/plugins/vault/plugin_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -760,7 +760,7 @@ func TestPlugin_Observation_PendingQueueEnabled_BroadcastsPendingQueueBlobsInPar
}

func TestPlugin_Observation_PendingQueueEnabled_BroadcastBlobError(t *testing.T) {
lggr := logger.TestLogger(t)
lggr, observed := logger.TestLoggerObserved(t, zapcore.WarnLevel)
store := requests.NewStore[*vaulttypes.Request]()
r := &ReportingPlugin{
lggr: lggr,
Expand Down Expand Up @@ -803,8 +803,15 @@ func TestPlugin_Observation_PendingQueueEnabled_BroadcastBlobError(t *testing.T)
require.NoError(t, store.Add(&vaulttypes.Request{Payload: p, IDVal: "request-1"}))
rdr := &kv{m: make(map[string]response)}

_, err = r.Observation(t.Context(), 1, types.AttributedQuery{}, rdr, &errorBlobBroadcastFetcher{err: errors.New("boom")})
require.ErrorContains(t, err, "could not broadcast pending queue item as blob: boom")
obs, err := r.Observation(t.Context(), 1, types.AttributedQuery{}, rdr, &errorBlobBroadcastFetcher{err: errors.New("boom")})
require.NoError(t, err)
require.NotNil(t, obs)

warnLogs := observed.FilterMessage("failed to broadcast pending queue item as blob, skipping")
assert.Equal(t, 1, warnLogs.Len())
fields := warnLogs.All()[0].ContextMap()
assert.Equal(t, "request-1", fields["requestID"])
assert.Contains(t, fmt.Sprint(fields["err"]), "boom")
}

func TestPlugin_Observation_GetSecretsRequest_SecretIdentifierInvalid(t *testing.T) {
Expand Down Expand Up @@ -5166,6 +5173,21 @@ func mockMarshalBlob(ocr3_1types.BlobHandle) ([]byte, error) {
return []byte{}, nil
}

type callbackBlobFetcher struct {
fn func(payload []byte) error
}

func (f *callbackBlobFetcher) BroadcastBlob(_ context.Context, payload []byte, _ ocr3_1types.BlobExpirationHint) (ocr3_1types.BlobHandle, error) {
if err := f.fn(payload); err != nil {
return ocr3_1types.BlobHandle{}, err
}
return ocr3_1types.BlobHandle{}, nil
}

func (f *callbackBlobFetcher) FetchBlob(context.Context, ocr3_1types.BlobHandle) ([]byte, error) {
panic("FetchBlob should not be called in broadcastBlobPayloads tests")
}

func TestPlugin_StateTransition_StoresPendingQueue(t *testing.T) {
lggr := logger.TestLogger(t)
store := requests.NewStore[*vaulttypes.Request]()
Expand Down Expand Up @@ -7108,3 +7130,206 @@ func TestLogUserErrorAware(t *testing.T) {
assert.Contains(t, fmt.Sprint(fields["error"]), "internal error")
})
}

func TestPlugin_broadcastBlobPayloads(t *testing.T) {
t.Run("empty payloads returns empty slice", func(t *testing.T) {
lggr := logger.TestLogger(t)
r := &ReportingPlugin{
lggr: lggr,
metrics: newTestMetrics(t),
marshalBlob: func(ocr3_1types.BlobHandle) ([]byte, error) {
return []byte("handle"), nil
},
}

fetcher := &callbackBlobFetcher{fn: func([]byte) error { return nil }}
result, err := r.broadcastBlobPayloads(t.Context(), fetcher, 1, nil, nil)
require.NoError(t, err)
assert.Empty(t, result)
})

t.Run("all payloads broadcast successfully", func(t *testing.T) {
lggr := logger.TestLogger(t)
r := &ReportingPlugin{
lggr: lggr,
metrics: newTestMetrics(t),
marshalBlob: func(ocr3_1types.BlobHandle) ([]byte, error) {
return []byte("handle"), nil
},
}

fetcher := &callbackBlobFetcher{fn: func([]byte) error { return nil }}
payloads := [][]byte{[]byte("p1"), []byte("p2"), []byte("p3")}
ids := []string{"req-1", "req-2", "req-3"}

result, err := r.broadcastBlobPayloads(t.Context(), fetcher, 1, payloads, ids)
require.NoError(t, err)
assert.Len(t, result, 3)
for _, item := range result {
assert.Equal(t, []byte("handle"), item)
}
})

t.Run("failed broadcast is skipped and logged", func(t *testing.T) {
lggr, observed := logger.TestLoggerObserved(t, zapcore.WarnLevel)
r := &ReportingPlugin{
lggr: lggr,
metrics: newTestMetrics(t),
marshalBlob: func(ocr3_1types.BlobHandle) ([]byte, error) {
return []byte("handle"), nil
},
}

fetcher := &callbackBlobFetcher{fn: func(payload []byte) error {
if string(payload) == "p2" {
return errors.New("broadcast error")
}
return nil
}}

payloads := [][]byte{[]byte("p1"), []byte("p2"), []byte("p3")}
ids := []string{"req-1", "req-2", "req-3"}

result, err := r.broadcastBlobPayloads(t.Context(), fetcher, 5, payloads, ids)
require.NoError(t, err)
assert.Len(t, result, 2)

warnLogs := observed.FilterMessage("failed to broadcast pending queue item as blob, skipping")
assert.Equal(t, 1, warnLogs.Len())
fields := warnLogs.All()[0].ContextMap()
assert.Equal(t, "req-2", fields["requestID"])
assert.Equal(t, uint64(5), fields["seqNr"])
assert.Contains(t, fmt.Sprint(fields["err"]), "broadcast error")
})

t.Run("all broadcasts fail returns empty slice", func(t *testing.T) {
lggr, observed := logger.TestLoggerObserved(t, zapcore.WarnLevel)
r := &ReportingPlugin{
lggr: lggr,
metrics: newTestMetrics(t),
marshalBlob: func(ocr3_1types.BlobHandle) ([]byte, error) {
return []byte("handle"), nil
},
}

fetcher := &errorBlobBroadcastFetcher{err: errors.New("network down")}
payloads := [][]byte{[]byte("p1"), []byte("p2")}
ids := []string{"req-1", "req-2"}

result, err := r.broadcastBlobPayloads(t.Context(), fetcher, 1, payloads, ids)
require.NoError(t, err)
assert.Empty(t, result)

warnLogs := observed.FilterMessage("failed to broadcast pending queue item as blob, skipping")
assert.Equal(t, 2, warnLogs.Len())
})

t.Run("marshal blob failure skips item and logs warning", func(t *testing.T) {
lggr, observed := logger.TestLoggerObserved(t, zapcore.WarnLevel)
r := &ReportingPlugin{
lggr: lggr,
metrics: newTestMetrics(t),
marshalBlob: func(ocr3_1types.BlobHandle) ([]byte, error) {
return nil, errors.New("marshal error")
},
}

fetcher := &callbackBlobFetcher{fn: func([]byte) error { return nil }}
payloads := [][]byte{[]byte("p1"), []byte("p2")}
ids := []string{"req-1", "req-2"}

result, err := r.broadcastBlobPayloads(t.Context(), fetcher, 1, payloads, ids)
require.NoError(t, err)
assert.Empty(t, result)

warnLogs := observed.FilterMessage("failed to marshal blob handle, skipping")
assert.Equal(t, 2, warnLogs.Len())
})

t.Run("mix of broadcast and marshal failures", func(t *testing.T) {
lggr, observed := logger.TestLoggerObserved(t, zapcore.WarnLevel)

marshalCallCount := atomic.Int32{}
r := &ReportingPlugin{
lggr: lggr,
metrics: newTestMetrics(t),
marshalBlob: func(ocr3_1types.BlobHandle) ([]byte, error) {
n := marshalCallCount.Add(1)
if n == 1 {
return nil, errors.New("marshal error")
}
return []byte("handle"), nil
},
}

fetcher := &callbackBlobFetcher{fn: func(payload []byte) error {
if string(payload) == "p1" {
return errors.New("broadcast error")
}
return nil
}}

payloads := [][]byte{[]byte("p1"), []byte("p2"), []byte("p3")}
ids := []string{"req-1", "req-2", "req-3"}

result, err := r.broadcastBlobPayloads(t.Context(), fetcher, 1, payloads, ids)
require.NoError(t, err)

broadcastWarns := observed.FilterMessage("failed to broadcast pending queue item as blob, skipping")
marshalWarns := observed.FilterMessage("failed to marshal blob handle, skipping")
assert.Equal(t, 1, broadcastWarns.Len())
assert.Equal(t, 1, marshalWarns.Len())
assert.Len(t, result, 1)
})

t.Run("context cancellation propagates error", func(t *testing.T) {
lggr := logger.TestLogger(t)
r := &ReportingPlugin{
lggr: lggr,
metrics: newTestMetrics(t),
marshalBlob: func(ocr3_1types.BlobHandle) ([]byte, error) {
return []byte("handle"), nil
},
}

ctx, cancel := context.WithCancel(t.Context())
cancel()

fetcher := &callbackBlobFetcher{fn: func([]byte) error {
return ctx.Err()
}}

payloads := [][]byte{[]byte("p1"), []byte("p2")}
ids := []string{"req-1", "req-2"}

result, err := r.broadcastBlobPayloads(ctx, fetcher, 1, payloads, ids)
assert.Nil(t, result)
assert.ErrorIs(t, err, context.Canceled)
})

t.Run("context deadline exceeded propagates error", func(t *testing.T) {
lggr := logger.TestLogger(t)
r := &ReportingPlugin{
lggr: lggr,
metrics: newTestMetrics(t),
marshalBlob: func(ocr3_1types.BlobHandle) ([]byte, error) {
return []byte("handle"), nil
},
}

ctx, cancel := context.WithTimeout(t.Context(), 0)
defer cancel()
<-ctx.Done()

fetcher := &callbackBlobFetcher{fn: func([]byte) error {
return ctx.Err()
}}

payloads := [][]byte{[]byte("p1")}
ids := []string{"req-1"}

result, err := r.broadcastBlobPayloads(ctx, fetcher, 1, payloads, ids)
assert.Nil(t, result)
assert.ErrorIs(t, err, context.DeadlineExceeded)
})
}
Loading