Skip to content

Commit bcf2522

Browse files
committed
minor fix and rebase
1 parent 32acffd commit bcf2522

File tree

6 files changed

+174
-174
lines changed

6 files changed

+174
-174
lines changed
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,12 @@
22
import torch
33
import numpy as np
44

5-
from pina.model.spline import Spline
5+
from pina._src.model.spline import Spline
66

77

88
class KANBlock(torch.nn.Module):
99
"""define a KAN layer using splines"""
10-
def __init__(self, k: int, input_dimensions: int, output_dimensions: int, inner_nodes: int, num=3, grid_eps=0.1, grid_range=[-1, 1], grid_extension=True, noise_scale=0.1, base_function=torch.nn.SiLU(), scale_base_mu=0.0, scale_base_sigma=1.0, scale_sp=1.0, sparse_init=True, sp_trainable=True, sb_trainable=True) -> None:
10+
def __init__(self, k, input_dimensions, output_dimensions, inner_nodes, num=3, grid_eps=0.1, grid_range=[-1, 1], grid_extension=True, noise_scale=0.1, base_function=torch.nn.SiLU(), scale_base_mu=0.0, scale_base_sigma=1.0, scale_sp=1.0, sparse_init=True, sp_trainable=True, sb_trainable=True):
1111
"""
1212
Initialize the KAN layer.
1313

pina/model/kolmogorov_arnold_network.py renamed to pina/_src/model/kolmogorov_arnold_network.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import torch.nn as nn
44
from typing import List
55

6-
from pina.model.block import KANBlock
6+
from pina._src.model.block.kan_block import KANBlock
77

88
class KolmogorovArnoldNetwork(torch.nn.Module):
99
"""

pina/_src/model/spline.py

Lines changed: 2 additions & 171 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ def __init__(self, order=4, knots=None, control_points=None):
117117
raise ValueError("knots must be one-dimensional.")
118118

119119
# Check dimensionality of control points
120-
if self.control_points.ndim > 2:
120+
if self.control_points.ndim > 1:
121121
raise ValueError("control_points must be one-dimensional.")
122122

123123
# Raise error if #knots != order + #control_points
@@ -279,11 +279,6 @@ def forward(self, x):
279279
"""
280280
basis = self.basis(x.as_subclass(torch.Tensor))
281281
return basis @ self.control_points
282-
return torch.einsum(
283-
"...bi, ...i -> ...b",
284-
basis,
285-
self.control_points,
286-
)
287282

288283
def derivative(self, x, degree):
289284
"""
@@ -474,171 +469,7 @@ def knots(self, value):
474469

475470
# Recompute boundary interval when knots change
476471
if hasattr(self, "_boundary_interval_idx"):
477-
self._boundary_interval_Widx = self._compute_boundary_interval()
472+
self._boundary_interval_idx = self._compute_boundary_interval()
478473

479474
# Recompute derivative denominators when knots change
480475
self._compute_derivative_denominators()
481-
482-
483-
import torch
484-
import torch.nn as nn
485-
486-
class SplineVectorized(nn.Module):
487-
"""
488-
Vectorized univariate B-spline model (shared knots, many splines).
489-
490-
Notation:
491-
- knots: shape (m,)
492-
- order: k (degree = k-1)
493-
- n_ctrl = m - k
494-
- control_points:
495-
* (S, n_ctrl) -> S splines, scalar output each
496-
* (S, O, n_ctrl) -> S splines, O outputs each (like multiple channels)
497-
Input:
498-
- x: shape (...,) or (..., B)
499-
Output:
500-
- if control_points is (S, n_ctrl): shape (..., S)
501-
- if control_points is (S, O, n_ctrl): shape (..., S, O)
502-
"""
503-
504-
def __init__(self, order: int, knots: torch.Tensor, control_points: torch.Tensor | None = None):
505-
super().__init__()
506-
if not isinstance(order, int) or order <= 0:
507-
raise ValueError("order must be a positive integer.")
508-
if not isinstance(knots, torch.Tensor):
509-
raise ValueError("knots must be a torch.Tensor.")
510-
if knots.ndim != 1:
511-
raise ValueError("knots must be 1D.")
512-
513-
self.order = order
514-
515-
# store sorted knots as buffer
516-
knots_sorted = knots.sort().values
517-
self.register_buffer("knots", knots_sorted)
518-
519-
n_ctrl = knots_sorted.numel() - order
520-
if n_ctrl <= 0:
521-
raise ValueError(f"Need #knots > order. Got #knots={knots_sorted.numel()} order={order}.")
522-
523-
# boundary interval idx for rightmost inclusion
524-
self._boundary_interval_idx = self._compute_boundary_interval_idx(knots_sorted)
525-
526-
# # control points init
527-
# if control_points is None:
528-
# # default: one spline
529-
# cp = torch.zeros(1, n_ctrl, dtype=knots_sorted.dtype, device=knots_sorted.device)
530-
# self.control_points = nn.Parameter(cp, requires_grad=True)
531-
# else:
532-
# if not isinstance(control_points, torch.Tensor):
533-
# raise ValueError("control_points must be a torch.Tensor or None.")
534-
# if control_points.ndim not in (2, 3):
535-
# raise ValueError("control_points must have shape (S, n_ctrl) or (S, O, n_ctrl).")
536-
# if control_points.shape[-1] != n_ctrl:
537-
# raise ValueError(
538-
# f"Last dim of control_points must be n_ctrl={n_ctrl}. Got {control_points.shape[-1]}."
539-
# )
540-
self.control_points = nn.Parameter(control_points, requires_grad=True)
541-
542-
@staticmethod
543-
def _compute_boundary_interval_idx(knots: torch.Tensor) -> int:
544-
if knots.numel() < 2:
545-
return 0
546-
diffs = knots[1:] - knots[:-1]
547-
valid = torch.nonzero(diffs > 0, as_tuple=False)
548-
if valid.numel() == 0:
549-
return 0
550-
return int(valid[-1])
551-
552-
def basis(self, x: torch.Tensor) -> torch.Tensor:
553-
"""
554-
Compute B-spline basis functions of order self.order at x.
555-
556-
Returns:
557-
basis: shape (..., n_ctrl)
558-
"""
559-
if not isinstance(x, torch.Tensor):
560-
x = torch.as_tensor(x)
561-
562-
# ensure float dtype consistent
563-
x = x.to(dtype=self.knots.dtype, device=self.knots.device)
564-
565-
# make x shape (..., 1) for broadcasting
566-
x_exp = x.unsqueeze(-1) # (..., 1)
567-
568-
# knots as (1, ..., 1, m) via unsqueeze to broadcast
569-
# (m,) -> (1,)*x.ndim + (m,)
570-
knots = self.knots.view(*([1] * x.ndim), -1)
571-
572-
# order-1 base: indicator on intervals [t_i, t_{i+1})
573-
basis = ((x_exp >= knots[..., :-1]) & (x_exp < knots[..., 1:])).to(x_exp.dtype) # (..., m-1)
574-
575-
# include rightmost boundary in the last non-degenerate interval
576-
j = self._boundary_interval_idx
577-
knot_left = knots[..., j]
578-
knot_right = knots[..., j + 1]
579-
at_right = (x >= knot_left.squeeze(-1)) & torch.isclose(x, knot_right.squeeze(-1), rtol=1e-8, atol=1e-10)
580-
if torch.any(at_right):
581-
basis_j = basis[..., j].bool() | at_right
582-
basis[..., j] = basis_j.to(basis.dtype)
583-
584-
# Cox-de Boor recursion up to order k
585-
# after i-th iteration, basis has length (m-1 - i)
586-
for i in range(1, self.order):
587-
denom1 = knots[..., i:-1] - knots[..., :-(i + 1)]
588-
denom2 = knots[..., i + 1:] - knots[..., 1:-i]
589-
590-
denom1 = torch.where(denom1.abs() < 1e-8, torch.ones_like(denom1), denom1)
591-
denom2 = torch.where(denom2.abs() < 1e-8, torch.ones_like(denom2), denom2)
592-
593-
term1 = ((x_exp - knots[..., :-(i + 1)]) / denom1) * basis[..., :-1]
594-
term2 = ((knots[..., i + 1:] - x_exp) / denom2) * basis[..., 1:]
595-
basis = term1 + term2
596-
597-
# final basis length is n_ctrl = m - order
598-
return basis # (..., n_ctrl)
599-
600-
def forward(self, x: torch.Tensor) -> torch.Tensor:
601-
"""
602-
Evaluate spline(s) at x.
603-
604-
If control_points is (S, n_ctrl): output (..., S)
605-
If control_points is (S, O, n_ctrl): output (..., S, O)
606-
"""
607-
B = self.basis(x) # (..., n_ctrl)
608-
609-
cp = self.control_points
610-
if cp.ndim == 2:
611-
# (S, n_ctrl)
612-
# want (..., S) = (..., n_ctrl) @ (n_ctrl, S)
613-
out = B @ cp.transpose(0, 1)
614-
return out
615-
else:
616-
# (S, O, n_ctrl)
617-
# Compute for each S: (..., n_ctrl) @ (n_ctrl, O) -> (..., O), then stack over S
618-
# vectorized using einsum (yes, this one is actually appropriate)
619-
# (..., n) * (S, O, n) -> (..., S, O)
620-
# out = torch.einsum("...n, son -> ...so", B, cp)
621-
out = torch.einsum("bsc,sco->bso", B, cp)
622-
623-
return out
624-
625-
def forward_basis(self, basis):
626-
"""
627-
Evaluate spline(s) given precomputed basis.
628-
629-
"""
630-
cp = self.control_points
631-
if cp.ndim == 2:
632-
# (S, n_ctrl)
633-
# want (..., S) = (..., n_ctrl) @ (n_ctrl, S)
634-
out = basis @ cp.transpose(0, 1)
635-
return out
636-
else:
637-
# (S, O, n_ctrl)
638-
# Compute for each S: (..., n_ctrl) @ (n_ctrl, O) -> (..., O), then stack over S
639-
# vectorized using einsum (yes, this one is actually appropriate)
640-
# (..., n) * (S, O, n) -> (..., S, O)
641-
# out = torch.einsum("...n, son -> ...so", B, cp)
642-
out = torch.einsum("bsc,sco->bso", basis, cp)
643-
644-
return out
Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
"""Vectorized univariate B-spline model."""
2+
3+
import torch
4+
import torch.nn as nn
5+
6+
class VectorizedSpline(nn.Module):
7+
"""
8+
Vectorized univariate B-spline model (shared knots, many splines).
9+
10+
Notation:
11+
- knots: shape (m,)
12+
- order: k (degree = k-1)
13+
- n_ctrl = m - k
14+
- control_points:
15+
* (S, n_ctrl) -> S splines, scalar output each
16+
* (S, O, n_ctrl) -> S splines, O outputs each (like multiple channels)
17+
Input:
18+
- x: shape (...,) or (..., B)
19+
Output:
20+
- if control_points is (S, n_ctrl): shape (..., S)
21+
- if control_points is (S, O, n_ctrl): shape (..., S, O)
22+
"""
23+
24+
def __init__(self, order: int, knots: torch.Tensor, control_points: torch.Tensor | None = None):
25+
super().__init__()
26+
if not isinstance(order, int) or order <= 0:
27+
raise ValueError("order must be a positive integer.")
28+
if not isinstance(knots, torch.Tensor):
29+
raise ValueError("knots must be a torch.Tensor.")
30+
if knots.ndim != 1:
31+
raise ValueError("knots must be 1D.")
32+
33+
self.order = order
34+
35+
# store sorted knots as buffer
36+
knots_sorted = knots.sort().values
37+
self.register_buffer("knots", knots_sorted)
38+
39+
n_ctrl = knots_sorted.numel() - order
40+
if n_ctrl <= 0:
41+
raise ValueError(f"Need #knots > order. Got #knots={knots_sorted.numel()} order={order}.")
42+
43+
# boundary interval idx for rightmost inclusion
44+
self._boundary_interval_idx = self._compute_boundary_interval_idx(knots_sorted)
45+
46+
# # control points init
47+
# if control_points is None:
48+
# # default: one spline
49+
# cp = torch.zeros(1, n_ctrl, dtype=knots_sorted.dtype, device=knots_sorted.device)
50+
# self.control_points = nn.Parameter(cp, requires_grad=True)
51+
# else:
52+
# if not isinstance(control_points, torch.Tensor):
53+
# raise ValueError("control_points must be a torch.Tensor or None.")
54+
# if control_points.ndim not in (2, 3):
55+
# raise ValueError("control_points must have shape (S, n_ctrl) or (S, O, n_ctrl).")
56+
# if control_points.shape[-1] != n_ctrl:
57+
# raise ValueError(
58+
# f"Last dim of control_points must be n_ctrl={n_ctrl}. Got {control_points.shape[-1]}."
59+
# )
60+
self.control_points = nn.Parameter(control_points, requires_grad=True)
61+
62+
@staticmethod
63+
def _compute_boundary_interval_idx(knots: torch.Tensor) -> int:
64+
if knots.numel() < 2:
65+
return 0
66+
diffs = knots[1:] - knots[:-1]
67+
valid = torch.nonzero(diffs > 0, as_tuple=False)
68+
if valid.numel() == 0:
69+
return 0
70+
return int(valid[-1])
71+
72+
def basis(self, x: torch.Tensor) -> torch.Tensor:
73+
"""
74+
Compute B-spline basis functions of order self.order at x.
75+
76+
Returns:
77+
basis: shape (..., n_ctrl)
78+
"""
79+
if not isinstance(x, torch.Tensor):
80+
x = torch.as_tensor(x)
81+
82+
# ensure float dtype consistent
83+
x = x.to(dtype=self.knots.dtype, device=self.knots.device)
84+
85+
# make x shape (..., 1) for broadcasting
86+
x_exp = x.unsqueeze(-1) # (..., 1)
87+
88+
# knots as (1, ..., 1, m) via unsqueeze to broadcast
89+
# (m,) -> (1,)*x.ndim + (m,)
90+
knots = self.knots.view(*([1] * x.ndim), -1)
91+
92+
# order-1 base: indicator on intervals [t_i, t_{i+1})
93+
basis = ((x_exp >= knots[..., :-1]) & (x_exp < knots[..., 1:])).to(x_exp.dtype) # (..., m-1)
94+
95+
# include rightmost boundary in the last non-degenerate interval
96+
j = self._boundary_interval_idx
97+
knot_left = knots[..., j]
98+
knot_right = knots[..., j + 1]
99+
at_right = (x >= knot_left.squeeze(-1)) & torch.isclose(x, knot_right.squeeze(-1), rtol=1e-8, atol=1e-10)
100+
if torch.any(at_right):
101+
basis_j = basis[..., j].bool() | at_right
102+
basis[..., j] = basis_j.to(basis.dtype)
103+
104+
# Cox-de Boor recursion up to order k
105+
# after i-th iteration, basis has length (m-1 - i)
106+
for i in range(1, self.order):
107+
denom1 = knots[..., i:-1] - knots[..., :-(i + 1)]
108+
denom2 = knots[..., i + 1:] - knots[..., 1:-i]
109+
110+
denom1 = torch.where(denom1.abs() < 1e-8, torch.ones_like(denom1), denom1)
111+
denom2 = torch.where(denom2.abs() < 1e-8, torch.ones_like(denom2), denom2)
112+
113+
term1 = ((x_exp - knots[..., :-(i + 1)]) / denom1) * basis[..., :-1]
114+
term2 = ((knots[..., i + 1:] - x_exp) / denom2) * basis[..., 1:]
115+
basis = term1 + term2
116+
117+
# final basis length is n_ctrl = m - order
118+
return basis # (..., n_ctrl)
119+
120+
def forward(self, x: torch.Tensor) -> torch.Tensor:
121+
"""
122+
Evaluate spline(s) at x.
123+
124+
If control_points is (S, n_ctrl): output (..., S)
125+
If control_points is (S, O, n_ctrl): output (..., S, O)
126+
"""
127+
B = self.basis(x) # (..., n_ctrl)
128+
129+
cp = self.control_points
130+
if cp.ndim == 2:
131+
# (S, n_ctrl)
132+
# want (..., S) = (..., n_ctrl) @ (n_ctrl, S)
133+
out = B @ cp.transpose(0, 1)
134+
return out
135+
else:
136+
# (S, O, n_ctrl)
137+
# Compute for each S: (..., n_ctrl) @ (n_ctrl, O) -> (..., O), then stack over S
138+
# vectorized using einsum (yes, this one is actually appropriate)
139+
# (..., n) * (S, O, n) -> (..., S, O)
140+
# out = torch.einsum("...n, son -> ...so", B, cp)
141+
out = torch.einsum("bsc,sco->bso", B, cp)
142+
143+
return out
144+
145+
def forward_basis(self, basis):
146+
"""
147+
Evaluate spline(s) given precomputed basis.
148+
149+
"""
150+
cp = self.control_points
151+
if cp.ndim == 2:
152+
# (S, n_ctrl)
153+
# want (..., S) = (..., n_ctrl) @ (n_ctrl, S)
154+
out = basis @ cp.transpose(0, 1)
155+
return out
156+
else:
157+
# (S, O, n_ctrl)
158+
# Compute for each S: (..., n_ctrl) @ (n_ctrl, O) -> (..., O), then stack over S
159+
# vectorized using einsum (yes, this one is actually appropriate)
160+
# (..., n) * (S, O, n) -> (..., S, O)
161+
# out = torch.einsum("...n, son -> ...so", B, cp)
162+
out = torch.einsum("bsc,sco->bso", basis, cp)
163+
164+
return out

pina/model/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
"EquivariantGraphNeuralOperator",
1818
"SINDy",
1919
"SplineSurface",
20+
"VectorizedSpline",
21+
"KolmogorovArnoldNetwork",
2022
]
2123

2224
from pina._src.model.feed_forward import FeedForward, ResidualFeedForward
@@ -34,3 +36,5 @@
3436
EquivariantGraphNeuralOperator,
3537
)
3638
from pina._src.model.sindy import SINDy
39+
from pina._src.model.vectorized_spline import VectorizedSpline
40+
from pina._src.model.kolmogorov_arnold_network import KolmogorovArnoldNetwork

0 commit comments

Comments
 (0)