55"""Distributed Newton-Schulz matrix orthogonalization via cuSolverMp."""
66
77from itertools import chain , cycle , islice , repeat
8- from typing import Iterator , List , Literal , Optional , Sequence
8+ from typing import Iterator , Literal , Optional , Sequence
99
1010import torch
1111import torch .distributed as dist
6363NSCoeffT = Literal [_COEFFICIENT_SETS .keys ()]
6464
6565CoeffIterMode = Literal ["cycle" , "repeat_last" ]
66+ CoeffT = tuple [float , float , float ]
6667
6768
6869def get_coefficient_iterator (
6970 steps : int ,
70- coefficient_sets : Sequence [tuple [ float , float , float ] ],
71+ coefficient_sets : Sequence [CoeffT ],
7172 mode : CoeffIterMode = "cycle" ,
72- ) -> Iterator [tuple [ float , float , float ] ]:
73+ ) -> Iterator [CoeffT ]:
7374 """Iterate through coefficient sets with configurable end behavior using itertools.
7475
7576 Args:
@@ -89,7 +90,7 @@ def get_coefficient_iterator(
8990 if not coefficient_sets :
9091 raise ValueError ("coefficient_sets must be non-empty." )
9192
92- base : Iterator [tuple [ float , float , float ] ]
93+ base : Iterator [CoeffT ]
9394 if mode == "cycle" :
9495 base = cycle (coefficient_sets )
9596 elif mode == "repeat_last" :
@@ -101,7 +102,7 @@ def get_coefficient_iterator(
101102 return islice (base , steps )
102103
103104
104- def get_coefficients (steps : int , coefficient_type : NSCoeffT = "quintic" ) -> List [ float ]:
105+ def get_coefficients (steps : int , coefficient_type : NSCoeffT = "quintic" ) -> list [ CoeffT ]:
105106 """Return the coefficient schedule for Newton-Schulz.
106107
107108 Parameter ``coefficient_type`` can be one of the following
@@ -119,7 +120,7 @@ def get_coefficients(steps: int, coefficient_type: NSCoeffT = "quintic") -> List
119120 coeff_iter = get_coefficient_iterator (
120121 steps , _COEFFICIENT_SETS [coefficient_type ], mode = iter_mode
121122 )
122- return list (chain . from_iterable ( coeff_iter ) )
123+ return list (coeff_iter )
123124
124125
125126class CusolverMpCtx :
@@ -159,7 +160,7 @@ def newton_schulz(
159160 x : torch .Tensor ,
160161 ctx : CusolverMpCtx ,
161162 num_iterations : int = 5 ,
162- coefficients : Optional [List [ float ]] = None ,
163+ coefficients : Optional [Sequence [ CoeffT ]] = None ,
163164) -> None :
164165 """Compute Newton-Schulz matrix orthogonalization in-place on a distributed matrix.
165166
@@ -173,16 +174,23 @@ def newton_schulz(
173174 cuSolverMp context created by :func:`cusolvermp_ctx_create`.
174175 num_iterations : int, optional
175176 Number of Newton-Schulz iterations. Default: 5.
176- coefficients : list of float, optional
177+ coefficients : sequence of tuple[ float, float, float] , optional
177178 Polynomial coefficients for the Newton-Schulz iteration.
178179 """
179180 if coefficients is None :
180181 coefficients = get_coefficients (num_iterations )
181- if len (coefficients ) != num_iterations * 3 :
182+ if len (coefficients ) != num_iterations :
182183 raise ValueError (
183184 f"Unexpected number of coefficients: { len (coefficients )} for"
184185 f" { num_iterations } iterations"
185186 )
187+ flat_coefficients : list [float ] = []
188+ for i , coeff in enumerate (coefficients ):
189+ if len (coeff ) != 3 :
190+ raise ValueError (
191+ f"Expected coefficient tuple of length 3 at iteration { i } , got { len (coeff )} "
192+ )
193+ flat_coefficients .extend (coeff )
186194
187195 if x .dim () != 2 :
188196 raise ValueError (f"Expected 2D tensor, got { x .dim ()} D" )
@@ -197,4 +205,4 @@ def newton_schulz(
197205 m = x .size (0 )
198206 n = x .size (1 ) * ctx .nranks
199207
200- tex .newton_schulz (ctx ._ptr , m , n , x , num_iterations , coefficients )
208+ tex .newton_schulz (ctx ._ptr , m , n , x , num_iterations , flat_coefficients )
0 commit comments