Skip to content

Commit 2dcc55e

Browse files
mishig25claude
andcommitted
[safetensors] Compute MoE active parameter count
Extend parseSafetensorsMetadata to return a `moe` breakdown for Mixture-of-Experts models, computed from tensor headers + config.json (already fetched by the parser for quantization config). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 20ff03d commit 2dcc55e

2 files changed

Lines changed: 179 additions & 2 deletions

File tree

packages/hub/src/lib/parse-safetensors-metadata.spec.ts

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,63 @@ describe("parseSafetensorsMetadata", () => {
151151
assert.deepStrictEqual(parse.parameterTotal, 109_482_240);
152152
});
153153

154+
it("computes MoE active-params for Mixtral-style per-expert layout", async () => {
155+
const parse = await parseSafetensorsMetadata({
156+
repo: "mistralai/Mixtral-8x7B-v0.1",
157+
computeParametersCount: true,
158+
revision: "fc7ac94680e38d7348cfa806e51218e6273104b0",
159+
});
160+
161+
assert(parse.sharded);
162+
assert(parse.moe, "expected `moe` field on MoE repo");
163+
assert.strictEqual(parse.moe.numExperts, 8);
164+
assert.strictEqual(parse.moe.topK, 2);
165+
assert.strictEqual(parse.moe.hasSharedExpert, false);
166+
// Published: ~12.9B active on 46.7B total. Tolerate small bucket-rounding.
167+
assert.ok(Math.abs(parse.moe.active - 12.88e9) < 0.05e9, `active=${parse.moe.active}`);
168+
assert.ok(parse.moe.alwaysActive > 1e9 && parse.moe.alwaysActive < 2e9);
169+
});
170+
171+
it("computes MoE active-params for stacked-3D layout (Qwen3-30B-A3B)", async () => {
172+
const parse = await parseSafetensorsMetadata({
173+
repo: "Qwen/Qwen3-30B-A3B",
174+
computeParametersCount: true,
175+
revision: "ad44e777bcd18fa416d9da3bd8f70d33ebb85d39",
176+
});
177+
178+
assert(parse.sharded);
179+
assert(parse.moe, "expected `moe` field on MoE repo");
180+
assert.strictEqual(parse.moe.numExperts, 128);
181+
assert.strictEqual(parse.moe.topK, 8);
182+
// Published: A3B (3B active).
183+
assert.ok(Math.abs(parse.moe.active - 3.35e9) < 0.05e9, `active=${parse.moe.active}`);
184+
});
185+
186+
it("detects shared experts (DeepSeek-V2-Lite)", async () => {
187+
const parse = await parseSafetensorsMetadata({
188+
repo: "deepseek-ai/DeepSeek-V2-Lite",
189+
computeParametersCount: true,
190+
revision: "604d5664dddd88a0433dbae533b7fe9472482de0",
191+
});
192+
193+
assert(parse.sharded);
194+
assert(parse.moe, "expected `moe` field on MoE repo");
195+
assert.strictEqual(parse.moe.numExperts, 64);
196+
assert.strictEqual(parse.moe.topK, 6);
197+
assert.strictEqual(parse.moe.hasSharedExpert, true);
198+
});
199+
200+
it("omits `moe` for dense models", async () => {
201+
const parse = await parseSafetensorsMetadata({
202+
repo: "google-bert/bert-base-uncased",
203+
computeParametersCount: true,
204+
revision: "86b5e0934494bd15c9632b12f734a8a67f723594",
205+
});
206+
207+
assert(!parse.sharded);
208+
assert.strictEqual(parse.moe, undefined);
209+
});
210+
154211
it("should detect sharded safetensors filename", async () => {
155212
const safetensorsFilename = "model_00005-of-00072.safetensors"; // https://huggingface.co/bigscience/bloom/blob/4d8e28c67403974b0f17a4ac5992e4ba0b0dbb6f/model_00005-of-00072.safetensors
156213
const safetensorsShardFileInfo = parseSafetensorsShardFilename(safetensorsFilename);

packages/hub/src/lib/parse-safetensors-metadata.ts

Lines changed: 122 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,12 @@ export type SafetensorsParseFromRepo =
9393
header: SafetensorsFileHeader;
9494
parameterCount?: Partial<Record<Dtype, number>>;
9595
parameterTotal?: number;
96+
/**
97+
* For Mixture-of-Experts models: breakdown of routed vs. always-active params,
98+
* computed when `computeParametersCount: true` and the repo's `config.json`
99+
* exposes MoE fields. Undefined for dense models.
100+
*/
101+
moe?: MoeInfo;
96102
filepaths: string[];
97103
}
98104
| {
@@ -101,6 +107,12 @@ export type SafetensorsParseFromRepo =
101107
headers: SafetensorsShardedHeaders;
102108
parameterCount?: Partial<Record<Dtype, number>>;
103109
parameterTotal?: number;
110+
/**
111+
* For Mixture-of-Experts models: breakdown of routed vs. always-active params,
112+
* computed when `computeParametersCount: true` and the repo's `config.json`
113+
* exposes MoE fields. Undefined for dense models.
114+
*/
115+
moe?: MoeInfo;
104116
filepaths: string[];
105117
};
106118

@@ -323,6 +335,7 @@ export async function parseSafetensorsMetadata(
323335
parameterCount: computeNumOfParamsByDtypeSingleFile(header, quantConfig),
324336
/// shortcut: get param count directly from metadata
325337
parameterTotal: parseTotalParameters(header.__metadata__?.total_parameters),
338+
moe: computeMoeInfoFromHeaders([header], modelConfig),
326339
}
327340
: undefined;
328341
return {
@@ -345,6 +358,7 @@ export async function parseSafetensorsMetadata(
345358
parameterCount: computeNumOfParamsByDtypeSharded(shardedMap, quantConfig),
346359
/// shortcut: get param count directly from metadata
347360
parameterTotal: parseTotalParameters(index.metadata?.total_parameters),
361+
moe: computeMoeInfoFromHeaders(Object.values(shardedMap), modelConfig),
348362
}
349363
: undefined;
350364
return {
@@ -370,9 +384,45 @@ export interface QuantizationConfig {
370384
config_groups?: Record<string, { weights?: { num_bits?: number } }>;
371385
}
372386

373-
export interface ModelConfig {
387+
interface MoeConfigFields {
388+
/** Common across Mixtral, Qwen2/3-MoE, Llama4, GPT-OSS, … */
389+
num_experts_per_tok?: number;
390+
/** Alternative spelling (some checkpoints) */
391+
num_experts_per_token?: number;
392+
num_local_experts?: number;
393+
num_experts?: number;
394+
/** DeepSeek family */
395+
n_routed_experts?: number;
396+
n_shared_experts?: number;
397+
/** Multi-modal Ernie 4.5 */
398+
moe_num_shared_experts?: number;
399+
}
400+
401+
export interface ModelConfig extends MoeConfigFields {
374402
quantization_config?: QuantizationConfig;
375-
text_config?: { quantization_config?: QuantizationConfig };
403+
text_config?: { quantization_config?: QuantizationConfig } & MoeConfigFields;
404+
}
405+
406+
/**
407+
* Active-parameter breakdown for Mixture-of-Experts models.
408+
*
409+
* For MoE models, only `topK` of `numExperts` routed experts run per token, so the
410+
* usable ("active") parameter count is much smaller than the total stored on disk.
411+
* `active = alwaysActive + topK * perExpert`. Returned by `parseSafetensorsMetadata`
412+
* when the model's `config.json` exposes MoE fields and tensor names indicate a
413+
* supported expert layout.
414+
*/
415+
export interface MoeInfo {
416+
numExperts: number;
417+
topK: number;
418+
/** Average parameter count per routed expert (= sum-of-routed / numExperts). */
419+
perExpert: number;
420+
/** Everything that runs on every token: embeddings, attention, norms, lm_head, router, shared experts, … */
421+
alwaysActive: number;
422+
/** alwaysActive + topK * perExpert */
423+
active: number;
424+
/** True when the model has a dense shared-expert MLP alongside routed experts (Deepseek, Qwen-MoE, Command-A, …). */
425+
hasSharedExpert: boolean;
376426
}
377427

378428
/**
@@ -473,6 +523,76 @@ function getQuantizationMultiplier(tensorName: string, dtype: Dtype, quantConfig
473523
}
474524
}
475525

526+
function getMoeConfig(config: ModelConfig | null): { topK: number; numExperts: number } | undefined {
527+
if (!config) return undefined;
528+
const sources: MoeConfigFields[] = [config, config.text_config ?? {}];
529+
let topK: number | undefined;
530+
let numExperts: number | undefined;
531+
for (const src of sources) {
532+
topK = topK ?? src.num_experts_per_tok ?? src.num_experts_per_token;
533+
numExperts = numExperts ?? src.num_local_experts ?? src.num_experts ?? src.n_routed_experts;
534+
}
535+
if (!topK || !numExperts || topK <= 0 || numExperts <= 0 || topK > numExperts) return undefined;
536+
return { topK, numExperts };
537+
}
538+
539+
/**
540+
* Decide whether a tensor belongs to a *routed* expert (one that is gated per token).
541+
* Shared/dense experts never match.
542+
*
543+
* Recognized layouts:
544+
* - per-expert legacy: `…experts.{int}.…` (Mixtral, Phi-MoE, OlMoE, Qwen-MoE, …)
545+
* - per-expert with prefix: `…experts.expert_{int}.…` (Switch Transformers)
546+
* - stacked 3D: `…experts.<name>` where shape[0] === numExperts
547+
* (GPT-OSS, modern Mixtral/Qwen/Deepseek in-memory format, GraniteMoE, JetMoE)
548+
*/
549+
function isRoutedExpertTensor(name: string, info: TensorInfo, numExperts: number): boolean {
550+
if (name.includes("shared_expert")) return false;
551+
if (/\.experts\.(?:expert_)?\d+\./.test(name)) return true;
552+
if (/\.experts\.[A-Za-z_][\w]*(?:\.(?:weight|bias))?$/.test(name) && info.shape[0] === numExperts) return true;
553+
return false;
554+
}
555+
556+
function computeMoeInfoFromHeaders(
557+
headers: Iterable<SafetensorsFileHeader>,
558+
config: ModelConfig | null,
559+
): MoeInfo | undefined {
560+
const moeCfg = getMoeConfig(config);
561+
if (!moeCfg) return undefined;
562+
563+
let total = 0;
564+
let routedExpert = 0;
565+
let hasSharedExpert = false;
566+
567+
for (const header of headers) {
568+
for (const [name, value] of Object.entries(header)) {
569+
if (name === "__metadata__") continue;
570+
const info = value as TensorInfo;
571+
if (info.shape.length === 0) continue;
572+
const n = info.shape.reduce((a, b) => a * b, 1);
573+
if (!Number.isFinite(n)) continue;
574+
total += n;
575+
if (isRoutedExpertTensor(name, info, moeCfg.numExperts)) routedExpert += n;
576+
else if (name.includes("shared_expert")) hasSharedExpert = true;
577+
}
578+
}
579+
580+
if (routedExpert === 0) return undefined; // config says MoE but tensors don't look like one — bail safely
581+
582+
const perExpert = routedExpert / moeCfg.numExperts;
583+
const alwaysActive = total - routedExpert;
584+
const active = alwaysActive + moeCfg.topK * perExpert;
585+
586+
return {
587+
numExperts: moeCfg.numExperts,
588+
topK: moeCfg.topK,
589+
perExpert,
590+
alwaysActive,
591+
active,
592+
hasSharedExpert,
593+
};
594+
}
595+
476596
function computeNumOfParamsByDtypeSingleFile(
477597
header: SafetensorsFileHeader,
478598
quantConfig?: QuantizationConfig,

0 commit comments

Comments
 (0)