Skip to content

Commit e7cb13f

Browse files
committed
fixup
1 parent 9a02481 commit e7cb13f

3 files changed

Lines changed: 23 additions & 96 deletions

File tree

deepmd/pt/optimizer/hybrid_muon.py

Lines changed: 12 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -408,9 +408,8 @@ def __init__(self) -> None:
408408
(float(a), float(b), float(c)) for a, b, c in POLAR_EXPRESS_COEFFICIENTS
409409
)
410410
self._restart_iteration_set = frozenset((2,))
411-
self._call_impl = self._orthogonalize_impl
412411
self._compiled_call = torch.compile(
413-
self._call_impl,
412+
self._orthogonalize_impl,
414413
fullgraph=True,
415414
dynamic=True,
416415
)
@@ -561,85 +560,6 @@ def _reshape_update_to_matrix_batch(
561560
return update_tensor.reshape(batch_size, rows, cols).contiguous()
562561

563562

564-
def _stack_bucket_updates(
565-
bucket_entries: list[
566-
tuple[dict[str, Any], torch.Tensor, torch.Tensor, torch.Tensor]
567-
],
568-
batch_size: int,
569-
rows: int,
570-
cols: int,
571-
) -> torch.Tensor:
572-
"""
573-
Stack same-shape Muon updates into one tensor for orthogonalization.
574-
575-
Parameters
576-
----------
577-
bucket_entries : list[tuple[dict[str, Any], torch.Tensor, torch.Tensor, torch.Tensor]]
578-
Bucket entries as ``(entry, update_tensor, grad, momentum_buffer)``.
579-
batch_size : int
580-
Number of matrix slices per parameter in the bucket.
581-
rows : int
582-
Matrix row count for each slice.
583-
cols : int
584-
Matrix column count for each slice.
585-
586-
Returns
587-
-------
588-
torch.Tensor
589-
Tensor with shape ``(num_entries, batch_size, rows, cols)``.
590-
"""
591-
update_batches = [
592-
_reshape_update_to_matrix_batch(update_tensor, batch_size, rows, cols)
593-
for _, update_tensor, _, _ in bucket_entries
594-
]
595-
return torch.stack(update_batches, dim=0)
596-
597-
598-
def _orthogonalize_standard_stacked(
599-
stacked_updates: torch.Tensor,
600-
rows: int,
601-
cols: int,
602-
use_flash: bool,
603-
flash_buffers: tuple[torch.Tensor, torch.Tensor] | None = None,
604-
) -> torch.Tensor:
605-
"""
606-
Orthogonalize stacked updates with the current standard Newton-Schulz path.
607-
608-
Parameters
609-
----------
610-
stacked_updates : torch.Tensor
611-
Tensor with shape ``(num_entries, batch_size, rows, cols)``.
612-
rows : int
613-
Matrix row count.
614-
cols : int
615-
Matrix column count.
616-
use_flash : bool
617-
Whether to use the Triton single-matrix path when ``batch_size == 1``.
618-
flash_buffers : tuple[torch.Tensor, torch.Tensor] | None, optional
619-
Pre-allocated flash buffers when ``use_flash`` is enabled.
620-
621-
Returns
622-
-------
623-
torch.Tensor
624-
Orthogonalized tensor with the same shape as ``stacked_updates``.
625-
"""
626-
num_entries, batch_size, _, _ = stacked_updates.shape
627-
if use_flash and batch_size == 1:
628-
if flash_buffers is None:
629-
raise ValueError("Flash buffers are required when use_flash=True.")
630-
buf1, buf2 = flash_buffers
631-
orthogonalized = [
632-
_flash_newton_schulz_orth(stacked_updates[idx, 0], buf1, buf2)
633-
for idx in range(num_entries)
634-
]
635-
return torch.stack(orthogonalized, dim=0).unsqueeze(1)
636-
637-
# Flash path only supports one matrix per entry; batched inputs use bmm path.
638-
flat_updates = stacked_updates.reshape(num_entries * batch_size, rows, cols)
639-
orthogonalized = _batched_newton_schulz_orth(flat_updates)
640-
return orthogonalized.reshape(num_entries, batch_size, rows, cols)
641-
642-
643563
def _compute_muon_nesterov_updates(
644564
gradients: list[torch.Tensor],
645565
momentum_buffers: list[torch.Tensor],
@@ -958,23 +878,25 @@ def __init__(
958878
] = {}
959879
self._gram_orthogonalizer: _GramNewtonSchulzOrthogonalizer | None = None
960880

961-
# === Step 5. Foreach acceleration (disabled under FSDP2 / DTensor) ===
881+
# === Step 5. Foreach acceleration ===
882+
# Defaults to True for single-GPU / DDP / ZeRO-1 (plain tensors). Callers
883+
# that train under FSDP2 (``fully_shard``) should pass
884+
# ``use_foreach=False`` explicitly because several ``torch._foreach_*``
885+
# ops lack DTensor sharding propagation on older PyTorch builds.
962886
self._use_foreach = self._resolve_foreach(use_foreach)
963887

964888
@staticmethod
965889
def _resolve_foreach(use_foreach: bool | None) -> bool:
966-
"""Decide whether to use ``torch._foreach_*`` multi-tensor kernels.
890+
"""Resolve the ``use_foreach`` flag for ``torch._foreach_*`` kernels.
967891
968892
Foreach fuses per-parameter loops into single kernel launches,
969-
eliminating Python overhead. Disabled when parameters are FSDP2
970-
``DTensor`` because several foreach ops lack DTensor sharding
971-
propagation rules in older PyTorch builds.
893+
eliminating Python overhead. When ``use_foreach`` is ``None`` the
894+
default is ``True`` because plain ``torch.Tensor`` (single-GPU, DDP,
895+
ZeRO-1) always supports these ops; callers that hit DTensor dispatch
896+
errors under FSDP2 must pass ``use_foreach=False`` explicitly.
972897
"""
973898
if use_foreach is not None:
974899
return bool(use_foreach)
975-
# Conservative default: enable foreach (safe for DDP / ZeRO-1).
976-
# FSDP2 users should pass use_foreach=False explicitly if they
977-
# encounter DTensor dispatch errors.
978900
return True
979901

980902
def _compute_magma_scales_merged(
@@ -1392,7 +1314,6 @@ def _process_merged_gram_buckets(
13921314

13931315
for (rows, cols, dev, dt), bucket_entries in gram_buckets.items():
13941316
min_dim = min(rows, cols)
1395-
max_dim = max(rows, cols)
13961317
transposed = rows > cols
13971318
sb_key = (min_dim, dev, dt)
13981319
if sb_key not in super_buckets:
@@ -1401,7 +1322,7 @@ def _process_merged_gram_buckets(
14011322

14021323
gram_orth = self._get_gram_orthogonalizer()
14031324

1404-
for (min_dim, _dev, _dt), sub_list in super_buckets.items():
1325+
for (_min_dim, _dev, _dt), sub_list in super_buckets.items():
14051326
# Find the maximum large-dimension across all sub-buckets.
14061327
padded_max_dim = max(max(r, c) for r, c, _, _ in sub_list)
14071328

deepmd/pt/train/training.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -910,6 +910,11 @@ def single_model_finetune(
910910
else bool(self.opt_param.get("enable_gram")),
911911
"flash_muon": bool(self.opt_param.get("flash_muon")),
912912
"magma_muon": bool(self.opt_param.get("magma_muon")),
913+
# FSDP2 shards parameters as DTensor; several torch._foreach_*
914+
# ops lack DTensor sharding propagation on older PyTorch, so
915+
# fall back to the per-tensor path under zero_stage >= 2.
916+
# DDP / ZeRO-1 keep plain tensors and use the default.
917+
"use_foreach": False if self.zero_stage >= 2 else None,
913918
}
914919
else:
915920
raise ValueError(f"Not supported optimizer type '{self.opt_type}'")

source/tests/pt/test_hybrid_muon.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -590,9 +590,10 @@ def test_gram_ns_column_pad_exact_equivalence(self) -> None:
590590
X = torch.randn(1, m, n, dtype=torch.float32, device=self.device)
591591
X_padded = torch.nn.functional.pad(X, (0, pad)) # (1, m, n+pad)
592592

593-
# _orthogonalize_impl bypasses compile for deterministic comparison
594-
out_orig = gram_orth._call_impl(X)
595-
out_padded = gram_orth._call_impl(X_padded)
593+
# Call the uncompiled implementation directly so the two
594+
# invocations share an identical numerical path.
595+
out_orig = gram_orth._orthogonalize_impl(X)
596+
out_padded = gram_orth._orthogonalize_impl(X_padded)
596597

597598
# Truncate padded output to original column count
598599
out_padded_trunc = out_padded[:, :, :n]
@@ -648,12 +649,12 @@ def test_gram_ns_batch_pad_equivalence(self) -> None:
648649
padded_batch = []
649650
for m, n in shapes:
650651
X = torch.randn(1, m, n, dtype=torch.float32, device=self.device)
651-
per_matrix_results.append(gram_orth._call_impl(X))
652+
per_matrix_results.append(gram_orth._orthogonalize_impl(X))
652653
padded_batch.append(torch.nn.functional.pad(X, (0, padded_max - n)))
653654

654655
# Run Gram NS on the batched padded tensor
655656
stacked = torch.cat(padded_batch, dim=0) # (3, 64, 384)
656-
batched_out = gram_orth._call_impl(stacked)
657+
batched_out = gram_orth._orthogonalize_impl(stacked)
657658

658659
for i, (m, n) in enumerate(shapes):
659660
expected = per_matrix_results[i]

0 commit comments

Comments
 (0)