@@ -445,29 +445,27 @@ compile_xla_flags: "" # Compiler options e.g. compile_xla_flags="--xla_tpu_num_s
445445# Parallelism
446446shard_mode : " auto" # can be either auto or explicit
447447custom_mesh_and_rule : " " # replace default mesh and logical rule by specifying yml name under config/mesh_and_rule/.
448- mesh_axes : ['diloco', 'data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', ' context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive']
448+ mesh_axes : ['diloco', 'data', 'stage', 'fsdp', 'fsdp_transpose', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive']
449449logical_axis_rules : [
450450 # ==========================================
451451 # Vocabulary Embedding
452452 # ==========================================
453453 # Vocab Activations
454454 ['activation_embed_and_logits_batch', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'expert']],
455- ['activation_embed_and_logits_batch_sequence', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', ' context', 'expert']],
455+ ['activation_embed_and_logits_batch_sequence', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'context', 'expert']],
456456 ['activation_vocab', ['tensor', 'tensor_transpose', 'tensor_sequence']],
457457 ['activation_vocab', ['tensor', 'tensor_transpose']],
458458 ['activation_vocab', 'tensor_sequence'],
459- ['activation_vocab', ['sequence', 'context']],
460459 # Vocab Weights
461460 ['vocab', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']],
462- ['embed_vocab', ['fsdp', 'fsdp_transpose', 'sequence', ' context', 'expert']],
461+ ['embed_vocab', ['fsdp', 'fsdp_transpose', 'context', 'expert']],
463462 # ==========================================
464463 # Attention
465464 # ==========================================
466465 # Attention Activations
467466 ['activation_batch_attn', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
468- ['activation_heads', ['tensor', 'tensor_transpose', 'sequence', 'tensor_sequence', 'autoregressive']],
469- ['activation_kv_heads', ['tensor', 'tensor_transpose', 'sequence', 'tensor_sequence']],
470- ['activation_length_attn', ['sequence', 'context']],
467+ ['activation_heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']],
468+ ['activation_kv_heads', ['tensor', 'tensor_transpose', 'tensor_sequence']],
471469 ['activation_length_attn', ['context']],
472470 ['activation_q_length', ['context']],
473471 ['activation_kv_length', []],
@@ -482,52 +480,50 @@ logical_axis_rules: [
482480 ['qkv', []],
483481 ['kv', []],
484482 ['kv_head_dim', []],
485- ['q_lora', ['fsdp', 'fsdp_transpose', 'sequence', ' context', 'tensor_transpose', 'expert']],
486- ['q_lora', ['fsdp', 'sequence', ' context', 'tensor_transpose', 'expert']],
487- ['q_lora', ['fsdp', 'fsdp_transpose', 'sequence', ' context', 'expert']],
488- ['q_lora', ['fsdp', 'sequence', ' context', 'expert']],
483+ ['q_lora', ['fsdp', 'fsdp_transpose', 'context', 'tensor_transpose', 'expert']],
484+ ['q_lora', ['fsdp', 'context', 'tensor_transpose', 'expert']],
485+ ['q_lora', ['fsdp', 'fsdp_transpose', 'context', 'expert']],
486+ ['q_lora', ['fsdp', 'context', 'expert']],
489487 ["q_lora_up_proj", []],
490- ['kv_lora', ['fsdp', 'fsdp_transpose', 'sequence', ' context', 'tensor_transpose', 'expert']],
491- ['kv_lora', ['fsdp', 'sequence', ' context', 'tensor_transpose', 'expert']],
492- ['kv_lora', ['fsdp', 'fsdp_transpose', 'sequence', ' context', 'expert']],
493- ['kv_lora', ['fsdp', 'sequence', ' context', 'expert']],
488+ ['kv_lora', ['fsdp', 'fsdp_transpose', 'context', 'tensor_transpose', 'expert']],
489+ ['kv_lora', ['fsdp', 'context', 'tensor_transpose', 'expert']],
490+ ['kv_lora', ['fsdp', 'fsdp_transpose', 'context', 'expert']],
491+ ['kv_lora', ['fsdp', 'context', 'expert']],
494492 ["kv_lora_up_proj", []],
495493 # ==========================================
496494 # Mixture of Experts (MoE)
497495 # ==========================================
498496 # MoE Activations
499497 ['activation_batch_moe', ['data', 'fsdp', 'fsdp_transpose']],
500- ['activation_length_moe', ['sequence', 'context']],
501498 ['activation_length_moe', ['context']],
502- ['activation_norm_length_moe', ['tensor_sequence', 'context', 'sequence' ]],
499+ ['activation_norm_length_moe', ['tensor_sequence', 'context']],
503500 ['activation_embed_moe', ['tensor', 'tensor_transpose']],
504501 ['activation_mlp_moe', ['tensor', 'tensor_transpose', 'tensor_sequence']],
505502 ['activation_exp', ['expert']],
506503 # MoE Weights
507504 ['exp', 'expert'],
508505 ['mlp_moe', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']],
509- ['embed_moe', ['fsdp', 'fsdp_transpose', 'sequence', ' tensor_transpose', 'context']],
510- ['embed_moe', ['fsdp', 'sequence', ' tensor_transpose', 'context']],
511- ['embed_moe', ['fsdp', 'fsdp_transpose', 'sequence', ' context']],
512- ['embed_moe', ['fsdp', 'sequence', ' context']],
506+ ['embed_moe', ['fsdp', 'fsdp_transpose', 'tensor_transpose', 'context']],
507+ ['embed_moe', ['fsdp', 'tensor_transpose', 'context']],
508+ ['embed_moe', ['fsdp', 'fsdp_transpose', 'context']],
509+ ['embed_moe', ['fsdp', 'context']],
513510 # ==========================================
514511 # Standard MLP / Dense Layers / Model Structure
515512 # ==========================================
516513 # Dense Activations
517514 ['activation_mlp', ['tensor', 'tensor_transpose', 'tensor_sequence']],
518515 # Note activation batch and length also get used in vocab
519516 ['activation_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
520- ['activation_length', ['sequence', 'context']],
521517 ['activation_length', ['context']],
522- ['activation_norm_length', ['tensor_sequence', 'context', 'sequence' ]],
518+ ['activation_norm_length', ['tensor_sequence', 'context']],
523519 ['activation_embed', ['tensor', 'tensor_transpose']],
524520 ['activation_stage', 'stage'],
525521 # General Weights
526522 ['mlp', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']],
527- ['embed', ['fsdp', 'fsdp_transpose', 'sequence', ' tensor_transpose', 'context', 'expert']],
528- ['embed', ['fsdp', 'sequence', ' tensor_transpose', 'context', 'expert']],
529- ['embed', ['fsdp', 'fsdp_transpose', 'sequence', ' context', 'expert']],
530- ['embed', ['fsdp', 'sequence', ' context', 'expert']],
523+ ['embed', ['fsdp', 'fsdp_transpose', 'tensor_transpose', 'context', 'expert']],
524+ ['embed', ['fsdp', 'tensor_transpose', 'context', 'expert']],
525+ ['embed', ['fsdp', 'fsdp_transpose', 'context', 'expert']],
526+ ['embed', ['fsdp', 'context', 'expert']],
531527 ['norm', ['tensor', 'tensor_transpose']],
532528 ['layers', 'stage'],
533529 ['diloco', 'diloco'],
@@ -538,11 +534,11 @@ logical_axis_rules: [
538534 # ==========================================
539535 # Inference(Prefill, Decode, Cache)
540536 # ==========================================
541- ['prefill_activation_length', ['sequence', ' context']],
542- ['prefill_activation_norm_length', ['tensor_sequence', 'context', 'sequence' ]],
537+ ['prefill_activation_length', ['context']],
538+ ['prefill_activation_norm_length', ['tensor_sequence', 'context']],
543539 ['activation_prefill_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
544540 ['decode_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
545- ['decode_length', ['sequence' ]],
541+ ['decode_length', []],
546542 ['cache_heads', ['autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence']],
547543 ['cache_heads', ['autoregressive', 'tensor', 'tensor_sequence']],
548544 ['paged_kv_heads', ['tensor']],
@@ -562,7 +558,7 @@ logical_axis_rules: [
562558 ['exp_with_fsdp', 'fsdp'],
563559 ]
564560# Axes used for DCN must be earlier in this list than ICI, see (b/339009148) for details
565- data_sharding : [['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', ' context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive']]
561+ data_sharding : [['data', 'stage', 'fsdp', 'fsdp_transpose', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive']]
566562input_data_sharding_logical_axes : ['activation_embed_and_logits_batch', 'activation_norm_length']
567563# Determines which physical axis plays the role of context parallelism for input data processing and load balancing
568564# only supports "context" or "expert" (when custom_mesh_and_rule=ep-as-cp)
0 commit comments