diff --git a/pylops/utils/backend.py b/pylops/utils/backend.py index 318362652..af315a29c 100644 --- a/pylops/utils/backend.py +++ b/pylops/utils/backend.py @@ -114,9 +114,9 @@ def get_module_name(mod: ModuleType) -> str: """ if mod == np: backend = "numpy" - elif mod == cp: + elif deps.cupy_enabled and mod == cp: backend = "cupy" - elif mod == jnp: + elif deps.jax_enabled and mod == jnp: backend = "jax" else: raise ValueError("module must be numpy, cupy, or jax")