Skip to content

Commit 89a0056

Browse files
committed
Add unit test for fused_mla_lora_proj output equivalence
Verifies that MLA with fused_mla_lora_proj=True produces numerically identical outputs to fused_mla_lora_proj=False when the fused weight (wq_kv_a) is set to the concatenation of the unfused weights (wq_a, wkv_a).
1 parent be3169b commit 89a0056

1 file changed

Lines changed: 86 additions & 0 deletions

File tree

tests/unit/attention_test.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1390,6 +1390,92 @@ def test_projection_initialization(self):
13901390
self.assertTrue(hasattr(mla_layer, "kv_norm"), "MLA should have 'kv_norm' projection.")
13911391
self.assertTrue(hasattr(mla_layer, "out"), "MLA should have 'out' projection.")
13921392

1393+
def test_fused_mla_lora_proj_output_equivalence(self):
1394+
"""Tests that fused_mla_lora_proj=True produces identical outputs to fused_mla_lora_proj=False."""
1395+
extra_args = get_decoupled_parallelism_overrides()
1396+
1397+
# Initialize the unfused model.
1398+
unfused_args = {**self.config_arguments, "fused_mla_lora_proj": False, **extra_args}
1399+
cfg_unfused = pyconfig.initialize([sys.argv[0], get_test_config_path()], **unfused_args)
1400+
devices_array = maxtext_utils.create_device_mesh(cfg_unfused)
1401+
mesh = Mesh(devices_array, cfg_unfused.mesh_axes)
1402+
dummy_q = jnp.ones((cfg_unfused.global_batch_size_to_train_on, cfg_unfused.max_target_length, cfg_unfused.base_emb_dim))
1403+
mla_unfused = MLA(
1404+
config=cfg_unfused,
1405+
num_query_heads=cfg_unfused.num_query_heads,
1406+
num_kv_heads=cfg_unfused.num_kv_heads,
1407+
head_dim=cfg_unfused.head_dim,
1408+
inputs_q_shape=dummy_q.shape,
1409+
inputs_kv_shape=dummy_q.shape,
1410+
max_target_length=cfg_unfused.max_target_length,
1411+
max_prefill_predict_length=cfg_unfused.max_prefill_predict_length,
1412+
mesh=mesh,
1413+
attention_kernel="dot_product",
1414+
dtype=cfg_unfused.dtype,
1415+
dropout_rate=cfg_unfused.dropout_rate,
1416+
attention_type=cfg_unfused.attention_type,
1417+
q_lora_rank=cfg_unfused.q_lora_rank,
1418+
kv_lora_rank=cfg_unfused.kv_lora_rank,
1419+
qk_nope_head_dim=cfg_unfused.qk_nope_head_dim,
1420+
qk_rope_head_dim=cfg_unfused.qk_rope_head_dim,
1421+
v_head_dim=cfg_unfused.v_head_dim,
1422+
model_mode=MODEL_MODE_TRAIN,
1423+
rngs=nnx.Rngs(params=0, dropout=jax.random.PRNGKey(42)),
1424+
)
1425+
1426+
# Initialize the fused model.
1427+
fused_args = {**self.config_arguments, "fused_mla_lora_proj": True, **extra_args}
1428+
cfg_fused = pyconfig.initialize([sys.argv[0], get_test_config_path()], **fused_args)
1429+
mla_fused = MLA(
1430+
config=cfg_fused,
1431+
num_query_heads=cfg_fused.num_query_heads,
1432+
num_kv_heads=cfg_fused.num_kv_heads,
1433+
head_dim=cfg_fused.head_dim,
1434+
inputs_q_shape=dummy_q.shape,
1435+
inputs_kv_shape=dummy_q.shape,
1436+
max_target_length=cfg_fused.max_target_length,
1437+
max_prefill_predict_length=cfg_fused.max_prefill_predict_length,
1438+
mesh=mesh,
1439+
attention_kernel="dot_product",
1440+
dtype=cfg_fused.dtype,
1441+
dropout_rate=cfg_fused.dropout_rate,
1442+
attention_type=cfg_fused.attention_type,
1443+
q_lora_rank=cfg_fused.q_lora_rank,
1444+
kv_lora_rank=cfg_fused.kv_lora_rank,
1445+
qk_nope_head_dim=cfg_fused.qk_nope_head_dim,
1446+
qk_rope_head_dim=cfg_fused.qk_rope_head_dim,
1447+
v_head_dim=cfg_fused.v_head_dim,
1448+
model_mode=MODEL_MODE_TRAIN,
1449+
rngs=nnx.Rngs(params=0, dropout=jax.random.PRNGKey(42)),
1450+
)
1451+
1452+
# Make both models mathematically equivalent:
1453+
# fused wq_kv_a = concat(unfused wq_a, unfused wkv_a) along the output axis.
1454+
mla_fused.wq_kv_a.kernel.value = jnp.concatenate(
1455+
[mla_unfused.wq_a.kernel.value, mla_unfused.wkv_a.kernel.value], axis=-1
1456+
)
1457+
mla_fused.wq_b.kernel.value = mla_unfused.wq_b.kernel.value
1458+
mla_fused.q_norm.scale.value = mla_unfused.q_norm.scale.value
1459+
mla_fused.wkv_b.kernel.value = mla_unfused.wkv_b.kernel.value
1460+
mla_fused.kv_norm.scale.value = mla_unfused.kv_norm.scale.value
1461+
mla_fused.out.kernel.value = mla_unfused.out.kernel.value
1462+
1463+
# Run both models on the same inputs and verify outputs are identical.
1464+
lnx, decoder_segment_ids, decoder_positions = self.get_data(cfg_unfused, cfg_unfused.dtype)
1465+
common_kwargs = dict(
1466+
decoder_segment_ids=decoder_segment_ids,
1467+
inputs_positions=decoder_positions,
1468+
deterministic=True,
1469+
model_mode=MODEL_MODE_TRAIN,
1470+
)
1471+
output_unfused, _ = mla_unfused(lnx, lnx, **common_kwargs)
1472+
output_fused, _ = mla_fused(lnx, lnx, **common_kwargs)
1473+
1474+
self.assertTrue(
1475+
jax.numpy.allclose(output_unfused, output_fused, rtol=1e-05, atol=1e-05, equal_nan=False),
1476+
"fused_mla_lora_proj=True and fused_mla_lora_proj=False produced different outputs.",
1477+
)
1478+
13931479
@parameterized.named_parameters(
13941480
{
13951481
"testcase_name": "cp_no_load_balance",

0 commit comments

Comments
 (0)