Skip to content

Commit c909ff9

Browse files
committed
refactor(host): 简化鉴权依赖注入
1 parent 5e3db4e commit c909ff9

3 files changed

Lines changed: 80 additions & 33 deletions

File tree

backend/biz/host/handler/v1/internal.go

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -28,27 +28,16 @@ import (
2828
"github.com/chaitin/MonkeyCode/backend/pkg/ws"
2929
)
3030

31-
type internalHostRepo interface {
32-
UpsertHost(context.Context, *taskflow.Host) error
33-
UpsertVirtualMachine(context.Context, *taskflow.VirtualMachine) error
34-
GetVirtualMachine(context.Context, string) (*db.VirtualMachine, error)
35-
UpdateVirtualMachine(context.Context, string, func(*db.VirtualMachineUpdateOne) error) error
36-
GetByID(context.Context, string) (*db.Host, error)
37-
GetVirtualMachineByEnvID(context.Context, string) (*db.VirtualMachine, error)
38-
GetGitCredentialByTask(context.Context, string) (*domain.GitCredentialInfo, error)
39-
}
40-
4131
// InternalHostHandler 处理 taskflow 回调的 host/VM 相关接口
4232
type InternalHostHandler struct {
4333
logger *slog.Logger
44-
repo internalHostRepo
34+
repo domain.HostRepo
4535
teamRepo domain.TeamHostRepo
4636
redis *redis.Client
4737
getAgentToken agentTokenGetter
4838
limiter vmDeleteLimiter
4939
vmDeleter vmDeleter
5040
skipSoftDelete func(context.Context) context.Context
51-
runAsync asyncRunner
5241
cache *cache.Cache
5342
taskLifecycle *lifecycle.Manager[uuid.UUID, consts.TaskStatus, lifecycle.TaskMetadata]
5443
hostUsecase domain.HostUsecase
@@ -71,7 +60,6 @@ func NewInternalHostHandler(i *do.Injector) (*InternalHostHandler, error) {
7160
limiter: rdb,
7261
vmDeleter: tf.VirtualMachiner(),
7362
skipSoftDelete: entx.SkipSoftDelete,
74-
runAsync: defaultAsyncRunner,
7563
cache: cache.New(15*time.Minute, 10*time.Minute),
7664
taskLifecycle: do.MustInvoke[*lifecycle.Manager[uuid.UUID, consts.TaskStatus, lifecycle.TaskMetadata]](i),
7765
hostUsecase: do.MustInvoke[domain.HostUsecase](i),

backend/biz/host/handler/v1/internal_auth.go

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,6 @@ type vmDeleter interface {
2727
Delete(ctx context.Context, req *taskflow.DeleteVirtualMachineReq) error
2828
}
2929

30-
type asyncRunner func(func())
31-
32-
func defaultAsyncRunner(fn func()) {
33-
go fn()
34-
}
35-
3630
type agentTokenGetter func(ctx context.Context, key string) (string, error)
3731

3832
func defaultAgentTokenGetter(rdb *redis.Client) agentTokenGetter {
@@ -59,7 +53,7 @@ return nil
5953
}
6054

6155
func (h *InternalHostHandler) tryRecycledVMDelete(ctx context.Context, vm *db.VirtualMachine, machineID string) {
62-
if h.limiter == nil || h.vmDeleter == nil || h.runAsync == nil {
56+
if h.limiter == nil || h.vmDeleter == nil {
6357
h.logger.WarnContext(ctx, "skip recycled vm delete retry", "vm_id", vm.ID, "machine_id", machineID, "error", "missing dependency")
6458
return
6559
}
@@ -71,7 +65,7 @@ func (h *InternalHostHandler) tryRecycledVMDelete(ctx context.Context, vm *db.Vi
7165
return
7266
}
7367

74-
h.runAsync(func() {
68+
go func() {
7569
deleteCtx, cancel := context.WithTimeout(context.Background(), recycledDeleteTimeout)
7670
defer cancel()
7771

@@ -85,5 +79,5 @@ func (h *InternalHostHandler) tryRecycledVMDelete(ctx context.Context, vm *db.Vi
8579
return
8680
}
8781
h.logger.InfoContext(deleteCtx, "reissue recycled vm delete success", "vm_id", vm.ID, "machine_id", machineID)
88-
})
82+
}()
8983
}

backend/biz/host/handler/v1/internal_auth_test.go

Lines changed: 76 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ import (
1717
)
1818

1919
func TestAgentAuthRecycledVMTriggersDeleteOnce(t *testing.T) {
20-
vmClient := &vmDeleterStub{}
20+
vmClient := &vmDeleterStub{ch: make(chan struct{}, 1)}
2121
handler := &InternalHostHandler{
2222
logger: slog.New(slog.NewTextHandler(io.Discard, nil)),
2323
getAgentToken: func(context.Context, string) (string, error) { return "", redis.Nil },
@@ -34,23 +34,23 @@ func TestAgentAuthRecycledVMTriggersDeleteOnce(t *testing.T) {
3434
vmDeleter: vmClient,
3535
limiter: &setNXLimiterStub{result: true},
3636
skipSoftDelete: func(ctx context.Context) context.Context { return ctx },
37-
runAsync: func(fn func()) { fn() },
3837
}
3938

4039
_, err := handler.agentAuth(context.Background(), "agent_1", "machine-1")
4140
if !errors.Is(err, errAgentVMRecycled) {
4241
t.Fatalf("agent auth error = %v, want %v", err, errAgentVMRecycled)
4342
}
44-
if len(vmClient.reqs) != 1 {
43+
reqs := vmClient.waitReqs(t, time.Second)
44+
if len(reqs) != 1 {
4545
t.Fatalf("delete calls = %d, want 1", len(vmClient.reqs))
4646
}
47-
if vmClient.reqs[0].ID != "env_1" {
48-
t.Fatalf("delete env id = %q, want env_1", vmClient.reqs[0].ID)
47+
if reqs[0].ID != "env_1" {
48+
t.Fatalf("delete env id = %q, want env_1", reqs[0].ID)
4949
}
5050
}
5151

5252
func TestAgentAuthRecycledVMLimitedSkipsDelete(t *testing.T) {
53-
vmClient := &vmDeleterStub{}
53+
vmClient := &vmDeleterStub{ch: make(chan struct{}, 1)}
5454
handler := &InternalHostHandler{
5555
logger: slog.New(slog.NewTextHandler(io.Discard, nil)),
5656
getAgentToken: func(context.Context, string) (string, error) { return "", redis.Nil },
@@ -67,20 +67,19 @@ func TestAgentAuthRecycledVMLimitedSkipsDelete(t *testing.T) {
6767
vmDeleter: vmClient,
6868
limiter: &setNXLimiterStub{result: false},
6969
skipSoftDelete: func(ctx context.Context) context.Context { return ctx },
70-
runAsync: func(fn func()) { fn() },
7170
}
7271

7372
_, err := handler.agentAuth(context.Background(), "agent_2", "machine-2")
7473
if !errors.Is(err, errAgentVMRecycled) {
7574
t.Fatalf("agent auth error = %v, want %v", err, errAgentVMRecycled)
7675
}
77-
if len(vmClient.reqs) != 0 {
76+
if vmClient.hasReqWithin(50 * time.Millisecond) {
7877
t.Fatalf("delete calls = %d, want 0", len(vmClient.reqs))
7978
}
8079
}
8180

8281
func TestAgentAuthSoftDeletedRecycledVMStillTriggersDelete(t *testing.T) {
83-
vmClient := &vmDeleterStub{}
82+
vmClient := &vmDeleterStub{ch: make(chan struct{}, 1)}
8483
skipCalled := false
8584
type testSkipMarkerKey struct{}
8685
markerKey := testSkipMarkerKey{}
@@ -107,7 +106,6 @@ func TestAgentAuthSoftDeletedRecycledVMStillTriggersDelete(t *testing.T) {
107106
skipCalled = true
108107
return context.WithValue(ctx, markerKey, markerValue)
109108
},
110-
runAsync: func(fn func()) { fn() },
111109
}
112110

113111
_, err := handler.agentAuth(context.Background(), "agent_deleted", "machine-deleted")
@@ -117,7 +115,7 @@ func TestAgentAuthSoftDeletedRecycledVMStillTriggersDelete(t *testing.T) {
117115
if !skipCalled {
118116
t.Fatal("expected skipSoftDelete to be called")
119117
}
120-
if len(vmClient.reqs) != 1 {
118+
if len(vmClient.waitReqs(t, time.Second)) != 1 {
121119
t.Fatalf("delete calls = %d, want 1", len(vmClient.reqs))
122120
}
123121
}
@@ -129,6 +127,14 @@ type internalHostRepoStub struct {
129127
skipMarkerValue string
130128
}
131129

130+
func (s *internalHostRepoStub) List(context.Context, uuid.UUID) ([]*db.Host, error) {
131+
return nil, errors.New("not implemented")
132+
}
133+
134+
func (s *internalHostRepoStub) GetHost(context.Context, uuid.UUID, string) (*domain.Host, error) {
135+
return nil, errors.New("not implemented")
136+
}
137+
132138
func (s *internalHostRepoStub) UpsertHost(context.Context, *taskflow.Host) error {
133139
return nil
134140
}
@@ -162,6 +168,38 @@ func (s *internalHostRepoStub) GetVirtualMachineByEnvID(context.Context, string)
162168
return nil, errors.New("vm not found")
163169
}
164170

171+
func (s *internalHostRepoStub) GetVirtualMachineWithUser(context.Context, uuid.UUID, string) (*db.VirtualMachine, error) {
172+
return nil, errors.New("vm not found")
173+
}
174+
175+
func (s *internalHostRepoStub) CreateVirtualMachine(context.Context, *domain.User, *domain.CreateVMReq, func(context.Context) (string, error), func(*db.Model, *db.Image) (*domain.VirtualMachine, error)) (*domain.VirtualMachine, error) {
176+
return nil, errors.New("not implemented")
177+
}
178+
179+
func (s *internalHostRepoStub) PastHourVirtualMachine(context.Context) ([]*db.VirtualMachine, error) {
180+
return nil, errors.New("not implemented")
181+
}
182+
183+
func (s *internalHostRepoStub) AllCountDownVirtualMachine(context.Context) ([]*db.VirtualMachine, error) {
184+
return nil, errors.New("not implemented")
185+
}
186+
187+
func (s *internalHostRepoStub) DeleteVirtualMachine(context.Context, uuid.UUID, string, string, func(*db.VirtualMachine) error) error {
188+
return errors.New("not implemented")
189+
}
190+
191+
func (s *internalHostRepoStub) DeleteHost(context.Context, uuid.UUID, string) error {
192+
return errors.New("not implemented")
193+
}
194+
195+
func (s *internalHostRepoStub) UpdateHost(context.Context, uuid.UUID, *domain.UpdateHostReq) error {
196+
return errors.New("not implemented")
197+
}
198+
199+
func (s *internalHostRepoStub) UpdateVM(context.Context, domain.UpdateVMReq, func(*db.VirtualMachine) error) (*db.VirtualMachine, int64, error) {
200+
return nil, 0, errors.New("not implemented")
201+
}
202+
165203
func (s *internalHostRepoStub) GetGitCredentialByTask(context.Context, string) (*domain.GitCredentialInfo, error) {
166204
return nil, errors.New("task not found")
167205
}
@@ -182,10 +220,37 @@ func (s *setNXLimiterStub) SetNX(_ context.Context, key string, _ interface{}, t
182220
type vmDeleterStub struct {
183221
reqs []*taskflow.DeleteVirtualMachineReq
184222
err error
223+
ch chan struct{}
185224
}
186225

187226
func (s *vmDeleterStub) Delete(_ context.Context, req *taskflow.DeleteVirtualMachineReq) error {
188227
cp := *req
189228
s.reqs = append(s.reqs, &cp)
229+
if s.ch != nil {
230+
select {
231+
case s.ch <- struct{}{}:
232+
default:
233+
}
234+
}
190235
return s.err
191236
}
237+
238+
func (s *vmDeleterStub) waitReqs(t *testing.T, timeout time.Duration) []*taskflow.DeleteVirtualMachineReq {
239+
t.Helper()
240+
select {
241+
case <-s.ch:
242+
return s.reqs
243+
case <-time.After(timeout):
244+
t.Fatal("timed out waiting for delete call")
245+
return nil
246+
}
247+
}
248+
249+
func (s *vmDeleterStub) hasReqWithin(timeout time.Duration) bool {
250+
select {
251+
case <-s.ch:
252+
return true
253+
case <-time.After(timeout):
254+
return false
255+
}
256+
}

0 commit comments

Comments
 (0)