Skip to content
Open
41 changes: 41 additions & 0 deletions examples/new-scorer-values.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
payloadProcessor:
listModels:
- facebook/opt-125m
- facebook/opt-350m
customConfig:
plugins:
- type: body-field-to-header
parameters:
fieldName: model
headerName: X-Gateway-Model-Name
- type: base-model-to-header
- type: model-selector
- type: avg-ttft-scorer
parameters:
decayWeight: 1.0
stalenessThreshold: "10s"
inflightWeight: 1.0
maxIdleProbes: 2
- type: max-score-picker
- type: request-metadata-extractor
parameters:
emaAlpha: 0.1
intervalDuration: 5s
- type: model-config-datasource
parameters:
modelsPath: /config/models.json
profiles:
- name: default
plugins:
request:
- pluginRef: model-selector
- pluginRef: avg-ttft-scorer
weight: 1.0
- pluginRef: max-score-picker
- pluginRef: body-field-to-header
- pluginRef: base-model-to-header
datalayer:
extractors:
- pluginRef: request-metadata-extractor
datasources:
- pluginRef: model-config-datasource
40 changes: 0 additions & 40 deletions pkg/framework/interface/datalayer/pricing/cost_digest_test.go

This file was deleted.

9 changes: 7 additions & 2 deletions pkg/framework/plugins/datalayer/requestmetadata/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,12 +88,13 @@ func ExtractorFactory(name string, parameters json.RawMessage, h plugin.Handle)
}

// ModelMetrics holds per-model metadata: in-flight request count and
// EMA estimates for TTFT and TPOT.
// EMA estimates for TTFT, TPOT, and inflight count.
type ModelMetrics struct {
Requests int64
AvgRequests float64
AvgTTFT float64
AvgTPOT float64
LastObservedAt int64 // Unix nanoseconds of the last TTFT EMA update; 0 if never observed.
LastObservedAt int64
}

func (r ModelMetrics) Clone() datalayer.Cloneable { return r }
Expand All @@ -112,13 +113,17 @@ type modelIntervalAccumulator struct {

// flush averages the accumulated interval observations into the EMA, emits Prometheus gauges, and resets the interval.
func (s *modelIntervalAccumulator) flush(now time.Time, model string, alpha float64) {
// Always update AvgRequests — it samples the current inflight count each interval
// regardless of whether any responses arrived.
s.AvgRequests = ema(s.AvgRequests, float64(s.Requests), alpha)
if s.ttftN > 0 {
s.AvgTTFT = ema(s.AvgTTFT, s.ttftSum/float64(s.ttftN), alpha)
s.LastObservedAt = now.UnixNano()
metrics.RecordModelAvgTTFT(model, s.AvgTTFT)
}
if s.tpotN > 0 {
s.AvgTPOT = ema(s.AvgTPOT, s.tpotSum/float64(s.tpotN), alpha)
s.LastObservedAt = now.UnixNano()
metrics.RecordModelAvgTPOT(model, s.AvgTPOT)
}
s.intervalStart = now
Expand Down
83 changes: 83 additions & 0 deletions pkg/framework/plugins/datalayer/requestmetadata/plugin_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,89 @@ func TestAvgTPOTZeroCompletionTokensIgnored(t *testing.T) {
}
}

// TestAvgRequestsTracksInflight verifies that AvgRequests is sampled at flush time
// and blended into an EMA. With intervalDuration=0 every response triggers a flush.
//
// Sequence: two requests arrive (Requests=2), one response arrives (Requests=1 after decrement).
// Flush samples Requests=1 as the first observation → AvgRequests=1.0.
func TestAvgRequestsTracksInflight(t *testing.T) {
ext, ds := newRequestMetadataTest(t)

if err := ext.Extract(context.Background(), []dlsrc.Event{
makeRequestEvent("m1"),
makeRequestEvent("m1"),
makeResponseEvent(0), // Requests: 2→1 then flush → AvgRequests = 1.0 (first observation)
}); err != nil {
t.Fatalf("Extract failed: %v", err)
}

rc := getRequestMetadata(t, ds, "m1")
if rc.AvgRequests != 1.0 {
t.Errorf("expected AvgRequests=1.0 (first observation), got %f", rc.AvgRequests)
}
}

// TestAvgRequestsEMABlend verifies that successive flushes blend AvgRequests with the EMA.
//
// Flush 1: Requests=1 → AvgRequests = 1.0 (first observation, no blend)
// Flush 2: Requests=0 → AvgRequests = 0.1×0 + 0.9×1.0 = 0.9
func TestAvgRequestsEMABlend(t *testing.T) {
ext, ds := newRequestMetadataTest(t)

// First flush: one request in flight at flush time.
if err := ext.Extract(context.Background(), []dlsrc.Event{
makeRequestEvent("m1"),
makeRequestEvent("m1"),
makeResponseEvent(0), // Requests: 2→1 at flush
}); err != nil {
t.Fatalf("first Extract failed: %v", err)
}

// Second flush: no requests in flight at flush time.
if err := ext.Extract(context.Background(), []dlsrc.Event{
makeResponseEvent(0), // Requests: 1→0 at flush
}); err != nil {
t.Fatalf("second Extract failed: %v", err)
}

rc := getRequestMetadata(t, ds, "m1")
want := 0.1*0.0 + 0.9*1.0 // 0.9
if rc.AvgRequests != want {
t.Errorf("expected AvgRequests=%f, got %f", want, rc.AvgRequests)
}
}

// TestAvgRequestsUpdatesWithoutTTFT verifies that AvgRequests is updated on flush
// even when no TTFT observations arrived in the interval (AvgTTFT stays unchanged).
func TestAvgRequestsUpdatesWithoutTTFT(t *testing.T) {
ext, ds := newRequestMetadataTest(t)

// Seed a TTFT value.
if err := ext.Extract(context.Background(), []dlsrc.Event{
makeResponseEventWithTTFT(0, 500*time.Millisecond),
}); err != nil {
t.Fatalf("seed Extract failed: %v", err)
}

// Two requests arrive; response decrements to 1 so flush sees Requests=1.
// No TTFT on this response — AvgTTFT must stay unchanged while AvgRequests updates.
if err := ext.Extract(context.Background(), []dlsrc.Event{
makeRequestEvent("m1"),
makeRequestEvent("m1"),
makeResponseEvent(0), // no TTFT field; Requests: 2→1 at flush
}); err != nil {
t.Fatalf("second Extract failed: %v", err)
}

rc := getRequestMetadata(t, ds, "m1")
if rc.AvgTTFT != 0.5 {
t.Errorf("expected AvgTTFT=0.5 (unchanged), got %f", rc.AvgTTFT)
}
if rc.AvgRequests == 0 {
t.Errorf("expected AvgRequests to be updated even without TTFT, got 0")
}
}

func TestExtractorFactoryWiresDatastore(t *testing.T) {
ds := datastore.NewFakeDataStore()
h := &fakeHandle{ds: ds}
Expand Down
71 changes: 62 additions & 9 deletions pkg/framework/plugins/modelselector/scorer/avgtpot/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@ package avgtpot
import (
"context"
"encoding/json"
"fmt"
"math"
"time"

"sigs.k8s.io/controller-runtime/pkg/log"

Expand All @@ -29,28 +31,66 @@ import (
"github.com/llm-d/llm-d-inference-payload-processor/pkg/framework/interface/plugin"
"github.com/llm-d/llm-d-inference-payload-processor/pkg/framework/interface/requesthandling"
requestmetadata "github.com/llm-d/llm-d-inference-payload-processor/pkg/framework/plugins/datalayer/requestmetadata"
"github.com/llm-d/llm-d-inference-payload-processor/pkg/framework/plugins/modelselector/scorer/internal/decay"
)

const PluginType = "avg-tpot-scorer"
const (
PluginType = "avg-tpot-scorer"

defaultDecayWeight = 1.0
defaultStalenessThreshold = 30 * time.Second
)

// compile-time interface assertion
var _ modelselector.Scorer = &AvgTPOTScorer{}

// AvgTPOTScorerConfig holds the scorer's JSON parameters.
type AvgTPOTScorerConfig struct {
// DecayWeight scales the staleness decay in [0,1]; 0 disables. Default 1.0.
DecayWeight *float64 `json:"decayWeight,omitempty"`
// StalenessThreshold is the elapsed time for full staleness (e.g. "30s"). Default "30s".
StalenessThreshold string `json:"stalenessThreshold,omitempty"`
}

// AvgTPOTScorer scores models based on their exponential moving average TPOT.
// The model with the lowest AvgTPOT scores 1.0; the highest scores 0.0.
// Models with no observed TPOT yet (AvgTPOT == 0) are treated as idle and score 1.0.
// If all models have the same AvgTPOT, all score 1.0.
// Stale EMAs are decayed toward zero (see the decay package); set DecayWeight=0 to disable.
type AvgTPOTScorer struct {
typedName plugin.TypedName
decayCfg decay.Config
}

func ScorerFactory(name string, _ json.RawMessage, _ plugin.Handle) (plugin.Plugin, error) {
return NewAvgTPOTScorer().WithName(name), nil
func ScorerFactory(name string, parameters json.RawMessage, _ plugin.Handle) (plugin.Plugin, error) {
config := AvgTPOTScorerConfig{
StalenessThreshold: defaultStalenessThreshold.String(),
}
if len(parameters) > 0 {
if err := json.Unmarshal(parameters, &config); err != nil {
return nil, fmt.Errorf("failed to parse parameters for plugin %q: %w", name, err)
}
}
weight := defaultDecayWeight
if config.DecayWeight != nil {
weight = *config.DecayWeight
}
if weight < 0 || weight > 1 {
return nil, fmt.Errorf("invalid decayWeight %v for plugin %q: must be in [0, 1]", weight, name)
}
threshold, err := time.ParseDuration(config.StalenessThreshold)
if err != nil {
return nil, fmt.Errorf("invalid stalenessThreshold %q for plugin %q: %w", config.StalenessThreshold, name, err)
}
return NewAvgTPOTScorer().
WithName(name).
WithDecay(decay.Config{Weight: weight, Threshold: threshold}), nil
}

func NewAvgTPOTScorer() *AvgTPOTScorer {
return &AvgTPOTScorer{
typedName: plugin.TypedName{Type: PluginType, Name: PluginType},
decayCfg: decay.Config{Weight: defaultDecayWeight, Threshold: defaultStalenessThreshold},
}
}

Expand All @@ -61,15 +101,21 @@ func (s *AvgTPOTScorer) WithName(name string) *AvgTPOTScorer {
return s
}

// Score returns a score in [0,1] for each model.
// Formula: score = (max - avgTPOT) / (max - min)
// WithDecay overrides the decay configuration.
func (s *AvgTPOTScorer) WithDecay(cfg decay.Config) *AvgTPOTScorer {
s.decayCfg = cfg
return s
}

// Score returns score = (max - avgTPOT)/(max - min) per model, with optional staleness decay.
func (s *AvgTPOTScorer) Score(ctx context.Context, _ *plugin.CycleState, _ *requesthandling.InferenceRequest, models []datalayer.Model) map[datalayer.Model]float64 {
now := time.Now()
tpots := make(map[datalayer.Model]float64, len(models))
minTPOT := math.MaxFloat64
maxTPOT := 0.0

for _, model := range models {
v := avgTPOT(model)
v := s.avgTPOT(model, now)
tpots[model] = v
if v > maxTPOT {
maxTPOT = v
Expand Down Expand Up @@ -97,8 +143,8 @@ func (s *AvgTPOTScorer) Score(ctx context.Context, _ *plugin.CycleState, _ *requ
return scores
}

// avgTPOT returns the AvgTPOT for a model, or 0 if not yet observed.
func avgTPOT(model datalayer.Model) float64 {
// avgTPOT returns the decay-adjusted AvgTPOT, or 0 if unobserved.
func (s *AvgTPOTScorer) avgTPOT(model datalayer.Model, now time.Time) float64 {
val, ok := model.GetAttributes().Get(requestmetadata.RequestMetadataAttributeKey)
if !ok {
return 0
Expand All @@ -107,5 +153,12 @@ func avgTPOT(model datalayer.Model) float64 {
if !ok {
return 0
}
return rc.AvgTPOT
if rc.AvgTPOT == 0 {
return 0
}
var lastObservedAt time.Time
if rc.LastObservedAt > 0 {
lastObservedAt = time.Unix(0, rc.LastObservedAt)
}
return decay.Apply(rc.AvgTPOT, lastObservedAt, rc.Requests, now, s.decayCfg)
}
Loading
Loading