Skip to content

Commit 401dfa5

Browse files
committed
Made lagrange a class.
1 parent f0014a5 commit 401dfa5

1 file changed

Lines changed: 63 additions & 48 deletions

File tree

src/alfred/math/lagrange.hpp

Lines changed: 63 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -2,26 +2,78 @@
22
#define AFMT_LAGRANGE
33

44
#include "comb.hpp"
5+
#include "vec-inv.hpp"
56
#include <cassert>
67
#include <iostream>
78
#include <vector>
89

910
// TODO: write lagrange as a class, supporting: O(n^2) init, O(n) query
1011

1112
template <class mint>
12-
inline mint lagrange(std::vector<mint> x, std::vector<mint> y, mint k) {
13-
mint ans = 0, cur;
14-
const int n = x.size();
15-
for (int i = 0; i < n; i++) {
16-
cur = y[i];
17-
for (int j = 0; j < n; j++) {
18-
if (j == i) continue;
19-
cur *= (k - x[j]) / (x[i] - x[j]);
13+
class Lagrange {
14+
private:
15+
std::vector<mint> x, y, b;
16+
17+
public:
18+
Lagrange(void) = default;
19+
Lagrange(std::vector<mint> _x, std::vector<mint> _y) : x(_x), y(_y) {
20+
for (size_t i = 0; i < x.size(); i++) insert(x[i], y[i]);
21+
}
22+
inline void insert(mint x0, mint y0) {
23+
b.push_back(y0);
24+
std::vector<mint> tmp(x.size());
25+
for (size_t i = 0; i < x.size(); i++) {
26+
tmp[i] = x0 - x[i];
27+
}
28+
VecInv<mint> inv(tmp);
29+
for (size_t i = 0; i < x.size(); i++) {
30+
b.back() *= inv[i], b[i] *= -inv[i];
2031
}
21-
ans += cur;
32+
x.push_back(x0), y.push_back(y0);
2233
}
23-
return ans;
24-
}
34+
inline mint query(mint k) {
35+
mint ans = 0, tot = 1;
36+
const int n = x.size();
37+
std::vector<mint> tmp(n);
38+
for (int i = 0; i < n; i++) {
39+
if (x[i] == k) return y[i];
40+
tmp[i] = k - x[i], tot *= tmp[i];
41+
}
42+
VecInv<mint> inv(tmp);
43+
for (int i = 0; i < n; i++) {
44+
ans += b[i] * tot * inv[i];
45+
}
46+
return ans;
47+
}
48+
std::vector<mint> coefficient(void) { // now O(n^2), TODO: maintain it dynamically.
49+
int n = x.size(), i;
50+
// F(k) = \prod (k - x_i): degree = n, n + 1 coefficients.
51+
std::vector<mint> F(n + 1);
52+
for (i = 0, F[0] = 1; i < n; i++) {
53+
for (int j = i + 1; j >= 0; j--) {
54+
F[j] *= -x[i];
55+
if (j) F[j] += F[j - 1];
56+
}
57+
}
58+
mint delta, c;
59+
std::vector<mint> ans(n), res(n);
60+
auto div = [&](mint xi) {
61+
delta = 0;
62+
for (int i = n; i > 0; i--) {
63+
res[i - 1] = F[i] + delta;
64+
delta = (F[i] + delta) * xi;
65+
}
66+
};
67+
for (int i = 0; i < n; i++) {
68+
c = b[i], div(x[i]);
69+
for (int j = 0; j < n; j++) {
70+
ans[j] += c * res[j];
71+
}
72+
}
73+
std::reverse(ans.begin(), ans.end());
74+
return ans;
75+
}
76+
};
2577

2678
// y[0] is placeholder.
2779
// If for all integer x_i in [1, n], we have f(x_i) = y_i (mod p), find f(k) mod p.
@@ -52,41 +104,4 @@ inline mint sum_of_kth_powers(mint n, int k) {
52104
return cont_lagrange(Y, n);
53105
}
54106

55-
template <class mint>
56-
std::vector<mint> find_coefficient(
57-
std::vector<mint> x, std::vector<mint> y
58-
) {
59-
// F(k) = \prod (k - x_i): n degree, n + 1 coefficients.
60-
int n = x.size(), i;
61-
std::vector<mint> F(n + 1);
62-
assert(n == (int)y.size());
63-
for (i = 0, F[0] = 1; i < n; i++) {
64-
for (int j = i + 1; j >= 0; j--) {
65-
F[j] *= -x[i];
66-
if (j) F[j] += F[j - 1];
67-
}
68-
}
69-
mint delta, c;
70-
std::vector<mint> ans(n), res(n);
71-
auto div = [&](mint xi) {
72-
delta = 0;
73-
for (int i = n; i > 0; i--) {
74-
res[i - 1] = F[i] + delta;
75-
delta = (F[i] + delta) * xi;
76-
}
77-
return res;
78-
};
79-
for (int i = 0; i < n; i++) {
80-
c = y[i];
81-
for (int j = 0; j < n; j++) {
82-
if (i != j) c /= x[i] - x[j];
83-
}
84-
div(x[i]);
85-
for (int j = 0; j < n; j++) {
86-
ans[j] += c * res[j];
87-
}
88-
}
89-
return ans;
90-
}
91-
92107
#endif // AFMT_LAGRANGE

0 commit comments

Comments
 (0)