@@ -49,18 +49,61 @@ def _get_lapack_funcs(dtype, names):
4949###############################################################################
5050# linalg.svd and linalg.pinv2
5151
52+ # TODO VERSION these can be removed once we use SciPy 1.18+ with batched processing
53+ # (otherwise call overhead for `lwork` is too high). Can't use NumPy because
54+ # it doesn't provide an interface to GESVD (only uses GESDD, which can be buggy).
55+
56+
57+ # Vendored from scipy: _compute_lwork and _check_work_float
58+
59+
60+ def _compute_lwork (routine , * args , ** kwargs ): # pragma: no cover
61+ dtype = getattr (routine , "dtype" , None )
62+ int_dtype = getattr (routine , "int_dtype" , None )
63+ ret = routine (* args , ** kwargs )
64+ if ret [- 1 ] != 0 :
65+ raise ValueError (f"Internal work array size computation failed: { ret [- 1 ]} " )
66+ if len (ret ) == 2 :
67+ return _check_work_float (ret [0 ].real , dtype , int_dtype )
68+ else :
69+ return tuple (_check_work_float (x .real , dtype , int_dtype ) for x in ret [:- 1 ])
70+
71+
72+ _int32_max = np .iinfo (np .int32 ).max
73+ _int64_max = np .iinfo (np .int64 ).max
74+
75+
76+ def _check_work_float (value , dtype , int_dtype ): # pragma: no cover
77+ if dtype == np .float32 or dtype == np .complex64 :
78+ # Single-precision routine -- take next fp value to work
79+ # around possible truncation in LAPACK code
80+ value = np .nextafter (value , np .inf , dtype = np .float32 )
81+
82+ value = int (value )
83+ if int_dtype .itemsize == 4 :
84+ if value < 0 or value > _int32_max :
85+ raise ValueError (
86+ "Too large work array required -- computation "
87+ "cannot be performed with standard 32-bit"
88+ " LAPACK."
89+ )
90+ elif int_dtype .itemsize == 8 :
91+ if value < 0 or value > _int64_max :
92+ raise ValueError (
93+ "Too large work array required -- computation"
94+ " cannot be performed with standard 64-bit"
95+ " LAPACK."
96+ )
97+ return value
98+
5299
53100def _svd_lwork (shape , dtype = np .float64 ):
54101 """Set up SVD calculations on identical-shape float64/complex128 arrays."""
55- try :
56- ds = linalg ._decomp_svd
57- except AttributeError : # < 1.8.0
58- ds = linalg .decomp_svd
59102 gesdd_lwork , gesvd_lwork = _get_lapack_funcs (dtype , ("gesdd_lwork" , "gesvd_lwork" ))
60- sdd_lwork = ds . _compute_lwork (
103+ sdd_lwork = _compute_lwork (
61104 gesdd_lwork , * shape , compute_uv = True , full_matrices = False
62105 )
63- svd_lwork = ds . _compute_lwork (
106+ svd_lwork = _compute_lwork (
64107 gesvd_lwork , * shape , compute_uv = True , full_matrices = False
65108 )
66109 return sdd_lwork , svd_lwork
0 commit comments