@@ -454,92 +454,118 @@ shard_mode: "auto" # can be either auto or explicit
454454custom_mesh_and_rule : " " # replace default mesh and logical rule by specifying yml name under config/mesh_and_rule/.
455455mesh_axes : ['diloco', 'data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive']
456456logical_axis_rules : [
457- ['activation_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
458- ['activation_batch_moe', ['data', 'fsdp', 'fsdp_transpose']],
457+ # ==========================================
458+ # Vocabulary Embedding
459+ # ==========================================
460+ # Vocab Activations
459461 ['activation_embed_and_logits_batch', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'expert']],
460462 ['activation_embed_and_logits_batch_sequence', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']],
461- ['activation_heads', ['tensor', 'tensor_transpose', 'sequence','tensor_sequence','autoregressive']],
462- ['activation_kv_heads', ['tensor', 'tensor_transpose', 'sequence','tensor_sequence']],
463- ['activation_length', ['sequence', 'context']],
464- ['activation_length', ['context']],
463+ ['activation_vocab', ['tensor', 'tensor_transpose', 'tensor_sequence']],
464+ ['activation_vocab', ['tensor', 'tensor_transpose']],
465+ ['activation_vocab', 'tensor_sequence'],
466+ ['activation_vocab', ['sequence', 'context']],
467+ # Vocab Weights
468+ ['vocab', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']],
469+ ['embed_vocab', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']],
470+ # ==========================================
471+ # Attention
472+ # ==========================================
473+ # Attention Activations
474+ ['activation_heads', ['tensor', 'tensor_transpose', 'sequence', 'tensor_sequence', 'autoregressive']],
475+ ['activation_kv_heads', ['tensor', 'tensor_transpose', 'sequence', 'tensor_sequence']],
465476 ['activation_attn_length', ['sequence', 'context']],
466477 ['activation_attn_length', ['context']],
467- ['activation_length_moe', ['sequence', 'context']],
468- ['activation_length_moe', ['context']],
469- ['activation_norm_length', ['tensor_sequence', 'context', 'sequence']],
470- ['activation_norm_length_moe', ['tensor_sequence', 'context', 'sequence']],
471478 ['activation_q_length', ['context']],
472- ['prefill_activation_length', ['sequence', 'context']],
473- ['prefill_activation_norm_length', ['tensor_sequence', 'context', 'sequence']],
474479 ['activation_kv_length', []],
475480 ['activation_attn_embed', ['tensor', 'tensor_transpose']],
476- ['activation_embed', ['tensor', 'tensor_transpose']],
477- ['activation_embed_moe', ['tensor', 'tensor_transpose']],
478- ['activation_mlp', ['tensor', 'tensor_transpose', 'tensor_sequence']],
479- ['activation_mlp_moe', ['tensor', 'tensor_transpose', 'tensor_sequence']],
480481 ['activation_kv', ['tensor', 'tensor_transpose', 'tensor_sequence']],
481- ['activation_prefill_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
482482 ['activation_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
483483 ['activation_kv_head_dim', ['tensor', 'tensor_transpose', 'tensor_sequence']],
484- ['activation_vocab', ['tensor', 'tensor_transpose', 'tensor_sequence']],
485- ['activation_vocab', ['tensor', 'tensor_transpose']],
486- ['activation_vocab', 'tensor_sequence'],
487- ['activation_vocab', ['sequence','context']],
488- ['activation_stage', 'stage'],
489- ['activation_exp', ['expert']],
490- ['decode_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
491- ['decode_batch_moe', ['data', 'fsdp', 'fsdp_transpose']],
492- ['decode_length', ['sequence']],
493- ['mlp', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']],
494- ['mlp_moe', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']],
495- ['mlp_no_fsdp', ['tensor', 'tensor_sequence', 'autoregressive']],
496- ['vocab', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']],
484+ # Attention Weights
497485 ['heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']],
498486 ['q_heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']],
499487 ['kv_heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']],
500- ['embed', ['fsdp', 'fsdp_transpose', 'sequence', 'tensor_transpose', 'context', 'expert']],
501- ['embed', ['fsdp', 'sequence', 'tensor_transpose', 'context' , 'expert']],
502- ['embed', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']],
503- ['embed', ['fsdp', 'sequence', 'context', 'expert']],
504- ['embed_moe', ['fsdp', 'fsdp_transpose', 'sequence', 'tensor_transpose', 'context']],
505- ['embed_moe', ['fsdp', 'sequence', 'tensor_transpose', 'context']],
506- ['embed_moe', ['fsdp', 'fsdp_transpose', 'sequence', 'context']],
507- ['embed_moe', ['fsdp', 'sequence', 'context']],
508- ['embed_tensor_transpose', ['tensor_transpose']],
488+ ['qkv', []],
489+ ['kv', []],
490+ ['kv_head_dim', []],
509491 ['q_lora', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'tensor_transpose', 'expert']],
510492 ['q_lora', ['fsdp', 'sequence', 'context', 'tensor_transpose', 'expert']],
511493 ['q_lora', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']],
512494 ['q_lora', ['fsdp', 'sequence', 'context', 'expert']],
513- ["q_lora_up_proj",[]],
495+ ["q_lora_up_proj", []],
514496 ['kv_lora', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'tensor_transpose', 'expert']],
515497 ['kv_lora', ['fsdp', 'sequence', 'context', 'tensor_transpose', 'expert']],
516498 ['kv_lora', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']],
517499 ['kv_lora', ['fsdp', 'sequence', 'context', 'expert']],
518- ["kv_lora_up_proj",[]],
500+ ["kv_lora_up_proj", []],
501+ # ==========================================
502+ # Mixture of Experts (MoE)
503+ # ==========================================
504+ # MoE Activations
505+ ['activation_batch_moe', ['data', 'fsdp', 'fsdp_transpose']],
506+ ['activation_length_moe', ['sequence', 'context']],
507+ ['activation_length_moe', ['context']],
508+ ['activation_norm_length_moe', ['tensor_sequence', 'context', 'sequence']],
509+ ['activation_embed_moe', ['tensor', 'tensor_transpose']],
510+ ['activation_mlp_moe', ['tensor', 'tensor_transpose', 'tensor_sequence']],
511+ ['activation_exp', ['expert']],
512+ # MoE Weights
513+ ['exp', 'expert'],
514+ ['mlp_moe', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']],
515+ ['embed_moe', ['fsdp', 'fsdp_transpose', 'sequence', 'tensor_transpose', 'context']],
516+ ['embed_moe', ['fsdp', 'sequence', 'tensor_transpose', 'context']],
517+ ['embed_moe', ['fsdp', 'fsdp_transpose', 'sequence', 'context']],
518+ ['embed_moe', ['fsdp', 'sequence', 'context']],
519+ # ==========================================
520+ # Standard MLP / Dense Layers / Model Structure
521+ # ==========================================
522+ # Dense Activations
523+ ['activation_mlp', ['tensor', 'tensor_transpose', 'tensor_sequence']],
524+ ['activation_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
525+ ['activation_length', ['sequence', 'context']],
526+ ['activation_length', ['context']],
527+ ['activation_norm_length', ['tensor_sequence', 'context', 'sequence']],
528+ ['activation_embed', ['tensor', 'tensor_transpose']],
529+ ['activation_stage', 'stage'],
530+ # General Weights
531+ ['mlp', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']],
532+ ['embed', ['fsdp', 'fsdp_transpose', 'sequence', 'tensor_transpose', 'context', 'expert']],
533+ ['embed', ['fsdp', 'sequence', 'tensor_transpose', 'context', 'expert']],
534+ ['embed', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']],
535+ ['embed', ['fsdp', 'sequence', 'context', 'expert']],
519536 ['norm', ['tensor', 'tensor_transpose']],
520537 ['layers', 'stage'],
521- ['qkv', []],
522- ['kv', []],
523- ['kv_head_dim', []],
538+ ['diloco', 'diloco'],
539+ ['engram_dim', ['tensor']],
540+ ['dense_layers', []],
541+ ['moe_layers', []],
542+ ['mhc', []],
543+ # ==========================================
544+ # Inference(Prefill, Decode, Cache)
545+ # ==========================================
546+ ['prefill_activation_length', ['sequence', 'context']],
547+ ['prefill_activation_norm_length', ['tensor_sequence', 'context', 'sequence']],
548+ ['activation_prefill_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
549+ ['decode_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
550+ ['decode_length', ['sequence']],
551+ ['cache_heads', ['autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence']],
552+ ['cache_heads', ['autoregressive', 'tensor', 'tensor_sequence']],
553+ ['paged_kv_heads', ['tensor']],
524554 ['cache_batch_prefill', []],
525555 ['cache_batch', []],
526556 ['cache_heads_none', []],
527- ['cache_heads', ['autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence']],
528- ['cache_heads', ['autoregressive', 'tensor', 'tensor_sequence']],
529557 ['cache_kv', []],
530558 ['cache_sequence', []],
531- ['exp', 'expert'],
532- ['exp_with_fsdp', 'fsdp'],
533- ['paged_kv_heads', ['tensor']],
534559 ['num_pages', []],
535560 ['tokens_per_page', []],
536561 ['paged_kv_head_dim_size', []],
537- ['dense_layers', []],
538- ['moe_layers', []],
539- ['engram_dim', ['tensor']],
540- ['mhc', []],
541- ['diloco', 'diloco'],
542- ]
562+ # ==========================================
563+ # Deprecated / Scheduled for Removal
564+ # ==========================================
565+ ['mlp_no_fsdp', ['tensor', 'tensor_sequence', 'autoregressive']],
566+ ['embed_tensor_transpose', ['tensor_transpose']],
567+ ['exp_with_fsdp', 'fsdp'],
568+ ]
543569# Axes used for DCN must be earlier in this list than ICI, see (b/339009148) for details
544570data_sharding : [['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive']]
545571input_data_sharding_logical_axes : ['activation_embed_and_logits_batch', 'activation_norm_length']
0 commit comments