Skip to content

Commit b9f6a03

Browse files
committed
fix(tasks): persist total to DB immediately when updated at runtime
Add Task.UpdateTotal() that writes only the total field to DB, avoiding race conditions with Done() which updates cur independently. Download handler now uses UpdateTotal() after discovering chunk count so that task progress queries return accurate Cur/Total during transfer.
1 parent 7f5e17e commit b9f6a03

3 files changed

Lines changed: 23 additions & 1 deletion

File tree

server/internal/core/task.go

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@ var (
2626
taskDBUpdateCur = func(taskID string, cur int) error {
2727
return db.UpdateTaskCur(taskID, cur)
2828
}
29+
taskDBUpdateTotal = func(taskID string, total int) error {
30+
return db.UpdateTaskTotal(taskID, total)
31+
}
2932
taskDBUpdateFinish = func(taskID string) error {
3033
return db.UpdateTaskFinish(taskID)
3134
}
@@ -216,6 +219,18 @@ func (t *Task) Publish(op string, spite *implantpb.Spite, msg string) {
216219
Callee: t.Callee,
217220
})
218221
}
222+
223+
// UpdateTotal sets the task's total count and persists only the total field to DB.
224+
// This avoids racing with Done() which updates cur independently.
225+
func (t *Task) UpdateTotal(total int) {
226+
t.progressMu.Lock()
227+
t.Total = total
228+
t.progressMu.Unlock()
229+
if err := taskDBUpdateTotal(t.TaskID(), total); err != nil {
230+
logs.Log.Warnf("task %s: update total failed: %v", t.TaskID(), err)
231+
}
232+
}
233+
219234
func (t *Task) Done(spite *implantpb.Spite, msg string) {
220235
t.progressMu.Lock()
221236
t.Cur++

server/internal/db/session_helper.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,13 @@ func UpdateTaskCur(taskID string, cur int) error {
278278
return taskModel.UpdateCur(Session(), cur)
279279
}
280280

281+
func UpdateTaskTotal(taskID string, total int) error {
282+
taskModel := &models.Task{
283+
ID: taskID,
284+
}
285+
return taskModel.UpdateTotal(Session(), total)
286+
}
287+
281288
func UpdateTaskFinish(taskID string) error {
282289
task, err := NewTaskQuery().WhereID(taskID).First()
283290
if err != nil {

server/rpc/rpc-file.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,7 @@ func (rpc *Server) Download(ctx context.Context, req *implantpb.DownloadRequest)
307307
}
308308
total := downloadChunkCount(int(resp.GetDownloadResponse().Size), greq.Session.GetPacketLength())
309309
downloadAbs := resp.GetDownloadResponse()
310-
greq.Task.Total = total
310+
greq.Task.UpdateTotal(total)
311311

312312
finalPath, err := fileutils.SafeJoin(configs.ContextPath, filepath.Join(greq.Session.ID, consts.DownloadPath, downloadAbs.Checksum))
313313
if err != nil {

0 commit comments

Comments
 (0)