diff --git a/packages/app/cypress/component/scatter-graph.cy.tsx b/packages/app/cypress/component/scatter-graph.cy.tsx index 91f3d14c..0d935a3b 100644 --- a/packages/app/cypress/component/scatter-graph.cy.tsx +++ b/packages/app/cypress/component/scatter-graph.cy.tsx @@ -246,4 +246,89 @@ describe('ScatterGraph', () => { .find('text') .should('contain.text', 'feature-branch'); }); + + it('renders M3 mtp rooflines with the EAGLE label (official + overlay)', () => { + const interactivityChartDef = createMockChartDefinition({ + chartType: 'interactivity', + y_tpPerGpu_roofline: 'upper_left', + }); + const officialData = [ + createMockInferenceData({ hwKey: 'h100_vllm_mtp', x: 8, y: 240, precision: Precision.FP4 }), + createMockInferenceData({ hwKey: 'h100_vllm_mtp', x: 16, y: 200, precision: Precision.FP4 }), + createMockInferenceData({ hwKey: 'h100_vllm_mtp', x: 32, y: 150, precision: Precision.FP4 }), + ]; + // Overlay roofline with no run metadata, so its line label falls back to the + // hw label — exercising the overlay path's model-aware suffix resolution. + const runUrl = 'https://github.com/x/y/actions/runs/999'; + const overlayData = { + data: [ + createMockInferenceData({ + hwKey: 'b200_vllm_mtp', + x: 8, + y: 320, + precision: Precision.FP4, + run_url: runUrl, + }), + createMockInferenceData({ + hwKey: 'b200_vllm_mtp', + x: 16, + y: 280, + precision: Precision.FP4, + run_url: runUrl, + }), + createMockInferenceData({ + hwKey: 'b200_vllm_mtp', + x: 32, + y: 220, + precision: Precision.FP4, + run_url: runUrl, + }), + ], + hardwareConfig: hwConfig, + label: '', + runUrl, + }; + + mountWithProviders( +
+ +
, + { + inference: { + hardwareConfig: hwConfig, + activeHwTypes: new Set(['h100_vllm_mtp']), + hwTypesWithData: new Set(['h100_vllm_mtp']), + selectedPrecisions: [Precision.FP4], + showLineLabels: true, + }, + unofficial: { + activeOverlayHwTypes: new Set(['b200_vllm_mtp']), + allOverlayHwTypes: new Set(['b200_vllm_mtp']), + runIndexByUrl: { [runUrl]: 0, '999': 0 }, + // Intentionally empty so the overlay label falls back to the hw label. + unofficialRunInfos: [], + }, + }, + ); + + // Official roofline label reads "EAGLE", not the generic "MTP". + cy.get('#test-scatter-m3-eagle svg .line-label') + .filter('[data-line-key]:not([data-line-key^="overlay-"])') + .find('text') + .should('contain.text', 'EAGLE'); + // Overlay roofline (no run metadata → hw-label fallback) also reads "EAGLE". + cy.get('#test-scatter-m3-eagle svg .line-label[data-line-key^="overlay-"]') + .find('text') + .should('contain.text', 'EAGLE'); + // No label should show the generic MTP token for M3. + cy.get('#test-scatter-m3-eagle svg .line-label text').should('not.contain.text', 'MTP'); + }); }); diff --git a/packages/app/cypress/component/submissions-table.cy.tsx b/packages/app/cypress/component/submissions-table.cy.tsx index 950b1041..6aa0f0c5 100644 --- a/packages/app/cypress/component/submissions-table.cy.tsx +++ b/packages/app/cypress/component/submissions-table.cy.tsx @@ -32,11 +32,11 @@ describe('SubmissionsTable — Spec Method column', () => { cy.contains('th', 'Spec Method').should('be.visible'); }); - it('renders spec_method values uppercased and shows an em-dash for "none"', () => { + it('renders canonical spec_method labels and shows an em-dash for "none"', () => { cy.mount(); - // CSS uppercases the value; the DOM text remains lowercase. - cy.contains('td', 'mtp').should('be.visible').and('have.class', 'uppercase'); - cy.contains('td', 'eagle').should('be.visible').and('have.class', 'uppercase'); + // The cell renders the canonical spec-method label (MTP/EAGLE) for the model. + cy.contains('td', 'MTP').should('be.visible').and('have.class', 'uppercase'); + cy.contains('td', 'EAGLE').should('be.visible').and('have.class', 'uppercase'); // The "none" row renders an em-dash placeholder instead of literal "none". // Hardware text is rendered uppercase via .toUpperCase(). cy.contains('tbody tr', 'MI355X').within(() => { @@ -44,6 +44,24 @@ describe('SubmissionsTable — Spec Method column', () => { }); }); + it('renders M3 mtp as EAGLE, not MTP', () => { + cy.mount( + , + ); + cy.contains('td', 'EAGLE').should('be.visible'); + cy.contains('td', 'MTP').should('not.exist'); + }); + it('sorts by spec_method when the header is clicked', () => { cy.mount(); // Desc alphabetical: 'none' (mi355x) → 'mtp' (h200) → 'eagle' (b300). diff --git a/packages/app/src/components/calculator/useThroughputData.ts b/packages/app/src/components/calculator/useThroughputData.ts index 5c324178..b1cd62f5 100644 --- a/packages/app/src/components/calculator/useThroughputData.ts +++ b/packages/app/src/components/calculator/useThroughputData.ts @@ -78,7 +78,7 @@ export function useThroughputData( const entry = rowToAggDataEntry(row); const hwKey = getHardwareKey(entry); - const hwConfig = getHardwareConfig(hwKey); + const hwConfig = getHardwareConfig(hwKey, entry.model); if (!hwConfig) continue; if (!hwConfigMap[hwKey]) hwConfigMap[hwKey] = { ...hwConfig, name: hwKey }; diff --git a/packages/app/src/components/evaluation/chart-data.ts b/packages/app/src/components/evaluation/chart-data.ts index 2fa18473..ff3bfe45 100644 --- a/packages/app/src/components/evaluation/chart-data.ts +++ b/packages/app/src/components/evaluation/chart-data.ts @@ -1,4 +1,7 @@ -import { DISPLAY_MODEL_TO_DB } from '@semianalysisai/inferencex-constants'; +import { + DISPLAY_MODEL_TO_DB, + resolveFrameworkPartLabel, +} from '@semianalysisai/inferencex-constants'; import type { EvalChangelogEntry, EvaluationChartData } from '@/components/evaluation/types'; import type { EvalRow } from '@/lib/api'; @@ -54,10 +57,13 @@ function buildConfigLabel( conc: number | null, params: EvalLabelParams, showPrecision: boolean, + model?: string, ): string { const headerSuffixes: string[] = []; if (framework && framework !== '1k8k') headerSuffixes.push(getFrameworkLabel(framework)); - if (specMethod && specMethod !== 'none') headerSuffixes.push(getFrameworkLabel(specMethod)); + // M3's `mtp` spec token renders as "EAGLE"; every other model keeps "MTP". + if (specMethod && specMethod !== 'none') + headerSuffixes.push(resolveFrameworkPartLabel(model, specMethod)); const detailSuffixes: string[] = []; if (precision && showPrecision) detailSuffixes.push(precision.toUpperCase()); @@ -128,7 +134,7 @@ export function buildEvaluationChartRows( return null; } - const hwConfig = getHardwareConfig(hwKey); + const hwConfig = getHardwareConfig(hwKey, selectedModel); const hwLabel = hwConfig.label; return { @@ -154,6 +160,7 @@ export function buildEvaluationChartRows( prefillNw: item.prefill_num_workers, }, showPrecision, + selectedModel, ), score, scoreError: item.metrics.em_strict_se ?? item.metrics.score_se ?? 0, @@ -290,7 +297,7 @@ export function buildEvalChangelogEntries( }) .map((item) => { const hwKey = normalizeEvalHardwareKey(item.hardware, item.framework, item.spec_method); - const hwConfig = getHardwareConfig(hwKey); + const hwConfig = getHardwareConfig(hwKey, selectedModel); const hwLabel = hwConfig.label; // Changelog labels historically omit TP/EP; keep that behavior while // still surfacing the disagg marker. @@ -308,6 +315,7 @@ export function buildEvalChangelogEntries( decodeDpa: item.decode_dp_attention, }, showPrecision, + selectedModel, ), }; }); diff --git a/packages/app/src/components/inference/InferenceContext.tsx b/packages/app/src/components/inference/InferenceContext.tsx index f8e9f647..2b42650f 100644 --- a/packages/app/src/components/inference/InferenceContext.tsx +++ b/packages/app/src/components/inference/InferenceContext.tsx @@ -311,9 +311,9 @@ export function InferenceProvider({ .toSorted((a, b) => getModelSortIndex(a) - getModelSortIndex(b) || a.localeCompare(b)) .map((hw) => ({ value: hw, - label: getDisplayLabel(getHardwareConfig(hw)), + label: getDisplayLabel(getHardwareConfig(hw, selectedModel)), })); - }, [availabilityRows, dbModelKeys, effectiveSequence, effectivePrecisions]); + }, [availabilityRows, dbModelKeys, effectiveSequence, effectivePrecisions, selectedModel]); // --- Tracked config functions --- const buildTrackedConfigId = useCallback((point: InferenceData): string => { diff --git a/packages/app/src/components/inference/ui/ChartDisplay.tsx b/packages/app/src/components/inference/ui/ChartDisplay.tsx index 3c44a433..882b6f93 100644 --- a/packages/app/src/components/inference/ui/ChartDisplay.tsx +++ b/packages/app/src/components/inference/ui/ChartDisplay.tsx @@ -466,7 +466,10 @@ export default function ChartDisplay() { ...visibleData, ...visibleOverlayRows, ]).map((issue) => - knownIssueCsvNote(issue, getDisplayLabel(getHardwareConfig(issue.hwKey))), + knownIssueCsvNote( + issue, + getDisplayLabel(getHardwareConfig(issue.hwKey, graph.model)), + ), ); exportToCsv( `InferenceX_${selectedModel}_${graph.chartDefinition.chartType}`, diff --git a/packages/app/src/components/inference/ui/ComparisonChangelog.tsx b/packages/app/src/components/inference/ui/ComparisonChangelog.tsx index d4386d98..82c8fb86 100644 --- a/packages/app/src/components/inference/ui/ComparisonChangelog.tsx +++ b/packages/app/src/components/inference/ui/ComparisonChangelog.tsx @@ -6,6 +6,8 @@ import { useMemo, useState } from 'react'; import { track } from '@/lib/analytics'; import { ExternalLinkIcon } from '@/components/ui/external-link-icon'; +import { DB_MODEL_TO_DISPLAY } from '@semianalysisai/inferencex-constants'; + import type { ComparisonChangelog as ComparisonChangelogType } from '@/hooks/api/use-comparison-changelogs'; import { configKeyMatchesHwKey, @@ -208,6 +210,10 @@ export default function ComparisonChangelog({ track('inference_comparison_changelog_toggled', { expanded: newState }); }; + // All modelDbKeys for a comparison map to one display model, so [0] suffices + // for per-model suffix overrides (e.g. M3 MTP → EAGLE). + const displayModel = DB_MODEL_TO_DISPLAY[modelDbKeys[0]] ?? modelDbKeys[0]; + /** Display labels of the selected GPUs that a set of changelog entries touches. */ const gpuLabelsFor = (entries: { config_keys: string[] }[]): string => { if (selectedGPUs.length <= 1) return ''; @@ -215,7 +221,7 @@ export default function ComparisonChangelog({ .filter((gpu) => entries.some((e) => e.config_keys.some((k) => configKeyMatchesHwKey(k, gpu))), ) - .map((gpu) => getDisplayLabel(getHardwareConfig(gpu))) + .map((gpu) => getDisplayLabel(getHardwareConfig(gpu, displayModel))) .join(', '); }; diff --git a/packages/app/src/components/inference/ui/GPUGraph.tsx b/packages/app/src/components/inference/ui/GPUGraph.tsx index b61bc1f0..3a3f2c86 100644 --- a/packages/app/src/components/inference/ui/GPUGraph.tsx +++ b/packages/app/src/components/inference/ui/GPUGraph.tsx @@ -63,7 +63,7 @@ const CHART_MARGIN = { top: 24, right: 10, bottom: 60, left: 60 }; // lookup misses (legacy data). function labelTextFor(pts: InferenceData[], numbering: Map): string { const hwKey = String(pts[0].hwKey); - const cfg = getHardwareConfig(hwKey); + const cfg = getHardwareConfig(hwKey, pts[0].model); const hwLabel = cfg ? getDisplayLabel(cfg) : hwKey; return `${hwLabel} • ${comparisonEntryLabel(String(pts[0].date), numbering)}`; } @@ -266,7 +266,7 @@ const GPUGraph = React.memo( const knownIssueAnnotations = useMemo( (): KnownIssueAnnotation[] => matchKnownConfigIssues(modelLabel, filteredData).map((issue) => { - const cfg = getHardwareConfig(issue.hwKey); + const cfg = getHardwareConfig(issue.hwKey, modelLabel); const colorEntry = allGraphs.find( (entry) => entry.hwKey === issue.hwKey && activeDates.has(entry.id), ); @@ -835,7 +835,7 @@ const GPUGraph = React.memo( hw: id, label: comparisonEntryLabel(date, runNumbering), color, - title: getDisplayLabel(getHardwareConfig(hwKey)), + title: getDisplayLabel(getHardwareConfig(hwKey, modelLabel)), isActive: activeDates.has(id), onClick: () => { toggleActiveDate(id); diff --git a/packages/app/src/components/inference/ui/InferenceTable.tsx b/packages/app/src/components/inference/ui/InferenceTable.tsx index c300e60d..5bdb713c 100644 --- a/packages/app/src/components/inference/ui/InferenceTable.tsx +++ b/packages/app/src/components/inference/ui/InferenceTable.tsx @@ -51,8 +51,8 @@ export default function InferenceTable({ () => [ { header: 'GPU', - cell: (row) => getDisplayLabel(getHardwareConfig(row.hwKey)), - sortValue: (row) => getDisplayLabel(getHardwareConfig(row.hwKey)), + cell: (row) => getDisplayLabel(getHardwareConfig(row.hwKey, row.model)), + sortValue: (row) => getDisplayLabel(getHardwareConfig(row.hwKey, row.model)), className: 'font-medium whitespace-nowrap', }, { diff --git a/packages/app/src/components/inference/ui/ScatterGraph.tsx b/packages/app/src/components/inference/ui/ScatterGraph.tsx index fec5be80..68e9cc14 100644 --- a/packages/app/src/components/inference/ui/ScatterGraph.tsx +++ b/packages/app/src/components/inference/ui/ScatterGraph.tsx @@ -102,17 +102,23 @@ const formatChangelogDescription = (desc: string | string[]): React.JSX.Element const CHART_MARGIN = { top: 24, right: 10, bottom: 60, left: 60 }; -// Derive a readable label from a hwKey using the HARDWARE_CONFIG source of truth -const parseHwKeyToLabel = (hwKey: string): { name: string; label: string } => { - const config = getHardwareConfig(hwKey); +// Derive a readable label from a hwKey using the HARDWARE_CONFIG source of truth. +// `model` (display name) enables per-model suffix overrides (e.g. M3 MTP → EAGLE). +const parseHwKeyToLabel = (hwKey: string, model?: string): { name: string; label: string } => { + const config = getHardwareConfig(hwKey, model); return { name: config.label, label: getDisplayLabel(config) }; }; // Line-label text for a curve. When more than one precision is shown, each curve // is its own line, so append the precision (e.g. "B200 (vLLM) FP8") to keep the // FP4 and FP8 curves of the same hardware distinguishable. -const lineLabelText = (hwKey: string, precision: string, includePrecision: boolean): string => { - const base = parseHwKeyToLabel(hwKey).label; +const lineLabelText = ( + hwKey: string, + precision: string, + includePrecision: boolean, + model?: string, +): string => { + const base = parseHwKeyToLabel(hwKey, model).label; return includePrecision ? `${base} ${getPrecisionLabel(precision as Precision)}` : base; }; @@ -366,7 +372,7 @@ const ScatterGraph = React.memo( const visiblePoints = [...filteredData, ...visibleOverlayPoints]; return matchKnownConfigIssues(modelLabel, visiblePoints).map((issue) => ({ issue, - label: parseHwKeyToLabel(issue.hwKey).label, + label: parseHwKeyToLabel(issue.hwKey, modelLabel).label, color: getCssColor(resolveColor(issue.hwKey)), points: visiblePoints .filter((p) => pointMatchesIssue(issue, p)) @@ -1065,7 +1071,7 @@ const ScatterGraph = React.memo( placeLabel( entry.key, entry.hw, - lineLabelText(entry.hw, entry.precision, multiPrecision), + lineLabelText(entry.hw, entry.precision, multiPrecision, modelLabel), getCssColor(resolveColor(entry.hw)), entry.points, ); @@ -1079,7 +1085,7 @@ const ScatterGraph = React.memo( lineLabels.push({ key: entry.key, hw: entry.hw, - label: lineLabelText(entry.hw, entry.precision, multiPrecision), + label: lineLabelText(entry.hw, entry.precision, multiPrecision, modelLabel), color: getCssColor(resolveColor(entry.hw)), x: xScale(entry.points[0].x), y: yScale(entry.points[0].y), @@ -1101,7 +1107,7 @@ const ScatterGraph = React.memo( const info = unofficialRunInfos[runIndex]; const base = info ? `✕ ${info.branch || `run ${info.id}`}` - : parseHwKeyToLabel(hwKey).label; + : parseHwKeyToLabel(hwKey, modelLabel).label; return multiPrecision ? `${base} ${getPrecisionLabel(precision as Precision)}` : base; @@ -1144,7 +1150,7 @@ const ScatterGraph = React.memo( lineLabels.push({ key: entry.key, hw: entry.hw, - label: lineLabelText(entry.hw, entry.precision, multiPrecision), + label: lineLabelText(entry.hw, entry.precision, multiPrecision, modelLabel), color: getCssColor(resolveColor(entry.hw)), x: xScale(pt.x), y: yScale(pt.y), @@ -1158,7 +1164,7 @@ const ScatterGraph = React.memo( const info = unofficialRunInfos[group.runIndex]; const branchOrHw = info ? `✕ ${info.branch || `run ${info.id}`}` - : parseHwKeyToLabel(group.hwKey).label; + : parseHwKeyToLabel(group.hwKey, modelLabel).label; const labelText = multiPrecision ? `${branchOrHw} ${getPrecisionLabel((group.points[0]?.precision ?? '') as Precision)}` : branchOrHw; diff --git a/packages/app/src/components/inference/ui/UnofficialChartDisplay.tsx b/packages/app/src/components/inference/ui/UnofficialChartDisplay.tsx index f9b1b3c8..799854d7 100644 --- a/packages/app/src/components/inference/ui/UnofficialChartDisplay.tsx +++ b/packages/app/src/components/inference/ui/UnofficialChartDisplay.tsx @@ -106,7 +106,7 @@ export function UnofficialChartDisplay() { hardwareConfig: Object.fromEntries( Object.entries(dataForChart.gpus || {}).map(([k, v]) => [ k, - { ...getHardwareConfig(k), ...v }, + { ...getHardwareConfig(k, selectedModel), ...v }, ]), ), }; diff --git a/packages/app/src/components/inference/utils/changelogFormatters.test.ts b/packages/app/src/components/inference/utils/changelogFormatters.test.ts index 79910096..905e8c89 100644 --- a/packages/app/src/components/inference/utils/changelogFormatters.test.ts +++ b/packages/app/src/components/inference/utils/changelogFormatters.test.ts @@ -17,6 +17,13 @@ describe('formatConfigKeys', () => { expect(result).toContain('FP8'); }); + it('renders M3 mtp as EAGLE (not MTP)', () => { + const result = formatConfigKeys('minimaxm3-fp8-h100-vllm-mtp'); + expect(result).toContain('H100'); + expect(result).toContain('EAGLE'); + expect(result).not.toContain('MTP'); + }); + it('formats compound framework names', () => { const result = formatConfigKeys('gptoss-fp4-b200-dynamo-sglang'); expect(result).toContain('B200'); diff --git a/packages/app/src/components/inference/utils/changelogFormatters.tsx b/packages/app/src/components/inference/utils/changelogFormatters.tsx index 7574f252..aae20a0b 100644 --- a/packages/app/src/components/inference/utils/changelogFormatters.tsx +++ b/packages/app/src/components/inference/utils/changelogFormatters.tsx @@ -1,4 +1,7 @@ -import { resolveFrameworkAliasesInString } from '@semianalysisai/inferencex-constants'; +import { + resolveFrameworkAliasesInString, + resolveFrameworkPartLabel, +} from '@semianalysisai/inferencex-constants'; import { type Precision, MODEL_PREFIX_MAPPING, getPrecisionLabel } from '@/lib/data-mappings'; import { getFrameworkLabel } from '@/lib/utils'; @@ -45,6 +48,8 @@ export function formatConfigKeys(key: string) { const isMtp = framework.endsWith('-mtp'); const baseFramework = isMtp ? framework.slice(0, -4) : framework; const baseLabel = getFrameworkLabel(baseFramework); - const frameworkLabel = isMtp ? `${baseLabel}, MTP` : baseLabel; + // M3's `mtp` spec token renders as "EAGLE"; every other model keeps "MTP". + const mtpLabel = resolveFrameworkPartLabel(MODEL_PREFIX_MAPPING[model], 'mtp'); + const frameworkLabel = isMtp ? `${baseLabel}, ${mtpLabel}` : baseLabel; return `${gpu.toUpperCase()} (${frameworkLabel}) ${MODEL_PREFIX_MAPPING[model]} ${getPrecisionLabel(precision as Precision)}`; } diff --git a/packages/app/src/components/submissions/SubmissionsTable.tsx b/packages/app/src/components/submissions/SubmissionsTable.tsx index d7845e30..5dd56f47 100644 --- a/packages/app/src/components/submissions/SubmissionsTable.tsx +++ b/packages/app/src/components/submissions/SubmissionsTable.tsx @@ -3,6 +3,11 @@ import { ChevronDown, ChevronRight, GitCompare, Info } from 'lucide-react'; import { useCallback, useEffect, useMemo, useState } from 'react'; +import { + DB_MODEL_TO_DISPLAY, + resolveFrameworkPartLabel, +} from '@semianalysisai/inferencex-constants'; + import { track } from '@/lib/analytics'; import { MODEL_PREFIX_MAPPING, getModelLabel } from '@/lib/data-mappings'; import type { SubmissionSummaryRow } from '@/lib/submissions-types'; @@ -302,7 +307,7 @@ function SubmissionRow({ {row.precision} {row.spec_method && row.spec_method !== 'none' ? ( - row.spec_method + resolveFrameworkPartLabel(DB_MODEL_TO_DISPLAY[row.model], row.spec_method) ) : ( )} @@ -368,7 +373,9 @@ function SubmissionRow({ label="Spec Method:" tip="Speculative decoding method (e.g. MTP, Eagle)" > - {row.spec_method || 'none'} + {row.spec_method && row.spec_method !== 'none' + ? resolveFrameworkPartLabel(DB_MODEL_TO_DISPLAY[row.model], row.spec_method) + : 'none'} { if (v <= 0 || !isFinite(mins[i]) || maxs[i] === mins[i]) return null; const norm = (v - mins[i]) / (maxs[i] - mins[i]); diff --git a/packages/app/src/lib/benchmark-transform.test.ts b/packages/app/src/lib/benchmark-transform.test.ts index b49fae39..8f27cc8f 100644 --- a/packages/app/src/lib/benchmark-transform.test.ts +++ b/packages/app/src/lib/benchmark-transform.test.ts @@ -302,6 +302,26 @@ describe('transformBenchmarkRows', () => { const hwKeys = Object.keys(hardwareConfig); expect(hwKeys.length).toBeGreaterThanOrEqual(2); }); + + it('labels M3 mtp configs with the "M3 EAGLE" suffix', () => { + const rows = [ + makeRow({ model: 'minimaxm3', hardware: 'h100', framework: 'vllm', spec_method: 'mtp' }), + ]; + const { hardwareConfig } = transformBenchmarkRows(rows); + const entry = hardwareConfig['h100_vllm_mtp']; + expect(entry).toBeDefined(); + expect(entry.suffix).toBe('(vLLM, M3 EAGLE)'); + }); + + it('keeps the generic MTP suffix for non-M3 mtp configs', () => { + const rows = [ + makeRow({ model: 'dsr1', hardware: 'h200', framework: 'sglang', spec_method: 'mtp' }), + ]; + const { hardwareConfig } = transformBenchmarkRows(rows); + const entry = hardwareConfig['h200_sglang_mtp']; + expect(entry).toBeDefined(); + expect(entry.suffix).toBe('(SGLang, MTP)'); + }); }); // --------------------------------------------------------------------------- diff --git a/packages/app/src/lib/benchmark-transform.ts b/packages/app/src/lib/benchmark-transform.ts index 87b48558..ac806b79 100644 --- a/packages/app/src/lib/benchmark-transform.ts +++ b/packages/app/src/lib/benchmark-transform.ts @@ -122,7 +122,7 @@ export function transformBenchmarkRows(rows: BenchmarkRow[]): { entry.hwKey = hwKey; if (!hwConfigCache.has(hwKey)) { - const hwConfig = getHardwareConfig(hwKey); + const hwConfig = getHardwareConfig(hwKey, entry.model); hwConfigCache.set(hwKey, hwConfig); if (hwConfig) gpuConfig[hwKey] = { ...hwConfig, name: hwKey }; } diff --git a/packages/app/src/lib/constants.test.ts b/packages/app/src/lib/constants.test.ts index 24fb371c..dd849319 100644 --- a/packages/app/src/lib/constants.test.ts +++ b/packages/app/src/lib/constants.test.ts @@ -154,6 +154,27 @@ describe('getHardwareConfig', () => { warnSpy.mockRestore(); }); + it('renders M3 mtp configs with the "M3 EAGLE" suffix when model is passed', () => { + const config = getHardwareConfig('h100_vllm_mtp', 'MiniMax-M3'); + expect(config.suffix).toBe('(vLLM, M3 EAGLE)'); + expect(config.gpu).toContain('M3 EAGLE'); + }); + + it('keeps the generic MTP suffix for non-M3 models', () => { + expect(getHardwareConfig('h100_sglang_mtp', 'DeepSeek-R1-0528').suffix).toBe('(SGLang, MTP)'); + }); + + it('keeps the generic MTP suffix when no model is passed (backward compatible)', () => { + expect(getHardwareConfig('h100_vllm_mtp').suffix).toBe('(vLLM, MTP)'); + }); + + it('does not let the model-scoped cache leak into the model-less lookup', () => { + // Prime the cache with the M3 (EAGLE) variant first, then ensure the + // model-less call still returns the generic label. + getHardwareConfig('b200_vllm_mtp', 'MiniMax-M3'); + expect(getHardwareConfig('b200_vllm_mtp').suffix).toBe('(vLLM, MTP)'); + }); + it('HW_REGISTRY has non-zero power for all entries', () => { for (const entry of Object.values(HW_REGISTRY)) { expect(entry.power).toBeGreaterThan(0); diff --git a/packages/app/src/lib/constants.ts b/packages/app/src/lib/constants.ts index a3987de3..a720077d 100644 --- a/packages/app/src/lib/constants.ts +++ b/packages/app/src/lib/constants.ts @@ -1,4 +1,4 @@ -import { FRAMEWORK_LABELS, HW_REGISTRY } from '@semianalysisai/inferencex-constants'; +import { HW_REGISTRY, resolveFrameworkPartLabel } from '@semianalysisai/inferencex-constants'; /** d3.schemeTableau10 — 10-color categorical palette for tracked configs. */ export const TABLEAU_10 = [ @@ -59,8 +59,11 @@ const UNKNOWN_HARDWARE: HardwareEntry = { /** * Build a hardware config entry from a key like "h100_dynamo-trt_mtp". * Derives all display fields from GPU_KEYS/GPU_VENDORS + FRAMEWORK_LABELS. + * + * `model` (frontend display name) enables per-model suffix overrides — e.g. M3's + * `mtp` token renders as "EAGLE" rather than the generic "MTP". */ -function buildHardwareEntry(hwKey: string): HardwareEntry | null { +function buildHardwareEntry(hwKey: string, model?: string): HardwareEntry | null { const base = hwKey.split('_')[0]; const reg = HW_REGISTRY[base]; if (!reg) return null; @@ -68,7 +71,7 @@ function buildHardwareEntry(hwKey: string): HardwareEntry | null { const parts = hwKey.split('_').slice(1); const label = reg.label; const gpuName = base.toUpperCase(); // always raw uppercase for gpu string - const partLabels = parts.map((p) => FRAMEWORK_LABELS[p] ?? p.toUpperCase()); + const partLabels = parts.map((p) => resolveFrameworkPartLabel(model, p)); return { name: hwKey.replaceAll('_', '-'), @@ -132,14 +135,18 @@ const hwCache = new Map(); /** * Get hardware config for a GPU key, building it dynamically from shared GPU constants + FRAMEWORK_LABELS. * Returns UNKNOWN_HARDWARE for unrecognized base GPUs. + * + * Pass `model` (frontend display name) to apply per-model suffix overrides + * (e.g. M3 `mtp` → "EAGLE"). Omitting it preserves the generic labels. */ -export function getHardwareConfig(hwKey: string): HardwareEntry { - const cached = hwCache.get(hwKey); +export function getHardwareConfig(hwKey: string, model?: string): HardwareEntry { + const cacheKey = model ? `${model}|${hwKey}` : hwKey; + const cached = hwCache.get(cacheKey); if (cached) return cached; - const entry = buildHardwareEntry(hwKey); + const entry = buildHardwareEntry(hwKey, model); if (entry) { - hwCache.set(hwKey, entry); + hwCache.set(cacheKey, entry); return entry; } diff --git a/packages/constants/src/framework-aliases.test.ts b/packages/constants/src/framework-aliases.test.ts index 5f819312..a28bea34 100644 --- a/packages/constants/src/framework-aliases.test.ts +++ b/packages/constants/src/framework-aliases.test.ts @@ -3,8 +3,10 @@ import { describe, expect, it } from 'vitest'; import { FRAMEWORK_ALIASES, FRAMEWORK_LABELS, + MODEL_SPEC_METHOD_LABELS, resolveFrameworkAlias, resolveFrameworkAliasesInString, + resolveFrameworkPartLabel, } from './framework-aliases'; describe('FRAMEWORK_LABELS', () => { @@ -17,6 +19,34 @@ describe('FRAMEWORK_LABELS', () => { }); }); +describe('MODEL_SPEC_METHOD_LABELS', () => { + it('maps MiniMax-M3 mtp to "M3 EAGLE"', () => { + expect(MODEL_SPEC_METHOD_LABELS['MiniMax-M3']?.mtp).toBe('M3 EAGLE'); + }); +}); + +describe('resolveFrameworkPartLabel', () => { + it('renders M3 mtp as "M3 EAGLE"', () => { + expect(resolveFrameworkPartLabel('MiniMax-M3', 'mtp')).toBe('M3 EAGLE'); + }); + + it('keeps the generic MTP label for other models', () => { + expect(resolveFrameworkPartLabel('DeepSeek-R1-0528', 'mtp')).toBe('MTP'); + }); + + it('keeps the generic MTP label when no model is provided', () => { + expect(resolveFrameworkPartLabel(undefined, 'mtp')).toBe('MTP'); + }); + + it('falls back to FRAMEWORK_LABELS for non-overridden parts even for M3', () => { + expect(resolveFrameworkPartLabel('MiniMax-M3', 'vllm')).toBe('vLLM'); + }); + + it('uppercases unknown tokens', () => { + expect(resolveFrameworkPartLabel('MiniMax-M3', 'foo')).toBe('FOO'); + }); +}); + describe('FRAMEWORK_ALIASES', () => { it('maps sglang-disagg to mori-sglang with disagg=true', () => { expect(FRAMEWORK_ALIASES['sglang-disagg']).toEqual({ canonical: 'mori-sglang', disagg: true }); diff --git a/packages/constants/src/framework-aliases.ts b/packages/constants/src/framework-aliases.ts index 6794069a..208fbd57 100644 --- a/packages/constants/src/framework-aliases.ts +++ b/packages/constants/src/framework-aliases.ts @@ -48,6 +48,31 @@ export const FRAMEWORK_LABELS: Record = { mtp: 'MTP', }; +/** + * Per-model display overrides for hwKey suffix parts (framework / spec_method + * tokens), keyed by frontend display model name → token → label. + * + * M3's speculative-decoding runs are ingested under the generic `mtp` token but + * actually use EAGLE, so for MiniMax-M3 the suffix reads "EAGLE" while every + * other model keeps the generic "MTP" label. + */ +export const MODEL_SPEC_METHOD_LABELS: Record> = { + 'MiniMax-M3': { mtp: 'M3 EAGLE' }, +}; + +/** + * Resolve a single hwKey suffix part (framework or spec_method token) to its + * display label, applying any per-model override before the generic + * FRAMEWORK_LABELS map. `model` is the frontend display model name. + */ +export function resolveFrameworkPartLabel(model: string | undefined, part: string): string { + return ( + (model ? MODEL_SPEC_METHOD_LABELS[model]?.[part] : undefined) ?? + FRAMEWORK_LABELS[part] ?? + part.toUpperCase() + ); +} + /** * Resolve a framework name to its canonical form. * Returns the input lowercased if no alias exists.