@@ -79,21 +79,21 @@ def calculate_subblock_memory(
7979 Given its configuration and runtime dimensions, returns bytes or a detailed dict.
8080
8181 Args:
82- subblock_config (FFNConfig | AttentionConfig) : Subblock configuration dataclass.
83- batch_size (int) : Batch size for memory estimate.
84- prefill_seq_len (int) : Sequence length for prefill phase.
85- generation_seq_len (int) : Sequence length for generation phase (token-by-token).
86- prefill_queue_size (int) : Token queue size for prefill attention memory allocation.
87- n_embd (int) : Embedding (hidden) dimension.
88- n_head (int) : Number of attention heads (used for non-FFN).
89- weights_dtype (torch.dtype) : PyTorch dtype for model weights.
90- kv_cache_dtype (torch.dtype) : PyTorch dtype for KV cache.
91- allocate_prefill_query (bool) : Whether to allocate query cache for prefill tokens.
92- model_config (PretrainedConfig) : HuggingFace-style config instance describing the model.
93- descriptor (type[ModelDescriptor]) : Model descriptor type (for puzzletron model types).
82+ subblock_config: Subblock configuration dataclass.
83+ batch_size: Batch size for memory estimate.
84+ prefill_seq_len: Sequence length for prefill phase.
85+ generation_seq_len: Sequence length for generation phase (token-by-token).
86+ prefill_queue_size: Token queue size for prefill attention memory allocation.
87+ n_embd: Embedding (hidden) dimension.
88+ n_head: Number of attention heads (used for non-FFN).
89+ weights_dtype: PyTorch dtype for model weights.
90+ kv_cache_dtype: PyTorch dtype for KV cache.
91+ allocate_prefill_query: Whether to allocate query cache for prefill tokens.
92+ model_config: HuggingFace-style config instance describing the model.
93+ descriptor: Model descriptor type (for puzzletron model types).
9494
9595 Returns:
96- float | dict[str, float]: Memory usage in bytes (float), or a dictionary by memory type.
96+ Memory usage in bytes (float), or a dictionary by memory type.
9797 """
9898 if subblock_config .no_op :
9999 return 0
@@ -229,7 +229,7 @@ def calc_subblock_active_params(
229229 block_idx: The index of the block/subblock within the network, used to index into the stats.
230230
231231 Returns:
232- int: The expected number of "active" parameters for the given subblock.
232+ The expected number of "active" parameters for the given subblock.
233233 """
234234 if not (isinstance (sublayer_config , FFNConfig ) and sublayer_config .is_moe ):
235235 return calculate_subblock_params (model_config , sublayer_config , descriptor )
@@ -245,12 +245,12 @@ def load_moe_stats(stats_file: str) -> dict:
245245 It returns the normalized probability distributions over experts for each block, as a list of numpy arrays.
246246
247247 Args:
248- stats_file (str) : Path to the JSON file containing expert routing statistics for each block.
248+ stats_file: Path to the JSON file containing expert routing statistics for each block.
249249
250250 Returns:
251- list[np.ndarray]: A list where each element is a numpy array containing the normalized probability
252- distribution over experts for the corresponding block. If a block's expert list is empty,
253- its entry is 0.
251+ A list where each element is a numpy array containing the normalized probability
252+ distribution over experts for the corresponding block. If a block's expert list is empty,
253+ its entry is 0.
254254 """
255255 with open (stats_file ) as f :
256256 stats = json .load (f )
@@ -271,12 +271,12 @@ def estimate_num_active_experts(
271271 expected number of active (i.e., selected at least once) experts is computed.
272272
273273 Args:
274- dist_over_experts (np.ndarray) : A 1D array of probabilities for each expert.
275- batch_size (int) : The number of samples in the batch.
276- num_experts (int) : The maximum number of experts to consider (fewer if `dist_over_experts` is shorter).
274+ dist_over_experts: A 1D array of probabilities for each expert.
275+ batch_size: The number of samples in the batch.
276+ num_experts: The maximum number of experts to consider (fewer if `dist_over_experts` is shorter).
277277
278278 Returns:
279- int: The expected number of experts selected at least once across the batch.
279+ The expected number of experts selected at least once across the batch.
280280 """
281281 # cut the tail and renormalize
282282 dist_over_experts = np .sort (dist_over_experts )[::- 1 ][:num_experts ]
@@ -296,14 +296,14 @@ def estimate_moe_active_params(
296296 """Estimate the expected number of active (used) parameters for a Mixture-of-Experts (MoE) FFN subblock.
297297
298298 Args:
299- subblock_config (FFNConfig) : The FFNConfig for the MoE subblock (with .moe field configured).
300- n_embd (int) : The embedding dimension (input and output size per expert).
301- moe_stats_file (Path | str) : Path to the JSON file containing routing/selection probabilities for experts.
302- batch_size (int) : Batch size to simulate/extrapolate expected expert use.
303- block_idx (int) : The index of the block/layer whose expert routing statistics should be used.
299+ subblock_config: The FFNConfig for the MoE subblock (with .moe field configured).
300+ n_embd: The embedding dimension (input and output size per expert).
301+ moe_stats_file: Path to the JSON file containing routing/selection probabilities for experts.
302+ batch_size: Batch size to simulate/extrapolate expected expert use.
303+ block_idx: The index of the block/layer whose expert routing statistics should be used.
304304
305305 Returns:
306- int: Estimated number of parameters actively used for the current batch and expert selection statistics.
306+ Estimated number of parameters actively used for the current batch and expert selection statistics.
307307 """
308308 assert Path (moe_stats_file ).exists ()
309309 # if not Path(moe_stats_file).exists(): # if path is not provided, should we assume uniform distribution?
@@ -382,16 +382,15 @@ def calculate_mamba_memory(
382382 """Calculate memory usage (MiB) for a Mamba attention subblock.
383383
384384 Args:
385- attention_config (AttentionConfig): Mamba attention configuration,
386- including Mamba-specific settings.
387- model_config (PretrainedConfig): Model configuration.
388- descriptor (type[ModelDescriptor]): Model descriptor class.
389- batch_size (int): Batch size for memory estimate.
390- weights_dtype (torch.dtype): Data type for model weights.
391- kv_cache_dtype (torch.dtype): Data type for state/kv-cache.
385+ attention_config: Mamba attention configuration, including Mamba-specific settings.
386+ model_config: Model configuration.
387+ descriptor: Model descriptor class.
388+ batch_size: Batch size for memory estimate.
389+ weights_dtype: Data type for model weights.
390+ kv_cache_dtype: Data type for state/kv-cache.
392391
393392 Returns:
394- int: Estimated memory usage in mebibytes (MiB) for the Mamba subblock.
393+ Estimated memory usage in mebibytes (MiB) for the Mamba subblock.
395394 """
396395 assert attention_config .mamba is not None
397396 mamba_config = attention_config .mamba
@@ -409,11 +408,11 @@ def calculate_mamba_state_size(
409408 """Calculate the total state size for a Mamba attention subblock.
410409
411410 Args:
412- mamba_config (MambaConfig) : Configuration object containing Mamba subblock parameters.
413- batch_size (int) : Batch size to estimate the memory/state requirements for.
411+ mamba_config: Configuration object containing Mamba subblock parameters.
412+ batch_size: Batch size to estimate the memory/state requirements for.
414413
415414 Returns:
416- int: Total state size (number of elements) required for the Mamba subblock, including convolution and SSM state.
415+ Total state size (number of elements) required for the Mamba subblock, including convolution and SSM state.
417416 """
418417 _ , _ , conv_dim , kernel_size = _calculate_mamba_intermediates (mamba_config )
419418 conv_state_size = math .prod ((batch_size , conv_dim , kernel_size ))
@@ -443,15 +442,14 @@ def calculate_ffn_memory(
443442 """Estimate the memory usage in MiB of a feed-forward network (FFN) subblock.
444443
445444 Args:
446- ffn_config (FFNConfig): FFN configuration for the block.
447- model_config (PretrainedConfig): The parent model configuration.
448- descriptor (type[ModelDescriptor]): Model descriptor class.
449- weights_dtype (torch.dtype | str): Data type for FFN weights.
450- experts_dtype (torch.dtype | str | None, optional): Data type for expert weights
451- (for MoE layers, if present). Defaults to None.
445+ ffn_config: FFN configuration for the block.
446+ model_config: The parent model configuration.
447+ descriptor: Model descriptor class.
448+ weights_dtype: Data type for FFN weights.
449+ experts_dtype: Data type for expert weights (for MoE layers, if present).
452450
453451 Returns:
454- float: Estimated FFN memory usage in mebibytes (MiB).
452+ Estimated FFN memory usage in mebibytes (MiB).
455453 """
456454 # TODO: How to separate between expert weights and the rest for any model (same as puzzletron).
457455 num_params = calculate_subblock_params (model_config , ffn_config , descriptor )
@@ -463,30 +461,13 @@ def calculate_non_block_memory(
463461 vocab_size : int ,
464462 weight_dtype : torch .dtype ,
465463) -> float :
466- """Estimate the memory usage in MiB of non-subblock components (e.g., embeddings, output projection).
467-
468- Args:
469- n_embd (int): Embedding dimension (hidden size).
470- vocab_size (int): Vocabulary size.
471- weight_dtype (torch.dtype): Data type for model weights.
472-
473- Returns:
474- float: Estimated non-subblock memory usage in mebibytes (MiB).
475- """
464+ """Estimate the memory usage in MiB of non-subblock components (e.g., embeddings, output projection)."""
476465 return calculate_non_block_params (n_embd , vocab_size ) * sizeof_dtype (weight_dtype ) / 2 ** 20
477466
478467
479468def calculate_non_block_params (
480469 n_embd : int ,
481470 vocab_size : int ,
482471) -> int :
483- """Calculate the number of parameters for non-subblock components (e.g., embeddings, output projection).
484-
485- Args:
486- n_embd (int): Embedding dimension (hidden size).
487- vocab_size (int): Vocabulary size.
488-
489- Returns:
490- int: Estimated non-subblock parameter count.
491- """
472+ """Calculate the number of parameters for non-subblock components (e.g., embeddings, output projection)."""
492473 return vocab_size * n_embd * 2 + n_embd
0 commit comments