@@ -1390,6 +1390,92 @@ def test_projection_initialization(self):
13901390 self .assertTrue (hasattr (mla_layer , "kv_norm" ), "MLA should have 'kv_norm' projection." )
13911391 self .assertTrue (hasattr (mla_layer , "out" ), "MLA should have 'out' projection." )
13921392
1393+ def test_fused_mla_lora_proj_output_equivalence (self ):
1394+ """Tests that fused_mla_lora_proj=True produces identical outputs to fused_mla_lora_proj=False."""
1395+ extra_args = get_decoupled_parallelism_overrides ()
1396+
1397+ # Initialize the unfused model.
1398+ unfused_args = {** self .config_arguments , "fused_mla_lora_proj" : False , ** extra_args }
1399+ cfg_unfused = pyconfig .initialize ([sys .argv [0 ], get_test_config_path ()], ** unfused_args )
1400+ devices_array = maxtext_utils .create_device_mesh (cfg_unfused )
1401+ mesh = Mesh (devices_array , cfg_unfused .mesh_axes )
1402+ dummy_q = jnp .ones ((cfg_unfused .global_batch_size_to_train_on , cfg_unfused .max_target_length , cfg_unfused .base_emb_dim ))
1403+ mla_unfused = MLA (
1404+ config = cfg_unfused ,
1405+ num_query_heads = cfg_unfused .num_query_heads ,
1406+ num_kv_heads = cfg_unfused .num_kv_heads ,
1407+ head_dim = cfg_unfused .head_dim ,
1408+ inputs_q_shape = dummy_q .shape ,
1409+ inputs_kv_shape = dummy_q .shape ,
1410+ max_target_length = cfg_unfused .max_target_length ,
1411+ max_prefill_predict_length = cfg_unfused .max_prefill_predict_length ,
1412+ mesh = mesh ,
1413+ attention_kernel = "dot_product" ,
1414+ dtype = cfg_unfused .dtype ,
1415+ dropout_rate = cfg_unfused .dropout_rate ,
1416+ attention_type = cfg_unfused .attention_type ,
1417+ q_lora_rank = cfg_unfused .q_lora_rank ,
1418+ kv_lora_rank = cfg_unfused .kv_lora_rank ,
1419+ qk_nope_head_dim = cfg_unfused .qk_nope_head_dim ,
1420+ qk_rope_head_dim = cfg_unfused .qk_rope_head_dim ,
1421+ v_head_dim = cfg_unfused .v_head_dim ,
1422+ model_mode = MODEL_MODE_TRAIN ,
1423+ rngs = nnx .Rngs (params = 0 , dropout = jax .random .PRNGKey (42 )),
1424+ )
1425+
1426+ # Initialize the fused model.
1427+ fused_args = {** self .config_arguments , "fused_mla_lora_proj" : True , ** extra_args }
1428+ cfg_fused = pyconfig .initialize ([sys .argv [0 ], get_test_config_path ()], ** fused_args )
1429+ mla_fused = MLA (
1430+ config = cfg_fused ,
1431+ num_query_heads = cfg_fused .num_query_heads ,
1432+ num_kv_heads = cfg_fused .num_kv_heads ,
1433+ head_dim = cfg_fused .head_dim ,
1434+ inputs_q_shape = dummy_q .shape ,
1435+ inputs_kv_shape = dummy_q .shape ,
1436+ max_target_length = cfg_fused .max_target_length ,
1437+ max_prefill_predict_length = cfg_fused .max_prefill_predict_length ,
1438+ mesh = mesh ,
1439+ attention_kernel = "dot_product" ,
1440+ dtype = cfg_fused .dtype ,
1441+ dropout_rate = cfg_fused .dropout_rate ,
1442+ attention_type = cfg_fused .attention_type ,
1443+ q_lora_rank = cfg_fused .q_lora_rank ,
1444+ kv_lora_rank = cfg_fused .kv_lora_rank ,
1445+ qk_nope_head_dim = cfg_fused .qk_nope_head_dim ,
1446+ qk_rope_head_dim = cfg_fused .qk_rope_head_dim ,
1447+ v_head_dim = cfg_fused .v_head_dim ,
1448+ model_mode = MODEL_MODE_TRAIN ,
1449+ rngs = nnx .Rngs (params = 0 , dropout = jax .random .PRNGKey (42 )),
1450+ )
1451+
1452+ # Make both models mathematically equivalent:
1453+ # fused wq_kv_a = concat(unfused wq_a, unfused wkv_a) along the output axis.
1454+ mla_fused .wq_kv_a .kernel .value = jnp .concatenate (
1455+ [mla_unfused .wq_a .kernel .value , mla_unfused .wkv_a .kernel .value ], axis = - 1
1456+ )
1457+ mla_fused .wq_b .kernel .value = mla_unfused .wq_b .kernel .value
1458+ mla_fused .q_norm .scale .value = mla_unfused .q_norm .scale .value
1459+ mla_fused .wkv_b .kernel .value = mla_unfused .wkv_b .kernel .value
1460+ mla_fused .kv_norm .scale .value = mla_unfused .kv_norm .scale .value
1461+ mla_fused .out .kernel .value = mla_unfused .out .kernel .value
1462+
1463+ # Run both models on the same inputs and verify outputs are identical.
1464+ lnx , decoder_segment_ids , decoder_positions = self .get_data (cfg_unfused , cfg_unfused .dtype )
1465+ common_kwargs = dict (
1466+ decoder_segment_ids = decoder_segment_ids ,
1467+ inputs_positions = decoder_positions ,
1468+ deterministic = True ,
1469+ model_mode = MODEL_MODE_TRAIN ,
1470+ )
1471+ output_unfused , _ = mla_unfused (lnx , lnx , ** common_kwargs )
1472+ output_fused , _ = mla_fused (lnx , lnx , ** common_kwargs )
1473+
1474+ self .assertTrue (
1475+ jax .numpy .allclose (output_unfused , output_fused , rtol = 1e-05 , atol = 1e-05 , equal_nan = False ),
1476+ "fused_mla_lora_proj=True and fused_mla_lora_proj=False produced different outputs." ,
1477+ )
1478+
13931479 @parameterized .named_parameters (
13941480 {
13951481 "testcase_name" : "cp_no_load_balance" ,
0 commit comments