@@ -278,6 +278,7 @@ def forward(self, x):
278278 :rtype: torch.Tensor
279279 """
280280 basis = self .basis (x .as_subclass (torch .Tensor ))
281+ return basis @ self .control_points
281282 return torch .einsum (
282283 "...bi, ...i -> ...b" ,
283284 basis ,
@@ -473,7 +474,171 @@ def knots(self, value):
473474
474475 # Recompute boundary interval when knots change
475476 if hasattr (self , "_boundary_interval_idx" ):
476- self ._boundary_interval_idx = self ._compute_boundary_interval ()
477+ self ._boundary_interval_Widx = self ._compute_boundary_interval ()
477478
478479 # Recompute derivative denominators when knots change
479- self ._compute_derivative_denominators ()
480+ self ._compute_derivative_denominators ()
481+
482+
483+ import torch
484+ import torch .nn as nn
485+
486+ class SplineVectorized (nn .Module ):
487+ """
488+ Vectorized univariate B-spline model (shared knots, many splines).
489+
490+ Notation:
491+ - knots: shape (m,)
492+ - order: k (degree = k-1)
493+ - n_ctrl = m - k
494+ - control_points:
495+ * (S, n_ctrl) -> S splines, scalar output each
496+ * (S, O, n_ctrl) -> S splines, O outputs each (like multiple channels)
497+ Input:
498+ - x: shape (...,) or (..., B)
499+ Output:
500+ - if control_points is (S, n_ctrl): shape (..., S)
501+ - if control_points is (S, O, n_ctrl): shape (..., S, O)
502+ """
503+
504+ def __init__ (self , order : int , knots : torch .Tensor , control_points : torch .Tensor | None = None ):
505+ super ().__init__ ()
506+ if not isinstance (order , int ) or order <= 0 :
507+ raise ValueError ("order must be a positive integer." )
508+ if not isinstance (knots , torch .Tensor ):
509+ raise ValueError ("knots must be a torch.Tensor." )
510+ if knots .ndim != 1 :
511+ raise ValueError ("knots must be 1D." )
512+
513+ self .order = order
514+
515+ # store sorted knots as buffer
516+ knots_sorted = knots .sort ().values
517+ self .register_buffer ("knots" , knots_sorted )
518+
519+ n_ctrl = knots_sorted .numel () - order
520+ if n_ctrl <= 0 :
521+ raise ValueError (f"Need #knots > order. Got #knots={ knots_sorted .numel ()} order={ order } ." )
522+
523+ # boundary interval idx for rightmost inclusion
524+ self ._boundary_interval_idx = self ._compute_boundary_interval_idx (knots_sorted )
525+
526+ # # control points init
527+ # if control_points is None:
528+ # # default: one spline
529+ # cp = torch.zeros(1, n_ctrl, dtype=knots_sorted.dtype, device=knots_sorted.device)
530+ # self.control_points = nn.Parameter(cp, requires_grad=True)
531+ # else:
532+ # if not isinstance(control_points, torch.Tensor):
533+ # raise ValueError("control_points must be a torch.Tensor or None.")
534+ # if control_points.ndim not in (2, 3):
535+ # raise ValueError("control_points must have shape (S, n_ctrl) or (S, O, n_ctrl).")
536+ # if control_points.shape[-1] != n_ctrl:
537+ # raise ValueError(
538+ # f"Last dim of control_points must be n_ctrl={n_ctrl}. Got {control_points.shape[-1]}."
539+ # )
540+ self .control_points = nn .Parameter (control_points , requires_grad = True )
541+
542+ @staticmethod
543+ def _compute_boundary_interval_idx (knots : torch .Tensor ) -> int :
544+ if knots .numel () < 2 :
545+ return 0
546+ diffs = knots [1 :] - knots [:- 1 ]
547+ valid = torch .nonzero (diffs > 0 , as_tuple = False )
548+ if valid .numel () == 0 :
549+ return 0
550+ return int (valid [- 1 ])
551+
552+ def basis (self , x : torch .Tensor ) -> torch .Tensor :
553+ """
554+ Compute B-spline basis functions of order self.order at x.
555+
556+ Returns:
557+ basis: shape (..., n_ctrl)
558+ """
559+ if not isinstance (x , torch .Tensor ):
560+ x = torch .as_tensor (x )
561+
562+ # ensure float dtype consistent
563+ x = x .to (dtype = self .knots .dtype , device = self .knots .device )
564+
565+ # make x shape (..., 1) for broadcasting
566+ x_exp = x .unsqueeze (- 1 ) # (..., 1)
567+
568+ # knots as (1, ..., 1, m) via unsqueeze to broadcast
569+ # (m,) -> (1,)*x.ndim + (m,)
570+ knots = self .knots .view (* ([1 ] * x .ndim ), - 1 )
571+
572+ # order-1 base: indicator on intervals [t_i, t_{i+1})
573+ basis = ((x_exp >= knots [..., :- 1 ]) & (x_exp < knots [..., 1 :])).to (x_exp .dtype ) # (..., m-1)
574+
575+ # include rightmost boundary in the last non-degenerate interval
576+ j = self ._boundary_interval_idx
577+ knot_left = knots [..., j ]
578+ knot_right = knots [..., j + 1 ]
579+ at_right = (x >= knot_left .squeeze (- 1 )) & torch .isclose (x , knot_right .squeeze (- 1 ), rtol = 1e-8 , atol = 1e-10 )
580+ if torch .any (at_right ):
581+ basis_j = basis [..., j ].bool () | at_right
582+ basis [..., j ] = basis_j .to (basis .dtype )
583+
584+ # Cox-de Boor recursion up to order k
585+ # after i-th iteration, basis has length (m-1 - i)
586+ for i in range (1 , self .order ):
587+ denom1 = knots [..., i :- 1 ] - knots [..., :- (i + 1 )]
588+ denom2 = knots [..., i + 1 :] - knots [..., 1 :- i ]
589+
590+ denom1 = torch .where (denom1 .abs () < 1e-8 , torch .ones_like (denom1 ), denom1 )
591+ denom2 = torch .where (denom2 .abs () < 1e-8 , torch .ones_like (denom2 ), denom2 )
592+
593+ term1 = ((x_exp - knots [..., :- (i + 1 )]) / denom1 ) * basis [..., :- 1 ]
594+ term2 = ((knots [..., i + 1 :] - x_exp ) / denom2 ) * basis [..., 1 :]
595+ basis = term1 + term2
596+
597+ # final basis length is n_ctrl = m - order
598+ return basis # (..., n_ctrl)
599+
600+ def forward (self , x : torch .Tensor ) -> torch .Tensor :
601+ """
602+ Evaluate spline(s) at x.
603+
604+ If control_points is (S, n_ctrl): output (..., S)
605+ If control_points is (S, O, n_ctrl): output (..., S, O)
606+ """
607+ B = self .basis (x ) # (..., n_ctrl)
608+
609+ cp = self .control_points
610+ if cp .ndim == 2 :
611+ # (S, n_ctrl)
612+ # want (..., S) = (..., n_ctrl) @ (n_ctrl, S)
613+ out = B @ cp .transpose (0 , 1 )
614+ return out
615+ else :
616+ # (S, O, n_ctrl)
617+ # Compute for each S: (..., n_ctrl) @ (n_ctrl, O) -> (..., O), then stack over S
618+ # vectorized using einsum (yes, this one is actually appropriate)
619+ # (..., n) * (S, O, n) -> (..., S, O)
620+ # out = torch.einsum("...n, son -> ...so", B, cp)
621+ out = torch .einsum ("bsc,sco->bso" , B , cp )
622+
623+ return out
624+
625+ def forward_basis (self , basis ):
626+ """
627+ Evaluate spline(s) given precomputed basis.
628+
629+ """
630+ cp = self .control_points
631+ if cp .ndim == 2 :
632+ # (S, n_ctrl)
633+ # want (..., S) = (..., n_ctrl) @ (n_ctrl, S)
634+ out = basis @ cp .transpose (0 , 1 )
635+ return out
636+ else :
637+ # (S, O, n_ctrl)
638+ # Compute for each S: (..., n_ctrl) @ (n_ctrl, O) -> (..., O), then stack over S
639+ # vectorized using einsum (yes, this one is actually appropriate)
640+ # (..., n) * (S, O, n) -> (..., S, O)
641+ # out = torch.einsum("...n, son -> ...so", B, cp)
642+ out = torch .einsum ("bsc,sco->bso" , basis , cp )
643+
644+ return out
0 commit comments