@@ -15,17 +15,20 @@ func installTaskDBMocks() func() {
1515 origGet := taskDBGetBySessionAndSeq
1616 origUpdate := taskDBUpdate
1717 origUpdateCur := taskDBUpdateCur
18+ origUpdateTotal := taskDBUpdateTotal
1819 origUpdateFinish := taskDBUpdateFinish
1920
2021 taskDBGetBySessionAndSeq = func (string , uint32 ) (* models.Task , error ) { return nil , nil }
2122 taskDBUpdate = func (* clientpb.Task ) error { return nil }
2223 taskDBUpdateCur = func (string , int ) error { return nil }
24+ taskDBUpdateTotal = func (string , int ) error { return nil }
2325 taskDBUpdateFinish = func (string ) error { return nil }
2426
2527 return func () {
2628 taskDBGetBySessionAndSeq = origGet
2729 taskDBUpdate = origUpdate
2830 taskDBUpdateCur = origUpdateCur
31+ taskDBUpdateTotal = origUpdateTotal
2932 taskDBUpdateFinish = origUpdateFinish
3033 }
3134}
@@ -243,3 +246,123 @@ func TestTask_ClosedFieldRaceSafe(t *testing.T) {
243246 t .Fatal ("task should be closed after Close()" )
244247 }
245248}
249+
250+ func TestUpdateTotalPersistsOnlyTotalField (t * testing.T ) {
251+ cleanup := installTaskDBMocks ()
252+ defer cleanup ()
253+
254+ var (
255+ persistedID string
256+ persistedTotal int
257+ curWritten bool
258+ )
259+ taskDBUpdateTotal = func (taskID string , total int ) error {
260+ persistedID = taskID
261+ persistedTotal = total
262+ return nil
263+ }
264+ // Ensure taskDBUpdate (which writes cur+total) is NOT called.
265+ taskDBUpdate = func (pb * clientpb.Task ) error {
266+ curWritten = true
267+ return nil
268+ }
269+
270+ task := & Task {
271+ Id : 5 ,
272+ SessionId : "update-total-test" ,
273+ Cur : 3 ,
274+ Total : 1 ,
275+ }
276+
277+ task .UpdateTotal (10 )
278+
279+ if task .Total != 10 {
280+ t .Fatalf ("in-memory total = %d, want 10" , task .Total )
281+ }
282+ if persistedID != task .TaskID () {
283+ t .Fatalf ("persisted task id = %q, want %q" , persistedID , task .TaskID ())
284+ }
285+ if persistedTotal != 10 {
286+ t .Fatalf ("persisted total = %d, want 10" , persistedTotal )
287+ }
288+ if curWritten {
289+ t .Fatal ("UpdateTotal should not call taskDBUpdate (which writes cur)" )
290+ }
291+ }
292+
293+ func TestUpdateTotalDoesNotRaceWithDone (t * testing.T ) {
294+ cleanup := installTaskDBMocks ()
295+ defer cleanup ()
296+
297+ broker := newTestBroker ()
298+ oldBroker := EventBroker
299+ EventBroker = broker
300+ defer func () { EventBroker = oldBroker }()
301+
302+ sess := newTestSession ("update-total-race" )
303+ sess .Ctx , sess .Cancel = context .WithCancel (context .Background ())
304+ defer sess .Cancel ()
305+
306+ task := sess .NewTask ("download" , 1 )
307+
308+ // Run UpdateTotal and Done concurrently.
309+ var wg sync.WaitGroup
310+ wg .Add (2 )
311+ go func () {
312+ defer wg .Done ()
313+ task .UpdateTotal (10 )
314+ }()
315+ go func () {
316+ defer wg .Done ()
317+ task .Done (& implantpb.Spite {TaskId : task .Id }, "chunk" )
318+ }()
319+ wg .Wait ()
320+
321+ cur , total := task .Progress ()
322+ if total != 10 {
323+ t .Fatalf ("total = %d, want 10" , total )
324+ }
325+ if cur != 1 {
326+ t .Fatalf ("cur = %d, want 1" , cur )
327+ }
328+ }
329+
330+ func TestFinishDoesNotOverwritePositiveTotal (t * testing.T ) {
331+ cleanup := installTaskDBMocks ()
332+ defer cleanup ()
333+
334+ broker := newTestBroker ()
335+ oldBroker := EventBroker
336+ EventBroker = broker
337+ defer func () { EventBroker = oldBroker }()
338+
339+ sess := newTestSession ("finish-positive-total" )
340+ task := & Task {
341+ Id : 7 ,
342+ Type : "download" ,
343+ SessionId : sess .ID ,
344+ Session : sess ,
345+ Cur : 5 ,
346+ Total : 10 ,
347+ DoneCh : make (chan bool , 1 ),
348+ }
349+ task .Ctx , task .Cancel = context .WithCancel (context .Background ())
350+ defer task .Cancel ()
351+
352+ // taskDBUpdate should NOT be called when Total is already positive,
353+ // because the `if t.Total < 0` branch in Finish() won't trigger.
354+ var updateCalled bool
355+ taskDBUpdate = func (pb * clientpb.Task ) error {
356+ updateCalled = true
357+ return nil
358+ }
359+
360+ task .Finish (& implantpb.Spite {TaskId : task .Id }, "done" )
361+
362+ if task .Total != 10 {
363+ t .Fatalf ("total = %d, want 10 (should not be overwritten)" , task .Total )
364+ }
365+ if updateCalled {
366+ t .Fatal ("Finish should not call taskDBUpdate when Total is already positive" )
367+ }
368+ }
0 commit comments