@@ -55,9 +55,18 @@ class VectorizedSpline(torch.nn.Module):
5555
5656 This class does not represent a single multivariate spline
5757 :math:`\mathbb{R}^s \to \mathbb{R}^o` with a genuinely multivariate
58- basis. Instead, it represents a vector spline built from ``s``
58+ basis. Instead, it represents a vector of splines built from ``s``
5959 independent univariate splines, one for each input feature.
6060
61+ .. note::
62+
63+ When using the :meth:`derivative` method of this class, derivatives are
64+ computed directly in vectorized form and returned with the correct
65+ shape. In contrast, when relying on ``autograd``, derivatives must be
66+ computed separately for each output dimension of each univariate spline
67+ and then combined, since autograd does not natively handle this
68+ vectorized structure.
69+
6170 :Example:
6271
6372 >>> from pina.model import VectorizedSpline
@@ -133,7 +142,8 @@ def __init__(
133142 :raises ValueError: If ``control_points`` is not a torch.Tensor,
134143 when provided.
135144 :raises ValueError: If both ``knots`` and ``control_points`` are None.
136- :raises ValueError: If ``knots`` is not two-dimensional.
145+ :raises ValueError: If ``knots`` is not two-dimensional, after
146+ processing.
137147 :raises ValueError: If ``control_points``, after expansion when
138148 two-dimensional, is not three-dimensional.
139149 :raises ValueError: If, for each univariate spline, the number of
@@ -180,10 +190,6 @@ def __init__(
180190 self .control_points = control_points
181191 self .aggregate_output = aggregate_output
182192
183- # Check dimensionality of knots
184- if self .knots .ndim != 2 :
185- raise ValueError ("knots must be two-dimensional." )
186-
187193 # Check dimensionality of control points
188194 if self .control_points .ndim != 3 :
189195 raise ValueError ("control_points must be three-dimensional." )
@@ -218,6 +224,9 @@ def __init__(
218224 # Precompute boundary interval index
219225 self ._boundary_interval_idx = self ._compute_boundary_interval ()
220226
227+ # Precompute denominators used in derivative formulas
228+ self ._compute_derivative_denominators ()
229+
221230 def _compute_boundary_interval (self ):
222231 """
223232 Precompute the index of the rightmost non-degenerate interval to improve
@@ -243,21 +252,55 @@ def _compute_boundary_interval(self):
243252 idx [s ] = valid_s [- 1 , 0 ] if valid_s .numel () > 0 else 0
244253
245254 return idx
255+
256+ def _compute_derivative_denominators (self ):
257+ """
258+ Precompute the denominators used in the derivatives for all orders up to
259+ the spline order to avoid redundant calculations.
260+ """
261+ # Precompute for order 2 to k
262+ for i in range (2 , self .order + 1 ):
263+
264+ # Denominators for the derivative recurrence relations
265+ left_den = self .knots [:, i - 1 : - 1 ] - self .knots [:, :- i ]
266+ right_den = self .knots [:, i :] - self .knots [:, 1 : - i + 1 ]
267+
268+ # If consecutive knots are equal, set left and right factors to zero
269+ left_fac = torch .where (
270+ torch .abs (left_den ) > 1e-10 ,
271+ (i - 1 ) / left_den ,
272+ torch .zeros_like (left_den ),
273+ )
274+ right_fac = torch .where (
275+ torch .abs (right_den ) > 1e-10 ,
276+ (i - 1 ) / right_den ,
277+ torch .zeros_like (right_den ),
278+ )
246279
247- def basis (self , x ):
280+ # Register buffers
281+ self .register_buffer (f"_left_factor_order_{ i } " , left_fac )
282+ self .register_buffer (f"_right_factor_order_{ i } " , right_fac )
283+
284+ def basis (self , x , collection = False ):
248285 """
249286 Evaluate the B-spline basis functions for each univariate spline.
250287
251288 This method applies the Cox-de Boor recursion in vectorized form across
252289 all univariate splines of the vector spline.
253290
254291 :param torch.Tensor x: The points to be evaluated.
292+ :param bool collection: If True, returns a list of basis functions for
293+ all orders up to the spline order. Default is False.
294+ :raise ValueError: If ``collection`` is not a boolean.
255295 :raises ValueError: If ``x`` is not two-dimensional.
256296 :raises ValueError: If the number of input features does not match
257297 the number of univariate splines.
258298 :return: The basis functions evaluated at x.
259299 :rtype: torch.Tensor
260300 """
301+ # Check consistency
302+ check_consistency (collection , bool )
303+
261304 # Ensure x is a tensor of the same dtype as knots
262305 x = x .as_subclass (torch .Tensor ).to (dtype = self .knots .dtype )
263306
@@ -300,6 +343,10 @@ def basis(self, x):
300343 b_idx , s_idx = torch .nonzero (at_rightmost_boundary , as_tuple = True )
301344 basis [b_idx , s_idx , self ._boundary_interval_idx [s_idx ]] = 1.0
302345
346+ # If returning the whole collection, initialize list
347+ if collection :
348+ basis_collection = [None , basis ]
349+
303350 # Cox-de Boor recursion -- iterative case
304351 for i in range (1 , self .order ):
305352
@@ -322,7 +369,10 @@ def basis(self, x):
322369 # Combine terms to get the new basis
323370 basis = term1 + term2
324371
325- return basis
372+ if collection :
373+ basis_collection .append (basis )
374+
375+ return basis_collection if collection else basis
326376
327377 def forward (self , x ):
328378 """
@@ -358,6 +408,91 @@ def forward(self, x):
358408 out = out .squeeze (- 1 )
359409
360410 return out
411+
412+ def derivative (self , x , degree ):
413+ """
414+ Compute the ``degree``-th derivative of each univariate spline at the
415+ given input points.
416+
417+ The output has shape ``[batch, s, o]``, where ``o`` is the output
418+ dimension of each univariate spline, unless an aggregation method is
419+ specified. If both ``s`` and ``o`` are 1, the output is aggregated
420+ across the last dimension, resulting in an output of shape
421+ ``[batch, s]``. If ``aggregate_output`` is set to ``"mean"`` or
422+ ``"sum"``, the output is aggregated across the last dimension, resulting
423+ in an output of shape ``[batch, s]``.
424+
425+ :param x: The input tensor.
426+ :type x: torch.Tensor | LabelTensor
427+ :param int degree: The derivative degree to compute.
428+ :return: The derivative tensor.
429+ :rtype: torch.Tensor
430+ """
431+ # Check consistency
432+ check_positive_integer (degree , strict = False )
433+
434+ # Compute basis derivative
435+ der = self ._basis_derivative (x .as_subclass (torch .Tensor ), degree = degree )
436+
437+ # Compute the output for each spline
438+ out = torch .einsum ("bsc,soc->bso" , der , self .control_points )
439+
440+ # Aggregate output if needed
441+ if self .aggregate_output == "mean" :
442+ out = out .mean (dim = - 1 )
443+ elif self .aggregate_output == "sum" :
444+ out = out .sum (dim = - 1 )
445+ elif out .shape [1 ] == 1 and out .shape [2 ] == 1 :
446+ out = out .squeeze (- 1 )
447+
448+ return out
449+
450+ def _basis_derivative (self , x , degree ):
451+ """
452+ Compute the ``degree``-th derivative of the vectorized spline basis
453+ functions at the given input points using an iterative approach.
454+
455+ :param torch.Tensor x: The points to be evaluated.
456+ :param int degree: The derivative degree to compute.
457+ :return: The derivative of the basis functions of order ``self.order``.
458+ :rtype: torch.Tensor
459+ """
460+ # Compute the whole basis collection
461+ basis = self .basis (x , collection = True )
462+
463+ # Derivatives initialization (dummy at index 0 for convenience)
464+ derivatives = [None ] + [basis [o ] for o in range (1 , self .order + 1 )]
465+
466+ # Iterate over derivative degrees
467+ for _ in range (1 , degree + 1 ):
468+
469+ # Current degree derivatives (with dummy at index 0 for convenience)
470+ current_der = [None ] * (self .order + 1 )
471+ current_der [1 ] = torch .zeros_like (derivatives [1 ])
472+
473+ # Iterate over basis orders
474+ for o in range (2 , self .order + 1 ):
475+
476+ # Retrieve precomputed factors
477+ left_fac = getattr (self , f"_left_factor_order_{ o } " )
478+ right_fac = getattr (self , f"_right_factor_order_{ o } " )
479+
480+ # derivatives[o - 1] has shape [b, s, m]
481+ # Slice previous derivatives to align
482+ left_part = derivatives [o - 1 ][..., :- 1 ]
483+ right_part = derivatives [o - 1 ][..., 1 :]
484+
485+ # Broadcast factors over batch dims
486+ left_fac = left_fac .unsqueeze (0 )
487+ right_fac = right_fac .unsqueeze (0 )
488+
489+ # Compute current derivatives
490+ current_der [o ] = left_fac * left_part - right_fac * right_part
491+
492+ # Update derivatives for next degree
493+ derivatives = current_der
494+
495+ return derivatives [self .order ]
361496
362497 @property
363498 def control_points (self ):
@@ -440,6 +575,7 @@ def knots(self, value):
440575 :raises ValueError: If a dictionary is provided but does not contain
441576 the required keys.
442577 :raises ValueError: If the mode specified in the dictionary is invalid.
578+ :raises ValueError: If knots is not two-dimensional after processing.
443579 """
444580 # If a dictionary is provided, initialize knots accordingly
445581 if isinstance (value , dict ):
@@ -498,6 +634,13 @@ def knots(self, value):
498634 # Set knots
499635 self .register_buffer ("_knots" , value .sort (dim = - 1 ).values )
500636
637+ # Check dimensionality of knots
638+ if self .knots .ndim != 2 :
639+ raise ValueError ("knots must be two-dimensional." )
640+
501641 # Recompute boundary interval when knots change
502642 if hasattr (self , "_boundary_interval_idx" ):
503643 self ._boundary_interval_idx = self ._compute_boundary_interval ()
644+
645+ # Recompute derivative denominators when knots change
646+ self ._compute_derivative_denominators ()
0 commit comments