Skip to content

Commit 8eda98d

Browse files
fix(inference): per-precision line labels when multiple precisions selected (#426)
When more than one precision is shown, each precision is its own curve, but line labels were deduplicated by hardware key — so only one of the two curves for a given hardware got a label, and that label omitted the precision. Now, when >1 precision is selected, every curve gets its own line label and the text includes the precision (e.g. "B200 (vLLM) FP8" vs "B200 (vLLM) FP4"). With a single precision selected, behavior is unchanged (one label per hardware, no precision suffix). Applies to both the interactivity (greedy placement) and TTFT/E2EL (endpoint) chart types, in the static render and the zoom re-placement paths. Overlay (unofficial run) line labels also gain the precision suffix so the two precision curves of an overlay stay distinguishable. Co-authored-by: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
1 parent 2a1cfe2 commit 8eda98d

2 files changed

Lines changed: 97 additions & 34 deletions

File tree

packages/app/cypress/e2e/line-labels.cy.ts

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,4 +78,31 @@ describe('Line Labels Toggle', () => {
7878
// Labels should be rendered
7979
cy.get('[data-testid="scatter-graph"] svg g.line-label').should('have.length.greaterThan', 0);
8080
});
81+
82+
it('appends the precision to each line label when multiple precisions are selected', () => {
83+
cy.visit('/inference?i_linelabel=1&i_prec=fp4,fp8', {
84+
onBeforeLoad(win) {
85+
win.localStorage.setItem('inferencex-star-modal-dismissed', String(Date.now()));
86+
},
87+
});
88+
cy.get('[data-testid="scatter-graph"]').should('be.visible');
89+
90+
// With both FP4 and FP8 shown, each curve is its own line and the label
91+
// must carry the precision so the two curves of the same hardware are
92+
// distinguishable (e.g. "B200 (vLLM) FP8" vs "B200 (vLLM) FP4").
93+
cy.get('[data-testid="scatter-graph"] svg g.line-label .ll-text')
94+
.should('have.length.greaterThan', 0)
95+
.then(($texts) => {
96+
const labels = $texts.toArray().map((el) => el.textContent ?? '');
97+
// At least one label for each selected precision.
98+
expect(
99+
labels.some((t) => /\bFP8\b/u.test(t)),
100+
'an FP8 line label exists',
101+
).to.equal(true);
102+
expect(
103+
labels.some((t) => /\bFP4\b/u.test(t)),
104+
'an FP4 line label exists',
105+
).to.equal(true);
106+
});
107+
});
81108
});

packages/app/src/components/inference/ui/ScatterGraph.tsx

Lines changed: 70 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ import ChartLegend from '@/components/ui/chart-legend';
1010
import { useUnofficialRun } from '@/components/unofficial-run-provider';
1111
import { computeToggle } from '@/hooks/useTogglableSet';
1212
import { getHardwareConfig, getModelSortIndex } from '@/lib/constants';
13-
import { getChartWatermark } from '@/lib/data-mappings';
13+
import { getChartWatermark, getPrecisionLabel, type Precision } from '@/lib/data-mappings';
1414
import { formatNumber, getDisplayLabel, updateRepoUrl } from '@/lib/utils';
1515
import { D3Chart } from '@/lib/d3-chart/D3Chart';
1616
import type {
@@ -101,6 +101,14 @@ const parseHwKeyToLabel = (hwKey: string): { name: string; label: string } => {
101101
return { name: config.label, label: getDisplayLabel(config) };
102102
};
103103

104+
// Line-label text for a curve. When more than one precision is shown, each curve
105+
// is its own line, so append the precision (e.g. "B200 (vLLM) FP8") to keep the
106+
// FP4 and FP8 curves of the same hardware distinguishable.
107+
const lineLabelText = (hwKey: string, precision: string, includePrecision: boolean): string => {
108+
const base = parseHwKeyToLabel(hwKey).label;
109+
return includePrecision ? `${base} ${getPrecisionLabel(precision as Precision)}` : base;
110+
};
111+
104112
const ScatterGraph = React.memo(
105113
({
106114
chartId,
@@ -914,6 +922,9 @@ const ScatterGraph = React.memo(
914922

915923
if (showLineLabels) {
916924
const isInteractivity = chartDefinition.chartType === 'interactivity';
925+
// With >1 precision selected each precision is its own curve, so label
926+
// every curve and include the precision in the text.
927+
const multiPrecision = selectedPrecisions.length > 1;
917928
const LABEL_H = 18;
918929
const LABEL_W = 120; // approximate label width for overlap check
919930

@@ -924,16 +935,19 @@ const ScatterGraph = React.memo(
924935
const collides = (cx: number, cy: number) =>
925936
placed.some((p) => Math.abs(p.y - cy) < LABEL_H && Math.abs(p.x - cx) < LABEL_W);
926937

927-
// Deduplicate by hw key — pick the roofline with most points per hw
928-
const bestByHw = new Map<string, (typeof entries)[0]>();
938+
// Deduplicate by group key — one label per curve. With a single
939+
// precision that's one per hw; with multiple it's one per (hw,
940+
// precision) so each precision curve keeps its own label.
941+
const bestByGroup = new Map<string, (typeof entries)[0]>();
929942
for (const e of entries) {
930943
if (!e.visible || e.points.length < 2) continue;
931-
const prev = bestByHw.get(e.hw);
932-
if (!prev || e.points.length > prev.points.length) bestByHw.set(e.hw, e);
944+
const groupKey = multiPrecision ? e.key : e.hw;
945+
const prev = bestByGroup.get(groupKey);
946+
if (!prev || e.points.length > prev.points.length) bestByGroup.set(groupKey, e);
933947
}
934948

935949
// Sort entries by highest y-value first (top of chart) for priority
936-
const sorted = [...bestByHw.values()].toSorted((a, b) => {
950+
const sorted = [...bestByGroup.values()].toSorted((a, b) => {
937951
const ay = yScale(a.points[0].y);
938952
const by = yScale(b.points[0].y);
939953
return ay - by; // smaller pixel y = higher on chart
@@ -948,7 +962,7 @@ const ScatterGraph = React.memo(
948962
pts.at(-1)!, // endpoint
949963
];
950964

951-
const { label } = parseHwKeyToLabel(entry.hw);
965+
const label = lineLabelText(entry.hw, entry.precision, multiPrecision);
952966
let foundPlacement = false;
953967
for (const pt of candidates) {
954968
const px = xScale(pt.x);
@@ -983,33 +997,40 @@ const ScatterGraph = React.memo(
983997
}
984998
}
985999

986-
// Also add hidden entries for non-visible hw (so D3 data-join is clean)
987-
const labeledHw = new Set(lineLabels.map((l) => l.hw));
1000+
// Also add hidden entries for any curve that wasn't placed (so the
1001+
// D3 data-join, keyed by series key, is clean).
1002+
const labeledKeys = new Set(lineLabels.map((l) => l.key));
9881003
for (const entry of entries) {
989-
if (entry.points.length >= 2 && !labeledHw.has(entry.hw)) {
990-
const { label } = parseHwKeyToLabel(entry.hw);
1004+
if (entry.points.length >= 2 && !labeledKeys.has(entry.key)) {
9911005
lineLabels.push({
9921006
key: entry.key,
9931007
hw: entry.hw,
994-
label,
1008+
label: lineLabelText(entry.hw, entry.precision, multiPrecision),
9951009
color: getCssColor(resolveColor(entry.hw)),
9961010
x: xScale(entry.points[0].x),
9971011
y: yScale(entry.points[0].y),
9981012
visible: false,
9991013
});
1000-
labeledHw.add(entry.hw);
1014+
labeledKeys.add(entry.key);
10011015
}
10021016
}
10031017

10041018
// Overlay (unofficial run) rooflines also get line labels using the
10051019
// run-palette color so they match the legend swatches. The label
10061020
// text mirrors the overlay legend ("✕ <branch>" — falls back to the
10071021
// hw label if run metadata isn't available, e.g. legacy callers).
1008-
const overlayLabelText = (runIndex: number, hwKey: string): string => {
1022+
const overlayLabelText = (
1023+
runIndex: number,
1024+
hwKey: string,
1025+
precision: string,
1026+
): string => {
10091027
const info = unofficialRunInfos[runIndex];
1010-
if (!info) return parseHwKeyToLabel(hwKey).label;
1011-
const branch = info.branch || `run ${info.id}`;
1012-
return `✕ ${branch}`;
1028+
const base = info
1029+
? `✕ ${info.branch || `run ${info.id}`}`
1030+
: parseHwKeyToLabel(hwKey).label;
1031+
return multiPrecision
1032+
? `${base} ${getPrecisionLabel(precision as Precision)}`
1033+
: base;
10131034
};
10141035
const sortedOverlay = Object.entries(overlayRooflines)
10151036
.filter(
@@ -1026,7 +1047,11 @@ const ScatterGraph = React.memo(
10261047
pts[Math.max(0, Math.floor((pts.length * 2) / 3))],
10271048
pts.at(-1)!,
10281049
];
1029-
const label = overlayLabelText(group.runIndex, group.hwKey);
1050+
const label = overlayLabelText(
1051+
group.runIndex,
1052+
group.hwKey,
1053+
group.points[0]?.precision ?? '',
1054+
);
10301055
let placedOverlay = false;
10311056
for (const pt of candidates) {
10321057
const px = xScale(pt.x);
@@ -1060,31 +1085,36 @@ const ScatterGraph = React.memo(
10601085
}
10611086
}
10621087
} else {
1063-
// TTFT / E2EL: endpoint labels, one per hw key
1064-
const seenHw = new Set<string>();
1088+
// TTFT / E2EL: endpoint labels, one per curve (per hw, or per
1089+
// (hw, precision) when multiple precisions are shown).
1090+
const seen = new Set<string>();
10651091
for (const entry of entries) {
1066-
if (entry.points.length < 2 || seenHw.has(entry.hw)) continue;
1067-
seenHw.add(entry.hw);
1092+
if (entry.points.length < 2 || !entry.visible) continue;
1093+
const groupKey = multiPrecision ? entry.key : entry.hw;
1094+
if (seen.has(groupKey)) continue;
1095+
seen.add(groupKey);
10681096
const pt = entry.points.at(-1)!;
1069-
const { label } = parseHwKeyToLabel(entry.hw);
10701097
lineLabels.push({
10711098
key: entry.key,
10721099
hw: entry.hw,
1073-
label,
1100+
label: lineLabelText(entry.hw, entry.precision, multiPrecision),
10741101
color: getCssColor(resolveColor(entry.hw)),
10751102
x: xScale(pt.x),
10761103
y: yScale(pt.y),
1077-
visible: entry.visible,
1104+
visible: true,
10781105
});
10791106
}
10801107
// Endpoint labels for overlay rooflines too (one per (hw, runIndex)),
10811108
// labeled with the run's branch name to mirror the overlay legend.
10821109
for (const [ovKey, group] of Object.entries(overlayRooflines)) {
10831110
if (group.points.length < 2 || !activeOverlayHwTypes.has(group.hwKey)) continue;
10841111
const info = unofficialRunInfos[group.runIndex];
1085-
const labelText = info
1112+
const branchOrHw = info
10861113
? `✕ ${info.branch || `run ${info.id}`}`
10871114
: parseHwKeyToLabel(group.hwKey).label;
1115+
const labelText = multiPrecision
1116+
? `${branchOrHw} ${getPrecisionLabel((group.points[0]?.precision ?? '') as Precision)}`
1117+
: branchOrHw;
10881118
const labelKey = `overlay-${ovKey}`;
10891119
const pt = group.points.at(-1)!;
10901120
lineLabels.push({
@@ -1236,6 +1266,7 @@ const ScatterGraph = React.memo(
12361266
// Update line label positions on zoom
12371267
if (showLineLabels) {
12381268
const isInteractivity = chartDefinition.chartType === 'interactivity';
1269+
const multiPrecision = selectedPrecisions.length > 1;
12391270
const LABEL_H = 18;
12401271
const LABEL_W = 120;
12411272

@@ -1245,17 +1276,19 @@ const ScatterGraph = React.memo(
12451276
const collides = (cx: number, cy: number) =>
12461277
placed.some((p) => Math.abs(p.y - cy) < LABEL_H && Math.abs(p.x - cx) < LABEL_W);
12471278

1248-
// Deduplicate by hw key — pick roofline with most points per hw
1249-
const bestByHw = new Map<string, [string, InferenceData[]]>();
1279+
// Deduplicate by group key — one curve per hw, or per (hw, precision)
1280+
// when multiple precisions are shown (mirrors the static render).
1281+
const bestByGroup = new Map<string, [string, InferenceData[]]>();
12501282
for (const [key, pts] of Object.entries(rooflines)) {
12511283
if (pts.length < 2) continue;
12521284
const hw = key.split('_').slice(0, -1).join('_');
12531285
const prec = key.split('_').pop()!;
12541286
if (!effectiveActiveHwTypes.has(hw) || !selectedPrecisions.includes(prec)) continue;
1255-
const prev = bestByHw.get(hw);
1256-
if (!prev || pts.length > prev[1].length) bestByHw.set(hw, [key, pts]);
1287+
const groupKey = multiPrecision ? key : hw;
1288+
const prev = bestByGroup.get(groupKey);
1289+
if (!prev || pts.length > prev[1].length) bestByGroup.set(groupKey, [key, pts]);
12571290
}
1258-
const visibleEntries = [...bestByHw.values()].toSorted(
1291+
const visibleEntries = [...bestByGroup.values()].toSorted(
12591292
([, a], [, b]) => newYScale(a[0].y) - newYScale(b[0].y),
12601293
);
12611294

@@ -1343,12 +1376,15 @@ const ScatterGraph = React.memo(
13431376
y: number;
13441377
}
13451378
const zoomLabels: ZoomLabel[] = [];
1346-
const seenHw = new Set<string>();
1379+
const seen = new Set<string>();
13471380
Object.entries(rooflines).forEach(([key, pts]) => {
13481381
if (pts.length < 2) return;
13491382
const hw = key.split('_').slice(0, -1).join('_');
1350-
if (seenHw.has(hw)) return;
1351-
seenHw.add(hw);
1383+
const prec = key.split('_').pop()!;
1384+
if (!effectiveActiveHwTypes.has(hw) || !selectedPrecisions.includes(prec)) return;
1385+
const groupKey = multiPrecision ? key : hw;
1386+
if (seen.has(groupKey)) return;
1387+
seen.add(groupKey);
13521388
const pt = pts.at(-1)!;
13531389
zoomLabels.push({ key, x: newXScale(pt.x), y: newYScale(pt.y) });
13541390
});

0 commit comments

Comments
 (0)