@@ -445,28 +445,27 @@ compile_xla_flags: "" # Compiler options e.g. compile_xla_flags="--xla_tpu_num_s
445445# Parallelism
446446shard_mode : " auto" # can be either auto or explicit
447447custom_mesh_and_rule : " " # replace default mesh and logical rule by specifying yml name under config/mesh_and_rule/.
448- mesh_axes : ['diloco', 'data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', ' context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive']
448+ mesh_axes : ['diloco', 'data', 'stage', 'fsdp', 'fsdp_transpose', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive']
449449logical_axis_rules : [
450450 # ==========================================
451451 # Vocabulary Embedding
452452 # ==========================================
453453 # Vocab Activations
454454 ['activation_embed_and_logits_batch', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'expert']],
455- ['activation_embed_and_logits_batch_sequence', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', ' context', 'expert']],
455+ ['activation_embed_and_logits_batch_sequence', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'context', 'expert']],
456456 ['activation_vocab', ['tensor', 'tensor_transpose', 'tensor_sequence']],
457457 ['activation_vocab', ['tensor', 'tensor_transpose']],
458458 ['activation_vocab', 'tensor_sequence'],
459- ['activation_vocab', ['sequence', 'context']],
460459 # Vocab Weights
461460 ['vocab', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']],
462- ['embed_vocab', ['fsdp', 'fsdp_transpose', 'sequence', ' context', 'expert']],
461+ ['embed_vocab', ['fsdp', 'fsdp_transpose', 'context', 'expert']],
463462 # ==========================================
464463 # Attention
465464 # ==========================================
466465 # Attention Activations
467- ['activation_heads', ['tensor', 'tensor_transpose', 'sequence', ' tensor_sequence', 'autoregressive']],
468- ['activation_kv_heads', ['tensor', 'tensor_transpose', 'sequence', ' tensor_sequence']],
469- ['activation_attn_length', ['sequence', ' context']],
466+ ['activation_heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']],
467+ ['activation_kv_heads', ['tensor', 'tensor_transpose', 'tensor_sequence']],
468+ ['activation_attn_length', ['context']],
470469 ['activation_attn_length', ['context']],
471470 ['activation_q_length', ['context']],
472471 ['activation_kv_length', []],
@@ -481,52 +480,52 @@ logical_axis_rules: [
481480 ['qkv', []],
482481 ['kv', []],
483482 ['kv_head_dim', []],
484- ['q_lora', ['fsdp', 'fsdp_transpose', 'sequence', ' context', 'tensor_transpose', 'expert']],
485- ['q_lora', ['fsdp', 'sequence', ' context', 'tensor_transpose', 'expert']],
486- ['q_lora', ['fsdp', 'fsdp_transpose', 'sequence', ' context', 'expert']],
487- ['q_lora', ['fsdp', 'sequence', ' context', 'expert']],
483+ ['q_lora', ['fsdp', 'fsdp_transpose', 'context', 'tensor_transpose', 'expert']],
484+ ['q_lora', ['fsdp', 'context', 'tensor_transpose', 'expert']],
485+ ['q_lora', ['fsdp', 'fsdp_transpose', 'context', 'expert']],
486+ ['q_lora', ['fsdp', 'context', 'expert']],
488487 ["q_lora_up_proj", []],
489- ['kv_lora', ['fsdp', 'fsdp_transpose', 'sequence', ' context', 'tensor_transpose', 'expert']],
490- ['kv_lora', ['fsdp', 'sequence', ' context', 'tensor_transpose', 'expert']],
491- ['kv_lora', ['fsdp', 'fsdp_transpose', 'sequence', ' context', 'expert']],
492- ['kv_lora', ['fsdp', 'sequence', ' context', 'expert']],
488+ ['kv_lora', ['fsdp', 'fsdp_transpose', 'context', 'tensor_transpose', 'expert']],
489+ ['kv_lora', ['fsdp', 'context', 'tensor_transpose', 'expert']],
490+ ['kv_lora', ['fsdp', 'fsdp_transpose', 'context', 'expert']],
491+ ['kv_lora', ['fsdp', 'context', 'expert']],
493492 ["kv_lora_up_proj", []],
494493 # ==========================================
495494 # Mixture of Experts (MoE)
496495 # ==========================================
497496 # MoE Activations
498497 ['activation_batch_moe', ['data', 'fsdp', 'fsdp_transpose']],
499- ['activation_length_moe', ['sequence', 'context']],
500498 ['activation_length_moe', ['context']],
501- ['activation_norm_length_moe', ['tensor_sequence', 'context', 'sequence']],
499+ ['activation_length_moe', ['context']],
500+ ['activation_norm_length_moe', ['tensor_sequence', 'context']],
502501 ['activation_embed_moe', ['tensor', 'tensor_transpose']],
503502 ['activation_mlp_moe', ['tensor', 'tensor_transpose', 'tensor_sequence']],
504503 ['activation_exp', ['expert']],
505504 # MoE Weights
506505 ['exp', 'expert'],
507506 ['mlp_moe', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']],
508- ['embed_moe', ['fsdp', 'fsdp_transpose', 'sequence', ' tensor_transpose', 'context']],
509- ['embed_moe', ['fsdp', 'sequence', ' tensor_transpose', 'context']],
510- ['embed_moe', ['fsdp', 'fsdp_transpose', 'sequence', ' context']],
511- ['embed_moe', ['fsdp', 'sequence', ' context']],
507+ ['embed_moe', ['fsdp', 'fsdp_transpose', 'tensor_transpose', 'context']],
508+ ['embed_moe', ['fsdp', 'tensor_transpose', 'context']],
509+ ['embed_moe', ['fsdp', 'fsdp_transpose', 'context']],
510+ ['embed_moe', ['fsdp', 'context']],
512511 # ==========================================
513512 # Standard MLP / Dense Layers / Model Structure
514513 # ==========================================
515514 # Dense Activations
516515 ['activation_mlp', ['tensor', 'tensor_transpose', 'tensor_sequence']],
517516 # Note activation batch and length also get used in attention and vocab
518517 ['activation_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
519- ['activation_length', ['sequence', 'context']],
520518 ['activation_length', ['context']],
521- ['activation_norm_length', ['tensor_sequence', 'context', 'sequence']],
519+ ['activation_length', ['context']],
520+ ['activation_norm_length', ['tensor_sequence', 'context']],
522521 ['activation_embed', ['tensor', 'tensor_transpose']],
523522 ['activation_stage', 'stage'],
524523 # General Weights
525524 ['mlp', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']],
526- ['embed', ['fsdp', 'fsdp_transpose', 'sequence', ' tensor_transpose', 'context', 'expert']],
527- ['embed', ['fsdp', 'sequence', ' tensor_transpose', 'context', 'expert']],
528- ['embed', ['fsdp', 'fsdp_transpose', 'sequence', ' context', 'expert']],
529- ['embed', ['fsdp', 'sequence', ' context', 'expert']],
525+ ['embed', ['fsdp', 'fsdp_transpose', 'tensor_transpose', 'context', 'expert']],
526+ ['embed', ['fsdp', 'tensor_transpose', 'context', 'expert']],
527+ ['embed', ['fsdp', 'fsdp_transpose', 'context', 'expert']],
528+ ['embed', ['fsdp', 'context', 'expert']],
530529 ['norm', ['tensor', 'tensor_transpose']],
531530 ['layers', 'stage'],
532531 ['diloco', 'diloco'],
@@ -537,11 +536,11 @@ logical_axis_rules: [
537536 # ==========================================
538537 # Inference(Prefill, Decode, Cache)
539538 # ==========================================
540- ['prefill_activation_length', ['sequence', ' context']],
541- ['prefill_activation_norm_length', ['tensor_sequence', 'context', 'sequence' ]],
539+ ['prefill_activation_length', ['context']],
540+ ['prefill_activation_norm_length', ['tensor_sequence', 'context']],
542541 ['activation_prefill_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
543542 ['decode_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
544- ['decode_length', ['sequence' ]],
543+ ['decode_length', []],
545544 ['cache_heads', ['autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence']],
546545 ['cache_heads', ['autoregressive', 'tensor', 'tensor_sequence']],
547546 ['paged_kv_heads', ['tensor']],
@@ -561,7 +560,7 @@ logical_axis_rules: [
561560 ['exp_with_fsdp', 'fsdp'],
562561 ]
563562# Axes used for DCN must be earlier in this list than ICI, see (b/339009148) for details
564- data_sharding : [['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', ' context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive']]
563+ data_sharding : [['data', 'stage', 'fsdp', 'fsdp_transpose', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive']]
565564input_data_sharding_logical_axes : ['activation_embed_and_logits_batch', 'activation_norm_length']
566565# Determines which physical axis plays the role of context parallelism for input data processing and load balancing
567566# only supports "context" or "expert" (when custom_mesh_and_rule=ep-as-cp)
0 commit comments