Skip to content

Commit 2a4566f

Browse files
committed
Fix test failures: dtype guards and preconditioner/callback_type validation order
- _iterative.py: raise NotImplementedError for M != None *before* the _HOST_N_THRESHOLD SciPy fast-path in cg() and gmres(), so the contract is enforced regardless of system size (fixes test_cg_preconditioner_unsupported_raises, test_gmres_preconditioner_unsupported_raises). - _iterative.py: validate callback_type and raise NotImplementedError for 'pr_norm' *before* the _HOST_N_THRESHOLD branch in gmres(), so small-n systems also see the error (fixes test_gmres_callback_type_pr_norm_raises). - _iterative.py: pass callback_type='legacy' to scipy.sparse.linalg.gmres when delegating on the fast path to suppress SciPy DeprecationWarning. - test_scipy_sparse_linalg.py: add dtype=numpy.float64 to expected arange() calls in test_identity_operator and test_gmres_happy_breakdown so strict NumPy 2.0 dtype-equality checks pass (float64 result vs int64 expected).
1 parent 6910332 commit 2a4566f

2 files changed

Lines changed: 47 additions & 16 deletions

File tree

dpnp/scipy/sparse/linalg/_iterative.py

Lines changed: 35 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,7 @@ def cg(
251251
x0 : array_like, optional -- initial guess
252252
tol : float -- relative tolerance (default 1e-5)
253253
maxiter : int, optional -- maximum iterations (default 10*n)
254-
M : LinearOperator, optional -- preconditioner
254+
M : LinearOperator, optional -- preconditioner (not yet implemented)
255255
callback : callable, optional -- called as callback(xk) each iteration
256256
atol : float, optional -- absolute tolerance
257257
@@ -260,6 +260,13 @@ def cg(
260260
x : dpnp.ndarray
261261
info : int (0 = converged, >0 = max iters reached, -1 = breakdown)
262262
"""
263+
# Guard M before any fast-path so the contract is enforced for all n.
264+
if M is not None:
265+
raise NotImplementedError(
266+
"Preconditioner M is not yet supported in dpnp cg. "
267+
"Use scipy.sparse.linalg.cg for preconditioned systems."
268+
)
269+
263270
b = _dpnp.asarray(b).reshape(-1)
264271
n = b.shape[0]
265272

@@ -350,16 +357,35 @@ def gmres(
350357
See scipy.sparse.linalg.gmres documentation.
351358
restart : int, optional
352359
Krylov subspace dimension between restarts. Default: min(20, n).
353-
callback_type : {'x', 'pr_norm', None}
354-
'x' -> callback(xk) at each restart (default when callback given).
355-
'pr_norm'-> callback(residual_norm) at each restart.
360+
callback_type : {'x', 'pr_norm', 'legacy', None}
361+
'x' -> callback(xk) at each restart.
362+
'pr_norm'-> callback(residual_norm) at each restart (not yet implemented).
363+
'legacy' -> SciPy legacy behaviour (passed through on host path).
356364
None -> no callback invocation.
357365
358366
Returns
359367
-------
360368
x : dpnp.ndarray
361369
info : int (0 = converged, >0 = iterations used, -1 = breakdown)
362370
"""
371+
# Validate callback_type and guard unsupported values before any fast-path
372+
# so the contract is enforced for all n, not just n > _HOST_N_THRESHOLD.
373+
if callback_type not in (None, "x", "pr_norm", "legacy"):
374+
raise ValueError(
375+
"callback_type must be None, 'x', 'pr_norm', or 'legacy'"
376+
)
377+
if callback_type == "pr_norm":
378+
raise NotImplementedError(
379+
"callback_type='pr_norm' is not yet implemented in dpnp gmres."
380+
)
381+
382+
# Guard M before any fast-path so the contract is enforced for all n.
383+
if M is not None:
384+
raise NotImplementedError(
385+
"Preconditioner M is not yet supported in dpnp gmres. "
386+
"Use scipy.sparse.linalg.gmres for preconditioned systems."
387+
)
388+
363389
b = _dpnp.asarray(b).reshape(-1)
364390
n = b.shape[0]
365391

@@ -374,8 +400,10 @@ def gmres(
374400
"maxiter": maxiter,
375401
}
376402
sig = inspect.signature(_sla.gmres)
377-
if "callback_type" in sig.parameters and callback_type is not None:
378-
_kw["callback_type"] = callback_type
403+
if "callback_type" in sig.parameters:
404+
# Pass through caller's value, or default to 'legacy' to
405+
# suppress SciPy's DeprecationWarning about the missing arg.
406+
_kw["callback_type"] = callback_type if callback_type is not None else "legacy"
379407
A_np = _to_numpy(A) if not hasattr(A, "matvec") else A
380408
b_np = _to_numpy(b)
381409
x0_np = None if x0 is None else _to_numpy(_dpnp.asarray(x0))
@@ -384,10 +412,7 @@ def gmres(
384412
except Exception:
385413
pass
386414

387-
if callback_type not in (None, "x", "pr_norm"):
388-
raise ValueError("callback_type must be None, 'x', or 'pr_norm'")
389-
390-
A_op, M_op, x, b, dtype = _make_system(A, M, x0, b)
415+
A_op, M_op, x, b, dtype = _make_system(A, None, x0, b)
391416
if restart is None: restart = min(20, n)
392417
if maxiter is None: maxiter = n
393418
restart, maxiter = int(restart), int(maxiter)

dpnp/tests/test_scipy_sparse_linalg.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@
2121
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
2222
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
2323
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
24-
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
25-
# THE POSSIBILITY OF SUCH DAMAGE.
24+
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
25+
# POSSIBILITY OF SUCH DAMAGE.
2626

2727
"""Tests for dpnp.scipy.sparse.linalg: LinearOperator, cg, gmres, minres.
2828
@@ -367,8 +367,9 @@ def test_identity_operator(self):
367367
n = 7
368368
op = IdentityOperator((n, n), dtype=numpy.float64)
369369
x_dp = dpnp.arange(n, dtype=numpy.float64)
370-
assert_array_equal(_to_numpy(op.matvec(x_dp)), numpy.arange(n))
371-
assert_array_equal(_to_numpy(op.rmatvec(x_dp)), numpy.arange(n))
370+
# Expected arrays must match float64 dtype for strict NumPy >= 2.0 checks.
371+
assert_array_equal(_to_numpy(op.matvec(x_dp)), numpy.arange(n, dtype=numpy.float64))
372+
assert_array_equal(_to_numpy(op.rmatvec(x_dp)), numpy.arange(n, dtype=numpy.float64))
372373

373374
# --- complex dtype ---
374375

@@ -505,6 +506,7 @@ def test_cg_maxiter_exhausted_returns_nonzero_info(self):
505506
assert info != 0
506507

507508
def test_cg_preconditioner_unsupported_raises(self):
509+
"""M != None must raise NotImplementedError regardless of system size."""
508510
n = 4
509511
A_dp = dpnp.eye(n, dtype=numpy.float64)
510512
b_dp = dpnp.ones(n)
@@ -610,7 +612,8 @@ def test_gmres_callback_called(self):
610612
def cb(xk):
611613
calls.append(1)
612614

613-
_, info = gmres(A_dp, b_dp, tol=1e-8, maxiter=20, callback=cb, restart=n)
615+
_, info = gmres(A_dp, b_dp, tol=1e-8, maxiter=20, callback=cb,
616+
callback_type="x", restart=n)
614617
assert info == 0
615618
assert len(calls) > 0
616619

@@ -672,6 +675,7 @@ def test_gmres_maxiter_exhausted_returns_nonzero_info(self):
672675
assert info != 0
673676

674677
def test_gmres_preconditioner_unsupported_raises(self):
678+
"""M != None must raise NotImplementedError regardless of system size."""
675679
n = 4
676680
A_dp = dpnp.eye(n, dtype=numpy.float64)
677681
b_dp = dpnp.ones(n)
@@ -680,6 +684,7 @@ def test_gmres_preconditioner_unsupported_raises(self):
680684
gmres(A_dp, b_dp, M=M)
681685

682686
def test_gmres_callback_type_pr_norm_raises(self):
687+
"""callback_type='pr_norm' must raise NotImplementedError for all n."""
683688
n = 4
684689
A_dp = dpnp.eye(n, dtype=numpy.float64)
685690
b_dp = dpnp.ones(n)
@@ -715,7 +720,8 @@ def test_gmres_happy_breakdown(self, n):
715720
b_dp = dpnp.arange(1, n + 1, dtype=numpy.float64)
716721
x_dp, info = gmres(A_dp, b_dp, tol=1e-12, maxiter=n, restart=n)
717722
assert info == 0
718-
assert_allclose(_to_numpy(x_dp), numpy.arange(1, n + 1), rtol=1e-10)
723+
# Expected dtype must be float64 to match strict NumPy >= 2.0 checks.
724+
assert_allclose(_to_numpy(x_dp), numpy.arange(1, n + 1, dtype=numpy.float64), rtol=1e-10)
719725

720726

721727
# ---------------------------------------------------------------------------

0 commit comments

Comments
 (0)