Skip to content

Commit a506ec5

Browse files
vcherepanov-nvKshitijLakhani
authored andcommitted
Make NS coefficients parameter 2D in Python API (#2904)
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
1 parent 95edbd4 commit a506ec5

2 files changed

Lines changed: 22 additions & 13 deletions

File tree

tests/pytorch/distributed/run_newton_schulz.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,12 @@
2121
)
2222

2323

24-
def newton_schulz_reference(in_x: torch.Tensor, coefficients: list[float]) -> torch.Tensor:
24+
def newton_schulz_reference(
25+
in_x: torch.Tensor, coefficients: list[tuple[float, float, float]]
26+
) -> torch.Tensor:
2527
"""Local Newton-Schulz reference mirroring the provided Octave update."""
2628
x = in_x.clone()
27-
for i in range(len(coefficients) // 3):
28-
a, b, c = coefficients[3 * i : 3 * (i + 1)]
29+
for a, b, c in coefficients:
2930
xxt = x @ x.mT
3031
x = a * x + b * xxt @ x + c * xxt @ xxt @ x
3132
return x

transformer_engine/pytorch/newton_schulz.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
"""Distributed Newton-Schulz matrix orthogonalization via cuSolverMp."""
66

77
from itertools import chain, cycle, islice, repeat
8-
from typing import Iterator, List, Literal, Optional, Sequence
8+
from typing import Iterator, Literal, Optional, Sequence
99

1010
import torch
1111
import torch.distributed as dist
@@ -63,13 +63,14 @@
6363
NSCoeffT = Literal[_COEFFICIENT_SETS.keys()]
6464

6565
CoeffIterMode = Literal["cycle", "repeat_last"]
66+
CoeffT = tuple[float, float, float]
6667

6768

6869
def get_coefficient_iterator(
6970
steps: int,
70-
coefficient_sets: Sequence[tuple[float, float, float]],
71+
coefficient_sets: Sequence[CoeffT],
7172
mode: CoeffIterMode = "cycle",
72-
) -> Iterator[tuple[float, float, float]]:
73+
) -> Iterator[CoeffT]:
7374
"""Iterate through coefficient sets with configurable end behavior using itertools.
7475
7576
Args:
@@ -89,7 +90,7 @@ def get_coefficient_iterator(
8990
if not coefficient_sets:
9091
raise ValueError("coefficient_sets must be non-empty.")
9192

92-
base: Iterator[tuple[float, float, float]]
93+
base: Iterator[CoeffT]
9394
if mode == "cycle":
9495
base = cycle(coefficient_sets)
9596
elif mode == "repeat_last":
@@ -101,7 +102,7 @@ def get_coefficient_iterator(
101102
return islice(base, steps)
102103

103104

104-
def get_coefficients(steps: int, coefficient_type: NSCoeffT = "quintic") -> List[float]:
105+
def get_coefficients(steps: int, coefficient_type: NSCoeffT = "quintic") -> list[CoeffT]:
105106
"""Return the coefficient schedule for Newton-Schulz.
106107
107108
Parameter ``coefficient_type`` can be one of the following
@@ -119,7 +120,7 @@ def get_coefficients(steps: int, coefficient_type: NSCoeffT = "quintic") -> List
119120
coeff_iter = get_coefficient_iterator(
120121
steps, _COEFFICIENT_SETS[coefficient_type], mode=iter_mode
121122
)
122-
return list(chain.from_iterable(coeff_iter))
123+
return list(coeff_iter)
123124

124125

125126
class CusolverMpCtx:
@@ -159,7 +160,7 @@ def newton_schulz(
159160
x: torch.Tensor,
160161
ctx: CusolverMpCtx,
161162
num_iterations: int = 5,
162-
coefficients: Optional[List[float]] = None,
163+
coefficients: Optional[Sequence[CoeffT]] = None,
163164
) -> None:
164165
"""Compute Newton-Schulz matrix orthogonalization in-place on a distributed matrix.
165166
@@ -173,16 +174,23 @@ def newton_schulz(
173174
cuSolverMp context created by :func:`cusolvermp_ctx_create`.
174175
num_iterations : int, optional
175176
Number of Newton-Schulz iterations. Default: 5.
176-
coefficients : list of float, optional
177+
coefficients : sequence of tuple[float, float, float], optional
177178
Polynomial coefficients for the Newton-Schulz iteration.
178179
"""
179180
if coefficients is None:
180181
coefficients = get_coefficients(num_iterations)
181-
if len(coefficients) != num_iterations * 3:
182+
if len(coefficients) != num_iterations:
182183
raise ValueError(
183184
f"Unexpected number of coefficients: {len(coefficients)} for"
184185
f" {num_iterations} iterations"
185186
)
187+
flat_coefficients: list[float] = []
188+
for i, coeff in enumerate(coefficients):
189+
if len(coeff) != 3:
190+
raise ValueError(
191+
f"Expected coefficient tuple of length 3 at iteration {i}, got {len(coeff)}"
192+
)
193+
flat_coefficients.extend(coeff)
186194

187195
if x.dim() != 2:
188196
raise ValueError(f"Expected 2D tensor, got {x.dim()}D")
@@ -197,4 +205,4 @@ def newton_schulz(
197205
m = x.size(0)
198206
n = x.size(1) * ctx.nranks
199207

200-
tex.newton_schulz(ctx._ptr, m, n, x, num_iterations, coefficients)
208+
tex.newton_schulz(ctx._ptr, m, n, x, num_iterations, flat_coefficients)

0 commit comments

Comments
 (0)