Skip to content

Commit 1856743

Browse files
committed
fix(host): PickHost 改为加权随机
1 parent 555e12d commit 1856743

File tree

2 files changed

+182
-12
lines changed

2 files changed

+182
-12
lines changed

backend/biz/host/usecase/publichost.go

Lines changed: 38 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,9 @@ package usecase
22

33
import (
44
"context"
5+
"crypto/rand"
6+
"encoding/binary"
57
"fmt"
6-
"sync/atomic"
78

89
"github.com/samber/do"
910

@@ -17,9 +18,10 @@ import (
1718
type PublicHostUsecase struct {
1819
repo domain.PublicHostRepo
1920
taskflow taskflow.Clienter
20-
rr uint64
2121
}
2222

23+
var randUint64n = randomUint64n
24+
2325
func NewPublicHostUsecase(i *do.Injector) (domain.PublicHostUsecase, error) {
2426
return &PublicHostUsecase{
2527
repo: do.MustInvoke[domain.PublicHostRepo](i),
@@ -52,9 +54,18 @@ func (p *PublicHostUsecase) PickHost(ctx context.Context) (*domain.Host, error)
5254
return nil, errcode.ErrPublicHostNotFound.Wrap(fmt.Errorf("no online public hosts found"))
5355
}
5456

55-
weights := make([]uint64, len(onlines))
57+
selected, err := pickWeightedHost(onlines)
58+
if err != nil {
59+
return nil, err
60+
}
61+
62+
return cvt.From(selected, &domain.Host{}), nil
63+
}
64+
65+
func pickWeightedHost(hosts []*db.Host) (*db.Host, error) {
66+
weights := make([]uint64, len(hosts))
5667
var totalWeight uint64
57-
for i, h := range onlines {
68+
for i, h := range hosts {
5869
w := h.Weight
5970
if w <= 0 {
6071
w = 1
@@ -66,19 +77,34 @@ func (p *PublicHostUsecase) PickHost(ctx context.Context) (*domain.Host, error)
6677
return nil, errcode.ErrPublicHostNotFound.Wrap(fmt.Errorf("no valid weights found"))
6778
}
6879

69-
idx := atomic.AddUint64(&p.rr, 1) - 1
70-
offset := idx % totalWeight
71-
var selected *db.Host
80+
offset, err := randUint64n(totalWeight)
81+
if err != nil {
82+
return nil, err
83+
}
7284
for i, w := range weights {
7385
if offset < w {
74-
selected = onlines[i]
75-
break
86+
return hosts[i], nil
7687
}
7788
offset -= w
7889
}
79-
if selected == nil {
80-
return nil, errcode.ErrPublicHostNotFound.Wrap(fmt.Errorf("failed to select public host"))
90+
91+
return nil, errcode.ErrPublicHostNotFound.Wrap(fmt.Errorf("failed to select public host"))
92+
}
93+
94+
func randomUint64n(n uint64) (uint64, error) {
95+
if n == 0 {
96+
return 0, fmt.Errorf("random upper bound must be positive")
8197
}
8298

83-
return cvt.From(selected, &domain.Host{}), nil
99+
limit := ^uint64(0) - (^uint64(0) % n)
100+
var buf [8]byte
101+
for {
102+
if _, err := rand.Read(buf[:]); err != nil {
103+
return 0, err
104+
}
105+
v := binary.BigEndian.Uint64(buf[:])
106+
if v < limit {
107+
return v % n, nil
108+
}
109+
}
84110
}
Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
package usecase
2+
3+
import (
4+
"context"
5+
"testing"
6+
7+
"github.com/chaitin/MonkeyCode/backend/db"
8+
"github.com/chaitin/MonkeyCode/backend/pkg/taskflow"
9+
)
10+
11+
func TestPickHostSelectsHostByRandomOffset(t *testing.T) {
12+
repo := &publicHostRepoStub{
13+
hosts: []*db.Host{
14+
{ID: "host-a", Hostname: "a", Weight: 1},
15+
{ID: "host-b", Hostname: "b", Weight: 3},
16+
{ID: "host-c", Hostname: "c", Weight: 1},
17+
},
18+
}
19+
hoster := &publicHosterStub{
20+
onlineMap: map[string]bool{
21+
"host-a": true,
22+
"host-b": true,
23+
"host-c": true,
24+
},
25+
}
26+
u := &PublicHostUsecase{
27+
repo: repo,
28+
taskflow: &taskflowClientStub{hoster: hoster},
29+
}
30+
31+
offsets := []uint64{0, 1, 3, 4}
32+
limits := make([]uint64, 0, len(offsets))
33+
prevRandUint64n := randUint64n
34+
randUint64n = func(n uint64) (uint64, error) {
35+
limits = append(limits, n)
36+
v := offsets[0]
37+
offsets = offsets[1:]
38+
return v, nil
39+
}
40+
t.Cleanup(func() {
41+
randUint64n = prevRandUint64n
42+
})
43+
44+
got := make([]string, 0, 4)
45+
for range 4 {
46+
host, err := u.PickHost(context.Background())
47+
if err != nil {
48+
t.Fatalf("PickHost() error = %v", err)
49+
}
50+
got = append(got, host.ID)
51+
}
52+
53+
want := []string{"host-a", "host-b", "host-b", "host-c"}
54+
for i := range want {
55+
if got[i] != want[i] {
56+
t.Fatalf("PickHost() at %d = %q, want %q", i, got[i], want[i])
57+
}
58+
}
59+
60+
for _, limit := range limits {
61+
if limit != 5 {
62+
t.Fatalf("rand limit = %d, want 5", limit)
63+
}
64+
}
65+
}
66+
67+
func TestPickHostTreatsNonPositiveWeightsAsOne(t *testing.T) {
68+
u := &PublicHostUsecase{
69+
repo: &publicHostRepoStub{
70+
hosts: []*db.Host{
71+
{ID: "host-a", Hostname: "a", Weight: 0},
72+
{ID: "host-b", Hostname: "b", Weight: -2},
73+
},
74+
},
75+
taskflow: &taskflowClientStub{
76+
hoster: &publicHosterStub{
77+
onlineMap: map[string]bool{
78+
"host-a": true,
79+
"host-b": true,
80+
},
81+
},
82+
},
83+
}
84+
85+
prevRandUint64n := randUint64n
86+
randUint64n = func(n uint64) (uint64, error) {
87+
if n != 2 {
88+
t.Fatalf("rand limit = %d, want 2", n)
89+
}
90+
return 1, nil
91+
}
92+
t.Cleanup(func() {
93+
randUint64n = prevRandUint64n
94+
})
95+
96+
host, err := u.PickHost(context.Background())
97+
if err != nil {
98+
t.Fatalf("PickHost() error = %v", err)
99+
}
100+
if host.ID != "host-b" {
101+
t.Fatalf("PickHost() = %q, want %q", host.ID, "host-b")
102+
}
103+
}
104+
105+
type publicHostRepoStub struct {
106+
hosts []*db.Host
107+
err error
108+
}
109+
110+
func (s *publicHostRepoStub) All(context.Context) ([]*db.Host, error) {
111+
return s.hosts, s.err
112+
}
113+
114+
type publicHosterStub struct {
115+
onlineMap map[string]bool
116+
err error
117+
}
118+
119+
func (s *publicHosterStub) List(context.Context, string) (map[string]*taskflow.Host, error) {
120+
return nil, nil
121+
}
122+
123+
func (s *publicHosterStub) IsOnline(context.Context, *taskflow.IsOnlineReq[string]) (*taskflow.IsOnlineResp, error) {
124+
if s.err != nil {
125+
return nil, s.err
126+
}
127+
return &taskflow.IsOnlineResp{OnlineMap: s.onlineMap}, nil
128+
}
129+
130+
type taskflowClientStub struct {
131+
hoster taskflow.Hoster
132+
}
133+
134+
func (s *taskflowClientStub) VirtualMachiner() taskflow.VirtualMachiner { return nil }
135+
func (s *taskflowClientStub) Host() taskflow.Hoster { return s.hoster }
136+
func (s *taskflowClientStub) FileManager() taskflow.FileManager { return nil }
137+
func (s *taskflowClientStub) TaskManager() taskflow.TaskManager { return nil }
138+
func (s *taskflowClientStub) PortForwarder() taskflow.PortForwarder { return nil }
139+
func (s *taskflowClientStub) Stats(context.Context) (*taskflow.Stats, error) {
140+
return nil, nil
141+
}
142+
func (s *taskflowClientStub) TaskLive(context.Context, string, bool, func(*taskflow.TaskChunk) error) error {
143+
return nil
144+
}

0 commit comments

Comments
 (0)