Skip to content

Commit 6ec4342

Browse files
committed
[bridge] fix: Update unit test for permute_dims on all ranks
Update test_transpose_non_rank_zero_hf_to_megatron to expect permuted tensors on non-rank-0, matching the behavior change where permute_dims is applied on all ranks for correct ReplicatedMapping shape handling. Signed-off-by: kebo01 <kebo01@baidu.com>
1 parent 4947203 commit 6ec4342

1 file changed

Lines changed: 5 additions & 4 deletions

File tree

tests/unit_tests/models/test_param_mapping.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -785,17 +785,18 @@ def test_transpose_non_rank_zero_hf_to_megatron(self, mock_distributed_env, tran
785785
mapping = AutoMapping("transpose.weight", "hf.weight", permute_dims=(1, 0))
786786

787787
hf_weight = torch.randn(4, 8)
788-
megatron_module = MockModule(transformer_config, weight_shape=(4, 4))
788+
megatron_module = MockModule(transformer_config, weight_shape=(8, 4))
789789

790790
with patch.object(mapping, "_mapping") as mock_delegate:
791-
mock_delegate.hf_to_megatron.return_value = torch.randn(4, 4)
791+
mock_delegate.hf_to_megatron.return_value = torch.randn(8, 4)
792792
with patch.object(mapping, "_detect_parallelism_type", return_value="column"):
793793
mapping.hf_to_megatron(hf_weight, megatron_module)
794794

795-
# On non-rank-0, permutation is skipped, original tensor passed to delegate
795+
# Permutation is applied on ALL ranks so delegate mappings
796+
# (e.g. ReplicatedMapping) always receive the correct shape.
796797
mock_delegate.hf_to_megatron.assert_called_once()
797798
passed_tensor = mock_delegate.hf_to_megatron.call_args[0][0]
798-
assert torch.equal(passed_tensor, hf_weight)
799+
assert torch.equal(passed_tensor, hf_weight.permute(1, 0).contiguous())
799800

800801
def test_transpose_identity_permutation(self, mock_distributed_env, transformer_config):
801802
"""Test AutoMapping with identity permutation."""

0 commit comments

Comments
 (0)