Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 57 additions & 0 deletions packages/hub/src/lib/parse-safetensors-metadata.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,63 @@ describe("parseSafetensorsMetadata", () => {
assert.deepStrictEqual(parse.parameterTotal, 109_482_240);
});

it("computes MoE active-params for Mixtral-style per-expert layout", async () => {
const parse = await parseSafetensorsMetadata({
repo: "mistralai/Mixtral-8x7B-v0.1",
computeParametersCount: true,
revision: "fc7ac94680e38d7348cfa806e51218e6273104b0",
});

assert(parse.sharded);
assert(parse.moe, "expected `moe` field on MoE repo");
assert.strictEqual(parse.moe.numExperts, 8);
assert.strictEqual(parse.moe.topK, 2);
assert.strictEqual(parse.moe.hasSharedExpert, false);
// Published: ~12.9B active on 46.7B total. Tolerate small bucket-rounding.
assert.ok(Math.abs(parse.moe.active - 12.88e9) < 0.05e9, `active=${parse.moe.active}`);
assert.ok(parse.moe.alwaysActive > 1e9 && parse.moe.alwaysActive < 2e9);
});

it("computes MoE active-params for stacked-3D layout (Qwen3-30B-A3B)", async () => {
const parse = await parseSafetensorsMetadata({
repo: "Qwen/Qwen3-30B-A3B",
computeParametersCount: true,
revision: "ad44e777bcd18fa416d9da3bd8f70d33ebb85d39",
});

assert(parse.sharded);
assert(parse.moe, "expected `moe` field on MoE repo");
assert.strictEqual(parse.moe.numExperts, 128);
assert.strictEqual(parse.moe.topK, 8);
// Published: A3B (3B active).
assert.ok(Math.abs(parse.moe.active - 3.35e9) < 0.05e9, `active=${parse.moe.active}`);
});

it("detects shared experts (DeepSeek-V2-Lite)", async () => {
const parse = await parseSafetensorsMetadata({
repo: "deepseek-ai/DeepSeek-V2-Lite",
computeParametersCount: true,
revision: "604d5664dddd88a0433dbae533b7fe9472482de0",
});

assert(parse.sharded);
assert(parse.moe, "expected `moe` field on MoE repo");
assert.strictEqual(parse.moe.numExperts, 64);
assert.strictEqual(parse.moe.topK, 6);
assert.strictEqual(parse.moe.hasSharedExpert, true);
});

it("omits `moe` for dense models", async () => {
const parse = await parseSafetensorsMetadata({
repo: "google-bert/bert-base-uncased",
computeParametersCount: true,
revision: "86b5e0934494bd15c9632b12f734a8a67f723594",
});

assert(!parse.sharded);
assert.strictEqual(parse.moe, undefined);
});

it("should detect sharded safetensors filename", async () => {
const safetensorsFilename = "model_00005-of-00072.safetensors"; // https://huggingface.co/bigscience/bloom/blob/4d8e28c67403974b0f17a4ac5992e4ba0b0dbb6f/model_00005-of-00072.safetensors
const safetensorsShardFileInfo = parseSafetensorsShardFilename(safetensorsFilename);
Expand Down
124 changes: 122 additions & 2 deletions packages/hub/src/lib/parse-safetensors-metadata.ts
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,12 @@ export type SafetensorsParseFromRepo =
header: SafetensorsFileHeader;
parameterCount?: Partial<Record<Dtype, number>>;
parameterTotal?: number;
/**
* For Mixture-of-Experts models: breakdown of routed vs. always-active params,
* computed when `computeParametersCount: true` and the repo's `config.json`
* exposes MoE fields. Undefined for dense models.
*/
moe?: MoeInfo;
filepaths: string[];
}
| {
Expand All @@ -101,6 +107,12 @@ export type SafetensorsParseFromRepo =
headers: SafetensorsShardedHeaders;
parameterCount?: Partial<Record<Dtype, number>>;
parameterTotal?: number;
/**
* For Mixture-of-Experts models: breakdown of routed vs. always-active params,
* computed when `computeParametersCount: true` and the repo's `config.json`
* exposes MoE fields. Undefined for dense models.
*/
moe?: MoeInfo;
filepaths: string[];
};

Expand Down Expand Up @@ -323,6 +335,7 @@ export async function parseSafetensorsMetadata(
parameterCount: computeNumOfParamsByDtypeSingleFile(header, quantConfig),
/// shortcut: get param count directly from metadata
parameterTotal: parseTotalParameters(header.__metadata__?.total_parameters),
moe: computeMoeInfoFromHeaders([header], modelConfig),
}
: undefined;
return {
Expand All @@ -345,6 +358,7 @@ export async function parseSafetensorsMetadata(
parameterCount: computeNumOfParamsByDtypeSharded(shardedMap, quantConfig),
/// shortcut: get param count directly from metadata
parameterTotal: parseTotalParameters(index.metadata?.total_parameters),
moe: computeMoeInfoFromHeaders(Object.values(shardedMap), modelConfig),
}
: undefined;
return {
Expand All @@ -370,9 +384,45 @@ export interface QuantizationConfig {
config_groups?: Record<string, { weights?: { num_bits?: number } }>;
}

export interface ModelConfig {
interface MoeConfigFields {
/** Common across Mixtral, Qwen2/3-MoE, Llama4, GPT-OSS, … */
num_experts_per_tok?: number;
/** Alternative spelling (some checkpoints) */
num_experts_per_token?: number;
num_local_experts?: number;
num_experts?: number;
/** DeepSeek family */
n_routed_experts?: number;
n_shared_experts?: number;
/** Multi-modal Ernie 4.5 */
moe_num_shared_experts?: number;
}

export interface ModelConfig extends MoeConfigFields {
quantization_config?: QuantizationConfig;
text_config?: { quantization_config?: QuantizationConfig };
text_config?: { quantization_config?: QuantizationConfig } & MoeConfigFields;
}

/**
* Active-parameter breakdown for Mixture-of-Experts models.
*
* For MoE models, only `topK` of `numExperts` routed experts run per token, so the
* usable ("active") parameter count is much smaller than the total stored on disk.
* `active = alwaysActive + topK * perExpert`. Returned by `parseSafetensorsMetadata`
* when the model's `config.json` exposes MoE fields and tensor names indicate a
* supported expert layout.
*/
export interface MoeInfo {
numExperts: number;
topK: number;
/** Average parameter count per routed expert (= sum-of-routed / numExperts). */
perExpert: number;
/** Everything that runs on every token: embeddings, attention, norms, lm_head, router, shared experts, … */
alwaysActive: number;
/** alwaysActive + topK * perExpert */
active: number;
/** True when the model has a dense shared-expert MLP alongside routed experts (Deepseek, Qwen-MoE, Command-A, …). */
hasSharedExpert: boolean;
}
Comment on lines +387 to 426
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

define those earlier in the file (with other types) please


/**
Expand Down Expand Up @@ -473,6 +523,76 @@ function getQuantizationMultiplier(tensorName: string, dtype: Dtype, quantConfig
}
}

function getMoeConfig(config: ModelConfig | null): Pick<MoeInfo, "topK" | "numExperts"> | undefined {
if (!config) return undefined;
const sources: MoeConfigFields[] = [config, config.text_config ?? {}];
let topK: number | undefined;
let numExperts: number | undefined;
for (const src of sources) {
topK = topK ?? src.num_experts_per_tok ?? src.num_experts_per_token;
numExperts = numExperts ?? src.num_local_experts ?? src.num_experts ?? src.n_routed_experts;
}
if (!topK || !numExperts || topK <= 0 || numExperts <= 0 || topK > numExperts) return undefined;
return { topK, numExperts };
}

/**
* Decide whether a tensor belongs to a *routed* expert (one that is gated per token).
* Shared/dense experts never match.
*
* Recognized layouts:
* - per-expert legacy: `…experts.{int}.…` (Mixtral, Phi-MoE, OlMoE, Qwen-MoE, …)
* - per-expert with prefix: `…experts.expert_{int}.…` (Switch Transformers)
* - stacked 3D: `…experts.<name>` where shape[0] === numExperts
* (GPT-OSS, modern Mixtral/Qwen/Deepseek in-memory format, GraniteMoE, JetMoE)
*/
function isRoutedExpertTensor(name: string, info: TensorInfo, numExperts: number): boolean {
if (name.includes("shared_expert")) return false;
if (/\.experts\.(?:expert_)?\d+\./.test(name)) return true;
if (/\.experts\.[A-Za-z_][\w]*(?:\.(?:weight|bias))?$/.test(name) && info.shape[0] === numExperts) return true;
return false;
}

function computeMoeInfoFromHeaders(
headers: Iterable<SafetensorsFileHeader>,
config: ModelConfig | null,
): MoeInfo | undefined {
const moeCfg = getMoeConfig(config);
if (!moeCfg) return undefined;

let total = 0;
let routedExpert = 0;
let hasSharedExpert = false;

for (const header of headers) {
for (const [name, value] of Object.entries(header)) {
if (name === "__metadata__") continue;
const info = value as TensorInfo;
if (info.shape.length === 0) continue;
const n = info.shape.reduce((a, b) => a * b, 1);
if (!Number.isFinite(n)) continue;
total += n;
if (isRoutedExpertTensor(name, info, moeCfg.numExperts)) routedExpert += n;
else if (name.includes("shared_expert")) hasSharedExpert = true;
}
}

if (routedExpert === 0) return undefined; // config says MoE but tensors don't look like one — bail safely

const perExpert = routedExpert / moeCfg.numExperts;
const alwaysActive = total - routedExpert;
const active = alwaysActive + moeCfg.topK * perExpert;

return {
numExperts: moeCfg.numExperts,
topK: moeCfg.topK,
perExpert,
alwaysActive,
active,
hasSharedExpert,
};
}

function computeNumOfParamsByDtypeSingleFile(
header: SafetensorsFileHeader,
quantConfig?: QuantizationConfig,
Expand Down
Loading