Skip to content

Commit df1c0f2

Browse files
authored
Merge pull request #750 from mrava87/feat-torchopadjoint
Feat: added forward/adjoint to TorchOperator
2 parents ebe4e15 + 71e426c commit df1c0f2

3 files changed

Lines changed: 124 additions & 1 deletion

File tree

pylops/torchoperator.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
]
44

55

6+
from math import prod
7+
68
import numpy as np
79

810
from pylops import LinearOperator
@@ -91,6 +93,15 @@ def __init__(
9193
def __call__(self, x):
9294
return self.apply(x)
9395

96+
def __repr__(self):
97+
M, N = prod(self.dimsd), prod(self.dims)
98+
if self.dtype is None:
99+
dt = "unspecified dtype"
100+
else:
101+
dt = "dtype=" + str(self.dtype)
102+
103+
return "<%dx%d %s with %s>" % (M, N, self.__class__.__name__, dt)
104+
94105
def apply(self, x: TensorTypeLike) -> TensorTypeLike:
95106
"""Apply forward pass to input vector
96107
@@ -106,3 +117,22 @@ def apply(self, x: TensorTypeLike) -> TensorTypeLike:
106117
107118
"""
108119
return self.Top(x, self.matvec, self.rmatvec, self.device, self.devicetorch)
120+
121+
# alias for forward pass
122+
forward = apply
123+
124+
def adjoint(self, x: TensorTypeLike) -> TensorTypeLike:
125+
"""Apply adjoint pass to input vector
126+
127+
Parameters
128+
----------
129+
x : :obj:`torch.Tensor`
130+
Input array
131+
132+
Returns
133+
-------
134+
y : :obj:`torch.Tensor`
135+
Output array resulting from the application of the adjoint operator to ``x``.
136+
137+
"""
138+
return self.Top(x, self.rmatvec, self.matvec, self.device, self.devicetorch)

pytests/test_torchoperator.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,3 +109,40 @@ def test_TorchOperator_batch_nd(par, dtype):
109109

110110
assert yt.dtype == dtype
111111
assert_array_equal(y, yt)
112+
113+
114+
@pytest.mark.skipif(platform.system() == "Darwin", reason="Not OSX enabled")
115+
@pytest.mark.parametrize("par", [(par1)])
116+
@pytest.mark.parametrize("dtype", [np.float32, np.float64])
117+
def test_TorchOperator_forward_adjoint(par, dtype):
118+
"""Compute gradient of L2 norm (chain of forward and adjoint) and
119+
compare Jacobian vector product with analytical solution"""
120+
device = "cpu" if backend == "numpy" else "cuda"
121+
122+
Dop = MatrixMult(
123+
np.random.normal(0.0, 1.0, (par["ny"], par["nx"])).astype(dtype), dtype=dtype
124+
)
125+
Top = TorchOperator(Dop, batch=False, device="cpu" if backend == "numpy" else "gpu")
126+
127+
x = np.random.normal(0.0, 1.0, par["nx"]).astype(dtype)
128+
xt = torch.from_numpy(to_numpy(x)).to(device).view(-1)
129+
xt.requires_grad = True
130+
y = -2 * np.arange(par["ny"], dtype=dtype)
131+
yt = torch.from_numpy(to_numpy(y)).to(device).view(-1)
132+
v = np.random.normal(0.0, 1.0, par["ny"]).astype(dtype)
133+
vt = torch.from_numpy(to_numpy(v)).to(device).view(-1)
134+
135+
# pylops operator
136+
f = Dop.H * (y - Dop * x)
137+
jvt = -Dop.H * Dop * v
138+
139+
# torch operator
140+
ft = Top.adjoint(yt - Top.forward(xt))
141+
ft.backward(vt, retain_graph=True)
142+
jvtt = xt.grad.cpu().numpy()
143+
ft = ft.detach().cpu().numpy()
144+
145+
assert ft.dtype == x.dtype
146+
assert jvtt.dtype == x.dtype
147+
assert_array_equal(f, ft)
148+
assert_array_equal(jvt, jvtt)

tutorials/torchop.py

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,13 @@ def forward(self, x):
149149
plt.tight_layout()
150150

151151
###############################################################################
152-
# And finally we do the same with a batch of 3 training samples.
152+
# And we can do the same with a batch of 3 training samples. Note that under
153+
# the hood, this effectively calls the matrix-matrix version of the forward
154+
# and adjoint operator (i.e., `matmat` and `rmatmat`); for operators that do
155+
# not implement these methods directly, this is simply implemented by calling
156+
# the matrix-vector of the forward and adjoint operator (i.e., `matvec` and
157+
# `rmatvec`)multiple times, which is less efficient.
158+
153159
net = Network(4)
154160
Cop = pylops.TorchOperator(pylops.Smoothing2D((5, 5), dims=(32, 32)), batch=True)
155161

@@ -169,3 +175,53 @@ def forward(self, x):
169175
axs[1].set_title("Gradient")
170176
axs[1].axis("tight")
171177
plt.tight_layout()
178+
179+
###############################################################################
180+
# Finally, whilst :class:`pylops.TorchOperator` is designed such that
181+
# when a PyLops linear operator is inserted into a Torch graph, the backward
182+
# pass will automatically call the adjoint of the operator, it is also possible to
183+
# explicitly call the forward and adjoint of the operator in the forward pass of
184+
# an AD chain. This can be useful in some scenarios, for example in the
185+
# implementation of so-called unrolled networks. In this case, we can simply
186+
# use the ``forward`` and ``adjoint`` methods of the :class:`pylops.TorchOperator`
187+
# class; Torch's AD will instead call the two methods swapped, namely ``adjoint``
188+
# and ``forward``.
189+
#
190+
# Let's consider the following example:
191+
#
192+
# .. math::
193+
# \mathbf{y}=\textbf{A}^H (\textbf{A} \mathbf{x} - \mathbf{d})
194+
#
195+
# whose Jacobian is given by:
196+
#
197+
# .. math::
198+
# \mathbf{J}=-\textbf{A}^H \textbf{A}
199+
#
200+
# Let's once again verify that the result of the product between
201+
# the transposed Jacobian and a vector :math:`\mathbf{v}` matches
202+
# with the analytical one.
203+
204+
nx, ny = 10, 6
205+
xt0 = torch.arange(nx, dtype=torch.double, requires_grad=True)
206+
x0 = xt0.detach().numpy()
207+
yt0 = -2 * torch.arange(ny, dtype=torch.double)
208+
y0 = xt0.detach().numpy()
209+
210+
# Forward
211+
A = np.random.normal(0.0, 1.0, (ny, nx))
212+
At = torch.from_numpy(A)
213+
Atop = pylops.TorchOperator(pylops.MatrixMult(A))
214+
yt = Atop.adjoint(yt0 - Atop.forward(xt0))
215+
216+
# AD
217+
v = torch.ones(nx, dtype=torch.double)
218+
yt.backward(v, retain_graph=True)
219+
adgrad = xt0.grad
220+
221+
# Analytical
222+
JT = -At.T @ At
223+
anagrad = torch.matmul(JT, v)
224+
225+
print("Input: ", x0)
226+
print("AD gradient: ", adgrad)
227+
print("Analytical gradient: ", anagrad)

0 commit comments

Comments
 (0)