Skip to content

Commit 8bbe89a

Browse files
localai-botmudler
andauthored
fix(distributed): route per request across loaded replicas + cache probeHealth (#9968)
* refactor(distributed): extract PickBestReplica from FindAndLockNodeWithModel Lifts the replica-selection policy (in_flight ASC, last_used ASC, available_vram DESC) out of the SQL ORDER BY into a pure Go function in the new replicapicker.go. The SQL clause keeps its FOR UPDATE atomicity and remains the production path used by SmartRouter; PickBestReplica is the canonical implementation that the future per-frontend rotating replica cache (TODO referenced from pkg/model) will call against an in-memory snapshot without paying a DB round-trip per inference. A new registry_test mirror spec seeds a multi-tier scenario and asserts both layers pick the same replica, so any future tweak to either side fails the test until the other side is updated. No behavior change. Signed-off-by: Ettore Di Giacinto <mudler@localai.io> Assisted-by: Claude:claude-opus-4-7 [Claude Code] * fix(distributed): route per inference request and cache probeHealth Two related fixes that together restore load balancing across loaded replicas of the same model. 1. ModelLoader.Load and LoadModel bypass the local *Model cache when modelRouter is set. The cached *Model wraps an InFlightTrackingClient bound to a single (nodeID, replicaIndex) — reusing it pinned every subsequent request to whichever node won the very first pick, so FindAndLockNodeWithModel's round-robin never got a chance to run even after the reconciler scaled the model out to a second node. In distributed mode SmartRouter.Route now runs per request, and PickBestReplica picks the least-loaded replica each time. SmartRouter has its own coalescing (advisory DB lock for first-time loads + singleflight on backend.install RPC) so concurrent first requests for a not-yet-loaded model still produce a single worker side install. 2. SmartRouter.probeHealth memoizes successful gRPC HealthCheck results in a new probeCache (probe_cache.go) with a 30s TTL. With per-request routing every inference call hits probeHealth, and llama.cpp-style backends serialize HealthCheck behind active Predict — so a burst of incoming requests stalled on the probe to a node already mid-stream, tripping the 2s timeout and falling through to the install path. singleflight collapses N concurrent first-time probes for the same (node, addr) into one round-trip, failed probes invalidate the entry so the staleness-recovery path still triggers, and the TTL matches pkg/model/model.go's healthCheckTTL so the single-process and distributed paths share a staleness budget. The background HealthMonitor still reaps actually-dead backends within ~45s. The bypass introduces one short FindAndLockNodeWithModel transaction per inference. A TODO in pkg/model/loader.go documents the future per modelID rotating-replica cache that would reuse PickBestReplica against an in-memory snapshot and skip the DB round-trip for hot paths. Signed-off-by: Ettore Di Giacinto <mudler@localai.io> Assisted-by: Claude:claude-opus-4-7 [Claude Code] --------- Signed-off-by: Ettore Di Giacinto <mudler@localai.io> Co-authored-by: Ettore Di Giacinto <mudler@localai.io>
1 parent dcc5599 commit 8bbe89a

9 files changed

Lines changed: 592 additions & 20 deletions

File tree

core/services/nodes/probe_cache.go

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
package nodes
2+
3+
import (
4+
"sync"
5+
"time"
6+
7+
"golang.org/x/sync/singleflight"
8+
)
9+
10+
// probeCache memoizes recent successful gRPC HealthCheck results for
11+
// (nodeID, addr) tuples so SmartRouter.probeHealth doesn't pay a round-trip
12+
// on every inference request.
13+
//
14+
// Why this exists: with per-request routing (see pkg/model/loader.go), every
15+
// inference call goes through SmartRouter.Route, which probes the backend
16+
// before returning a client. Many gRPC backends (notably llama.cpp's server)
17+
// serialize HealthCheck against active Predict on a shared goroutine, so a
18+
// burst of new requests can stall behind a single long-running stream —
19+
// exactly the "queue stalling" symptom observed in distributed clusters.
20+
//
21+
// The background HealthMonitor (perModelHealthCheck) is still the cluster-wide
22+
// source of truth that reaps actually-dead backends within ~45s; this cache
23+
// only saves the per-request hot path from re-asking when nothing has changed.
24+
//
25+
// TTL matches healthCheckTTL in pkg/model/model.go so the single-process
26+
// IsRecentlyHealthy path and this distributed-mode path share the same
27+
// staleness budget.
28+
type probeCache struct {
29+
ttl time.Duration
30+
mu sync.Mutex
31+
seen map[string]time.Time // key → last successful probe
32+
flight singleflight.Group // coalesces concurrent probes for the same key
33+
}
34+
35+
// newProbeCache returns a probeCache with the given TTL. Zero TTL disables
36+
// caching: every call to DoOrCached invokes the probe.
37+
func newProbeCache(ttl time.Duration) *probeCache {
38+
return &probeCache{
39+
ttl: ttl,
40+
seen: make(map[string]time.Time),
41+
}
42+
}
43+
44+
// IsFresh reports whether key was successfully probed within TTL.
45+
func (c *probeCache) IsFresh(key string) bool {
46+
if c.ttl <= 0 {
47+
return false
48+
}
49+
c.mu.Lock()
50+
defer c.mu.Unlock()
51+
last, ok := c.seen[key]
52+
return ok && time.Since(last) < c.ttl
53+
}
54+
55+
// markFresh records key as successfully probed at the current time.
56+
func (c *probeCache) markFresh(key string) {
57+
c.mu.Lock()
58+
defer c.mu.Unlock()
59+
c.seen[key] = time.Now()
60+
}
61+
62+
// Invalidate drops any cached freshness for key. Used after a probe failure
63+
// (or any other signal that the backend may not be alive) so the next call
64+
// will re-probe instead of trusting stale state.
65+
func (c *probeCache) Invalidate(key string) {
66+
c.mu.Lock()
67+
defer c.mu.Unlock()
68+
delete(c.seen, key)
69+
}
70+
71+
// DoOrCached returns true if key is fresh; otherwise it runs probe (coalescing
72+
// concurrent callers via singleflight) and caches a successful result. Failed
73+
// probes invalidate the cache, so a transient miss doesn't pin every
74+
// subsequent request to a re-probe.
75+
func (c *probeCache) DoOrCached(key string, probe func() bool) bool {
76+
if c.IsFresh(key) {
77+
return true
78+
}
79+
v, _, _ := c.flight.Do(key, func() (any, error) {
80+
// Double-check after potentially waiting: another caller in this
81+
// flight may have just populated the cache.
82+
if c.IsFresh(key) {
83+
return true, nil
84+
}
85+
ok := probe()
86+
if ok {
87+
c.markFresh(key)
88+
} else {
89+
c.Invalidate(key)
90+
}
91+
return ok, nil
92+
})
93+
return v.(bool)
94+
}
Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
package nodes
2+
3+
import (
4+
"sync"
5+
"sync/atomic"
6+
"time"
7+
8+
. "github.com/onsi/ginkgo/v2"
9+
. "github.com/onsi/gomega"
10+
)
11+
12+
var _ = Describe("probeCache", func() {
13+
It("invokes the probe on a cold cache and caches success", func() {
14+
c := newProbeCache(time.Minute)
15+
var calls int32
16+
probe := func() bool {
17+
atomic.AddInt32(&calls, 1)
18+
return true
19+
}
20+
21+
Expect(c.DoOrCached("k", probe)).To(BeTrue())
22+
Expect(c.DoOrCached("k", probe)).To(BeTrue())
23+
Expect(c.DoOrCached("k", probe)).To(BeTrue())
24+
25+
// Cached: probe ran once.
26+
Expect(atomic.LoadInt32(&calls)).To(Equal(int32(1)))
27+
})
28+
29+
It("re-probes after the TTL expires", func() {
30+
// 1 ms TTL means the second call is virtually guaranteed to see an
31+
// expired entry without flaking on scheduler jitter.
32+
c := newProbeCache(time.Millisecond)
33+
var calls int32
34+
probe := func() bool {
35+
atomic.AddInt32(&calls, 1)
36+
return true
37+
}
38+
39+
Expect(c.DoOrCached("k", probe)).To(BeTrue())
40+
time.Sleep(5 * time.Millisecond)
41+
Expect(c.DoOrCached("k", probe)).To(BeTrue())
42+
43+
Expect(atomic.LoadInt32(&calls)).To(Equal(int32(2)))
44+
})
45+
46+
It("does not cache failed probes — next call re-probes", func() {
47+
c := newProbeCache(time.Minute)
48+
var calls int32
49+
var result atomic.Bool
50+
probe := func() bool {
51+
atomic.AddInt32(&calls, 1)
52+
return result.Load()
53+
}
54+
55+
// First probe fails — must NOT be cached.
56+
result.Store(false)
57+
Expect(c.DoOrCached("k", probe)).To(BeFalse())
58+
Expect(c.IsFresh("k")).To(BeFalse())
59+
60+
// Recover: second probe succeeds and is cached.
61+
result.Store(true)
62+
Expect(c.DoOrCached("k", probe)).To(BeTrue())
63+
Expect(c.IsFresh("k")).To(BeTrue())
64+
65+
// Third call short-circuits on the fresh entry.
66+
Expect(c.DoOrCached("k", probe)).To(BeTrue())
67+
Expect(atomic.LoadInt32(&calls)).To(Equal(int32(2)))
68+
})
69+
70+
It("coalesces concurrent probes via singleflight", func() {
71+
// Models the "6 chat completions arrive simultaneously for a
72+
// not-yet-cached backend" scenario. Without singleflight every caller
73+
// would dial the backend, defeating the purpose of the cache.
74+
c := newProbeCache(time.Minute)
75+
var calls int32
76+
start := make(chan struct{})
77+
probe := func() bool {
78+
atomic.AddInt32(&calls, 1)
79+
// Stall briefly so the test reliably has all goroutines parked
80+
// inside flight.Do at the same time.
81+
time.Sleep(50 * time.Millisecond)
82+
return true
83+
}
84+
85+
const N = 8
86+
var wg sync.WaitGroup
87+
results := make([]bool, N)
88+
for i := 0; i < N; i++ {
89+
wg.Add(1)
90+
go func(i int) {
91+
defer wg.Done()
92+
<-start
93+
results[i] = c.DoOrCached("k", probe)
94+
}(i)
95+
}
96+
97+
close(start)
98+
wg.Wait()
99+
100+
Expect(atomic.LoadInt32(&calls)).To(Equal(int32(1)),
101+
"singleflight must collapse %d concurrent probes into one", N)
102+
for i, got := range results {
103+
Expect(got).To(BeTrue(), "goroutine %d saw a different result", i)
104+
}
105+
})
106+
107+
It("treats different keys independently", func() {
108+
c := newProbeCache(time.Minute)
109+
var aCalls, bCalls int32
110+
Expect(c.DoOrCached("a", func() bool { atomic.AddInt32(&aCalls, 1); return true })).To(BeTrue())
111+
Expect(c.DoOrCached("b", func() bool { atomic.AddInt32(&bCalls, 1); return true })).To(BeTrue())
112+
Expect(c.DoOrCached("a", func() bool { atomic.AddInt32(&aCalls, 1); return true })).To(BeTrue())
113+
114+
Expect(atomic.LoadInt32(&aCalls)).To(Equal(int32(1)))
115+
Expect(atomic.LoadInt32(&bCalls)).To(Equal(int32(1)))
116+
})
117+
118+
It("disables caching when TTL is zero", func() {
119+
c := newProbeCache(0)
120+
var calls int32
121+
probe := func() bool {
122+
atomic.AddInt32(&calls, 1)
123+
return true
124+
}
125+
126+
Expect(c.DoOrCached("k", probe)).To(BeTrue())
127+
Expect(c.DoOrCached("k", probe)).To(BeTrue())
128+
Expect(c.DoOrCached("k", probe)).To(BeTrue())
129+
130+
Expect(atomic.LoadInt32(&calls)).To(Equal(int32(3)))
131+
})
132+
133+
It("Invalidate forces the next call to re-probe", func() {
134+
c := newProbeCache(time.Hour)
135+
var calls int32
136+
probe := func() bool {
137+
atomic.AddInt32(&calls, 1)
138+
return true
139+
}
140+
Expect(c.DoOrCached("k", probe)).To(BeTrue())
141+
c.Invalidate("k")
142+
Expect(c.DoOrCached("k", probe)).To(BeTrue())
143+
Expect(atomic.LoadInt32(&calls)).To(Equal(int32(2)))
144+
})
145+
})

core/services/nodes/registry.go

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -668,10 +668,21 @@ func (r *NodeRegistry) FindNodesWithModel(ctx context.Context, modelName string)
668668
return nodes, nil
669669
}
670670

671-
// FindAndLockNodeWithModel atomically finds the least-loaded node with the given
672-
// model loaded and increments its in-flight counter within a single transaction.
673-
// The SELECT FOR UPDATE row lock prevents concurrent eviction from removing the
674-
// NodeModel row between the find and increment operations.
671+
// FindAndLockNodeWithModel atomically finds the best loaded replica of the
672+
// given model and increments its in-flight counter within a single
673+
// transaction. The SELECT FOR UPDATE row lock prevents concurrent eviction
674+
// from removing the NodeModel row between the find and increment operations,
675+
// and serializes contending routers so concurrent picks distribute across
676+
// replicas instead of all landing on the same row.
677+
//
678+
// **Policy:** the SQL ORDER BY below MUST mirror PickBestReplica
679+
// (replicapicker.go). PickBestReplica is the canonical Go implementation of
680+
// the same rule — the per-frontend rotating-replica cache (TODO, see
681+
// pkg/model/loader.go) will eventually use it against in-memory snapshots so
682+
// hot inference requests don't pay this DB round-trip. If you change the
683+
// ordering here, change both sides; the TestFindAndLockNodeWithModelMirror
684+
// spec ("agrees with PickBestReplica on a seeded dataset") fails fast if they
685+
// drift.
675686
//
676687
// When candidateNodeIDs is non-empty, only nodes in that set are considered.
677688
// Pass nil (or empty) to consider any node. This lets callers pre-filter by
@@ -683,16 +694,16 @@ func (r *NodeRegistry) FindAndLockNodeWithModel(ctx context.Context, modelName s
683694
var node BackendNode
684695

685696
err := r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
686-
// Order by in_flight ASC (least busy replica), then by last_used ASC
687-
// (round-robin between equally-loaded replicas — oldest used wins, and
688-
// every successful pick refreshes last_used below, so the "oldest" naturally
689-
// rotates through the candidate set). available_vram DESC is the final
690-
// tiebreaker for cold starts where last_used is identical.
697+
// Mirror of PickBestReplica's policy (see replicapicker.go):
698+
// 1. in_flight ASC — least busy replica.
699+
// 2. last_used ASC — round-robin between equally-loaded replicas.
700+
// Every successful pick refreshes last_used below, so the
701+
// "oldest" tier naturally rotates through the candidate set.
702+
// Without this tier, in_flight ties collapsed to "fattest GPU
703+
// wins every time" and one node took nearly all the load.
704+
// 3. available_vram DESC — final tiebreaker for cold starts where
705+
// last_used is identical across replicas.
691706
//
692-
// Without the last_used tier, a tie on in_flight (the common case at low
693-
// to moderate concurrency where requests don't overlap) collapses to
694-
// "biggest GPU wins every time" and one node ends up taking nearly all
695-
// the load while replicas on other nodes sit idle.
696707
// Filter on backend_nodes.status = healthy in the inner JOIN itself,
697708
// not only in the later node-fetch step. The previous version picked
698709
// a (node_id, replica) pair purely on node_models state, then bailed

core/services/nodes/registry_test.go

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package nodes
33
import (
44
"context"
55
"runtime"
6+
"time"
67

78
. "github.com/onsi/ginkgo/v2"
89
. "github.com/onsi/gomega"
@@ -357,6 +358,79 @@ var _ = Describe("NodeRegistry", func() {
357358
_, _, err := registry.FindAndLockNodeWithModel(context.Background(), "no-match-model", []string{emptyIncluded.ID})
358359
Expect(err).To(HaveOccurred())
359360
})
361+
362+
It("agrees with PickBestReplica on a seeded dataset (policy mirror)", func() {
363+
// Guard against drift between the SQL ORDER BY in
364+
// FindAndLockNodeWithModel and the canonical Go implementation in
365+
// PickBestReplica. The two layers will eventually diverge in
366+
// caller (DB-backed atomic pick vs in-memory snapshot pick for the
367+
// per-frontend rotating cache), but the policy itself must stay
368+
// the single source of truth. If this test fails, update *both*
369+
// sides — never just one.
370+
//
371+
// Scenario exercises all three tiers:
372+
// - "loser-busy" has the most VRAM but in_flight=2 — loses tier 1.
373+
// - "loser-recent" ties at in_flight=0 but its last_used is the
374+
// newest of the in_flight=0 group — loses tier 2.
375+
// - "winner-mid" and "winner-fat" both tie at in_flight=0 and
376+
// share the oldest last_used — tier 3 decides: fattest wins.
377+
loserBusy := makeNode("mirror-loser-busy", "10.0.0.70:50051", 32_000_000_000)
378+
loserRecent := makeNode("mirror-loser-recent", "10.0.0.71:50051", 8_000_000_000)
379+
winnerMid := makeNode("mirror-winner-mid", "10.0.0.72:50051", 16_000_000_000)
380+
winnerFat := makeNode("mirror-winner-fat", "10.0.0.73:50051", 24_000_000_000)
381+
for _, n := range []*BackendNode{loserBusy, loserRecent, winnerMid, winnerFat} {
382+
Expect(registry.Register(context.Background(), n, true)).To(Succeed())
383+
Expect(registry.SetNodeModel(context.Background(), n.ID, "mirror-model", 0, "loaded", "", 0)).To(Succeed())
384+
}
385+
386+
// Force in_flight=2 on the "busy" node so tier 1 disqualifies it.
387+
Expect(registry.IncrementInFlight(context.Background(), loserBusy.ID, "mirror-model", 0)).To(Succeed())
388+
Expect(registry.IncrementInFlight(context.Background(), loserBusy.ID, "mirror-model", 0)).To(Succeed())
389+
390+
// Slam last_used to known values so the test is deterministic
391+
// regardless of clock resolution between the helpers above.
392+
base := time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC)
393+
set := func(id string, t time.Time) {
394+
Expect(db.Model(&NodeModel{}).
395+
Where("node_id = ? AND model_name = ?", id, "mirror-model").
396+
Update("last_used", t).Error).To(Succeed())
397+
}
398+
set(loserBusy.ID, base) // newest doesn't matter — already disqualified by tier 1
399+
set(loserRecent.ID, base.Add(time.Hour))
400+
set(winnerMid.ID, base)
401+
set(winnerFat.ID, base)
402+
403+
// Pull the same dataset both pickers will operate on. The Go
404+
// picker is a faithful representation of the policy; the SQL is
405+
// the production path.
406+
var rows []NodeModel
407+
Expect(db.Where("model_name = ? AND state = ?", "mirror-model", "loaded").
408+
Find(&rows).Error).To(Succeed())
409+
candidates := make([]ReplicaCandidate, 0, len(rows))
410+
for _, nm := range rows {
411+
var bn BackendNode
412+
Expect(db.First(&bn, "id = ? AND status = ?", nm.NodeID, StatusHealthy).Error).To(Succeed())
413+
candidates = append(candidates, ReplicaCandidate{
414+
NodeID: nm.NodeID,
415+
Address: bn.Address,
416+
ReplicaIndex: nm.ReplicaIndex,
417+
InFlight: nm.InFlight,
418+
LastUsed: nm.LastUsed,
419+
AvailableVRAM: bn.AvailableVRAM,
420+
})
421+
}
422+
goPick := PickBestReplica(candidates)
423+
Expect(goPick).ToNot(BeNil())
424+
425+
sqlNode, _, err := registry.FindAndLockNodeWithModel(context.Background(), "mirror-model", nil)
426+
Expect(err).ToNot(HaveOccurred())
427+
428+
Expect(sqlNode.ID).To(Equal(goPick.NodeID),
429+
"SQL ORDER BY picked %s; PickBestReplica picked %s — policy has drifted",
430+
sqlNode.ID, goPick.NodeID)
431+
// Sanity check: the policy says winner-fat wins on tier 3.
432+
Expect(goPick.NodeID).To(Equal(winnerFat.ID))
433+
})
360434
})
361435

362436
Describe("MarkHealthy and MarkUnhealthy round-trip", func() {

0 commit comments

Comments
 (0)