Skip to content

Commit f0985b3

Browse files
authored
Increase VpMult batch size (#511)
Number of divmod operation which is the bottleneck depends on batch size. batch_size*BASE**2 doesn't need to be less than 2**63-1, it needs to be less than 2**64. We can increase batch size to 16.
1 parent 32fb1de commit f0985b3

2 files changed

Lines changed: 16 additions & 13 deletions

File tree

ext/bigdecimal/bigdecimal.c

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,9 @@
3535

3636
#define BIGDECIMAL_VERSION "4.1.0"
3737

38-
#define NTT_MULTIPLICATION_THRESHOLD 350
38+
/* Make sure VPMULT_BATCH_SIZE*BASE*BASE does not overflow DECDIG_DBL */
39+
#define VPMULT_BATCH_SIZE 16
40+
#define NTT_MULTIPLICATION_THRESHOLD 450
3941
#define NEWTON_RAPHSON_DIVISION_THRESHOLD 100
4042
#define SIGNED_VALUE_MAX INTPTR_MAX
4143
#define SIGNED_VALUE_MIN INTPTR_MIN
@@ -4842,7 +4844,7 @@ VP_EXPORT size_t
48424844
VpMult(Real *c, Real *a, Real *b)
48434845
{
48444846
ssize_t a_batch_max, b_batch_max;
4845-
DECDIG_DBL batch[15];
4847+
DECDIG_DBL batch[VPMULT_BATCH_SIZE * 2 - 1];
48464848

48474849
if (!VpIsDefOP(c, a, b, OP_SW_MULT)) return 0; /* No significant digit */
48484850

@@ -4882,27 +4884,28 @@ VpMult(Real *c, Real *a, Real *b)
48824884
c->Prec = a->Prec + b->Prec; /* set precision */
48834885
memset(c->frac, 0, c->Prec * sizeof(DECDIG)); /* Initialize c */
48844886

4885-
// Process 8 decdigits at a time to reduce the number of carry operations.
4886-
a_batch_max = (a->Prec - 1) / 8;
4887-
b_batch_max = (b->Prec - 1) / 8;
4887+
// Process VPMULT_BATCH_SIZE decdigits at a time to reduce the number of carry operations.
4888+
a_batch_max = (a->Prec - 1) / VPMULT_BATCH_SIZE;
4889+
b_batch_max = (b->Prec - 1) / VPMULT_BATCH_SIZE;
48884890
for (ssize_t ibatch = a_batch_max; ibatch >= 0; ibatch--) {
4889-
int isize = ibatch == a_batch_max ? (a->Prec - 1) % 8 + 1 : 8;
4891+
int isize = ibatch == a_batch_max ? (a->Prec - 1) % VPMULT_BATCH_SIZE + 1 : VPMULT_BATCH_SIZE;
48904892
for (ssize_t jbatch = b_batch_max; jbatch >= 0; jbatch--) {
4891-
int jsize = jbatch == b_batch_max ? (b->Prec - 1) % 8 + 1 : 8;
4893+
int jsize = jbatch == b_batch_max ? (b->Prec - 1) % VPMULT_BATCH_SIZE + 1 : VPMULT_BATCH_SIZE;
48924894
memset(batch, 0, (isize + jsize - 1) * sizeof(DECDIG_DBL));
48934895

48944896
// Perform multiplication without carry calculation.
4895-
// 999999999 * 999999999 * 8 < 2**63 - 1, so DECDIG_DBL can hold the intermediate sum without overflow.
4897+
// BASE * BASE * VPMULT_BATCH_SIZE < 2**64 should be satisfied so that
4898+
// DECDIG_DBL can hold the intermediate sum without overflow.
48964899
for (int i = 0; i < isize; i++) {
48974900
for (int j = 0; j < jsize; j++) {
4898-
batch[i + j] += (DECDIG_DBL)a->frac[ibatch * 8 + i] * b->frac[jbatch * 8 + j];
4901+
batch[i + j] += (DECDIG_DBL)a->frac[ibatch * VPMULT_BATCH_SIZE + i] * b->frac[jbatch * VPMULT_BATCH_SIZE + j];
48994902
}
49004903
}
49014904

49024905
// Add the batch result to c with carry calculation.
49034906
DECDIG_DBL carry = 0;
49044907
for (int k = isize + jsize - 2; k >= 0; k--) {
4905-
size_t l = (ibatch + jbatch) * 8 + k + 1;
4908+
size_t l = (ibatch + jbatch) * VPMULT_BATCH_SIZE + k + 1;
49064909
DECDIG_DBL s = c->frac[l] + batch[k] + carry;
49074910
c->frac[l] = (DECDIG)(s % BASE);
49084911
carry = (DECDIG_DBL)(s / BASE);
@@ -4911,7 +4914,7 @@ VpMult(Real *c, Real *a, Real *b)
49114914
// Adding carry may exceed BASE, but it won't cause overflow of DECDIG.
49124915
// Exceeded value will be resolved in the carry operation of next (ibatch + jbatch - 1) batch.
49134916
// WARNING: This safety strongly relies on the current nested loop execution order.
4914-
c->frac[(ibatch + jbatch) * 8] += (DECDIG)carry;
4917+
c->frac[(ibatch + jbatch) * VPMULT_BATCH_SIZE] += (DECDIG)carry;
49154918
}
49164919
}
49174920

test/bigdecimal/test_vp_operation.rb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ def setup
1515

1616
def test_vpmult
1717
# Max carry case
18-
[*32...40].repeated_permutation(2) do |n, m|
18+
[*32...48].repeated_permutation(2) do |n, m|
1919
x = BigDecimal('9' * BASE_FIG * n)
2020
y = BigDecimal('9' * BASE_FIG * m)
2121
assert_equal(x.to_i * y.to_i, x.vpmult(y))
@@ -30,7 +30,7 @@ def test_vpmult
3030

3131
def test_nttmult
3232
# Max carry case
33-
[*32...40].repeated_permutation(2) do |n, m|
33+
[*32...48].repeated_permutation(2) do |n, m|
3434
x = BigDecimal('9' * BASE_FIG * n)
3535
y = BigDecimal('9' * BASE_FIG * m)
3636
assert_equal(x.to_i * y.to_i, x.nttmult(y))

0 commit comments

Comments
 (0)