Skip to content

Commit 6ee448c

Browse files
authored
feat: implement model-specific suffix resolution for M3 mtp as "EAGLE" (#469)
1 parent b57f844 commit 6ee448c

21 files changed

Lines changed: 296 additions & 46 deletions

packages/app/cypress/component/scatter-graph.cy.tsx

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,4 +246,89 @@ describe('ScatterGraph', () => {
246246
.find('text')
247247
.should('contain.text', 'feature-branch');
248248
});
249+
250+
it('renders M3 mtp rooflines with the EAGLE label (official + overlay)', () => {
251+
const interactivityChartDef = createMockChartDefinition({
252+
chartType: 'interactivity',
253+
y_tpPerGpu_roofline: 'upper_left',
254+
});
255+
const officialData = [
256+
createMockInferenceData({ hwKey: 'h100_vllm_mtp', x: 8, y: 240, precision: Precision.FP4 }),
257+
createMockInferenceData({ hwKey: 'h100_vllm_mtp', x: 16, y: 200, precision: Precision.FP4 }),
258+
createMockInferenceData({ hwKey: 'h100_vllm_mtp', x: 32, y: 150, precision: Precision.FP4 }),
259+
];
260+
// Overlay roofline with no run metadata, so its line label falls back to the
261+
// hw label — exercising the overlay path's model-aware suffix resolution.
262+
const runUrl = 'https://github.com/x/y/actions/runs/999';
263+
const overlayData = {
264+
data: [
265+
createMockInferenceData({
266+
hwKey: 'b200_vllm_mtp',
267+
x: 8,
268+
y: 320,
269+
precision: Precision.FP4,
270+
run_url: runUrl,
271+
}),
272+
createMockInferenceData({
273+
hwKey: 'b200_vllm_mtp',
274+
x: 16,
275+
y: 280,
276+
precision: Precision.FP4,
277+
run_url: runUrl,
278+
}),
279+
createMockInferenceData({
280+
hwKey: 'b200_vllm_mtp',
281+
x: 32,
282+
y: 220,
283+
precision: Precision.FP4,
284+
run_url: runUrl,
285+
}),
286+
],
287+
hardwareConfig: hwConfig,
288+
label: '',
289+
runUrl,
290+
};
291+
292+
mountWithProviders(
293+
<div style={{ width: 800, height: 600 }}>
294+
<ScatterGraph
295+
chartId="test-scatter-m3-eagle"
296+
modelLabel="MiniMax-M3"
297+
data={officialData}
298+
xLabel="Concurrency"
299+
yLabel="Throughput / GPU (tok/s)"
300+
chartDefinition={interactivityChartDef}
301+
overlayData={overlayData}
302+
/>
303+
</div>,
304+
{
305+
inference: {
306+
hardwareConfig: hwConfig,
307+
activeHwTypes: new Set(['h100_vllm_mtp']),
308+
hwTypesWithData: new Set(['h100_vllm_mtp']),
309+
selectedPrecisions: [Precision.FP4],
310+
showLineLabels: true,
311+
},
312+
unofficial: {
313+
activeOverlayHwTypes: new Set(['b200_vllm_mtp']),
314+
allOverlayHwTypes: new Set(['b200_vllm_mtp']),
315+
runIndexByUrl: { [runUrl]: 0, '999': 0 },
316+
// Intentionally empty so the overlay label falls back to the hw label.
317+
unofficialRunInfos: [],
318+
},
319+
},
320+
);
321+
322+
// Official roofline label reads "EAGLE", not the generic "MTP".
323+
cy.get('#test-scatter-m3-eagle svg .line-label')
324+
.filter('[data-line-key]:not([data-line-key^="overlay-"])')
325+
.find('text')
326+
.should('contain.text', 'EAGLE');
327+
// Overlay roofline (no run metadata → hw-label fallback) also reads "EAGLE".
328+
cy.get('#test-scatter-m3-eagle svg .line-label[data-line-key^="overlay-"]')
329+
.find('text')
330+
.should('contain.text', 'EAGLE');
331+
// No label should show the generic MTP token for M3.
332+
cy.get('#test-scatter-m3-eagle svg .line-label text').should('not.contain.text', 'MTP');
333+
});
249334
});

packages/app/cypress/component/submissions-table.cy.tsx

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,18 +32,36 @@ describe('SubmissionsTable — Spec Method column', () => {
3232
cy.contains('th', 'Spec Method').should('be.visible');
3333
});
3434

35-
it('renders spec_method values uppercased and shows an em-dash for "none"', () => {
35+
it('renders canonical spec_method labels and shows an em-dash for "none"', () => {
3636
cy.mount(<SubmissionsTable data={rows} />);
37-
// CSS uppercases the value; the DOM text remains lowercase.
38-
cy.contains('td', 'mtp').should('be.visible').and('have.class', 'uppercase');
39-
cy.contains('td', 'eagle').should('be.visible').and('have.class', 'uppercase');
37+
// The cell renders the canonical spec-method label (MTP/EAGLE) for the model.
38+
cy.contains('td', 'MTP').should('be.visible').and('have.class', 'uppercase');
39+
cy.contains('td', 'EAGLE').should('be.visible').and('have.class', 'uppercase');
4040
// The "none" row renders an em-dash placeholder instead of literal "none".
4141
// Hardware text is rendered uppercase via .toUpperCase().
4242
cy.contains('tbody tr', 'MI355X').within(() => {
4343
cy.contains('—').should('be.visible');
4444
});
4545
});
4646

47+
it('renders M3 mtp as EAGLE, not MTP', () => {
48+
cy.mount(
49+
<SubmissionsTable
50+
data={[
51+
{
52+
...baseRow,
53+
model: 'minimaxm3',
54+
hardware: 'b200',
55+
spec_method: 'mtp',
56+
date: '2026-05-10',
57+
},
58+
]}
59+
/>,
60+
);
61+
cy.contains('td', 'EAGLE').should('be.visible');
62+
cy.contains('td', 'MTP').should('not.exist');
63+
});
64+
4765
it('sorts by spec_method when the header is clicked', () => {
4866
cy.mount(<SubmissionsTable data={rows} />);
4967
// Desc alphabetical: 'none' (mi355x) → 'mtp' (h200) → 'eagle' (b300).

packages/app/src/components/calculator/useThroughputData.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ export function useThroughputData(
7878

7979
const entry = rowToAggDataEntry(row);
8080
const hwKey = getHardwareKey(entry);
81-
const hwConfig = getHardwareConfig(hwKey);
81+
const hwConfig = getHardwareConfig(hwKey, entry.model);
8282
if (!hwConfig) continue;
8383

8484
if (!hwConfigMap[hwKey]) hwConfigMap[hwKey] = { ...hwConfig, name: hwKey };

packages/app/src/components/evaluation/chart-data.ts

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1-
import { DISPLAY_MODEL_TO_DB } from '@semianalysisai/inferencex-constants';
1+
import {
2+
DISPLAY_MODEL_TO_DB,
3+
resolveFrameworkPartLabel,
4+
} from '@semianalysisai/inferencex-constants';
25

36
import type { EvalChangelogEntry, EvaluationChartData } from '@/components/evaluation/types';
47
import type { EvalRow } from '@/lib/api';
@@ -54,10 +57,13 @@ function buildConfigLabel(
5457
conc: number | null,
5558
params: EvalLabelParams,
5659
showPrecision: boolean,
60+
model?: string,
5761
): string {
5862
const headerSuffixes: string[] = [];
5963
if (framework && framework !== '1k8k') headerSuffixes.push(getFrameworkLabel(framework));
60-
if (specMethod && specMethod !== 'none') headerSuffixes.push(getFrameworkLabel(specMethod));
64+
// M3's `mtp` spec token renders as "EAGLE"; every other model keeps "MTP".
65+
if (specMethod && specMethod !== 'none')
66+
headerSuffixes.push(resolveFrameworkPartLabel(model, specMethod));
6167

6268
const detailSuffixes: string[] = [];
6369
if (precision && showPrecision) detailSuffixes.push(precision.toUpperCase());
@@ -128,7 +134,7 @@ export function buildEvaluationChartRows(
128134
return null;
129135
}
130136

131-
const hwConfig = getHardwareConfig(hwKey);
137+
const hwConfig = getHardwareConfig(hwKey, selectedModel);
132138
const hwLabel = hwConfig.label;
133139

134140
return {
@@ -154,6 +160,7 @@ export function buildEvaluationChartRows(
154160
prefillNw: item.prefill_num_workers,
155161
},
156162
showPrecision,
163+
selectedModel,
157164
),
158165
score,
159166
scoreError: item.metrics.em_strict_se ?? item.metrics.score_se ?? 0,
@@ -290,7 +297,7 @@ export function buildEvalChangelogEntries(
290297
})
291298
.map((item) => {
292299
const hwKey = normalizeEvalHardwareKey(item.hardware, item.framework, item.spec_method);
293-
const hwConfig = getHardwareConfig(hwKey);
300+
const hwConfig = getHardwareConfig(hwKey, selectedModel);
294301
const hwLabel = hwConfig.label;
295302
// Changelog labels historically omit TP/EP; keep that behavior while
296303
// still surfacing the disagg marker.
@@ -308,6 +315,7 @@ export function buildEvalChangelogEntries(
308315
decodeDpa: item.decode_dp_attention,
309316
},
310317
showPrecision,
318+
selectedModel,
311319
),
312320
};
313321
});

packages/app/src/components/inference/InferenceContext.tsx

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -311,9 +311,9 @@ export function InferenceProvider({
311311
.toSorted((a, b) => getModelSortIndex(a) - getModelSortIndex(b) || a.localeCompare(b))
312312
.map((hw) => ({
313313
value: hw,
314-
label: getDisplayLabel(getHardwareConfig(hw)),
314+
label: getDisplayLabel(getHardwareConfig(hw, selectedModel)),
315315
}));
316-
}, [availabilityRows, dbModelKeys, effectiveSequence, effectivePrecisions]);
316+
}, [availabilityRows, dbModelKeys, effectiveSequence, effectivePrecisions, selectedModel]);
317317

318318
// --- Tracked config functions ---
319319
const buildTrackedConfigId = useCallback((point: InferenceData): string => {

packages/app/src/components/inference/ui/ChartDisplay.tsx

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -466,7 +466,10 @@ export default function ChartDisplay() {
466466
...visibleData,
467467
...visibleOverlayRows,
468468
]).map((issue) =>
469-
knownIssueCsvNote(issue, getDisplayLabel(getHardwareConfig(issue.hwKey))),
469+
knownIssueCsvNote(
470+
issue,
471+
getDisplayLabel(getHardwareConfig(issue.hwKey, graph.model)),
472+
),
470473
);
471474
exportToCsv(
472475
`InferenceX_${selectedModel}_${graph.chartDefinition.chartType}`,

packages/app/src/components/inference/ui/ComparisonChangelog.tsx

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ import { useMemo, useState } from 'react';
66
import { track } from '@/lib/analytics';
77
import { ExternalLinkIcon } from '@/components/ui/external-link-icon';
88

9+
import { DB_MODEL_TO_DISPLAY } from '@semianalysisai/inferencex-constants';
10+
911
import type { ComparisonChangelog as ComparisonChangelogType } from '@/hooks/api/use-comparison-changelogs';
1012
import {
1113
configKeyMatchesHwKey,
@@ -208,14 +210,18 @@ export default function ComparisonChangelog({
208210
track('inference_comparison_changelog_toggled', { expanded: newState });
209211
};
210212

213+
// All modelDbKeys for a comparison map to one display model, so [0] suffices
214+
// for per-model suffix overrides (e.g. M3 MTP → EAGLE).
215+
const displayModel = DB_MODEL_TO_DISPLAY[modelDbKeys[0]] ?? modelDbKeys[0];
216+
211217
/** Display labels of the selected GPUs that a set of changelog entries touches. */
212218
const gpuLabelsFor = (entries: { config_keys: string[] }[]): string => {
213219
if (selectedGPUs.length <= 1) return '';
214220
return selectedGPUs
215221
.filter((gpu) =>
216222
entries.some((e) => e.config_keys.some((k) => configKeyMatchesHwKey(k, gpu))),
217223
)
218-
.map((gpu) => getDisplayLabel(getHardwareConfig(gpu)))
224+
.map((gpu) => getDisplayLabel(getHardwareConfig(gpu, displayModel)))
219225
.join(', ');
220226
};
221227

packages/app/src/components/inference/ui/GPUGraph.tsx

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ const CHART_MARGIN = { top: 24, right: 10, bottom: 60, left: 60 };
6363
// lookup misses (legacy data).
6464
function labelTextFor(pts: InferenceData[], numbering: Map<string, number>): string {
6565
const hwKey = String(pts[0].hwKey);
66-
const cfg = getHardwareConfig(hwKey);
66+
const cfg = getHardwareConfig(hwKey, pts[0].model);
6767
const hwLabel = cfg ? getDisplayLabel(cfg) : hwKey;
6868
return `${hwLabel}${comparisonEntryLabel(String(pts[0].date), numbering)}`;
6969
}
@@ -266,7 +266,7 @@ const GPUGraph = React.memo(
266266
const knownIssueAnnotations = useMemo(
267267
(): KnownIssueAnnotation[] =>
268268
matchKnownConfigIssues(modelLabel, filteredData).map((issue) => {
269-
const cfg = getHardwareConfig(issue.hwKey);
269+
const cfg = getHardwareConfig(issue.hwKey, modelLabel);
270270
const colorEntry = allGraphs.find(
271271
(entry) => entry.hwKey === issue.hwKey && activeDates.has(entry.id),
272272
);
@@ -835,7 +835,7 @@ const GPUGraph = React.memo(
835835
hw: id,
836836
label: comparisonEntryLabel(date, runNumbering),
837837
color,
838-
title: getDisplayLabel(getHardwareConfig(hwKey)),
838+
title: getDisplayLabel(getHardwareConfig(hwKey, modelLabel)),
839839
isActive: activeDates.has(id),
840840
onClick: () => {
841841
toggleActiveDate(id);

packages/app/src/components/inference/ui/InferenceTable.tsx

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,8 @@ export default function InferenceTable({
5151
() => [
5252
{
5353
header: 'GPU',
54-
cell: (row) => getDisplayLabel(getHardwareConfig(row.hwKey)),
55-
sortValue: (row) => getDisplayLabel(getHardwareConfig(row.hwKey)),
54+
cell: (row) => getDisplayLabel(getHardwareConfig(row.hwKey, row.model)),
55+
sortValue: (row) => getDisplayLabel(getHardwareConfig(row.hwKey, row.model)),
5656
className: 'font-medium whitespace-nowrap',
5757
},
5858
{

packages/app/src/components/inference/ui/ScatterGraph.tsx

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -102,17 +102,23 @@ const formatChangelogDescription = (desc: string | string[]): React.JSX.Element
102102

103103
const CHART_MARGIN = { top: 24, right: 10, bottom: 60, left: 60 };
104104

105-
// Derive a readable label from a hwKey using the HARDWARE_CONFIG source of truth
106-
const parseHwKeyToLabel = (hwKey: string): { name: string; label: string } => {
107-
const config = getHardwareConfig(hwKey);
105+
// Derive a readable label from a hwKey using the HARDWARE_CONFIG source of truth.
106+
// `model` (display name) enables per-model suffix overrides (e.g. M3 MTP → EAGLE).
107+
const parseHwKeyToLabel = (hwKey: string, model?: string): { name: string; label: string } => {
108+
const config = getHardwareConfig(hwKey, model);
108109
return { name: config.label, label: getDisplayLabel(config) };
109110
};
110111

111112
// Line-label text for a curve. When more than one precision is shown, each curve
112113
// is its own line, so append the precision (e.g. "B200 (vLLM) FP8") to keep the
113114
// FP4 and FP8 curves of the same hardware distinguishable.
114-
const lineLabelText = (hwKey: string, precision: string, includePrecision: boolean): string => {
115-
const base = parseHwKeyToLabel(hwKey).label;
115+
const lineLabelText = (
116+
hwKey: string,
117+
precision: string,
118+
includePrecision: boolean,
119+
model?: string,
120+
): string => {
121+
const base = parseHwKeyToLabel(hwKey, model).label;
116122
return includePrecision ? `${base} ${getPrecisionLabel(precision as Precision)}` : base;
117123
};
118124

@@ -366,7 +372,7 @@ const ScatterGraph = React.memo(
366372
const visiblePoints = [...filteredData, ...visibleOverlayPoints];
367373
return matchKnownConfigIssues(modelLabel, visiblePoints).map((issue) => ({
368374
issue,
369-
label: parseHwKeyToLabel(issue.hwKey).label,
375+
label: parseHwKeyToLabel(issue.hwKey, modelLabel).label,
370376
color: getCssColor(resolveColor(issue.hwKey)),
371377
points: visiblePoints
372378
.filter((p) => pointMatchesIssue(issue, p))
@@ -1065,7 +1071,7 @@ const ScatterGraph = React.memo(
10651071
placeLabel(
10661072
entry.key,
10671073
entry.hw,
1068-
lineLabelText(entry.hw, entry.precision, multiPrecision),
1074+
lineLabelText(entry.hw, entry.precision, multiPrecision, modelLabel),
10691075
getCssColor(resolveColor(entry.hw)),
10701076
entry.points,
10711077
);
@@ -1079,7 +1085,7 @@ const ScatterGraph = React.memo(
10791085
lineLabels.push({
10801086
key: entry.key,
10811087
hw: entry.hw,
1082-
label: lineLabelText(entry.hw, entry.precision, multiPrecision),
1088+
label: lineLabelText(entry.hw, entry.precision, multiPrecision, modelLabel),
10831089
color: getCssColor(resolveColor(entry.hw)),
10841090
x: xScale(entry.points[0].x),
10851091
y: yScale(entry.points[0].y),
@@ -1101,7 +1107,7 @@ const ScatterGraph = React.memo(
11011107
const info = unofficialRunInfos[runIndex];
11021108
const base = info
11031109
? `✕ ${info.branch || `run ${info.id}`}`
1104-
: parseHwKeyToLabel(hwKey).label;
1110+
: parseHwKeyToLabel(hwKey, modelLabel).label;
11051111
return multiPrecision
11061112
? `${base} ${getPrecisionLabel(precision as Precision)}`
11071113
: base;
@@ -1144,7 +1150,7 @@ const ScatterGraph = React.memo(
11441150
lineLabels.push({
11451151
key: entry.key,
11461152
hw: entry.hw,
1147-
label: lineLabelText(entry.hw, entry.precision, multiPrecision),
1153+
label: lineLabelText(entry.hw, entry.precision, multiPrecision, modelLabel),
11481154
color: getCssColor(resolveColor(entry.hw)),
11491155
x: xScale(pt.x),
11501156
y: yScale(pt.y),
@@ -1158,7 +1164,7 @@ const ScatterGraph = React.memo(
11581164
const info = unofficialRunInfos[group.runIndex];
11591165
const branchOrHw = info
11601166
? `✕ ${info.branch || `run ${info.id}`}`
1161-
: parseHwKeyToLabel(group.hwKey).label;
1167+
: parseHwKeyToLabel(group.hwKey, modelLabel).label;
11621168
const labelText = multiPrecision
11631169
? `${branchOrHw} ${getPrecisionLabel((group.points[0]?.precision ?? '') as Precision)}`
11641170
: branchOrHw;

0 commit comments

Comments
 (0)