diff --git a/pylops/waveeqprocessing/mdd.py b/pylops/waveeqprocessing/mdd.py index 90e2cba3b..02f61b3d3 100644 --- a/pylops/waveeqprocessing/mdd.py +++ b/pylops/waveeqprocessing/mdd.py @@ -405,17 +405,19 @@ def MDD( # Add negative part to data and model if twosided and add_negative: - G = np.concatenate((ncp.zeros((ns, nr, nt - 1)), G), axis=-1) - d = np.concatenate((np.squeeze(np.zeros((ns, nv, nt - 1))), d), axis=-1) + G = ncp.concatenate((ncp.zeros((ns, nr, nt - 1), dtype=G.dtype), G), axis=-1) + d = ncp.concatenate( + (ncp.squeeze(ncp.zeros((ns, nv, nt - 1), dtype=d.dtype)), d), axis=-1 + ) # Bring kernel to frequency domain - Gfft = np.fft.rfft(G, nt2, axis=-1) + Gfft = ncp.fft.rfft(G, nt2, axis=-1) Gfft = Gfft[..., :nfmax] # Bring frequency/time to first dimension - Gfft = np.moveaxis(Gfft, -1, 0) - d = np.moveaxis(d, -1, 0) + Gfft = ncp.moveaxis(Gfft, -1, 0) + d = ncp.moveaxis(d, -1, 0) if psf: - G = np.moveaxis(G, -1, 0) + G = ncp.moveaxis(G, -1, 0) # Define MDC linear operator MDCop = MDC( @@ -455,12 +457,12 @@ def MDD( # Adjoint if adjoint: madj = MDCop.H * d.ravel() - madj = np.squeeze(madj.reshape(nt2, nr, nv)) - madj = np.moveaxis(madj, 0, -1) + madj = ncp.squeeze(madj.reshape(nt2, nr, nv)) + madj = ncp.moveaxis(madj, 0, -1) if psf: psfadj = PSFop.H * G.ravel() - psfadj = np.squeeze(psfadj.reshape(nt2, nr, nr)) - psfadj = np.moveaxis(psfadj, 0, -1) + psfadj = ncp.squeeze(psfadj.reshape(nt2, nr, nr)) + psfadj = ncp.moveaxis(psfadj, 0, -1) # Inverse if twosided and causality_precond: @@ -481,8 +483,8 @@ def MDD( ncp.zeros(int(MDCop.shape[1]), dtype=MDCop.dtype), **kwargs_solver )[0] - minv = np.squeeze(minv.reshape(nt2, nr, nv)) - minv = np.moveaxis(minv, 0, -1) + minv = ncp.squeeze(minv.reshape(nt2, nr, nv)) + minv = ncp.moveaxis(minv, 0, -1) if wav is not None: wav1 = wav.copy() @@ -500,8 +502,8 @@ def MDD( ncp.zeros(int(PSFop.shape[1]), dtype=PSFop.dtype), **kwargs_solver )[0] - psfinv = np.squeeze(psfinv.reshape(nt2, nr, nr)) - psfinv = np.moveaxis(psfinv, 0, -1) + psfinv = ncp.squeeze(psfinv.reshape(nt2, nr, nr)) + psfinv = ncp.moveaxis(psfinv, 0, -1) if wav is not None: wav1 = wav.copy() for _ in range(psfinv.ndim - 1):