@@ -117,7 +117,7 @@ def __init__(self, order=4, knots=None, control_points=None):
117117 raise ValueError ("knots must be one-dimensional." )
118118
119119 # Check dimensionality of control points
120- if self .control_points .ndim > 2 :
120+ if self .control_points .ndim > 1 :
121121 raise ValueError ("control_points must be one-dimensional." )
122122
123123 # Raise error if #knots != order + #control_points
@@ -279,11 +279,6 @@ def forward(self, x):
279279 """
280280 basis = self .basis (x .as_subclass (torch .Tensor ))
281281 return basis @ self .control_points
282- return torch .einsum (
283- "...bi, ...i -> ...b" ,
284- basis ,
285- self .control_points ,
286- )
287282
288283 def derivative (self , x , degree ):
289284 """
@@ -474,171 +469,7 @@ def knots(self, value):
474469
475470 # Recompute boundary interval when knots change
476471 if hasattr (self , "_boundary_interval_idx" ):
477- self ._boundary_interval_Widx = self ._compute_boundary_interval ()
472+ self ._boundary_interval_idx = self ._compute_boundary_interval ()
478473
479474 # Recompute derivative denominators when knots change
480475 self ._compute_derivative_denominators ()
481-
482-
483- import torch
484- import torch .nn as nn
485-
486- class SplineVectorized (nn .Module ):
487- """
488- Vectorized univariate B-spline model (shared knots, many splines).
489-
490- Notation:
491- - knots: shape (m,)
492- - order: k (degree = k-1)
493- - n_ctrl = m - k
494- - control_points:
495- * (S, n_ctrl) -> S splines, scalar output each
496- * (S, O, n_ctrl) -> S splines, O outputs each (like multiple channels)
497- Input:
498- - x: shape (...,) or (..., B)
499- Output:
500- - if control_points is (S, n_ctrl): shape (..., S)
501- - if control_points is (S, O, n_ctrl): shape (..., S, O)
502- """
503-
504- def __init__ (self , order : int , knots : torch .Tensor , control_points : torch .Tensor | None = None ):
505- super ().__init__ ()
506- if not isinstance (order , int ) or order <= 0 :
507- raise ValueError ("order must be a positive integer." )
508- if not isinstance (knots , torch .Tensor ):
509- raise ValueError ("knots must be a torch.Tensor." )
510- if knots .ndim != 1 :
511- raise ValueError ("knots must be 1D." )
512-
513- self .order = order
514-
515- # store sorted knots as buffer
516- knots_sorted = knots .sort ().values
517- self .register_buffer ("knots" , knots_sorted )
518-
519- n_ctrl = knots_sorted .numel () - order
520- if n_ctrl <= 0 :
521- raise ValueError (f"Need #knots > order. Got #knots={ knots_sorted .numel ()} order={ order } ." )
522-
523- # boundary interval idx for rightmost inclusion
524- self ._boundary_interval_idx = self ._compute_boundary_interval_idx (knots_sorted )
525-
526- # # control points init
527- # if control_points is None:
528- # # default: one spline
529- # cp = torch.zeros(1, n_ctrl, dtype=knots_sorted.dtype, device=knots_sorted.device)
530- # self.control_points = nn.Parameter(cp, requires_grad=True)
531- # else:
532- # if not isinstance(control_points, torch.Tensor):
533- # raise ValueError("control_points must be a torch.Tensor or None.")
534- # if control_points.ndim not in (2, 3):
535- # raise ValueError("control_points must have shape (S, n_ctrl) or (S, O, n_ctrl).")
536- # if control_points.shape[-1] != n_ctrl:
537- # raise ValueError(
538- # f"Last dim of control_points must be n_ctrl={n_ctrl}. Got {control_points.shape[-1]}."
539- # )
540- self .control_points = nn .Parameter (control_points , requires_grad = True )
541-
542- @staticmethod
543- def _compute_boundary_interval_idx (knots : torch .Tensor ) -> int :
544- if knots .numel () < 2 :
545- return 0
546- diffs = knots [1 :] - knots [:- 1 ]
547- valid = torch .nonzero (diffs > 0 , as_tuple = False )
548- if valid .numel () == 0 :
549- return 0
550- return int (valid [- 1 ])
551-
552- def basis (self , x : torch .Tensor ) -> torch .Tensor :
553- """
554- Compute B-spline basis functions of order self.order at x.
555-
556- Returns:
557- basis: shape (..., n_ctrl)
558- """
559- if not isinstance (x , torch .Tensor ):
560- x = torch .as_tensor (x )
561-
562- # ensure float dtype consistent
563- x = x .to (dtype = self .knots .dtype , device = self .knots .device )
564-
565- # make x shape (..., 1) for broadcasting
566- x_exp = x .unsqueeze (- 1 ) # (..., 1)
567-
568- # knots as (1, ..., 1, m) via unsqueeze to broadcast
569- # (m,) -> (1,)*x.ndim + (m,)
570- knots = self .knots .view (* ([1 ] * x .ndim ), - 1 )
571-
572- # order-1 base: indicator on intervals [t_i, t_{i+1})
573- basis = ((x_exp >= knots [..., :- 1 ]) & (x_exp < knots [..., 1 :])).to (x_exp .dtype ) # (..., m-1)
574-
575- # include rightmost boundary in the last non-degenerate interval
576- j = self ._boundary_interval_idx
577- knot_left = knots [..., j ]
578- knot_right = knots [..., j + 1 ]
579- at_right = (x >= knot_left .squeeze (- 1 )) & torch .isclose (x , knot_right .squeeze (- 1 ), rtol = 1e-8 , atol = 1e-10 )
580- if torch .any (at_right ):
581- basis_j = basis [..., j ].bool () | at_right
582- basis [..., j ] = basis_j .to (basis .dtype )
583-
584- # Cox-de Boor recursion up to order k
585- # after i-th iteration, basis has length (m-1 - i)
586- for i in range (1 , self .order ):
587- denom1 = knots [..., i :- 1 ] - knots [..., :- (i + 1 )]
588- denom2 = knots [..., i + 1 :] - knots [..., 1 :- i ]
589-
590- denom1 = torch .where (denom1 .abs () < 1e-8 , torch .ones_like (denom1 ), denom1 )
591- denom2 = torch .where (denom2 .abs () < 1e-8 , torch .ones_like (denom2 ), denom2 )
592-
593- term1 = ((x_exp - knots [..., :- (i + 1 )]) / denom1 ) * basis [..., :- 1 ]
594- term2 = ((knots [..., i + 1 :] - x_exp ) / denom2 ) * basis [..., 1 :]
595- basis = term1 + term2
596-
597- # final basis length is n_ctrl = m - order
598- return basis # (..., n_ctrl)
599-
600- def forward (self , x : torch .Tensor ) -> torch .Tensor :
601- """
602- Evaluate spline(s) at x.
603-
604- If control_points is (S, n_ctrl): output (..., S)
605- If control_points is (S, O, n_ctrl): output (..., S, O)
606- """
607- B = self .basis (x ) # (..., n_ctrl)
608-
609- cp = self .control_points
610- if cp .ndim == 2 :
611- # (S, n_ctrl)
612- # want (..., S) = (..., n_ctrl) @ (n_ctrl, S)
613- out = B @ cp .transpose (0 , 1 )
614- return out
615- else :
616- # (S, O, n_ctrl)
617- # Compute for each S: (..., n_ctrl) @ (n_ctrl, O) -> (..., O), then stack over S
618- # vectorized using einsum (yes, this one is actually appropriate)
619- # (..., n) * (S, O, n) -> (..., S, O)
620- # out = torch.einsum("...n, son -> ...so", B, cp)
621- out = torch .einsum ("bsc,sco->bso" , B , cp )
622-
623- return out
624-
625- def forward_basis (self , basis ):
626- """
627- Evaluate spline(s) given precomputed basis.
628-
629- """
630- cp = self .control_points
631- if cp .ndim == 2 :
632- # (S, n_ctrl)
633- # want (..., S) = (..., n_ctrl) @ (n_ctrl, S)
634- out = basis @ cp .transpose (0 , 1 )
635- return out
636- else :
637- # (S, O, n_ctrl)
638- # Compute for each S: (..., n_ctrl) @ (n_ctrl, O) -> (..., O), then stack over S
639- # vectorized using einsum (yes, this one is actually appropriate)
640- # (..., n) * (S, O, n) -> (..., S, O)
641- # out = torch.einsum("...n, son -> ...so", B, cp)
642- out = torch .einsum ("bsc,sco->bso" , basis , cp )
643-
644- return out
0 commit comments