|
9 | 9 | correctly handles this mismatch. |
10 | 10 | """ |
11 | 11 |
|
| 12 | +import math |
| 13 | +from unittest.mock import MagicMock |
| 14 | + |
12 | 15 | import pytest |
13 | 16 | import torch |
14 | 17 |
|
@@ -290,3 +293,182 @@ def test_bnb_shape_mismatch_raises(self): |
290 | 293 | shard_id="w2", |
291 | 294 | expert_id=0, |
292 | 295 | ) |
| 296 | + |
| 297 | + |
| 298 | +def _make_fused_moe_mock(*, is_act_and_mul: bool = True): |
| 299 | + """Build a FusedMoE mock for weight loading tests.""" |
| 300 | + moe_module = MagicMock(spec=FusedMoE) |
| 301 | + moe_module.moe_config = MagicMock() |
| 302 | + moe_module.moe_config.is_act_and_mul = is_act_and_mul |
| 303 | + |
| 304 | + moe_module._get_hidden_dim = FusedMoE._get_hidden_dim |
| 305 | + moe_module._narrow_expert_data_for_padding = ( |
| 306 | + FusedMoE._narrow_expert_data_for_padding |
| 307 | + ) |
| 308 | + return moe_module |
| 309 | + |
| 310 | + |
| 311 | +class TestBlockQuantPaddedHiddenAndIntermediateSize: |
| 312 | + """Tests weight loading with padded hidden_size and intermediate_size |
| 313 | + across TP ranks. |
| 314 | +
|
| 315 | + hidden_size: 192 -> 256 (DeepEP-style round-up) |
| 316 | + intermediate_size_per_partition: 448 -> 512 (block_n=128 alignment) |
| 317 | + """ |
| 318 | + |
| 319 | + BLOCK_N = 128 |
| 320 | + HIDDEN_UNPADDED = 192 |
| 321 | + HIDDEN_PADDED = math.ceil(HIDDEN_UNPADDED / BLOCK_N) * BLOCK_N |
| 322 | + INTERMEDIATE_UNPADDED = 448 |
| 323 | + INTERMEDIATE_PADDED = math.ceil(INTERMEDIATE_UNPADDED / BLOCK_N) * BLOCK_N |
| 324 | + TP_SIZE = 4 |
| 325 | + GLOBAL_INTER = INTERMEDIATE_UNPADDED * TP_SIZE |
| 326 | + |
| 327 | + def _make_fused_moe(self): |
| 328 | + return _make_fused_moe_mock() |
| 329 | + |
| 330 | + def test_load_w1_weight_all_tp_ranks(self): |
| 331 | + """Each TP rank loads block-aligned rows into the w1 half. |
| 332 | + The last rank gets fewer rows; the rest is padding.""" |
| 333 | + moe_module = self._make_fused_moe() |
| 334 | + checkpoint = torch.randn(self.GLOBAL_INTER, self.HIDDEN_UNPADDED) |
| 335 | + |
| 336 | + for tp_rank in range(self.TP_SIZE): |
| 337 | + expert_data = torch.zeros(2 * self.INTERMEDIATE_PADDED, self.HIDDEN_PADDED) |
| 338 | + FusedMoE._load_w13( |
| 339 | + moe_module, |
| 340 | + expert_data=expert_data, |
| 341 | + shard_dim=0, |
| 342 | + shard_id="w1", |
| 343 | + loaded_weight=checkpoint.clone(), |
| 344 | + tp_rank=tp_rank, |
| 345 | + ) |
| 346 | + w1 = expert_data[: self.INTERMEDIATE_PADDED] |
| 347 | + start = tp_rank * self.INTERMEDIATE_PADDED |
| 348 | + n_available = min(self.INTERMEDIATE_PADDED, self.GLOBAL_INTER - start) |
| 349 | + expected = checkpoint[start : start + n_available] |
| 350 | + |
| 351 | + assert torch.equal(w1[:n_available, : self.HIDDEN_UNPADDED], expected) |
| 352 | + assert torch.all(w1[n_available:] == 0) |
| 353 | + assert torch.all(w1[:n_available, self.HIDDEN_UNPADDED :] == 0) |
| 354 | + assert torch.all(expert_data[self.INTERMEDIATE_PADDED :] == 0) |
| 355 | + |
| 356 | + def test_load_w3_weight_into_second_half(self): |
| 357 | + """w3 weight is written into the second half of the w13 allocation.""" |
| 358 | + moe_module = self._make_fused_moe() |
| 359 | + checkpoint = torch.randn(self.GLOBAL_INTER, self.HIDDEN_UNPADDED) |
| 360 | + tp_rank = 2 |
| 361 | + |
| 362 | + expert_data = torch.zeros(2 * self.INTERMEDIATE_PADDED, self.HIDDEN_PADDED) |
| 363 | + FusedMoE._load_w13( |
| 364 | + moe_module, |
| 365 | + expert_data=expert_data, |
| 366 | + shard_dim=0, |
| 367 | + shard_id="w3", |
| 368 | + loaded_weight=checkpoint.clone(), |
| 369 | + tp_rank=tp_rank, |
| 370 | + ) |
| 371 | + assert torch.all(expert_data[: self.INTERMEDIATE_PADDED] == 0) |
| 372 | + |
| 373 | + w3 = expert_data[self.INTERMEDIATE_PADDED :] |
| 374 | + start = tp_rank * self.INTERMEDIATE_PADDED |
| 375 | + n_available = min(self.INTERMEDIATE_PADDED, self.GLOBAL_INTER - start) |
| 376 | + assert torch.equal( |
| 377 | + w3[:n_available, : self.HIDDEN_UNPADDED], |
| 378 | + checkpoint[start : start + n_available], |
| 379 | + ) |
| 380 | + assert torch.all(w3[n_available:] == 0) |
| 381 | + |
| 382 | + def test_load_w2_weight_all_tp_ranks(self): |
| 383 | + """Each TP rank loads block-aligned columns of w2.""" |
| 384 | + moe_module = self._make_fused_moe() |
| 385 | + checkpoint = torch.randn(self.HIDDEN_UNPADDED, self.GLOBAL_INTER) |
| 386 | + |
| 387 | + for tp_rank in range(self.TP_SIZE): |
| 388 | + expert_data = torch.zeros(self.HIDDEN_PADDED, self.INTERMEDIATE_PADDED) |
| 389 | + FusedMoE._load_w2( |
| 390 | + moe_module, |
| 391 | + expert_data=expert_data, |
| 392 | + shard_dim=1, |
| 393 | + loaded_weight=checkpoint.clone(), |
| 394 | + tp_rank=tp_rank, |
| 395 | + ) |
| 396 | + start = tp_rank * self.INTERMEDIATE_PADDED |
| 397 | + n_available = min(self.INTERMEDIATE_PADDED, self.GLOBAL_INTER - start) |
| 398 | + expected = checkpoint[:, start : start + n_available] |
| 399 | + assert torch.equal( |
| 400 | + expert_data[: self.HIDDEN_UNPADDED, :n_available], expected |
| 401 | + ) |
| 402 | + assert torch.all(expert_data[:, n_available:] == 0) |
| 403 | + assert torch.all(expert_data[self.HIDDEN_UNPADDED :] == 0) |
| 404 | + |
| 405 | + def test_load_w1_scale_all_tp_ranks(self): |
| 406 | + """Each TP rank loads block-aligned scale rows for w1.""" |
| 407 | + moe_module = self._make_fused_moe() |
| 408 | + n_rows_global = math.ceil(self.GLOBAL_INTER / self.BLOCK_N) |
| 409 | + n_cols_ckpt = math.ceil(self.HIDDEN_UNPADDED / self.BLOCK_N) |
| 410 | + n_rows_local = math.ceil(self.INTERMEDIATE_PADDED / self.BLOCK_N) |
| 411 | + n_cols_alloc = math.ceil(self.HIDDEN_PADDED / self.BLOCK_N) |
| 412 | + |
| 413 | + checkpoint_scale = torch.randn(n_rows_global, n_cols_ckpt) |
| 414 | + |
| 415 | + for tp_rank in range(self.TP_SIZE): |
| 416 | + expert_data = torch.zeros(2 * n_rows_local, n_cols_alloc) |
| 417 | + FusedMoE._load_w13( |
| 418 | + moe_module, |
| 419 | + expert_data=expert_data, |
| 420 | + shard_dim=0, |
| 421 | + shard_id="w1", |
| 422 | + loaded_weight=checkpoint_scale.clone(), |
| 423 | + tp_rank=tp_rank, |
| 424 | + ) |
| 425 | + w1_scale = expert_data[:n_rows_local] |
| 426 | + start = n_rows_local * tp_rank |
| 427 | + loaded = min(n_rows_local, n_rows_global - start) |
| 428 | + expected = checkpoint_scale[start : start + loaded] |
| 429 | + assert torch.equal(w1_scale[:loaded, :n_cols_ckpt], expected) |
| 430 | + |
| 431 | + def test_load_w2_scale_all_tp_ranks(self): |
| 432 | + """Each TP rank loads block-aligned scale columns for w2.""" |
| 433 | + moe_module = self._make_fused_moe() |
| 434 | + n_rows = math.ceil(self.HIDDEN_UNPADDED / self.BLOCK_N) |
| 435 | + n_cols_global = math.ceil(self.GLOBAL_INTER / self.BLOCK_N) |
| 436 | + n_cols_local = math.ceil(self.INTERMEDIATE_PADDED / self.BLOCK_N) |
| 437 | + |
| 438 | + checkpoint_scale = torch.randn(n_rows, n_cols_global) |
| 439 | + |
| 440 | + for tp_rank in range(self.TP_SIZE): |
| 441 | + expert_data = torch.zeros(n_rows, n_cols_local) |
| 442 | + FusedMoE._load_w2( |
| 443 | + moe_module, |
| 444 | + expert_data=expert_data, |
| 445 | + shard_dim=1, |
| 446 | + loaded_weight=checkpoint_scale.clone(), |
| 447 | + tp_rank=tp_rank, |
| 448 | + ) |
| 449 | + start = n_cols_local * tp_rank |
| 450 | + loaded = min(n_cols_local, n_cols_global - start) |
| 451 | + expected = checkpoint_scale[:, start : start + loaded] |
| 452 | + assert torch.equal(expert_data[:, :loaded], expected) |
| 453 | + |
| 454 | + def test_no_padding_matches_simple_shard(self): |
| 455 | + """When sizes are already block-aligned, loading is a simple |
| 456 | + shard_size * tp_rank partition.""" |
| 457 | + intermediate = 512 |
| 458 | + hidden = 256 |
| 459 | + moe_module = _make_fused_moe_mock() |
| 460 | + checkpoint = torch.randn(intermediate * self.TP_SIZE, hidden) |
| 461 | + |
| 462 | + for tp_rank in range(self.TP_SIZE): |
| 463 | + expert_data = torch.zeros(2 * intermediate, hidden) |
| 464 | + FusedMoE._load_w13( |
| 465 | + moe_module, |
| 466 | + expert_data=expert_data, |
| 467 | + shard_dim=0, |
| 468 | + shard_id="w1", |
| 469 | + loaded_weight=checkpoint.clone(), |
| 470 | + tp_rank=tp_rank, |
| 471 | + ) |
| 472 | + w1 = expert_data[:intermediate] |
| 473 | + start = tp_rank * intermediate |
| 474 | + assert torch.equal(w1, checkpoint[start : start + intermediate]) |
0 commit comments