1+ // https://oj.socoding.cn/p/1778
2+ #include < bits/stdc++.h>
3+ using namespace std ;
4+ using ll = long long ;
5+ using pii = pair<int , int >;
6+ using pll = pair<ll, ll>;
7+ const ll mod = 998244353 ;
8+
9+ ll qui (ll a, ll x)
10+ {
11+ ll ret = 1 ;
12+ while (x)
13+ {
14+ if (x & 1 )
15+ ret = ret * a % mod;
16+ a = a * a % mod;
17+ x >>= 1 ;
18+ }
19+ return ret;
20+ }
21+
22+ using Poly = vector<ll>;
23+ const int BIT = 20 ;
24+ int p[1 << BIT];
25+ const ll maxn = 1e5 + 10 ;
26+ ll fac[maxn], inv[maxn];
27+
28+ Poly operator *(const Poly &a, const Poly &b)
29+ {
30+ int n = a.size () - 1 , m = b.size () - 1 ;
31+ int L, l = 0 ;
32+ for (L = 1 ; L <= n + m; l++, L = L << 1 )
33+ ;
34+
35+ vector<int > p (L);
36+
37+ for (int i = 1 ; i < L; i++)
38+ p[i] = ((p[i >> 1 ] >> 1 ) | ((i & 1 ) << (l - 1 )));
39+ auto u = a, v = b;
40+ u.resize (L, 0 ), v.resize (L, 0 );
41+ auto ntt = [&L, &l, &p](Poly &g, int type)
42+ {
43+ for (int i = 0 ; i < L; i++)
44+ if (i < p[i])
45+ swap (g[i], g[p[i]]);
46+ for (int i = 1 ; i < L; (i <<= 1 ))
47+ {
48+ ll wn = qui (3 , (mod - 1 ) / (i << 1 ));
49+ for (int j = 0 ; j < L; j += (i << 1 ))
50+ {
51+ ll w = 1 ;
52+ for (int k = j; k < j + i; w = w * wn % mod, k++)
53+ {
54+ assert (k + i < L);
55+ assert (k < L);
56+ ll t = g[k + i] * w % mod;
57+ g[k + i] = (g[k] - t + mod) % mod;
58+ g[k] = (g[k] + t) % mod;
59+ }
60+ }
61+ }
62+ if (type == 1 )
63+ return ;
64+ reverse (g.begin () + 1 , g.begin () + L);
65+ ll ni = qui (L, mod - 2 );
66+ for (int i = 0 ; i < L; i++)
67+ g[i] = g[i] * ni % mod;
68+ };
69+ ntt (u, 1 ), ntt (v, 1 );
70+
71+ Poly g (L, 0 );
72+ for (int i = 0 ; i < L; i++)
73+ g[i] = u[i] * v[i] % mod;
74+ ntt (g, -1 );
75+
76+ return g;
77+ }
78+
79+ signed main ()
80+ {
81+ cin.tie (0 )->sync_with_stdio (false );
82+ int n, k;
83+ cin >> n >> k;
84+
85+ vector<int > b (n + 1 );
86+ for (int i = 1 ; i <= n; i++)
87+ cin >> b[i];
88+
89+ function<Poly (int , int )> calc = [&](int l, int r)
90+ {
91+ if (l >= r)
92+ {
93+ assert (l == r);
94+ return Poly{1 , b[l]};
95+ }
96+ int mid = (l + r) / 2 ;
97+ return calc (l, mid) * calc (mid + 1 , r);
98+ };
99+
100+ Poly ep = calc (0 , n);
101+ cout << ep[k];
102+ return 0 ;
103+ }
0 commit comments