Skip to content

Commit 0064d3e

Browse files
committed
Add high precision for the four operations, exp and log.
1 parent 4b92e80 commit 0064d3e

2 files changed

Lines changed: 902 additions & 0 deletions

File tree

src/shaders/embedded_shaders.h

Lines changed: 376 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,382 @@
11
#pragma once
22
#include <string>
33

4+
inline std::string SRC_HIGH_PRECISION_FRAG = R"(#version 300 es
5+
6+
const int LIMB_SIZE = 32;
7+
const int FRACTIONAL_SIZE = LIMB_SIZE/2;
8+
9+
struct number{
10+
uint limb[LIMB_SIZE];
11+
int sign;
12+
bool is_infinite;
13+
};
14+
15+
struct hp_vec2{
16+
number x;
17+
number y;
18+
};
19+
20+
number null_number(){
21+
number res;
22+
for(int i = 0; i < LIMB_SIZE; ++i){
23+
res.limb[i] = 0u;
24+
}
25+
res.sign = 1;
26+
return res;
27+
}
28+
29+
number infinite_number(){
30+
number res;
31+
for(int i = 0; i < LIMB_SIZE; ++i){
32+
res.limb[i] = (1u<<32u)-1u;
33+
}
34+
res.sign = 1;
35+
res.is_infinite = true;
36+
return res;
37+
}
38+
39+
number number_one() {
40+
number res = null_number();
41+
res.limb[FRACTIONAL_SIZE] = 1u;
42+
return res;
43+
}
44+
45+
number neg(number a){
46+
a.sign *= -1;
47+
return a;
48+
}
49+
50+
uint hi(uint a){
51+
return a >> 16;
52+
}
53+
54+
uint lo(uint a){
55+
return a & (0xFFFFu);
56+
}
57+
58+
59+
int compare_abs(number a, number b) {
60+
for(int i = LIMB_SIZE - 1; i >= 0; --i) {
61+
if(a.limb[i] > b.limb[i]) return 1;
62+
if(a.limb[i] < b.limb[i]) return -1;
63+
}
64+
return 0;
65+
}
66+
67+
bool is_zero(number a){
68+
bool notzero = false;
69+
for(int i = 0; i < LIMB_SIZE; ++i){
70+
notzero = a.limb[i] != 0u || notzero;
71+
}
72+
return !notzero;
73+
}
74+
75+
uvec2 add_with_carry(uint a, uint b){
76+
uvec2 res = uvec2(0u,0u);
77+
res.x = a + b;
78+
res.y = uint(a > res.x || b > res.x);
79+
return res;
80+
}
81+
number abs_add(number a, number b){
82+
number c = null_number();
83+
uint carry = 0u;
84+
for(int i = 0; i < LIMB_SIZE; ++i){
85+
uvec2 res = add_with_carry(add_with_carry(a.limb[i],b.limb[i]).x,carry);
86+
uint sum = res.x;
87+
carry = res.y;
88+
c.limb[i] = sum;
89+
}
90+
return c;
91+
}
92+
number abs_sub(number a, number b) {
93+
number c = null_number();
94+
uint borrow = 0u;
95+
for(int i = 0; i < LIMB_SIZE; ++i) {
96+
uint sub = a.limb[i] - b.limb[i] - borrow;
97+
borrow = uint((a.limb[i] < b.limb[i] || (a.limb[i] == b.limb[i] && borrow > 0u)));
98+
c.limb[i] = sub;
99+
}
100+
return c;
101+
}
102+
103+
number add(number a, number b){
104+
number c = null_number();
105+
if(a.sign == b.sign){
106+
c = abs_add(a,b);
107+
c.sign = a.sign;
108+
return c;
109+
}
110+
int cmp = compare_abs(a,b);
111+
if(cmp >= 0){
112+
c = abs_sub(a,b);
113+
c.sign = a.sign;
114+
return c;
115+
}
116+
c = abs_sub(b,a);
117+
c.sign = b.sign;
118+
return c;
119+
}
120+
121+
122+
123+
number sub(number a, number b){
124+
return add(a,neg(b));
125+
}
126+
127+
uvec2 multiply_with_remainder(uint a, uint b) {
128+
uint low_part = a * b;
129+
130+
uint al = lo(a);
131+
uint ah = hi(a);
132+
uint bl = lo(b);
133+
uint bh = hi(b);
134+
135+
uint p0 = al * bl;
136+
uint p1 = al * bh;
137+
uint p2 = ah * bl;
138+
uint p3 = ah * bh;
139+
140+
uvec2 m_sum = add_with_carry(p1, p2);
141+
uvec2 m_combined = add_with_carry(m_sum.x, hi(p0));
142+
143+
uint high_part = p3 + hi(m_combined.x) + ((m_sum.y + m_combined.y) << 16);
144+
145+
return uvec2(low_part, high_part);
146+
}
147+
148+
number mult(number a, number b) {
149+
number c = null_number();
150+
c.sign = a.sign * b.sign;
151+
152+
if (is_zero(a) || is_zero(b)) return c;
153+
if (a.is_infinite || b.is_infinite) {
154+
c.limb = infinite_number().limb;
155+
return c;
156+
}
157+
for (int i = 0; i < LIMB_SIZE; ++i) {
158+
if (a.limb[i] == 0u) continue;
159+
uint carry = 0u;
160+
161+
for (int j = 0; j < LIMB_SIZE; ++j) {
162+
int target = i + j - FRACTIONAL_SIZE;
163+
164+
if (target >= LIMB_SIZE) break;
165+
166+
uvec2 prod = multiply_with_remainder(a.limb[i], b.limb[j]);
167+
if (target < 0) {
168+
uvec2 add2 = add_with_carry(prod.x, carry);
169+
carry = prod.y + add2.y;
170+
} else {
171+
uvec2 add1 = add_with_carry(c.limb[target], prod.x);
172+
uvec2 add2 = add_with_carry(add1.x, carry);
173+
c.limb[target] = add2.x;
174+
carry = prod.y + add1.y + add2.y;
175+
}
176+
}
177+
}
178+
c.sign = c.sign * int(!is_zero(c));
179+
return c;
180+
}
181+
182+
number shift_left(number a, int shift) {
183+
if (shift <= 0) return a;
184+
185+
int limb_shift = shift / 32;
186+
int bit_shift = shift % 32;
187+
188+
number c = null_number();
189+
c.sign = a.sign;
190+
191+
if (limb_shift >= LIMB_SIZE) return null_number();
192+
193+
for (int i = LIMB_SIZE - 1; i >= limb_shift; --i) {
194+
int target = i;
195+
int source = i - limb_shift;
196+
197+
uint val = a.limb[source] << bit_shift;
198+
199+
if (source > 0 && bit_shift > 0) {
200+
val |= a.limb[source - 1] >> (32 - bit_shift);
201+
}
202+
203+
c.limb[target] = val;
204+
}
205+
206+
return c;
207+
}
208+
209+
number shift_right(number a, int shift) {
210+
if (shift <= 0) return a;
211+
212+
int limb_shift = shift >> 5;
213+
int bit_shift = shift % 32;
214+
215+
number c = null_number();
216+
c.sign = a.sign;
217+
218+
if (limb_shift >= LIMB_SIZE) return null_number();
219+
for (int i = 0; i < LIMB_SIZE - limb_shift; ++i) {
220+
int target = i;
221+
int source = i + limb_shift;
222+
223+
uint val = a.limb[source] >> bit_shift;
224+
225+
if (source < LIMB_SIZE - 1 && bit_shift > 0) {
226+
val |= a.limb[source + 1] << (32 - bit_shift);
227+
}
228+
c.limb[target] = val;
229+
}
230+
return c;
231+
}
232+
233+
int find_msb(number a) {
234+
for (int i = LIMB_SIZE - 1; i >= 0; --i) {
235+
uint x = a.limb[i];
236+
if (x != 0u) {
237+
int bit_pos = 0;
238+
if ((x & 0xFFFF0000u) != 0u) { bit_pos += 16; x >>= 16; }
239+
if ((x & 0x0000FF00u) != 0u) { bit_pos += 8; x >>= 8; }
240+
if ((x & 0x000000F0u) != 0u) { bit_pos += 4; x >>= 4; }
241+
if ((x & 0x0000000Cu) != 0u) { bit_pos += 2; x >>= 2; }
242+
if ((x & 0x00000002u) != 0u) { bit_pos += 1; }
243+
return (i * 32) + bit_pos;
244+
}
245+
}
246+
return -1;
247+
}
248+
249+
uint get_half(number a, int index){
250+
uint l = a.limb[index / 2];
251+
return uint(index % 2 == 0) * hi(l) + uint(index%2==1) * lo(l);
252+
}
253+
254+
void set_half(inout number a, int index, uint val) {
255+
if (index >= LIMB_SIZE * 2 || index < 0) return;
256+
int limb_idx = index / 2;
257+
if (index % 2 == 1) {
258+
a.limb[limb_idx] = (a.limb[limb_idx] & 0x0000FFFFu) | (val << 16);
259+
} else {
260+
a.limb[limb_idx] = (a.limb[limb_idx] & 0xFFFF0000u) | (val & 0xFFFFu);
261+
}
262+
}
263+
264+
number mult_scalar_16(number a, uint b_16) {
265+
number c = null_number();
266+
c.sign = a.sign;
267+
uint carry = 0u;
268+
269+
for (int i = 0; i < LIMB_SIZE; ++i) {
270+
if (a.limb[i] == 0u && carry == 0u) continue;
271+
272+
uvec2 prod = multiply_with_remainder(a.limb[i], b_16);
273+
uvec2 add1 = add_with_carry(prod.x, carry);
274+
275+
c.limb[i] = add1.x;
276+
carry = prod.y + add1.y;
277+
}
278+
return c;
279+
}
280+
281+
number div(number n, number d){
282+
number q = null_number();
283+
q.sign = n.sign * d.sign;
284+
285+
n = shift_left(n, FRACTIONAL_SIZE * 32);
286+
287+
if(is_zero(d)){
288+
q = infinite_number();
289+
return q;
290+
};
291+
if(d.is_infinite) return q;
292+
293+
int msb_d = find_msb(d);
294+
if(msb_d == -1) return q;
295+
296+
int msb_n = find_msb(n);
297+
if(msb_n == -1) return q;
298+
299+
int len_v = (msb_d >> 4) + 1;
300+
int len_u = (msb_n >> 4) + 1;
301+
302+
if (len_u < len_v) return q;
303+
304+
if (len_v == 1) {
305+
uint v0 = get_half(d, 0);
306+
uint rem = 0u;
307+
for (int i = len_u - 1; i >= 0; --i) {
308+
uint dividend = (rem << 16) | get_half(n, i);
309+
uint q_i = dividend / v0;
310+
rem = dividend % v0;
311+
set_half(q, i, q_i);
312+
}
313+
return q;
314+
}
315+
int shift = 15 - (msb_d % 16);
316+
number u = shift_left(n, shift);
317+
number v = shift_left(d, shift);
318+
319+
int m = len_u - len_v;
320+
int n_len = len_v;
321+
322+
uint v_n1 = get_half(v, n_len - 1);
323+
uint v_n2 = get_half(v, n_len - 2);
324+
325+
for (int j = m; j >= 0; --j) {
326+
uint u_jn = get_half(u, j + n_len);
327+
uint u_jn1 = get_half(u, j + n_len - 1);
328+
uint u_jn2 = get_half(u, j + n_len - 2);
329+
330+
uint dividend = (u_jn << 16) | u_jn1;
331+
uint q_hat, r_hat;
332+
333+
if (u_jn == v_n1) {
334+
q_hat = 0xFFFFu;
335+
r_hat = u_jn1 + v_n1;
336+
} else {
337+
q_hat = dividend / v_n1;
338+
r_hat = dividend % v_n1;
339+
}
340+
341+
while (r_hat < 0x10000u && (q_hat * v_n2) > ((r_hat << 16) | u_jn2)) {
342+
q_hat--;
343+
r_hat += v_n1;
344+
}
345+
346+
uint k = 0u;
347+
uint borrow = 0u;
348+
for (int i = 0; i < n_len; ++i) {
349+
uint p = q_hat * get_half(v, i) + k;
350+
k = p >> 16;
351+
uint p_lo = p & 0xFFFFu;
352+
353+
uint u_ji = get_half(u, j + i);
354+
int diff = int(u_ji) - int(p_lo) - int(borrow);
355+
356+
set_half(u, j + i, uint(diff) & 0xFFFFu);
357+
borrow = (diff < 0) ? 1u : 0u;
358+
}
359+
int final_diff = int(get_half(u, j + n_len)) - int(k) - int(borrow);
360+
set_half(u, j + n_len, uint(final_diff) & 0xFFFFu);
361+
if (final_diff < 0) {
362+
q_hat--;
363+
uint carry_add = 0u;
364+
for (int i = 0; i < n_len; ++i) {
365+
uint sum = get_half(u, j + i) + get_half(v, i) + carry_add;
366+
set_half(u, j + i, sum & 0xFFFFu);
367+
carry_add = sum >> 16;
368+
}
369+
uint sum_last = get_half(u, j + n_len) + carry_add;
370+
set_half(u, j + n_len, sum_last & 0xFFFFu);
371+
}
372+
set_half(q, j, q_hat);
373+
}
374+
return q;
375+
}
376+
377+
378+
)";
379+
4380
inline std::string SRC_PICKER_FRAG = R"(#version 300 es
5381
precision highp float;
6382
precision highp int;

0 commit comments

Comments
 (0)