Skip to content

Commit 6eb49fb

Browse files
implement derivatives for vector splines
1 parent 5bd5902 commit 6eb49fb

File tree

3 files changed

+184
-11
lines changed

3 files changed

+184
-11
lines changed

pina/_src/model/spline.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -277,9 +277,11 @@ def forward(self, x):
277277
:return: The output tensor.
278278
:rtype: torch.Tensor
279279
"""
280-
basis = self.basis(x.as_subclass(torch.Tensor))
281-
282-
return basis @ self.control_points
280+
return torch.einsum(
281+
"...bi, i -> ...b",
282+
self.basis(x.as_subclass(torch.Tensor)).squeeze(-1),
283+
self.control_points,
284+
)
283285

284286
def derivative(self, x, degree):
285287
"""

pina/_src/model/vectorized_spline.py

Lines changed: 151 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,18 @@ class VectorizedSpline(torch.nn.Module):
5555
5656
This class does not represent a single multivariate spline
5757
:math:`\mathbb{R}^s \to \mathbb{R}^o` with a genuinely multivariate
58-
basis. Instead, it represents a vector spline built from ``s``
58+
basis. Instead, it represents a vector of splines built from ``s``
5959
independent univariate splines, one for each input feature.
6060
61+
.. note::
62+
63+
When using the :meth:`derivative` method of this class, derivatives are
64+
computed directly in vectorized form and returned with the correct
65+
shape. In contrast, when relying on ``autograd``, derivatives must be
66+
computed separately for each output dimension of each univariate spline
67+
and then combined, since autograd does not natively handle this
68+
vectorized structure.
69+
6170
:Example:
6271
6372
>>> from pina.model import VectorizedSpline
@@ -133,7 +142,8 @@ def __init__(
133142
:raises ValueError: If ``control_points`` is not a torch.Tensor,
134143
when provided.
135144
:raises ValueError: If both ``knots`` and ``control_points`` are None.
136-
:raises ValueError: If ``knots`` is not two-dimensional.
145+
:raises ValueError: If ``knots`` is not two-dimensional, after
146+
processing.
137147
:raises ValueError: If ``control_points``, after expansion when
138148
two-dimensional, is not three-dimensional.
139149
:raises ValueError: If, for each univariate spline, the number of
@@ -180,10 +190,6 @@ def __init__(
180190
self.control_points = control_points
181191
self.aggregate_output = aggregate_output
182192

183-
# Check dimensionality of knots
184-
if self.knots.ndim != 2:
185-
raise ValueError("knots must be two-dimensional.")
186-
187193
# Check dimensionality of control points
188194
if self.control_points.ndim != 3:
189195
raise ValueError("control_points must be three-dimensional.")
@@ -218,6 +224,9 @@ def __init__(
218224
# Precompute boundary interval index
219225
self._boundary_interval_idx = self._compute_boundary_interval()
220226

227+
# Precompute denominators used in derivative formulas
228+
self._compute_derivative_denominators()
229+
221230
def _compute_boundary_interval(self):
222231
"""
223232
Precompute the index of the rightmost non-degenerate interval to improve
@@ -243,21 +252,55 @@ def _compute_boundary_interval(self):
243252
idx[s] = valid_s[-1, 0] if valid_s.numel() > 0 else 0
244253

245254
return idx
255+
256+
def _compute_derivative_denominators(self):
257+
"""
258+
Precompute the denominators used in the derivatives for all orders up to
259+
the spline order to avoid redundant calculations.
260+
"""
261+
# Precompute for order 2 to k
262+
for i in range(2, self.order + 1):
263+
264+
# Denominators for the derivative recurrence relations
265+
left_den = self.knots[:, i - 1 : -1] - self.knots[:, :-i]
266+
right_den = self.knots[:, i:] - self.knots[:, 1 : -i + 1]
267+
268+
# If consecutive knots are equal, set left and right factors to zero
269+
left_fac = torch.where(
270+
torch.abs(left_den) > 1e-10,
271+
(i - 1) / left_den,
272+
torch.zeros_like(left_den),
273+
)
274+
right_fac = torch.where(
275+
torch.abs(right_den) > 1e-10,
276+
(i - 1) / right_den,
277+
torch.zeros_like(right_den),
278+
)
246279

247-
def basis(self, x):
280+
# Register buffers
281+
self.register_buffer(f"_left_factor_order_{i}", left_fac)
282+
self.register_buffer(f"_right_factor_order_{i}", right_fac)
283+
284+
def basis(self, x, collection=False):
248285
"""
249286
Evaluate the B-spline basis functions for each univariate spline.
250287
251288
This method applies the Cox-de Boor recursion in vectorized form across
252289
all univariate splines of the vector spline.
253290
254291
:param torch.Tensor x: The points to be evaluated.
292+
:param bool collection: If True, returns a list of basis functions for
293+
all orders up to the spline order. Default is False.
294+
:raise ValueError: If ``collection`` is not a boolean.
255295
:raises ValueError: If ``x`` is not two-dimensional.
256296
:raises ValueError: If the number of input features does not match
257297
the number of univariate splines.
258298
:return: The basis functions evaluated at x.
259299
:rtype: torch.Tensor
260300
"""
301+
# Check consistency
302+
check_consistency(collection, bool)
303+
261304
# Ensure x is a tensor of the same dtype as knots
262305
x = x.as_subclass(torch.Tensor).to(dtype=self.knots.dtype)
263306

@@ -300,6 +343,10 @@ def basis(self, x):
300343
b_idx, s_idx = torch.nonzero(at_rightmost_boundary, as_tuple=True)
301344
basis[b_idx, s_idx, self._boundary_interval_idx[s_idx]] = 1.0
302345

346+
# If returning the whole collection, initialize list
347+
if collection:
348+
basis_collection = [None, basis]
349+
303350
# Cox-de Boor recursion -- iterative case
304351
for i in range(1, self.order):
305352

@@ -322,7 +369,10 @@ def basis(self, x):
322369
# Combine terms to get the new basis
323370
basis = term1 + term2
324371

325-
return basis
372+
if collection:
373+
basis_collection.append(basis)
374+
375+
return basis_collection if collection else basis
326376

327377
def forward(self, x):
328378
"""
@@ -358,6 +408,91 @@ def forward(self, x):
358408
out = out.squeeze(-1)
359409

360410
return out
411+
412+
def derivative(self, x, degree):
413+
"""
414+
Compute the ``degree``-th derivative of each univariate spline at the
415+
given input points.
416+
417+
The output has shape ``[batch, s, o]``, where ``o`` is the output
418+
dimension of each univariate spline, unless an aggregation method is
419+
specified. If both ``s`` and ``o`` are 1, the output is aggregated
420+
across the last dimension, resulting in an output of shape
421+
``[batch, s]``. If ``aggregate_output`` is set to ``"mean"`` or
422+
``"sum"``, the output is aggregated across the last dimension, resulting
423+
in an output of shape ``[batch, s]``.
424+
425+
:param x: The input tensor.
426+
:type x: torch.Tensor | LabelTensor
427+
:param int degree: The derivative degree to compute.
428+
:return: The derivative tensor.
429+
:rtype: torch.Tensor
430+
"""
431+
# Check consistency
432+
check_positive_integer(degree, strict=False)
433+
434+
# Compute basis derivative
435+
der = self._basis_derivative(x.as_subclass(torch.Tensor), degree=degree)
436+
437+
# Compute the output for each spline
438+
out = torch.einsum("bsc,soc->bso", der, self.control_points)
439+
440+
# Aggregate output if needed
441+
if self.aggregate_output == "mean":
442+
out = out.mean(dim=-1)
443+
elif self.aggregate_output == "sum":
444+
out = out.sum(dim=-1)
445+
elif out.shape[1] == 1 and out.shape[2] == 1:
446+
out = out.squeeze(-1)
447+
448+
return out
449+
450+
def _basis_derivative(self, x, degree):
451+
"""
452+
Compute the ``degree``-th derivative of the vectorized spline basis
453+
functions at the given input points using an iterative approach.
454+
455+
:param torch.Tensor x: The points to be evaluated.
456+
:param int degree: The derivative degree to compute.
457+
:return: The derivative of the basis functions of order ``self.order``.
458+
:rtype: torch.Tensor
459+
"""
460+
# Compute the whole basis collection
461+
basis = self.basis(x, collection=True)
462+
463+
# Derivatives initialization (dummy at index 0 for convenience)
464+
derivatives = [None] + [basis[o] for o in range(1, self.order + 1)]
465+
466+
# Iterate over derivative degrees
467+
for _ in range(1, degree + 1):
468+
469+
# Current degree derivatives (with dummy at index 0 for convenience)
470+
current_der = [None] * (self.order + 1)
471+
current_der[1] = torch.zeros_like(derivatives[1])
472+
473+
# Iterate over basis orders
474+
for o in range(2, self.order + 1):
475+
476+
# Retrieve precomputed factors
477+
left_fac = getattr(self, f"_left_factor_order_{o}")
478+
right_fac = getattr(self, f"_right_factor_order_{o}")
479+
480+
# derivatives[o - 1] has shape [b, s, m]
481+
# Slice previous derivatives to align
482+
left_part = derivatives[o - 1][..., :-1]
483+
right_part = derivatives[o - 1][..., 1:]
484+
485+
# Broadcast factors over batch dims
486+
left_fac = left_fac.unsqueeze(0)
487+
right_fac = right_fac.unsqueeze(0)
488+
489+
# Compute current derivatives
490+
current_der[o] = left_fac * left_part - right_fac * right_part
491+
492+
# Update derivatives for next degree
493+
derivatives = current_der
494+
495+
return derivatives[self.order]
361496

362497
@property
363498
def control_points(self):
@@ -440,6 +575,7 @@ def knots(self, value):
440575
:raises ValueError: If a dictionary is provided but does not contain
441576
the required keys.
442577
:raises ValueError: If the mode specified in the dictionary is invalid.
578+
:raises ValueError: If knots is not two-dimensional after processing.
443579
"""
444580
# If a dictionary is provided, initialize knots accordingly
445581
if isinstance(value, dict):
@@ -498,6 +634,13 @@ def knots(self, value):
498634
# Set knots
499635
self.register_buffer("_knots", value.sort(dim=-1).values)
500636

637+
# Check dimensionality of knots
638+
if self.knots.ndim != 2:
639+
raise ValueError("knots must be two-dimensional.")
640+
501641
# Recompute boundary interval when knots change
502642
if hasattr(self, "_boundary_interval_idx"):
503643
self._boundary_interval_idx = self._compute_boundary_interval()
644+
645+
# Recompute derivative denominators when knots change
646+
self._compute_derivative_denominators()

tests/test_model/test_vectorized_spline.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import torch
22
import pytest
33
from pina.model import VectorizedSpline, Spline
4+
from pina.operator import grad
45
from pina import LabelTensor
56

67

@@ -227,6 +228,33 @@ def test_backward(args):
227228
assert model.control_points.grad.shape == model.control_points.shape
228229

229230

231+
@pytest.mark.parametrize("args", valid_args)
232+
def test_derivative(args):
233+
234+
# Define and evaluate the model
235+
model = VectorizedSpline(**args)
236+
pts.requires_grad_(True)
237+
output_ = model(pts)
238+
239+
# Compute analytical derivatives
240+
first_der = model.derivative(x=pts, degree=1)
241+
242+
# Compute autograd derivatives -- we need to loop over output dimensions
243+
# since autograd doesn't support vectorized outputs
244+
gradients = []
245+
for j in range(output_.shape[2]):
246+
out = output_[:, :, j].squeeze(-1)
247+
out = LabelTensor(out, [f"u{j}" for j in range(out.shape[1])])
248+
gradients.append(
249+
grad(out, pts)[[f"du{j}dx{j}" for j in range(pts.shape[1])]]
250+
)
251+
first_der_auto = torch.stack(gradients, dim=-1)
252+
253+
# Check shape and value
254+
assert first_der.shape == first_der_auto.shape
255+
assert torch.allclose(first_der, first_der_auto, atol=1e-4, rtol=1e-4)
256+
257+
230258
def test_1d_vs_vectorized():
231259

232260
control_points = torch.rand(1, 1, n_ctrl_pts)

0 commit comments

Comments
 (0)