Skip to content

Commit 63b33da

Browse files
committed
test(tasks): add unit tests for UpdateTotal and its interaction with Done/Finish
- Verify UpdateTotal persists only total field without touching cur - Verify UpdateTotal and Done do not race on concurrent access - Verify Finish does not overwrite a positive Total set by UpdateTotal
1 parent b9f6a03 commit 63b33da

1 file changed

Lines changed: 123 additions & 0 deletions

File tree

server/internal/core/task_runtime_test.go

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,17 +15,20 @@ func installTaskDBMocks() func() {
1515
origGet := taskDBGetBySessionAndSeq
1616
origUpdate := taskDBUpdate
1717
origUpdateCur := taskDBUpdateCur
18+
origUpdateTotal := taskDBUpdateTotal
1819
origUpdateFinish := taskDBUpdateFinish
1920

2021
taskDBGetBySessionAndSeq = func(string, uint32) (*models.Task, error) { return nil, nil }
2122
taskDBUpdate = func(*clientpb.Task) error { return nil }
2223
taskDBUpdateCur = func(string, int) error { return nil }
24+
taskDBUpdateTotal = func(string, int) error { return nil }
2325
taskDBUpdateFinish = func(string) error { return nil }
2426

2527
return func() {
2628
taskDBGetBySessionAndSeq = origGet
2729
taskDBUpdate = origUpdate
2830
taskDBUpdateCur = origUpdateCur
31+
taskDBUpdateTotal = origUpdateTotal
2932
taskDBUpdateFinish = origUpdateFinish
3033
}
3134
}
@@ -243,3 +246,123 @@ func TestTask_ClosedFieldRaceSafe(t *testing.T) {
243246
t.Fatal("task should be closed after Close()")
244247
}
245248
}
249+
250+
func TestUpdateTotalPersistsOnlyTotalField(t *testing.T) {
251+
cleanup := installTaskDBMocks()
252+
defer cleanup()
253+
254+
var (
255+
persistedID string
256+
persistedTotal int
257+
curWritten bool
258+
)
259+
taskDBUpdateTotal = func(taskID string, total int) error {
260+
persistedID = taskID
261+
persistedTotal = total
262+
return nil
263+
}
264+
// Ensure taskDBUpdate (which writes cur+total) is NOT called.
265+
taskDBUpdate = func(pb *clientpb.Task) error {
266+
curWritten = true
267+
return nil
268+
}
269+
270+
task := &Task{
271+
Id: 5,
272+
SessionId: "update-total-test",
273+
Cur: 3,
274+
Total: 1,
275+
}
276+
277+
task.UpdateTotal(10)
278+
279+
if task.Total != 10 {
280+
t.Fatalf("in-memory total = %d, want 10", task.Total)
281+
}
282+
if persistedID != task.TaskID() {
283+
t.Fatalf("persisted task id = %q, want %q", persistedID, task.TaskID())
284+
}
285+
if persistedTotal != 10 {
286+
t.Fatalf("persisted total = %d, want 10", persistedTotal)
287+
}
288+
if curWritten {
289+
t.Fatal("UpdateTotal should not call taskDBUpdate (which writes cur)")
290+
}
291+
}
292+
293+
func TestUpdateTotalDoesNotRaceWithDone(t *testing.T) {
294+
cleanup := installTaskDBMocks()
295+
defer cleanup()
296+
297+
broker := newTestBroker()
298+
oldBroker := EventBroker
299+
EventBroker = broker
300+
defer func() { EventBroker = oldBroker }()
301+
302+
sess := newTestSession("update-total-race")
303+
sess.Ctx, sess.Cancel = context.WithCancel(context.Background())
304+
defer sess.Cancel()
305+
306+
task := sess.NewTask("download", 1)
307+
308+
// Run UpdateTotal and Done concurrently.
309+
var wg sync.WaitGroup
310+
wg.Add(2)
311+
go func() {
312+
defer wg.Done()
313+
task.UpdateTotal(10)
314+
}()
315+
go func() {
316+
defer wg.Done()
317+
task.Done(&implantpb.Spite{TaskId: task.Id}, "chunk")
318+
}()
319+
wg.Wait()
320+
321+
cur, total := task.Progress()
322+
if total != 10 {
323+
t.Fatalf("total = %d, want 10", total)
324+
}
325+
if cur != 1 {
326+
t.Fatalf("cur = %d, want 1", cur)
327+
}
328+
}
329+
330+
func TestFinishDoesNotOverwritePositiveTotal(t *testing.T) {
331+
cleanup := installTaskDBMocks()
332+
defer cleanup()
333+
334+
broker := newTestBroker()
335+
oldBroker := EventBroker
336+
EventBroker = broker
337+
defer func() { EventBroker = oldBroker }()
338+
339+
sess := newTestSession("finish-positive-total")
340+
task := &Task{
341+
Id: 7,
342+
Type: "download",
343+
SessionId: sess.ID,
344+
Session: sess,
345+
Cur: 5,
346+
Total: 10,
347+
DoneCh: make(chan bool, 1),
348+
}
349+
task.Ctx, task.Cancel = context.WithCancel(context.Background())
350+
defer task.Cancel()
351+
352+
// taskDBUpdate should NOT be called when Total is already positive,
353+
// because the `if t.Total < 0` branch in Finish() won't trigger.
354+
var updateCalled bool
355+
taskDBUpdate = func(pb *clientpb.Task) error {
356+
updateCalled = true
357+
return nil
358+
}
359+
360+
task.Finish(&implantpb.Spite{TaskId: task.Id}, "done")
361+
362+
if task.Total != 10 {
363+
t.Fatalf("total = %d, want 10 (should not be overwritten)", task.Total)
364+
}
365+
if updateCalled {
366+
t.Fatal("Finish should not call taskDBUpdate when Total is already positive")
367+
}
368+
}

0 commit comments

Comments
 (0)