Skip to content

Commit d5e0add

Browse files
committed
[Prism] Fix SDF splits and checkpoints in ElementManager.ReturnResiduals
Fixes a bug where splittable DoFN (SDF) splits and checkpoints would completely lose their deferred residual restrictions, causing the pipeline to terminate prematurely and downstream assertions to fail due to missing elements (e.g., test_register_finalizations). Root Cause: - The previous patch in v2 fixed a double-counting bug of livePending by removing the decoding and rescheduling of split residuals (stage.AddPending and em.addPending) from ReturnResiduals(), assuming they were already placed back by stage.splitBundle(). - This holds true for normal non-SDF channel splits, where stage.splitBundle already puts the unprocessed original elements back. - However, when a splittable DoFn (SDF) checkpoints itself, the active element splits on its restriction rather than simple unprocessed channel elements. In this case, the original remaining elements (res) in splitBundle() has a length of 0, but the SDK worker returns a new restriction in the residuals.Data (e.g. unprocessedElements length 1). - Because the previous patch completely removed rescheduling from ReturnResiduals, this new residual restriction was completely lost and never added back to the pending queue. Solution: - Inside ReturnResiduals(), we dynamically calculate the original remaining elements in the bundle: originalRemainingCount := len(completed.es) - firstRsIndex. - We compare the total returned residuals (unprocessedElements) against originalRemainingCount. - If len(unprocessedElements) > originalRemainingCount, the difference represents the new SDF residual restrictions. We selectively add ONLY these new residuals back to the stage pending heap and safely increment em.livePending by this difference. - This elegantly preserves the fix for normal channel splits (preventing double-counting), while ensuring SDF checkpoint residuals are correctly scheduled. - Also includes detailed slog.Info logging during execution to track livePending state changes accurately in addPending, ReturnResiduals, and splitBundle.
1 parent 76e656e commit d5e0add

2 files changed

Lines changed: 75 additions & 45 deletions

File tree

sdks/go/pkg/beam/runners/prism/internal/engine/elementmanager.go

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -240,8 +240,10 @@ type ElementManager struct {
240240
}
241241

242242
func (em *ElementManager) addPending(v int) {
243+
prev := em.livePending.Load()
243244
em.livePending.Add(int64(v))
244245
em.pendingElements.Add(v)
246+
slog.Info("em.addPending", "delta", v, "prev", prev, "current", em.livePending.Load())
245247
}
246248

247249
// LinkID represents a fully qualified input or output.
@@ -1091,16 +1093,25 @@ func (em *ElementManager) FailBundle(rb RunBundle) {
10911093
em.markChangedAndClearBundle(rb.StageID, rb.BundleID, nil)
10921094
}
10931095

1094-
// ReturnResiduals is called after a successful split, so the remaining work
1095-
// can be re-assigned to a new bundle.
10961096
func (em *ElementManager) ReturnResiduals(rb RunBundle, firstRsIndex int, inputInfo PColInfo, residuals Residuals) {
10971097
stage := em.stages[rb.StageID]
10981098

1099+
slog.Info("ElementManager.ReturnResiduals start", "bundle", rb, "firstRsIndex", firstRsIndex)
1100+
1101+
stage.mu.Lock()
1102+
completed := stage.inprogress[rb.BundleID]
1103+
originalRemainingCount := len(completed.es) - firstRsIndex
1104+
stage.mu.Unlock()
1105+
10991106
stage.splitBundle(rb, firstRsIndex, em)
11001107
unprocessedElements := reElementResiduals(residuals.Data, inputInfo, rb)
1101-
if len(unprocessedElements) > 0 {
1102-
slog.Debug("ReturnResiduals: unprocessed elements", "bundle", rb, "count", len(unprocessedElements))
1108+
if len(unprocessedElements) > originalRemainingCount {
1109+
newResiduals := unprocessedElements[originalRemainingCount:]
1110+
slog.Info("ReturnResiduals: new residuals added back", "bundle", rb, "count", len(newResiduals))
1111+
count := stage.AddPending(em, newResiduals)
1112+
em.addPending(count)
11031113
}
1114+
slog.Info("ElementManager.ReturnResiduals end", "bundle", rb, "unprocessedCount", len(unprocessedElements), "livePending", em.livePending.Load())
11041115
em.markStagesAsChanged(singleSet(rb.StageID))
11051116
}
11061117

@@ -2185,7 +2196,7 @@ func (ss *stageState) splitBundle(rb RunBundle, firstResidual int, em *ElementMa
21852196
defer ss.mu.Unlock()
21862197

21872198
es := ss.inprogress[rb.BundleID]
2188-
slog.Debug("split elements", "bundle", rb, "elem count", len(es.es), "res", firstResidual)
2199+
slog.Info("splitBundle start", "bundle", rb, "elem count", len(es.es), "firstResidual", firstResidual, "livePending", em.livePending.Load())
21892200

21902201
prim := es.es[:firstResidual]
21912202
res := es.es[firstResidual:]
@@ -2205,6 +2216,7 @@ func (ss *stageState) splitBundle(rb RunBundle, firstResidual int, em *ElementMa
22052216
// we don't need to increment pending count in em, since it is already pending
22062217
ss.kind.addPending(ss, em, res)
22072218
ss.inprogress[rb.BundleID] = es
2219+
slog.Info("splitBundle completed", "bundle", rb, "primaryCount", len(prim), "residualCount", len(res), "livePending", em.livePending.Load())
22082220
}
22092221

22102222
// minimumPendingTimestamp returns the minimum pending timestamp from all pending elements,

sdks/go/pkg/beam/runners/prism/internal/engine/elementmanager_test.go

Lines changed: 58 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -686,53 +686,71 @@ func TestElementManager_OnWindowExpiration(t *testing.T) {
686686
}
687687

688688
func TestElementManager_ReturnResidualsPendingCount(t *testing.T) {
689-
em := NewElementManager(Config{})
690-
em.AddStage("impulse", nil, []string{"input"}, nil)
691-
em.AddStage("dofn", []string{"input"}, nil, nil)
692-
em.Impulse("impulse")
693-
694-
stage := em.stages["dofn"]
695-
info := PColInfo{
696-
GlobalID: "generic_info",
697-
WDec: exec.MakeWindowDecoder(coder.NewGlobalWindow()),
698-
WEnc: exec.MakeWindowEncoder(coder.NewGlobalWindow()),
699-
EDec: func(r io.Reader) []byte {
700-
b, _ := io.ReadAll(r)
701-
return b
689+
tests := []struct {
690+
name string
691+
firstRsIndex int
692+
wantFinalPending int64
693+
}{
694+
{
695+
name: "ChannelSplit",
696+
firstRsIndex: 0,
697+
wantFinalPending: 1,
698+
},
699+
{
700+
name: "SDFCheckpoint",
701+
firstRsIndex: 1,
702+
wantFinalPending: 2, // Incremented by 1 because the active portion (index 0) is still in progress and will be completed/decremented in PersistBundle.
702703
},
703704
}
704705

705-
// Initial state should have 1 pending element from impulse
706-
if got, want := em.livePending.Load(), int64(1); got != want {
707-
t.Fatalf("initial livePending = %v, want %v", got, want)
708-
}
706+
for _, test := range tests {
707+
t.Run(test.name, func(t *testing.T) {
708+
em := NewElementManager(Config{})
709+
em.AddStage("impulse", nil, []string{"input"}, nil)
710+
em.AddStage("dofn", []string{"input"}, nil, nil)
711+
em.Impulse("impulse")
712+
713+
stage := em.stages["dofn"]
714+
info := PColInfo{
715+
GlobalID: "generic_info",
716+
WDec: exec.MakeWindowDecoder(coder.NewGlobalWindow()),
717+
WEnc: exec.MakeWindowEncoder(coder.NewGlobalWindow()),
718+
EDec: func(r io.Reader) []byte {
719+
b, _ := io.ReadAll(r)
720+
return b
721+
},
722+
}
709723

710-
// Start a bundle
711-
bundID, ok, _, _ := stage.startEventTimeBundle(mtime.MaxTimestamp, func() string { return "inst0" })
712-
if !ok {
713-
t.Fatalf("failed to start bundle")
714-
}
724+
// Initial state should have 1 pending element from impulse
725+
if got, want := em.livePending.Load(), int64(1); got != want {
726+
t.Fatalf("initial livePending = %v, want %v", got, want)
727+
}
715728

716-
// Waitgroup/livePending shouldn't change on starting a bundle (it's still pending)
717-
if got, want := em.livePending.Load(), int64(1); got != want {
718-
t.Fatalf("livePending after startEventTimeBundle = %v, want %v", got, want)
719-
}
729+
// Start a bundle
730+
bundID, ok, _, _ := stage.startEventTimeBundle(mtime.MaxTimestamp, func() string { return "inst0" })
731+
if !ok {
732+
t.Fatalf("failed to start bundle")
733+
}
720734

721-
// Prepare residuals
722-
residBytes := []byte{127, 223, 59, 100, 90, 28, 172, 9, 0, 0, 0, 1, 15, 3, 65, 66, 67} // windowed value header + ABC
723-
residuals := Residuals{
724-
Data: []Residual{{Element: residBytes}},
725-
}
735+
// Waitgroup/livePending shouldn't change on starting a bundle (it's still pending)
736+
if got, want := em.livePending.Load(), int64(1); got != want {
737+
t.Fatalf("livePending after startEventTimeBundle = %v, want %v", got, want)
738+
}
739+
740+
// Prepare residuals
741+
residBytes := []byte{127, 223, 59, 100, 90, 28, 172, 9, 0, 0, 0, 1, 15, 3, 65, 66, 67} // windowed value header + ABC
742+
residuals := Residuals{
743+
Data: []Residual{{Element: residBytes}},
744+
}
745+
746+
rb := RunBundle{StageID: "dofn", BundleID: bundID}
726747

727-
rb := RunBundle{StageID: "dofn", BundleID: bundID}
728-
729-
// Return residuals (Simulates splitting)
730-
em.ReturnResiduals(rb, 0, info, residuals)
748+
// Return residuals (Simulates splitting)
749+
em.ReturnResiduals(rb, test.firstRsIndex, info, residuals)
731750

732-
// Since we split the bundle (0 primary completed, 1 residual returned),
733-
// the element remains pending. The pending count MUST still be exactly 1!
734-
if got, want := em.livePending.Load(), int64(1); got != want {
735-
t.Errorf("BUG DETECTED: livePending after ReturnResiduals = %v, want 1! (Elements counted twice)", got)
751+
if got, want := em.livePending.Load(), test.wantFinalPending; got != want {
752+
t.Errorf("livePending after ReturnResiduals = %v, want %v", got, want)
753+
}
754+
})
736755
}
737756
}
738-

0 commit comments

Comments
 (0)