diff --git a/pkg/framework/interface/datalayer/pricing/cost_digest_test.go b/pkg/framework/interface/datalayer/pricing/cost_digest_test.go deleted file mode 100644 index e18c85e..0000000 --- a/pkg/framework/interface/datalayer/pricing/cost_digest_test.go +++ /dev/null @@ -1,40 +0,0 @@ -/* -Copyright 2026 The llm-d Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package pricing - -import ( - "testing" - - "github.com/caio/go-tdigest/v5" -) - -// newDigest builds a tdigest with the proposal's default compression and -// adds the given samples. It fails the test on any library error so that -// the call sites stay free of error handling. -func newDigest(t *testing.T, samples ...float64) *tdigest.TDigest { - t.Helper() - d, err := tdigest.New(tdigest.Compression(200)) - if err != nil { - t.Fatalf("tdigest.New: %v", err) - } - for _, s := range samples { - if err := d.Add(s); err != nil { - t.Fatalf("tdigest.Add(%v): %v", s, err) - } - } - return d -} diff --git a/pkg/framework/plugins/datalayer/requestcostmetadata/README.md b/pkg/framework/plugins/datalayer/requestcostmetadata/README.md new file mode 100644 index 0000000..21ff8be --- /dev/null +++ b/pkg/framework/plugins/datalayer/requestcostmetadata/README.md @@ -0,0 +1,69 @@ +# Request Cost Metadata Extractor + +## What it is + +A datasource extractor that turns each completed inference response into a per-request +cost sample and folds it into a per-model [t-digest](https://github.com/caio/go-tdigest) +stored on the Model's `AttributeMap`. It is registered as type +`request-cost-metadata-extractor` and runs on the same response-event loop as +`request-metadata-extractor`. It is a building block for the CostGuard scorer +(see [docs/proposals/050-costguard-scorer/README.md](../../../../../../docs/proposals/050-costguard-scorer/README.md)). + +## What it does + +1. Ignores `RequestEventType` events. Cost is observed only after a response. +2. On each `ResponseEventType` event: + - Reads the model name from the request body's `model` field. + - Reads `prompt_tokens` and `completion_tokens` from the response's `usage` block. + Skips the sample (with a debug log) if either is absent or non-positive. + - Reads the model's `*pricing.TokenPrices` from the AttributeMap under + `pricing.TokenPricesAttributeKey`. Skips the sample (with a debug log) if absent — + a model with no declared pricing has no defined cost. A model declared with + `TokenPrices{0,0}` (a free model) is *not* skipped: it records `cost=0`. + - Computes + `cost = prompt_tokens * InputTokenPrice + completion_tokens * OutputTokenPrice` + and adds the value to the model's running t-digest. +3. At the end of each `Extract` batch, for every model whose digest was updated and + whose flush interval has elapsed since the last publish, writes a *clone* of the + digest to the Model's AttributeMap under `pricing.CostDigestAttributeKey`. The + stored value is a `*pricing.CostDigest`. + +This extractor does not freeze and replace the digest at epoch boundaries — the +digest accumulates without bound. Epoch handling lands in a follow-up PR. + +## Inputs consumed + +- `dlsrc.ResponsePayload.Request.Body["model"]` — the model name (string). +- `dlsrc.ResponsePayload.Response.Body["usage"]` — a `map[string]any` containing + `prompt_tokens` and `completion_tokens` as `float64`. +- `pricing.TokenPricesAttributeKey` on the Model's AttributeMap — populated by the + `modelconfigcollector` plugin at startup and on config-file changes. + +## Configuration + +```json +{ + "compression": 200, + "flushIntervalDuration": "5s" +} +``` + +- `compression` (optional, default `200`): t-digest compression. Higher values + trade memory for quantile accuracy. Must be `> 0`. +- `flushIntervalDuration` (optional, default `"5s"`): aggregation window before a + per-model digest snapshot is published to the AttributeMap. Set to `"0s"` to + publish on every event (used in unit tests). Must be `>= 0`. + +## Known limitations + +- **Side-effect creation of empty Models for unconfigured names.** When a + response arrives for a model name that the operator never declared (i.e. a + model with no `pricing.TokenPrices` attribute), this extractor's lookup + goes through `Datastore.GetOrCreateModel`, which registers an empty Model + in the datastore as a side effect. The cost sample is correctly skipped, + but the model name leaks into `Datastore.Models()` and becomes visible to + every other plugin that enumerates the store. This is a limitation of the + current `Datastore` interface, which has no read-only `GetModel(name)` + method. A follow-up PR will add `GetModel` to the interface and migrate + this extractor to use it; once that lands, responses for unconfigured + models will be skipped without any datastore mutation. diff --git a/pkg/framework/plugins/datalayer/requestcostmetadata/plugin.go b/pkg/framework/plugins/datalayer/requestcostmetadata/plugin.go new file mode 100644 index 0000000..1281260 --- /dev/null +++ b/pkg/framework/plugins/datalayer/requestcostmetadata/plugin.go @@ -0,0 +1,321 @@ +/* +Copyright 2026 The llm-d Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +// Package requestcostmetadata implements a datasource extractor that +// turns each completed inference response into a per-request cost sample +// and folds it into a per-model t-digest stored on the Model's +// AttributeMap under pricing.CostDigestAttributeKey. It is a building +// block for the CostGuard scorer; see the package README for behavioral +// intent and configuration. +package requestcostmetadata + +import ( + "context" + "encoding/json" + "fmt" + "time" + + "github.com/caio/go-tdigest/v5" + "sigs.k8s.io/controller-runtime/pkg/log" + + logutil "github.com/llm-d/llm-d-inference-payload-processor/pkg/common/observability/logging" + "github.com/llm-d/llm-d-inference-payload-processor/pkg/framework/interface/datalayer" + dlsrc "github.com/llm-d/llm-d-inference-payload-processor/pkg/framework/interface/datalayer/datasource" + "github.com/llm-d/llm-d-inference-payload-processor/pkg/framework/interface/datalayer/pricing" + "github.com/llm-d/llm-d-inference-payload-processor/pkg/framework/interface/plugin" +) + +const ( + // PluginType is the identifier used when registering this extractor. + PluginType = "model-cost-extractor" + + // defaultCompression matches the t-digest compression value used in the + // CostGuard proposal (docs/proposals/050-costguard-scorer/README.md). + defaultCompression = 200.0 + + // defaultFlushIntervalDuration is the aggregation window before a per-model + // digest snapshot is published to the AttributeMap. Mirrors the pattern in + // the requestmetadata extractor. + defaultFlushIntervalDuration = 5 * time.Second +) + +// compile-time interface assertion +var _ dlsrc.Extractor = &RequestCostMetadataExtractor{} + +// RequestCostMetadataExtractorConfig holds the JSON-configurable parameters +// for the extractor. +type RequestCostMetadataExtractorConfig struct { + // Compression is the t-digest compression. Higher values trade memory + // for quantile accuracy. Must be > 0. Defaults to 200 if not specified. + Compression float64 `json:"compression,omitempty"` + + // FlushIntervalDuration is the aggregation window before a per-model digest + // snapshot is published to the AttributeMap (e.g. "5s", "1m"). Set to "0s" + // to publish on every event (used in unit tests). Defaults to "5s". + FlushIntervalDuration string `json:"flushIntervalDuration,omitempty"` +} + +// ExtractorFactory creates a RequestCostMetadataExtractor wired to the shared +// Datastore. +func ExtractorFactory(name string, parameters json.RawMessage, h plugin.Handle) (plugin.Plugin, error) { + config := RequestCostMetadataExtractorConfig{ + Compression: defaultCompression, + FlushIntervalDuration: defaultFlushIntervalDuration.String(), + } + if len(parameters) > 0 { + if err := json.Unmarshal(parameters, &config); err != nil { + return nil, fmt.Errorf("failed to parse parameters for plugin %q: %w", name, err) + } + } + + if config.Compression <= 0 { + return nil, fmt.Errorf("invalid compression %v for plugin %q: must be > 0", config.Compression, name) + } + + flushInterval, err := time.ParseDuration(config.FlushIntervalDuration) + if err != nil { + return nil, fmt.Errorf("invalid flushIntervalDuration %q for plugin %q: %w", config.FlushIntervalDuration, name, err) + } + if flushInterval < 0 { + return nil, fmt.Errorf("invalid flushIntervalDuration %q for plugin %q: must be >= 0", config.FlushIntervalDuration, name) + } + + return NewRequestCostMetadataExtractor(h.Datastore()). + WithName(name). + WithCompression(config.Compression). + WithFlushInterval(flushInterval), nil +} + +// modelCostAccumulator holds the running t-digest for a single model and the +// timestamp of its last flush, so the extractor can decide when to publish a +// snapshot to the AttributeMap. +type modelCostAccumulator struct { + digest *tdigest.TDigest + lastFlush time.Time +} + +// TODO: in a separate PR, add a request-handling plugin that sets +// stream_options.include_usage=true on the request body so that streamed +// responses always carry the usage block this extractor consumes. + +// RequestCostMetadataExtractor accumulates per-model cost samples derived from +// response usage metadata and pricing attributes, and publishes a t-digest +// snapshot to the Model's AttributeMap on flush. +// +// Extract is assumed to be called from a single goroutine (the +// NotificationSource event loop). +// Note: If parallel dispatch is introduced, add a +// sync.Mutex around state and the Datastore writes. +type RequestCostMetadataExtractor struct { + typedName plugin.TypedName + ds datalayer.Datastore + state map[string]*modelCostAccumulator + compression float64 + flushInterval time.Duration +} + +// NewRequestCostMetadataExtractor constructs an extractor wired to ds with +// proposal-default compression and flush interval. +func NewRequestCostMetadataExtractor(ds datalayer.Datastore) *RequestCostMetadataExtractor { + return &RequestCostMetadataExtractor{ + typedName: plugin.TypedName{Type: PluginType, Name: PluginType}, + ds: ds, + state: make(map[string]*modelCostAccumulator), + compression: defaultCompression, + flushInterval: defaultFlushIntervalDuration, + } +} + +func (e *RequestCostMetadataExtractor) TypedName() plugin.TypedName { return e.typedName } + +// WithName sets the instance name, used by the factory when the plugin is +// configured by name. +func (e *RequestCostMetadataExtractor) WithName(name string) *RequestCostMetadataExtractor { + e.typedName.Name = name + return e +} + +// WithCompression overrides the t-digest compression used for newly created +// per-model digests. +func (e *RequestCostMetadataExtractor) WithCompression(c float64) *RequestCostMetadataExtractor { + e.compression = c + return e +} + +// WithFlushInterval overrides the publish interval. Pass 0 to flush after +// every event (used in unit tests). +func (e *RequestCostMetadataExtractor) WithFlushInterval(d time.Duration) *RequestCostMetadataExtractor { + e.flushInterval = d + return e +} + +// Extract processes a batch of events. RequestEventType events are ignored; +// each ResponseEventType produces (at most) one cost sample, which is added +// to that model's running t-digest. Per-model snapshots are published to the +// AttributeMap when the flush interval has elapsed since the last publish. +func (e *RequestCostMetadataExtractor) Extract(ctx context.Context, events []dlsrc.Event) error { + debugLogger := log.FromContext(ctx).V(logutil.DEBUG) + debugLogger.Info("request-cost-metadata extractor invoked", "num_events", len(events)) + + now := time.Now() + updated := map[string]bool{} + // Cache token prices per-model to avoid repeated lookups within this batch + tokenPricesCache := make(map[string]*pricing.TokenPrices) + + for _, ev := range events { + if ev.Type != dlsrc.ResponseEventType { + continue + } + p, ok := ev.Payload.(dlsrc.ResponsePayload) + if !ok { + continue + } + // Distinguish "model field absent" from "model field present but + // not a string" so a malformed upstream is visible in debug logs + // rather than indistinguishable from a request with no model. + rawModel, hasModel := p.Request.Body["model"] + if !hasModel { + continue + } + model, isString := rawModel.(string) + if !isString { + debugLogger.Info("response request body has non-string model field, skipping", "model_type", fmt.Sprintf("%T", rawModel)) + continue + } + if model == "" { + continue + } + + promptTokens, completionTokens, ok := extractTokenCounts(p) + if !ok { + debugLogger.Info("response missing usable usage fields, skipping", "model", model) + continue + } + + // Check cache first; only lookup if not already cached + tp, ok := tokenPricesCache[model] + if !ok { + found := false + tp, found = lookupTokenPrices(e.ds, model) + if !found { + debugLogger.Info("model has no TokenPrices attribute, skipping cost sample", "model", model) + continue + } + tokenPricesCache[model] = tp + } + + cost := promptTokens*tp.InputTokenPrice + completionTokens*tp.OutputTokenPrice + + acc, err := e.getOrCreateAccumulator(model, now) + if err != nil { + debugLogger.Info("failed to create tdigest accumulator, skipping sample", "model", model, "err", err) + continue + } + if err := acc.digest.Add(cost); err != nil { + debugLogger.Info("tdigest.Add returned an unexpected error, skipping sample", "model", model, "err", err) + continue + } + updated[model] = true + } + + // After extracting all valid samples, pre-fetch accumulators for all models + // to avoid repeated lookups during flush. + modelToAcc := make(map[string]*modelCostAccumulator) + for model := range updated { + acc, err := e.getOrCreateAccumulator(model, now) + if err != nil { + debugLogger.Info("failed to create tdigest accumulator", "model", model, "err", err) + delete(updated, model) // mark as failed + continue + } + modelToAcc[model] = acc + } + + // updated contains exactly the models that received a fresh sample in + // this batch, so the flushInterval gate below only consults + // tdigest accumulators whose digest actually changed since the last publish. + for model := range updated { + acc := modelToAcc[model] + // flushInterval == 0 means publish on every event + if e.flushInterval > 0 && now.Sub(acc.lastFlush) < e.flushInterval { + continue + } + acc.lastFlush = now + snapshot := acc.digest.Clone() + // assumes that all models are configured with the pricing attributes and validated + // in the modelconfig collector + e.ds.GetOrCreateModel(model).GetAttributes().Put( + pricing.CostDigestAttributeKey, + &pricing.CostDigest{Digest: snapshot}, + ) + debugLogger.Info("request-cost-metadata published cost digest snapshot", + "model", model, + "count", snapshot.Count(), + ) + } + return nil +} + +// extractTokenCounts pulls prompt_tokens and completion_tokens from the +// response's usage block. Both must be present and positive; any failure +// returns ok=false so the sample is skipped. +func extractTokenCounts(p dlsrc.ResponsePayload) (prompt, completion float64, ok bool) { + usage, ok := p.Response.Body["usage"].(map[string]any) + if !ok { + return 0, 0, false + } + prompt, ok = usage["prompt_tokens"].(float64) + if !ok || prompt <= 0 { + return 0, 0, false + } + completion, ok = usage["completion_tokens"].(float64) + if !ok || completion <= 0 { + return 0, 0, false + } + return prompt, completion, true +} + +// lookupTokenPrices fetches the *pricing.TokenPrices stored on the model's +// AttributeMap. Returns ok=false if the attribute is absent or of the wrong +// type, in which case the caller skips the cost sample (a model with no +// pricing has no defined cost to record). +func lookupTokenPrices(ds datalayer.Datastore, model string) (*pricing.TokenPrices, bool) { + v, ok := ds.GetOrCreateModel(model).GetAttributes().Get(pricing.TokenPricesAttributeKey) + if !ok { + return nil, false + } + tp, ok := v.(*pricing.TokenPrices) + if !ok { + return nil, false + } + return tp, true +} + +// getOrCreateAccumulator returns the per-model accumulator, creating a fresh +// t-digest on first use. lastFlush is initialized to now so the first publish +// happens after one full flushInterval, matching the requestmetadata pattern. +func (e *RequestCostMetadataExtractor) getOrCreateAccumulator(model string, now time.Time) (*modelCostAccumulator, error) { + if acc, ok := e.state[model]; ok { + return acc, nil + } + d, err := tdigest.New(tdigest.Compression(e.compression)) + if err != nil { + return nil, err + } + acc := &modelCostAccumulator{digest: d, lastFlush: now} + e.state[model] = acc + return acc, nil +} diff --git a/pkg/framework/plugins/datalayer/requestcostmetadata/plugin_test.go b/pkg/framework/plugins/datalayer/requestcostmetadata/plugin_test.go new file mode 100644 index 0000000..68d063f --- /dev/null +++ b/pkg/framework/plugins/datalayer/requestcostmetadata/plugin_test.go @@ -0,0 +1,469 @@ +/* +Copyright 2026 The llm-d Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package requestcostmetadata + +import ( + "context" + "encoding/json" + "testing" + "time" + + ctrlbuilder "sigs.k8s.io/controller-runtime/pkg/builder" + "sigs.k8s.io/controller-runtime/pkg/client" + + "github.com/llm-d/llm-d-inference-payload-processor/pkg/datastore" + "github.com/llm-d/llm-d-inference-payload-processor/pkg/framework/interface/datalayer" + dlsrc "github.com/llm-d/llm-d-inference-payload-processor/pkg/framework/interface/datalayer/datasource" + "github.com/llm-d/llm-d-inference-payload-processor/pkg/framework/interface/datalayer/pricing" + "github.com/llm-d/llm-d-inference-payload-processor/pkg/framework/interface/plugin" + "github.com/llm-d/llm-d-inference-payload-processor/pkg/framework/interface/requesthandling" +) + +// fakeHandle implements plugin.Handle for unit tests. +type fakeHandle struct{ ds datalayer.Datastore } + +func (f *fakeHandle) Context() context.Context { return context.Background() } +func (f *fakeHandle) Client() client.Client { return nil } +func (f *fakeHandle) ReconcilerBuilder() *ctrlbuilder.Builder { return nil } +func (f *fakeHandle) Datastore() datalayer.Datastore { return f.ds } +func (f *fakeHandle) EventNotifier() datalayer.EventNotifier { return nil } +func (f *fakeHandle) Plugin(string) plugin.Plugin { return nil } +func (f *fakeHandle) AddPlugin(string, plugin.Plugin) {} +func (f *fakeHandle) GetAllPlugins() []plugin.Plugin { return nil } +func (f *fakeHandle) GetAllPluginsWithNames() map[string]plugin.Plugin { return nil } + +// makeResponseEvent builds a ResponseEventType event for the named model whose +// usage block reports promptTokens and completionTokens. Pass <= 0 to omit a +// field; pass omitUsage=true to omit the entire usage block. +func makeResponseEvent(model string, promptTokens, completionTokens float64, omitUsage bool) dlsrc.Event { + req := requesthandling.NewInferenceRequest() + req.Body["model"] = model + resp := requesthandling.NewInferenceResponse() + if !omitUsage { + usage := map[string]any{} + if promptTokens > 0 { + usage["prompt_tokens"] = promptTokens + } + if completionTokens > 0 { + usage["completion_tokens"] = completionTokens + } + resp.Body["usage"] = usage + } + return dlsrc.Event{ + Type: dlsrc.ResponseEventType, + Payload: dlsrc.ResponsePayload{Request: req, Response: resp}, + } +} + +// setTokenPrices attaches a TokenPrices attribute to the named model in ds. +func setTokenPrices(ds datalayer.Datastore, model string, in, out float64) { + ds.GetOrCreateModel(model).GetAttributes().Put( + pricing.TokenPricesAttributeKey, + &pricing.TokenPrices{InputTokenPrice: in, OutputTokenPrice: out}, + ) +} + +// readDigest fetches the *pricing.CostDigest for model from ds, returning +// (digest, true) if present and well-typed, or (nil, false) otherwise. +func readDigest(ds datalayer.Datastore, model string) (*pricing.CostDigest, bool) { + v, ok := ds.GetOrCreateModel(model).GetAttributes().Get(pricing.CostDigestAttributeKey) + if !ok { + return nil, false + } + cd, ok := v.(*pricing.CostDigest) + return cd, ok +} + +// newTestExtractor builds an extractor with flushInterval=0 so every event +// flushes immediately, mirroring the requestmetadata test pattern. Tests +// exercising non-zero flush intervals build their extractor inline. +func newTestExtractor(t *testing.T) (*RequestCostMetadataExtractor, datalayer.Datastore) { + t.Helper() + ds := datastore.NewFakeDataStore() + ext := NewRequestCostMetadataExtractor(ds).WithFlushInterval(0) + return ext, ds +} + +// --- Factory tests --- + +func TestExtractorFactory_Defaults(t *testing.T) { + ds := datastore.NewFakeDataStore() + p, err := ExtractorFactory("x", nil, &fakeHandle{ds: ds}) + if err != nil { + t.Fatalf("ExtractorFactory: %v", err) + } + ext := p.(*RequestCostMetadataExtractor) + if ext.compression != defaultCompression { + t.Errorf("compression = %v, want %v", ext.compression, defaultCompression) + } + if ext.flushInterval != defaultFlushIntervalDuration { + t.Errorf("flushInterval = %v, want %v", ext.flushInterval, defaultFlushIntervalDuration) + } +} + +func TestExtractorFactory_HonorsConfig(t *testing.T) { + ds := datastore.NewFakeDataStore() + raw := json.RawMessage(`{"compression":50,"flushIntervalDuration":"1m"}`) + p, err := ExtractorFactory("x", raw, &fakeHandle{ds: ds}) + if err != nil { + t.Fatalf("ExtractorFactory: %v", err) + } + ext := p.(*RequestCostMetadataExtractor) + if ext.compression != 50 { + t.Errorf("compression = %v, want 50", ext.compression) + } + if ext.flushInterval != time.Minute { + t.Errorf("flushInterval = %v, want 1m", ext.flushInterval) + } +} + +func TestExtractorFactory_RejectsInvalidJSON(t *testing.T) { + ds := datastore.NewFakeDataStore() + if _, err := ExtractorFactory("x", json.RawMessage(`{broken`), &fakeHandle{ds: ds}); err == nil { + t.Error("expected error for invalid JSON, got nil") + } +} + +func TestExtractorFactory_RejectsNonPositiveCompression(t *testing.T) { + ds := datastore.NewFakeDataStore() + raw := json.RawMessage(`{"compression":0,"flushIntervalDuration":"5s"}`) + if _, err := ExtractorFactory("x", raw, &fakeHandle{ds: ds}); err == nil { + t.Error("expected error for compression=0, got nil") + } +} + +func TestExtractorFactory_RejectsInvalidFlushInterval(t *testing.T) { + ds := datastore.NewFakeDataStore() + raw := json.RawMessage(`{"compression":200,"flushIntervalDuration":"not-a-duration"}`) + if _, err := ExtractorFactory("x", raw, &fakeHandle{ds: ds}); err == nil { + t.Error("expected error for invalid flushIntervalDuration, got nil") + } +} + +func TestExtractorFactory_RejectsNegativeFlushInterval(t *testing.T) { + ds := datastore.NewFakeDataStore() + raw := json.RawMessage(`{"compression":200,"flushIntervalDuration":"-1s"}`) + if _, err := ExtractorFactory("x", raw, &fakeHandle{ds: ds}); err == nil { + t.Error("expected error for negative flushIntervalDuration, got nil") + } +} + +// --- Extract tests --- + +// TestExtract_PublishesCostDigest verifies the happy path: a response event +// for a model with TokenPrices produces a digest snapshot on the AttributeMap +// whose count includes the new sample. +func TestExtract_PublishesCostDigest(t *testing.T) { + ext, ds := newTestExtractor(t) + setTokenPrices(ds, "m1", 1e-6, 4e-6) // input $1/M, output $4/M (per token) + + ev := makeResponseEvent("m1", 100, 50, false) + if err := ext.Extract(context.Background(), []dlsrc.Event{ev}); err != nil { + t.Fatalf("Extract: %v", err) + } + + cd, ok := readDigest(ds, "m1") + if !ok { + t.Fatal("expected CostDigest attribute to be present") + } + if cd.Digest.Count() != 1 { + t.Errorf("digest count = %d, want 1", cd.Digest.Count()) + } + // Cost = 100 * 1e-6 + 50 * 4e-6 = 1e-4 + 2e-4 = 3e-4. With one sample, + // the digest's median should equal the inserted value. + wantCost := 100.0*1e-6 + 50.0*4e-6 + if got := cd.Digest.Quantile(0.5); got != wantCost { + t.Errorf("Quantile(0.5) = %v, want %v", got, wantCost) + } +} + +// TestExtract_AccumulatesMultipleSamples verifies that successive responses +// add samples to the same model's digest. With flushInterval=0 every event +// publishes, so the final attribute reflects all samples. +func TestExtract_AccumulatesMultipleSamples(t *testing.T) { + ext, ds := newTestExtractor(t) + setTokenPrices(ds, "m1", 1e-6, 1e-6) + + for i := range 5 { + ev := makeResponseEvent("m1", 100, 100, false) + if err := ext.Extract(context.Background(), []dlsrc.Event{ev}); err != nil { + t.Fatalf("Extract iter %d: %v", i, err) + } + } + + cd, ok := readDigest(ds, "m1") + if !ok { + t.Fatal("expected CostDigest attribute to be present") + } + if cd.Digest.Count() != 5 { + t.Errorf("digest count = %d, want 5", cd.Digest.Count()) + } +} + +// TestExtract_SkipsRequestEvents verifies that RequestEventType events do not +// produce cost samples. (Cost is observable only on the response.) +func TestExtract_SkipsRequestEvents(t *testing.T) { + ext, ds := newTestExtractor(t) + setTokenPrices(ds, "m1", 1e-6, 1e-6) + + req := requesthandling.NewInferenceRequest() + req.Body["model"] = "m1" + ev := dlsrc.Event{Type: dlsrc.RequestEventType, Payload: dlsrc.RequestPayload{Request: req}} + + if err := ext.Extract(context.Background(), []dlsrc.Event{ev}); err != nil { + t.Fatalf("Extract: %v", err) + } + if _, ok := readDigest(ds, "m1"); ok { + t.Error("expected no CostDigest attribute after request-only batch") + } +} + +// TestExtract_SkipsMissingUsage verifies that a response with no usage block +// is skipped without panicking and without publishing a digest. +func TestExtract_SkipsMissingUsage(t *testing.T) { + ext, ds := newTestExtractor(t) + setTokenPrices(ds, "m1", 1e-6, 1e-6) + + ev := makeResponseEvent("m1", 0, 0, true) + if err := ext.Extract(context.Background(), []dlsrc.Event{ev}); err != nil { + t.Fatalf("Extract: %v", err) + } + if _, ok := readDigest(ds, "m1"); ok { + t.Error("expected no CostDigest attribute when usage is missing") + } +} + +// TestExtract_SkipsMissingPromptTokens verifies that a usage block missing +// prompt_tokens is skipped (we do not impute zero — the sample is unusable). +func TestExtract_SkipsMissingPromptTokens(t *testing.T) { + ext, ds := newTestExtractor(t) + setTokenPrices(ds, "m1", 1e-6, 1e-6) + + ev := makeResponseEvent("m1", 0, 50, false) // promptTokens omitted + if err := ext.Extract(context.Background(), []dlsrc.Event{ev}); err != nil { + t.Fatalf("Extract: %v", err) + } + if _, ok := readDigest(ds, "m1"); ok { + t.Error("expected no CostDigest attribute when prompt_tokens is missing") + } +} + +// TestExtract_SkipsMissingCompletionTokens mirrors the above for completion_tokens. +func TestExtract_SkipsMissingCompletionTokens(t *testing.T) { + ext, ds := newTestExtractor(t) + setTokenPrices(ds, "m1", 1e-6, 1e-6) + + ev := makeResponseEvent("m1", 100, 0, false) // completionTokens omitted + if err := ext.Extract(context.Background(), []dlsrc.Event{ev}); err != nil { + t.Fatalf("Extract: %v", err) + } + if _, ok := readDigest(ds, "m1"); ok { + t.Error("expected no CostDigest attribute when completion_tokens is missing") + } +} + +// TestExtract_SkipsModelWithoutTokenPrices verifies that a model that has +// never been seen by the modelconfigcollector (and therefore has no +// TokenPrices attribute) is silently skipped — there is no defined cost. +func TestExtract_SkipsModelWithoutTokenPrices(t *testing.T) { + ext, ds := newTestExtractor(t) + // Note: setTokenPrices NOT called. + + ev := makeResponseEvent("m1", 100, 50, false) + if err := ext.Extract(context.Background(), []dlsrc.Event{ev}); err != nil { + t.Fatalf("Extract: %v", err) + } + if _, ok := readDigest(ds, "m1"); ok { + t.Error("expected no CostDigest attribute when TokenPrices is absent") + } +} + +// TestExtract_FreeModelRecordsZeroSample verifies the locked-in decision: a +// model with TokenPrices{0,0} still produces a sample (cost=0) so CostGuard's +// arm-pull bookkeeping is not skewed by free models. +func TestExtract_FreeModelRecordsZeroSample(t *testing.T) { + ext, ds := newTestExtractor(t) + setTokenPrices(ds, "free", 0, 0) + + ev := makeResponseEvent("free", 100, 50, false) + if err := ext.Extract(context.Background(), []dlsrc.Event{ev}); err != nil { + t.Fatalf("Extract: %v", err) + } + cd, ok := readDigest(ds, "free") + if !ok { + t.Fatal("expected CostDigest attribute even for free model") + } + if cd.Digest.Count() != 1 { + t.Errorf("digest count = %d, want 1", cd.Digest.Count()) + } + if got := cd.Digest.Quantile(0.5); got != 0 { + t.Errorf("Quantile(0.5) = %v, want 0", got) + } +} + +// TestExtract_PerModelIsolation verifies that two models accumulate +// independent digests; a sample for model A does not appear in model B's +// digest. +func TestExtract_PerModelIsolation(t *testing.T) { + ext, ds := newTestExtractor(t) + setTokenPrices(ds, "a", 1e-6, 1e-6) + setTokenPrices(ds, "b", 2e-6, 2e-6) + + batch := []dlsrc.Event{ + makeResponseEvent("a", 100, 100, false), + makeResponseEvent("b", 100, 100, false), + makeResponseEvent("a", 200, 200, false), + } + if err := ext.Extract(context.Background(), batch); err != nil { + t.Fatalf("Extract: %v", err) + } + + a, ok := readDigest(ds, "a") + if !ok { + t.Fatal("expected CostDigest for a") + } + if a.Digest.Count() != 2 { + t.Errorf("a digest count = %d, want 2", a.Digest.Count()) + } + b, ok := readDigest(ds, "b") + if !ok { + t.Fatal("expected CostDigest for b") + } + if b.Digest.Count() != 1 { + t.Errorf("b digest count = %d, want 1", b.Digest.Count()) + } +} + +// TestExtract_RejectsNonFloatTokenFields verifies that usage fields of the +// wrong Go type (e.g. int) are treated as malformed and the sample skipped. +// In production encoding/json decodes JSON numbers as float64 so this never +// happens, but a misconfigured upstream that constructs the map directly in +// Go could trip this — we want the failure mode to be "skip", not "panic". +func TestExtract_RejectsNonFloatTokenFields(t *testing.T) { + ext, ds := newTestExtractor(t) + setTokenPrices(ds, "m1", 1e-6, 1e-6) + + req := requesthandling.NewInferenceRequest() + req.Body["model"] = "m1" + resp := requesthandling.NewInferenceResponse() + // int values, not float64 — the type assertion in extractTokenCounts must reject. + resp.Body["usage"] = map[string]any{"prompt_tokens": 100, "completion_tokens": 50} + ev := dlsrc.Event{ + Type: dlsrc.ResponseEventType, + Payload: dlsrc.ResponsePayload{Request: req, Response: resp}, + } + + if err := ext.Extract(context.Background(), []dlsrc.Event{ev}); err != nil { + t.Fatalf("Extract: %v", err) + } + if _, ok := readDigest(ds, "m1"); ok { + t.Error("expected no CostDigest when usage fields are non-float64") + } +} + +// TestExtract_MultipleEventsSameModelInBatch verifies that a single Extract +// call carrying several response events for the same model adds each cost +// sample to the digest and publishes the snapshot exactly once. The +// end-of-batch publish loop iterates a map keyed by model, so a regression +// that double-published or only retained the last sample would be visible +// here. +func TestExtract_MultipleEventsSameModelInBatch(t *testing.T) { + ext, ds := newTestExtractor(t) + setTokenPrices(ds, "m1", 1e-6, 1e-6) + + batch := []dlsrc.Event{ + makeResponseEvent("m1", 100, 100, false), + makeResponseEvent("m1", 200, 200, false), + makeResponseEvent("m1", 300, 300, false), + } + if err := ext.Extract(context.Background(), batch); err != nil { + t.Fatalf("Extract: %v", err) + } + + cd, ok := readDigest(ds, "m1") + if !ok { + t.Fatal("expected CostDigest for m1") + } + if cd.Digest.Count() != 3 { + t.Errorf("digest count = %d, want 3 (one per response event in the batch)", cd.Digest.Count()) + } +} + +// TestExtract_WrongPayloadTypeForResponseEvent verifies that an event tagged +// as ResponseEventType but carrying a RequestPayload (a programming error +// upstream) is silently skipped via the type-assertion guard rather than +// panicking. Locks the defensive `if !ok { continue }` branch. +func TestExtract_WrongPayloadTypeForResponseEvent(t *testing.T) { + ext, ds := newTestExtractor(t) + setTokenPrices(ds, "m1", 1e-6, 1e-6) + + req := requesthandling.NewInferenceRequest() + req.Body["model"] = "m1" + ev := dlsrc.Event{ + Type: dlsrc.ResponseEventType, + Payload: dlsrc.RequestPayload{Request: req}, // wrong payload type for ResponseEventType + } + + if err := ext.Extract(context.Background(), []dlsrc.Event{ev}); err != nil { + t.Fatalf("Extract: %v", err) + } + if _, ok := readDigest(ds, "m1"); ok { + t.Error("expected no CostDigest when payload type does not match event type") + } +} + +// TestExtract_FlushIntervalDelaysPublish verifies that with a non-zero +// flushInterval, samples accumulate in memory but the AttributeMap snapshot +// is only refreshed once the interval has elapsed since the last publish. +// +// Uses small real durations: 50 ms keeps the test under 100 ms while being +// long enough to be robust on slow CI runners. The flush interval is set +// just longer than the per-iteration sleep to keep the early-iteration +// "no publish yet" assertion reliable. +func TestExtract_FlushIntervalDelaysPublish(t *testing.T) { + ds := datastore.NewFakeDataStore() + setTokenPrices(ds, "m1", 1e-6, 1e-6) + + const flushInterval = 50 * time.Millisecond + ext := NewRequestCostMetadataExtractor(ds).WithFlushInterval(flushInterval) + + // First event creates the accumulator with lastFlush=now and adds a + // sample. No publish: zero time has elapsed since lastFlush. + ev := makeResponseEvent("m1", 100, 100, false) + if err := ext.Extract(context.Background(), []dlsrc.Event{ev}); err != nil { + t.Fatalf("Extract iter 0: %v", err) + } + if _, ok := readDigest(ds, "m1"); ok { + t.Error("expected no published CostDigest before flush interval elapses") + } + + // Wait past the flush interval, then send another event. With elapsed + // >= flushInterval, the publish gate opens and the snapshot lands on + // the AttributeMap with both samples. + time.Sleep(2 * flushInterval) + + if err := ext.Extract(context.Background(), []dlsrc.Event{ev}); err != nil { + t.Fatalf("Extract post-interval: %v", err) + } + cd, ok := readDigest(ds, "m1") + if !ok { + t.Fatal("expected CostDigest after flush interval elapses") + } + if cd.Digest.Count() != 2 { + t.Errorf("digest count = %d, want 2", cd.Digest.Count()) + } +}