Skip to content

Commit b1e4706

Browse files
committed
update
1 parent 30d05d8 commit b1e4706

4 files changed

Lines changed: 417 additions & 374 deletions

File tree

SuPyMode/_version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,4 +31,4 @@
3131
__version__ = version = "2.2.1"
3232
__version_tuple__ = version_tuple = (2, 2, 1)
3333

34-
__commit_id__ = commit_id = "gefeac8807"
34+
__commit_id__ = commit_id = "g30d05d89b"

SuPyMode/cpp/taper/interpolator.h

Lines changed: 73 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,12 @@ class Interpolator1D {
1111
public:
1212
Interpolator1D() = default;
1313

14-
Interpolator1D(std::vector<double> x_values,
15-
std::vector<double> y_values,
16-
bool bounds_error,
17-
double fill_value)
14+
Interpolator1D(
15+
std::vector<double> x_values,
16+
std::vector<double> y_values,
17+
bool bounds_error,
18+
double fill_value
19+
)
1820
: x_(std::move(x_values)),
1921
y_(std::move(y_values)),
2022
bounds_error_(bounds_error),
@@ -70,3 +72,70 @@ class Interpolator1D {
7072
bool bounds_error_ = false;
7173
double fill_value_ = 0.0;
7274
};
75+
76+
77+
// Linear interpolation similar to scipy interp1d(kind="linear") with
78+
// bounds_error configurable and fill_value outside bounds.
79+
class LinearInterpolator {
80+
public:
81+
LinearInterpolator() = default;
82+
83+
LinearInterpolator(
84+
std::vector<double> x_values,
85+
std::vector<double> y_values,
86+
bool bounds_error,
87+
double fill_value)
88+
: x_(std::move(x_values)),
89+
y_(std::move(y_values)),
90+
bounds_error_(bounds_error),
91+
fill_value_(fill_value)
92+
{
93+
if (x_.size() != y_.size()) throw std::invalid_argument("Interpolator: x and y size mismatch.");
94+
if (x_.size() < 2) throw std::invalid_argument("Interpolator: need at least 2 points.");
95+
if (!std::is_sorted(x_.begin(), x_.end())) {
96+
throw std::invalid_argument("Interpolator: x must be sorted ascending.");
97+
}
98+
}
99+
100+
double operator()(double x_query) const {
101+
const double x_min = x_.front();
102+
const double x_max = x_.back();
103+
104+
if (x_query < x_min || x_query > x_max) {
105+
if (bounds_error_) throw std::out_of_range("Interpolator: query out of bounds.");
106+
return fill_value_;
107+
}
108+
109+
// Handle exact endpoints quickly
110+
if (x_query == x_min) return y_.front();
111+
if (x_query == x_max) return y_.back();
112+
113+
// Find rightmost index such that x_[index] <= x_query
114+
auto it = std::upper_bound(x_.begin(), x_.end(), x_query);
115+
const std::size_t right_index = static_cast<std::size_t>(std::distance(x_.begin(), it));
116+
const std::size_t left_index = right_index - 1;
117+
118+
const double x0 = x_[left_index];
119+
const double x1 = x_[right_index];
120+
const double y0 = y_[left_index];
121+
const double y1 = y_[right_index];
122+
123+
const double dx = x1 - x0;
124+
if (dx == 0.0) return y0;
125+
126+
const double t = (x_query - x0) / dx;
127+
return y0 + t * (y1 - y0);
128+
}
129+
130+
std::vector<double> operator()(const std::vector<double>& x_queries) const {
131+
std::vector<double> y_queries(x_queries.size());
132+
for (std::size_t i = 0; i < x_queries.size(); ++i) y_queries[i] = (*this)(x_queries[i]);
133+
return y_queries;
134+
}
135+
136+
private:
137+
std::vector<double> x_;
138+
std::vector<double> y_;
139+
bool bounds_error_ = false;
140+
double fill_value_ = 0.0;
141+
};

0 commit comments

Comments
 (0)