Skip to content

Commit 3ad5a68

Browse files
committed
vectorized spline
1 parent 33d3862 commit 3ad5a68

File tree

4 files changed

+254
-25
lines changed

4 files changed

+254
-25
lines changed

pina/model/block/kan_block.py

Lines changed: 48 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ def __init__(self, k: int, input_dimensions: int, output_dimensions: int, inner_
2222
self.grid_eps = grid_eps
2323
self.grid_range = grid_range
2424
self.grid_extension = grid_extension
25+
self.vec = True
26+
# self.vec = False
2527

2628
if sparse_init:
2729
self.mask = torch.nn.Parameter(self.sparse_mask(input_dimensions, output_dimensions)).requires_grad_(False)
@@ -43,19 +45,35 @@ def __init__(self, k: int, input_dimensions: int, output_dimensions: int, inner_
4345
# torch.randn(self.input_dimensions, self.output_dimensions, n_control_points) * noise_scale
4446
# )
4547
# print(control_points.shape)
46-
spline_q = []
47-
for q in range(self.output_dimensions):
48-
spline_p = []
49-
for p in range(self.input_dimensions):
50-
spline_ = Spline(
51-
order=self.k,
52-
knots=knots,
53-
control_points=torch.randn(n_control_points)
54-
)
55-
spline_p.append(spline_)
56-
spline_p = torch.nn.ModuleList(spline_p)
57-
spline_q.append(spline_p)
58-
self.spline_q = torch.nn.ModuleList(spline_q)
48+
if self.vec:
49+
from pina.model.spline import SplineVectorized as VectorizedSpline
50+
control_points = torch.randn(self.input_dimensions * self.output_dimensions, n_control_points)
51+
print('control points', control_points.shape)
52+
control_points = torch.stack([
53+
torch.randn(n_control_points)
54+
for _ in range(self.input_dimensions * self.output_dimensions)
55+
])
56+
print('control points', control_points.shape)
57+
self.spline_q = VectorizedSpline(
58+
order=self.k,
59+
knots=knots,
60+
control_points=control_points
61+
)
62+
63+
else:
64+
spline_q = []
65+
for q in range(self.output_dimensions):
66+
spline_p = []
67+
for p in range(self.input_dimensions):
68+
spline_ = Spline(
69+
order=self.k,
70+
knots=knots,
71+
control_points=torch.randn(n_control_points)
72+
)
73+
spline_p.append(spline_)
74+
spline_p = torch.nn.ModuleList(spline_p)
75+
spline_q.append(spline_p)
76+
self.spline_q = torch.nn.ModuleList(spline_q)
5977

6078

6179
# control_points = torch.nn.Parameter(
@@ -99,16 +117,24 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
99117
x_tensor = x.tensor
100118
else:
101119
x_tensor = x
120+
102121

103-
y = []
104-
for q in range(self.output_dimensions):
105-
y_q = []
106-
for p in range(self.input_dimensions):
107-
spline_out = self.spline_q[q][p].forward(x_tensor[:, p]) # (batch, input_dimensions, output_dimensions)
108-
base_out = self.base_function(x_tensor[:, p]) # (batch, input_dimensions)
109-
y_q.append(spline_out + base_out)
110-
y.append(torch.stack(y_q, dim=1).sum(dim=1))
111-
y = torch.stack(y, dim=1)
122+
if self.vec:
123+
y = self.spline_q.forward(x_tensor) # (batch, output_dimensions, input_dimensions)
124+
y = y.reshape(y.shape[0], y.shape[1], self.output_dimensions, self.input_dimensions)
125+
base_out = self.base_function(x_tensor) # (batch, input_dimensions)
126+
y = y + base_out[:, :, None, None]
127+
y = y.sum(dim=3).sum(dim=1) # sum over input dimensions
128+
else:
129+
y = []
130+
for q in range(self.output_dimensions):
131+
y_q = []
132+
for p in range(self.input_dimensions):
133+
spline_out = self.spline_q[q][p].forward(x_tensor[:, p]) # (batch, input_dimensions, output_dimensions)
134+
base_out = self.base_function(x_tensor[:, p]) # (batch, input_dimensions)
135+
y_q.append(spline_out + base_out)
136+
y.append(torch.stack(y_q, dim=1).sum(dim=1))
137+
y = torch.stack(y, dim=1)
112138

113139
return y
114140

pina/model/kolmogorov_arnold_network.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
108108

109109
for i, layer in enumerate(self.kan_layers):
110110
current = layer(current)
111-
current = torch.nn.functional.sigmoid(current)
111+
# current = torch.nn.functional.sigmoid(current)
112112

113113
if self.save_act:
114114
self.acts.append(current.detach())

pina/model/spline.py

Lines changed: 167 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,7 @@ def forward(self, x):
278278
:rtype: torch.Tensor
279279
"""
280280
basis = self.basis(x.as_subclass(torch.Tensor))
281+
return basis @ self.control_points
281282
return torch.einsum(
282283
"...bi, ...i -> ...b",
283284
basis,
@@ -473,7 +474,171 @@ def knots(self, value):
473474

474475
# Recompute boundary interval when knots change
475476
if hasattr(self, "_boundary_interval_idx"):
476-
self._boundary_interval_idx = self._compute_boundary_interval()
477+
self._boundary_interval_Widx = self._compute_boundary_interval()
477478

478479
# Recompute derivative denominators when knots change
479-
self._compute_derivative_denominators()
480+
self._compute_derivative_denominators()
481+
482+
483+
import torch
484+
import torch.nn as nn
485+
486+
class SplineVectorized(nn.Module):
487+
"""
488+
Vectorized univariate B-spline model (shared knots, many splines).
489+
490+
Notation:
491+
- knots: shape (m,)
492+
- order: k (degree = k-1)
493+
- n_ctrl = m - k
494+
- control_points:
495+
* (S, n_ctrl) -> S splines, scalar output each
496+
* (S, O, n_ctrl) -> S splines, O outputs each (like multiple channels)
497+
Input:
498+
- x: shape (...,) or (..., B)
499+
Output:
500+
- if control_points is (S, n_ctrl): shape (..., S)
501+
- if control_points is (S, O, n_ctrl): shape (..., S, O)
502+
"""
503+
504+
def __init__(self, order: int, knots: torch.Tensor, control_points: torch.Tensor | None = None):
505+
super().__init__()
506+
if not isinstance(order, int) or order <= 0:
507+
raise ValueError("order must be a positive integer.")
508+
if not isinstance(knots, torch.Tensor):
509+
raise ValueError("knots must be a torch.Tensor.")
510+
if knots.ndim != 1:
511+
raise ValueError("knots must be 1D.")
512+
513+
self.order = order
514+
515+
# store sorted knots as buffer
516+
knots_sorted = knots.sort().values
517+
self.register_buffer("knots", knots_sorted)
518+
519+
n_ctrl = knots_sorted.numel() - order
520+
if n_ctrl <= 0:
521+
raise ValueError(f"Need #knots > order. Got #knots={knots_sorted.numel()} order={order}.")
522+
523+
# boundary interval idx for rightmost inclusion
524+
self._boundary_interval_idx = self._compute_boundary_interval_idx(knots_sorted)
525+
526+
# # control points init
527+
# if control_points is None:
528+
# # default: one spline
529+
# cp = torch.zeros(1, n_ctrl, dtype=knots_sorted.dtype, device=knots_sorted.device)
530+
# self.control_points = nn.Parameter(cp, requires_grad=True)
531+
# else:
532+
# if not isinstance(control_points, torch.Tensor):
533+
# raise ValueError("control_points must be a torch.Tensor or None.")
534+
# if control_points.ndim not in (2, 3):
535+
# raise ValueError("control_points must have shape (S, n_ctrl) or (S, O, n_ctrl).")
536+
# if control_points.shape[-1] != n_ctrl:
537+
# raise ValueError(
538+
# f"Last dim of control_points must be n_ctrl={n_ctrl}. Got {control_points.shape[-1]}."
539+
# )
540+
self.control_points = nn.Parameter(control_points, requires_grad=True)
541+
542+
@staticmethod
543+
def _compute_boundary_interval_idx(knots: torch.Tensor) -> int:
544+
if knots.numel() < 2:
545+
return 0
546+
diffs = knots[1:] - knots[:-1]
547+
valid = torch.nonzero(diffs > 0, as_tuple=False)
548+
if valid.numel() == 0:
549+
return 0
550+
return int(valid[-1])
551+
552+
def basis(self, x: torch.Tensor) -> torch.Tensor:
553+
"""
554+
Compute B-spline basis functions of order self.order at x.
555+
556+
Returns:
557+
basis: shape (..., n_ctrl)
558+
"""
559+
if not isinstance(x, torch.Tensor):
560+
x = torch.as_tensor(x)
561+
562+
# ensure float dtype consistent
563+
x = x.to(dtype=self.knots.dtype, device=self.knots.device)
564+
565+
# make x shape (..., 1) for broadcasting
566+
x_exp = x.unsqueeze(-1) # (..., 1)
567+
568+
# knots as (1, ..., 1, m) via unsqueeze to broadcast
569+
# (m,) -> (1,)*x.ndim + (m,)
570+
knots = self.knots.view(*([1] * x.ndim), -1)
571+
572+
# order-1 base: indicator on intervals [t_i, t_{i+1})
573+
basis = ((x_exp >= knots[..., :-1]) & (x_exp < knots[..., 1:])).to(x_exp.dtype) # (..., m-1)
574+
575+
# include rightmost boundary in the last non-degenerate interval
576+
j = self._boundary_interval_idx
577+
knot_left = knots[..., j]
578+
knot_right = knots[..., j + 1]
579+
at_right = (x >= knot_left.squeeze(-1)) & torch.isclose(x, knot_right.squeeze(-1), rtol=1e-8, atol=1e-10)
580+
if torch.any(at_right):
581+
basis_j = basis[..., j].bool() | at_right
582+
basis[..., j] = basis_j.to(basis.dtype)
583+
584+
# Cox-de Boor recursion up to order k
585+
# after i-th iteration, basis has length (m-1 - i)
586+
for i in range(1, self.order):
587+
denom1 = knots[..., i:-1] - knots[..., :-(i + 1)]
588+
denom2 = knots[..., i + 1:] - knots[..., 1:-i]
589+
590+
denom1 = torch.where(denom1.abs() < 1e-8, torch.ones_like(denom1), denom1)
591+
denom2 = torch.where(denom2.abs() < 1e-8, torch.ones_like(denom2), denom2)
592+
593+
term1 = ((x_exp - knots[..., :-(i + 1)]) / denom1) * basis[..., :-1]
594+
term2 = ((knots[..., i + 1:] - x_exp) / denom2) * basis[..., 1:]
595+
basis = term1 + term2
596+
597+
# final basis length is n_ctrl = m - order
598+
return basis # (..., n_ctrl)
599+
600+
def forward(self, x: torch.Tensor) -> torch.Tensor:
601+
"""
602+
Evaluate spline(s) at x.
603+
604+
If control_points is (S, n_ctrl): output (..., S)
605+
If control_points is (S, O, n_ctrl): output (..., S, O)
606+
"""
607+
B = self.basis(x) # (..., n_ctrl)
608+
609+
cp = self.control_points
610+
if cp.ndim == 2:
611+
# (S, n_ctrl)
612+
# want (..., S) = (..., n_ctrl) @ (n_ctrl, S)
613+
out = B @ cp.transpose(0, 1)
614+
return out
615+
else:
616+
# (S, O, n_ctrl)
617+
# Compute for each S: (..., n_ctrl) @ (n_ctrl, O) -> (..., O), then stack over S
618+
# vectorized using einsum (yes, this one is actually appropriate)
619+
# (..., n) * (S, O, n) -> (..., S, O)
620+
# out = torch.einsum("...n, son -> ...so", B, cp)
621+
out = torch.einsum("bsc,sco->bso", B, cp)
622+
623+
return out
624+
625+
def forward_basis(self, basis):
626+
"""
627+
Evaluate spline(s) given precomputed basis.
628+
629+
"""
630+
cp = self.control_points
631+
if cp.ndim == 2:
632+
# (S, n_ctrl)
633+
# want (..., S) = (..., n_ctrl) @ (n_ctrl, S)
634+
out = basis @ cp.transpose(0, 1)
635+
return out
636+
else:
637+
# (S, O, n_ctrl)
638+
# Compute for each S: (..., n_ctrl) @ (n_ctrl, O) -> (..., O), then stack over S
639+
# vectorized using einsum (yes, this one is actually appropriate)
640+
# (..., n) * (S, O, n) -> (..., S, O)
641+
# out = torch.einsum("...n, son -> ...so", B, cp)
642+
out = torch.einsum("bsc,sco->bso", basis, cp)
643+
644+
return out

tests/test_model/test_spline.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,3 +192,41 @@ def test_derivative(args, pts):
192192
# Check shape and value
193193
assert first_der.shape == pts.shape
194194
assert torch.allclose(first_der, first_der_auto, atol=1e-4, rtol=1e-4)
195+
196+
197+
@pytest.mark.parametrize("out_dim", [1, 3, 5])
198+
def test_vectorized(out_dim):
199+
200+
N = 7
201+
cps = []
202+
splines = []
203+
for i in range(N):
204+
cp = torch.rand(n_ctrl_pts, 3)
205+
cps.append(cp)
206+
spline = Spline(
207+
order=order,
208+
control_points=cp
209+
)
210+
splines.append(spline)
211+
212+
from pina.model.spline import SplineVectorized as VectorizedSpline
213+
unique_cps = torch.stack(cps, dim=0)
214+
print(unique_cps.shape)
215+
print(cps[0].shape)
216+
# Vectorized control points
217+
vectorized_spline = VectorizedSpline(
218+
order=order,
219+
knots=splines[0].knots,
220+
control_points=torch.stack(cps, dim=0)
221+
)
222+
223+
x = torch.rand(100, 1)
224+
225+
result_single = torch.stack([
226+
splines[i](x) for i in range(N)
227+
])
228+
result_single = result_single.squeeze(2).permute(1, 0, 2)
229+
out_vectorized = vectorized_spline(x)
230+
print(out_vectorized.shape)
231+
print(result_single.shape)
232+
assert torch.allclose(out_vectorized, result_single, atol=1e-5, rtol=1e-5)

0 commit comments

Comments
 (0)