Skip to content

Commit 2211849

Browse files
committed
feat(task): 增加任务最近活跃时间
1 parent 3cf3f2c commit 2211849

20 files changed

Lines changed: 505 additions & 38 deletions

backend/biz/task/handler/v1/task.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ type TaskHandler struct {
4646
taskConns *ws.TaskConn
4747
controlConns *ws.ControlConn
4848
taskSummary *service.TaskSummaryService
49+
taskActivity service.TaskActivityRefresher
4950
idleRefresher vmidle.VMIdleRefresher
5051
activeRepo domain.UserActiveRepo
5152
}
@@ -64,6 +65,7 @@ func NewTaskHandler(i *do.Injector) (*TaskHandler, error) {
6465
tc := do.MustInvoke[*ws.TaskConn](i)
6566
cc := do.MustInvoke[*ws.ControlConn](i)
6667
ts := do.MustInvoke[*service.TaskSummaryService](i)
68+
ta := do.MustInvoke[service.TaskActivityRefresher](i)
6769
ir := do.MustInvoke[vmidle.VMIdleRefresher](i)
6870

6971
// Optional deps
@@ -91,6 +93,7 @@ func NewTaskHandler(i *do.Injector) (*TaskHandler, error) {
9193
taskConns: tc,
9294
controlConns: cc,
9395
taskSummary: ts,
96+
taskActivity: ta,
9497
idleRefresher: ir,
9598
activeRepo: activeRepo,
9699
}

backend/biz/task/handler/v1/task_control.go

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,9 @@ import (
1010
"golang.org/x/sync/errgroup"
1111

1212
"github.com/GoYoko/web"
13+
"github.com/google/uuid"
1314

15+
"github.com/chaitin/MonkeyCode/backend/biz/task/service"
1416
"github.com/chaitin/MonkeyCode/backend/consts"
1517
"github.com/chaitin/MonkeyCode/backend/domain"
1618
"github.com/chaitin/MonkeyCode/backend/middleware"
@@ -130,6 +132,9 @@ func (h *TaskHandler) Control(c *web.Context, req domain.TaskControlReq) error {
130132

131133
logger := h.logger.With("task_id", task.ID, "fn", "task.control")
132134
taskID := task.ID.String()
135+
if err := h.taskActivity.Refresh(c.Request().Context(), task.ID); err != nil {
136+
logger.WarnContext(c.Request().Context(), "failed to refresh task last active on control connect", "error", err)
137+
}
133138

134139
// 连接建立:刷新空闲计时器
135140
if vm := task.VirtualMachine; vm != nil {
@@ -182,7 +187,7 @@ func (h *TaskHandler) Control(c *web.Context, req domain.TaskControlReq) error {
182187
// 定期刷新空闲计时器,保持 VM 活跃
183188
if vm := task.VirtualMachine; vm != nil {
184189
g.Go(func() error {
185-
return h.controlKeepAlive(ctx, vm.ID)
190+
return h.controlKeepAlive(ctx, task.ID, vm.ID)
186191
})
187192
}
188193

@@ -211,17 +216,23 @@ func (h *TaskHandler) controlPing(ctx context.Context, wsConn *ws.WebsocketManag
211216
}
212217

213218
// controlKeepAlive 定期刷新空闲计时器,防止 VM 被误判空闲
214-
func (h *TaskHandler) controlKeepAlive(ctx context.Context, vmID string) error {
215-
ticker := time.NewTicker(1 * time.Minute)
216-
defer ticker.Stop()
219+
func (h *TaskHandler) controlKeepAlive(ctx context.Context, taskID uuid.UUID, vmID string) error {
220+
idleTicker := time.NewTicker(1 * time.Minute)
221+
activityTicker := time.NewTicker(service.TaskActivityRefreshInterval)
222+
defer idleTicker.Stop()
223+
defer activityTicker.Stop()
217224
for {
218225
select {
219226
case <-ctx.Done():
220227
return ctx.Err()
221-
case <-ticker.C:
228+
case <-idleTicker.C:
222229
if err := h.idleRefresher.Refresh(ctx, vmID); err != nil {
223230
h.logger.WarnContext(ctx, "keepalive refresh failed", "vmID", vmID, "error", err)
224231
}
232+
case <-activityTicker.C:
233+
if err := h.taskActivity.Refresh(ctx, taskID); err != nil {
234+
h.logger.WarnContext(ctx, "task activity refresh failed", "taskID", taskID, "error", err)
235+
}
225236
}
226237
}
227238
}

backend/biz/task/register.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import (
1313
func ProvideTask(i *do.Injector) {
1414
do.Provide(i, usecase.NewTaskUsecase)
1515
do.Provide(i, usecase.NewGitTaskUsecase)
16+
do.Provide(i, service.NewTaskActivityRefresher)
1617
do.Provide(i, service.NewTaskSummaryService)
1718
do.Provide(i, v1.NewTaskHandler)
1819
do.Provide(i, repo.NewTaskRepo)

backend/biz/task/repo/task.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,15 @@ func (t *TaskRepo) Update(ctx context.Context, _ *domain.User, id uuid.UUID, fn
261261
})
262262
}
263263

264+
func (t *TaskRepo) RefreshLastActiveAt(ctx context.Context, id uuid.UUID, at time.Time, minInterval time.Duration) error {
265+
up := t.db.Task.Update().Where(task.ID(id))
266+
if minInterval > 0 {
267+
up = up.Where(task.LastActiveAtLT(at.Add(-minInterval)))
268+
}
269+
_, err := up.SetLastActiveAt(at).Save(ctx)
270+
return err
271+
}
272+
264273
// Delete implements domain.TaskRepo.
265274
func (t *TaskRepo) Delete(ctx context.Context, user *domain.User, id uuid.UUID) error {
266275
_, err := t.db.Task.Delete().
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
package service
2+
3+
import (
4+
"context"
5+
"time"
6+
7+
"github.com/google/uuid"
8+
"github.com/samber/do"
9+
10+
"github.com/chaitin/MonkeyCode/backend/domain"
11+
)
12+
13+
const TaskActivityRefreshInterval = 5 * time.Minute
14+
15+
type TaskActivityRefresher interface {
16+
Refresh(ctx context.Context, taskID uuid.UUID) error
17+
ForceRefresh(ctx context.Context, taskID uuid.UUID) error
18+
}
19+
20+
type taskActivityRefresher struct {
21+
repo taskActivityRepo
22+
clock func() time.Time
23+
}
24+
25+
type taskActivityRepo interface {
26+
RefreshLastActiveAt(ctx context.Context, id uuid.UUID, at time.Time, minInterval time.Duration) error
27+
}
28+
29+
func NewTaskActivityRefresher(i *do.Injector) (TaskActivityRefresher, error) {
30+
return &taskActivityRefresher{
31+
repo: do.MustInvoke[domain.TaskRepo](i),
32+
clock: time.Now,
33+
}, nil
34+
}
35+
36+
func (r *taskActivityRefresher) Refresh(ctx context.Context, taskID uuid.UUID) error {
37+
return r.repo.RefreshLastActiveAt(ctx, taskID, r.clock(), TaskActivityRefreshInterval)
38+
}
39+
40+
func (r *taskActivityRefresher) ForceRefresh(ctx context.Context, taskID uuid.UUID) error {
41+
return r.repo.RefreshLastActiveAt(ctx, taskID, r.clock(), 0)
42+
}
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
package service
2+
3+
import (
4+
"context"
5+
"testing"
6+
"time"
7+
8+
"github.com/google/uuid"
9+
)
10+
11+
func TestTaskActivityRefresherRefreshUsesThrottleInterval(t *testing.T) {
12+
taskID := uuid.MustParse("11111111-1111-1111-1111-111111111111")
13+
now := time.Unix(1_700_000_000, 0).UTC()
14+
repo := &taskActivityRepoStub{}
15+
refresher := &taskActivityRefresher{
16+
repo: repo,
17+
clock: func() time.Time { return now },
18+
}
19+
20+
if err := refresher.Refresh(context.Background(), taskID); err != nil {
21+
t.Fatalf("Refresh() error = %v", err)
22+
}
23+
24+
if repo.taskID != taskID {
25+
t.Fatalf("task id = %s, want %s", repo.taskID, taskID)
26+
}
27+
if !repo.at.Equal(now) {
28+
t.Fatalf("refresh time = %s, want %s", repo.at, now)
29+
}
30+
if repo.minInterval != TaskActivityRefreshInterval {
31+
t.Fatalf("min interval = %s, want %s", repo.minInterval, TaskActivityRefreshInterval)
32+
}
33+
}
34+
35+
func TestTaskActivityRefresherForceRefreshBypassesThrottle(t *testing.T) {
36+
taskID := uuid.MustParse("22222222-2222-2222-2222-222222222222")
37+
now := time.Unix(1_700_000_100, 0).UTC()
38+
repo := &taskActivityRepoStub{}
39+
refresher := &taskActivityRefresher{
40+
repo: repo,
41+
clock: func() time.Time { return now },
42+
}
43+
44+
if err := refresher.ForceRefresh(context.Background(), taskID); err != nil {
45+
t.Fatalf("ForceRefresh() error = %v", err)
46+
}
47+
48+
if repo.taskID != taskID {
49+
t.Fatalf("task id = %s, want %s", repo.taskID, taskID)
50+
}
51+
if !repo.at.Equal(now) {
52+
t.Fatalf("refresh time = %s, want %s", repo.at, now)
53+
}
54+
if repo.minInterval != 0 {
55+
t.Fatalf("min interval = %s, want 0", repo.minInterval)
56+
}
57+
}
58+
59+
type taskActivityRepoStub struct {
60+
taskID uuid.UUID
61+
at time.Time
62+
minInterval time.Duration
63+
}
64+
65+
func (s *taskActivityRepoStub) RefreshLastActiveAt(_ context.Context, taskID uuid.UUID, at time.Time, minInterval time.Duration) error {
66+
s.taskID = taskID
67+
s.at = at
68+
s.minInterval = minInterval
69+
return nil
70+
}

backend/biz/task/usecase/task.go

Lines changed: 52 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ import (
1919
"github.com/samber/do"
2020

2121
gituc "github.com/chaitin/MonkeyCode/backend/biz/git/usecase"
22+
"github.com/chaitin/MonkeyCode/backend/biz/task/service"
23+
vmidle "github.com/chaitin/MonkeyCode/backend/biz/vmidle/usecase"
2224
"github.com/chaitin/MonkeyCode/backend/config"
2325
"github.com/chaitin/MonkeyCode/backend/consts"
2426
"github.com/chaitin/MonkeyCode/backend/db"
@@ -36,40 +38,44 @@ import (
3638

3739
// TaskUsecase 任务业务逻辑实现
3840
type TaskUsecase struct {
39-
cfg *config.Config
40-
repo domain.TaskRepo
41-
modelRepo domain.ModelRepo
42-
logger *slog.Logger
43-
taskflow taskflow.Clienter
44-
loki *loki.Client
45-
redis *redis.Client
46-
notifyDispatcher *dispatcher.Dispatcher
47-
taskHook domain.TaskHook
48-
privilegeChecker domain.PrivilegeChecker
49-
modelHook domain.ModelHook
50-
taskLifecycle *lifecycle.Manager[uuid.UUID, consts.TaskStatus, lifecycle.TaskMetadata]
51-
vmLifecycle *lifecycle.Manager[string, lifecycle.VMState, lifecycle.VMMetadata]
52-
girepo domain.GitIdentityRepo
53-
tokenProvider *gituc.TokenProvider
54-
projectRepo domain.ProjectRepo
41+
cfg *config.Config
42+
repo domain.TaskRepo
43+
modelRepo domain.ModelRepo
44+
logger *slog.Logger
45+
taskflow taskflow.Clienter
46+
loki *loki.Client
47+
redis *redis.Client
48+
notifyDispatcher *dispatcher.Dispatcher
49+
taskHook domain.TaskHook
50+
privilegeChecker domain.PrivilegeChecker
51+
modelHook domain.ModelHook
52+
taskLifecycle *lifecycle.Manager[uuid.UUID, consts.TaskStatus, lifecycle.TaskMetadata]
53+
vmLifecycle *lifecycle.Manager[string, lifecycle.VMState, lifecycle.VMMetadata]
54+
girepo domain.GitIdentityRepo
55+
tokenProvider *gituc.TokenProvider
56+
projectRepo domain.ProjectRepo
57+
taskActivityRefresher service.TaskActivityRefresher
58+
idleRefresher vmidle.VMIdleRefresher
5559
}
5660

5761
// NewTaskUsecase 创建任务业务逻辑实例
5862
func NewTaskUsecase(i *do.Injector) (domain.TaskUsecase, error) {
5963
u := &TaskUsecase{
60-
cfg: do.MustInvoke[*config.Config](i),
61-
repo: do.MustInvoke[domain.TaskRepo](i),
62-
modelRepo: do.MustInvoke[domain.ModelRepo](i),
63-
logger: do.MustInvoke[*slog.Logger](i).With("module", "usecase.TaskUsecase"),
64-
taskflow: do.MustInvoke[taskflow.Clienter](i),
65-
loki: do.MustInvoke[*loki.Client](i),
66-
redis: do.MustInvoke[*redis.Client](i),
67-
notifyDispatcher: do.MustInvoke[*dispatcher.Dispatcher](i),
68-
taskLifecycle: do.MustInvoke[*lifecycle.Manager[uuid.UUID, consts.TaskStatus, lifecycle.TaskMetadata]](i),
69-
vmLifecycle: do.MustInvoke[*lifecycle.Manager[string, lifecycle.VMState, lifecycle.VMMetadata]](i),
70-
girepo: do.MustInvoke[domain.GitIdentityRepo](i),
71-
tokenProvider: do.MustInvoke[*gituc.TokenProvider](i),
72-
projectRepo: do.MustInvoke[domain.ProjectRepo](i),
64+
cfg: do.MustInvoke[*config.Config](i),
65+
repo: do.MustInvoke[domain.TaskRepo](i),
66+
modelRepo: do.MustInvoke[domain.ModelRepo](i),
67+
logger: do.MustInvoke[*slog.Logger](i).With("module", "usecase.TaskUsecase"),
68+
taskflow: do.MustInvoke[taskflow.Clienter](i),
69+
loki: do.MustInvoke[*loki.Client](i),
70+
redis: do.MustInvoke[*redis.Client](i),
71+
notifyDispatcher: do.MustInvoke[*dispatcher.Dispatcher](i),
72+
taskLifecycle: do.MustInvoke[*lifecycle.Manager[uuid.UUID, consts.TaskStatus, lifecycle.TaskMetadata]](i),
73+
vmLifecycle: do.MustInvoke[*lifecycle.Manager[string, lifecycle.VMState, lifecycle.VMMetadata]](i),
74+
girepo: do.MustInvoke[domain.GitIdentityRepo](i),
75+
tokenProvider: do.MustInvoke[*gituc.TokenProvider](i),
76+
projectRepo: do.MustInvoke[domain.ProjectRepo](i),
77+
taskActivityRefresher: do.MustInvoke[service.TaskActivityRefresher](i),
78+
idleRefresher: do.MustInvoke[vmidle.VMIdleRefresher](i),
7379
}
7480

7581
// 可选注入 TaskHook
@@ -450,6 +456,11 @@ func (a *TaskUsecase) Create(ctx context.Context, user *domain.User, req domain.
450456
if err := a.IncrUserInputCount(ctx, user.ID, pt.Edges.Task.ID); err != nil {
451457
a.logger.WarnContext(ctx, "failed to incr user input count on create", "error", err)
452458
}
459+
vmID := ""
460+
if createdVm != nil {
461+
vmID = createdVm.ID
462+
}
463+
a.refreshCreatedTaskState(ctx, pt.TaskID, vmID)
453464

454465
result := cvt.From(pt, &domain.ProjectTask{})
455466

@@ -463,6 +474,18 @@ func (a *TaskUsecase) Create(ctx context.Context, user *domain.User, req domain.
463474
return result, nil
464475
}
465476

477+
func (a *TaskUsecase) refreshCreatedTaskState(ctx context.Context, taskID uuid.UUID, vmID string) {
478+
if err := a.taskActivityRefresher.ForceRefresh(ctx, taskID); err != nil {
479+
a.logger.WarnContext(ctx, "failed to refresh task last active on create", "task_id", taskID, "error", err)
480+
}
481+
if vmID == "" {
482+
return
483+
}
484+
if err := a.idleRefresher.Refresh(ctx, vmID); err != nil {
485+
a.logger.WarnContext(ctx, "failed to refresh vm idle timers on create", "task_id", taskID, "vm_id", vmID, "error", err)
486+
}
487+
}
488+
466489
func (a *TaskUsecase) buildMCPConfigs(taskID uuid.UUID, token string) []taskflow.McpServerConfig {
467490
mcps := []taskflow.McpServerConfig{
468491
{
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
package usecase
2+
3+
import (
4+
"context"
5+
"errors"
6+
"io"
7+
"log/slog"
8+
"testing"
9+
10+
"github.com/google/uuid"
11+
)
12+
13+
func TestRefreshCreatedTaskStateAlwaysRefreshesIdleTimer(t *testing.T) {
14+
taskID := uuid.MustParse("11111111-1111-1111-1111-111111111111")
15+
vmID := "vm-1"
16+
taskRefresher := &taskActivityRefresherStub{err: errors.New("db write failed")}
17+
idleRefresher := &vmIdleRefresherStub{}
18+
u := &TaskUsecase{
19+
logger: slog.New(slog.NewTextHandler(io.Discard, nil)),
20+
taskActivityRefresher: taskRefresher,
21+
idleRefresher: idleRefresher,
22+
}
23+
24+
u.refreshCreatedTaskState(context.Background(), taskID, vmID)
25+
26+
if !taskRefresher.forceCalled {
27+
t.Fatal("expected task activity refresher to be called")
28+
}
29+
if taskRefresher.taskID != taskID {
30+
t.Fatalf("task id = %s, want %s", taskRefresher.taskID, taskID)
31+
}
32+
if !idleRefresher.called {
33+
t.Fatal("expected vm idle refresher to be called")
34+
}
35+
if idleRefresher.vmID != vmID {
36+
t.Fatalf("vm id = %s, want %s", idleRefresher.vmID, vmID)
37+
}
38+
}
39+
40+
type taskActivityRefresherStub struct {
41+
taskID uuid.UUID
42+
forceCalled bool
43+
err error
44+
}
45+
46+
func (s *taskActivityRefresherStub) Refresh(context.Context, uuid.UUID) error {
47+
return nil
48+
}
49+
50+
func (s *taskActivityRefresherStub) ForceRefresh(_ context.Context, taskID uuid.UUID) error {
51+
s.taskID = taskID
52+
s.forceCalled = true
53+
return s.err
54+
}
55+
56+
type vmIdleRefresherStub struct {
57+
vmID string
58+
called bool
59+
err error
60+
}
61+
62+
func (s *vmIdleRefresherStub) Refresh(_ context.Context, vmID string) error {
63+
s.vmID = vmID
64+
s.called = true
65+
return s.err
66+
}

0 commit comments

Comments
 (0)