Skip to content

Commit 6b0fe7a

Browse files
authored
Add support for array sizes >=2**31 for random number generation (#414)
* Add support for array sizes >=2**31 for random number generation * Remove unused imports
1 parent 3d83295 commit 6b0fe7a

2 files changed

Lines changed: 232 additions & 75 deletions

File tree

pycuda/curandom.py

Lines changed: 65 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -276,41 +276,41 @@ def get_curand_version():
276276
# {{{ Base class
277277

278278
gen_template = """
279-
__global__ void %(name)s(%(state_type)s *s, %(out_type)s *d, const int n)
279+
__global__ void %(name)s(%(state_type)s *s, %(out_type)s *d, const size_t n)
280280
{
281-
const int tidx = blockIdx.x*blockDim.x+threadIdx.x;
282-
const int delta = blockDim.x*gridDim.x;
283-
for (int idx = tidx; idx < n; idx += delta)
281+
const size_t tidx = blockIdx.x*blockDim.x+threadIdx.x;
282+
const size_t delta = blockDim.x*gridDim.x;
283+
for (size_t idx = tidx; idx < n; idx += delta)
284284
d[idx] = curand%(suffix)s(&s[tidx]);
285285
}
286286
"""
287287

288288
gen_log_template = """
289-
__global__ void %(name)s(%(state_type)s *s, %(out_type)s *d, %(in_type)s mean, %(in_type)s stddev, const int n)
289+
__global__ void %(name)s(%(state_type)s *s, %(out_type)s *d, %(in_type)s mean, %(in_type)s stddev, const size_t n)
290290
{
291-
const int tidx = blockIdx.x*blockDim.x+threadIdx.x;
292-
const int delta = blockDim.x*gridDim.x;
293-
for (int idx = tidx; idx < n; idx += delta)
291+
const size_t tidx = blockIdx.x*blockDim.x+threadIdx.x;
292+
const size_t delta = blockDim.x*gridDim.x;
293+
for (size_t idx = tidx; idx < n; idx += delta)
294294
d[idx] = curand_log%(suffix)s(&s[tidx], mean, stddev);
295295
}
296296
"""
297297

298298
gen_poisson_template = """
299-
__global__ void %(name)s(%(state_type)s *s, %(out_type)s *d, double lambda, const int n)
299+
__global__ void %(name)s(%(state_type)s *s, %(out_type)s *d, double lambda, const size_t n)
300300
{
301-
const int tidx = blockIdx.x*blockDim.x+threadIdx.x;
302-
const int delta = blockDim.x*gridDim.x;
303-
for (int idx = tidx; idx < n; idx += delta)
301+
const size_t tidx = blockIdx.x*blockDim.x+threadIdx.x;
302+
const size_t delta = blockDim.x*gridDim.x;
303+
for (size_t idx = tidx; idx < n; idx += delta)
304304
d[idx] = curand_poisson%(suffix)s(&s[tidx], lambda);
305305
}
306306
"""
307307

308308
gen_poisson_inplace_template = """
309-
__global__ void %(name)s(%(state_type)s *s, %(inout_type)s *d, const int n)
309+
__global__ void %(name)s(%(state_type)s *s, %(inout_type)s *d, const size_t n)
310310
{
311-
const int tidx = blockIdx.x*blockDim.x+threadIdx.x;
312-
const int delta = blockDim.x*gridDim.x;
313-
for (int idx = tidx; idx < n; idx += delta)
311+
const size_t tidx = blockIdx.x*blockDim.x+threadIdx.x;
312+
const size_t delta = blockDim.x*gridDim.x;
313+
for (size_t idx = tidx; idx < n; idx += delta)
314314
d[idx] = (%(inout_type)s)(curand_poisson%(suffix)s(&s[tidx], double(d[idx])));
315315
}
316316
"""
@@ -330,16 +330,16 @@ def get_curand_version():
330330

331331
random_skip_ahead32_source = """
332332
extern "C" {
333-
__global__ void skip_ahead(%(state_type)s *s, const int n, const unsigned int skip)
333+
__global__ void skip_ahead(%(state_type)s *s, const size_t n, const unsigned int skip)
334334
{
335-
const int idx = blockIdx.x*blockDim.x+threadIdx.x;
335+
const size_t idx = blockIdx.x*blockDim.x+threadIdx.x;
336336
if (idx < n)
337337
skipahead(skip, &s[idx]);
338338
}
339339
340-
__global__ void skip_ahead_array(%(state_type)s *s, const int n, const unsigned int *skip)
340+
__global__ void skip_ahead_array(%(state_type)s *s, const size_t n, const unsigned int *skip)
341341
{
342-
const int idx = blockIdx.x*blockDim.x+threadIdx.x;
342+
const size_t idx = blockIdx.x*blockDim.x+threadIdx.x;
343343
if (idx < n)
344344
skipahead(skip[idx], &s[idx]);
345345
}
@@ -348,16 +348,16 @@ def get_curand_version():
348348

349349
random_skip_ahead64_source = """
350350
extern "C" {
351-
__global__ void skip_ahead(%(state_type)s *s, const int n, const unsigned long long skip)
351+
__global__ void skip_ahead(%(state_type)s *s, const size_t n, const unsigned long long skip)
352352
{
353-
const int idx = blockIdx.x*blockDim.x+threadIdx.x;
353+
const size_t idx = blockIdx.x*blockDim.x+threadIdx.x;
354354
if (idx < n)
355355
skipahead(skip, &s[idx]);
356356
}
357357
358-
__global__ void skip_ahead_array(%(state_type)s *s, const int n, const unsigned long long *skip)
358+
__global__ void skip_ahead_array(%(state_type)s *s, const size_t n, const unsigned long long *skip)
359359
{
360-
const int idx = blockIdx.x*blockDim.x+threadIdx.x;
360+
const size_t idx = blockIdx.x*blockDim.x+threadIdx.x;
361361
if (idx < n)
362362
skipahead(skip[idx], &s[idx]);
363363
}
@@ -517,24 +517,24 @@ def do_generate(out_type):
517517
self.generators = {}
518518
for name, out_type, suffix in my_generators:
519519
gen_func = module.get_function(name)
520-
gen_func.prepare("PPi")
520+
gen_func.prepare("PPn")
521521
self.generators[name] = gen_func
522522
if get_curand_version() >= (4, 0, 0):
523523
for name, in_type, out_type, suffix in my_log_generators:
524524
gen_func = module.get_function(name)
525525
if in_type == "float":
526-
gen_func.prepare("PPffi")
526+
gen_func.prepare("PPffn")
527527
if in_type == "double":
528-
gen_func.prepare("PPddi")
528+
gen_func.prepare("PPddn")
529529
self.generators[name] = gen_func
530530
if get_curand_version() >= (5, 0, 0):
531531
for name, out_type, suffix in my_poisson_generators:
532532
gen_func = module.get_function(name)
533-
gen_func.prepare("PPdi")
533+
gen_func.prepare("PPdn")
534534
self.generators[name] = gen_func
535535
for name, inout_type, suffix in my_poisson_inplace_generators:
536536
gen_func = module.get_function(name)
537-
gen_func.prepare("PPi")
537+
gen_func.prepare("PPn")
538538
self.generators[name] = gen_func
539539

540540
self.generator_bits = generator_bits
@@ -546,11 +546,11 @@ def do_generate(out_type):
546546
def _prepare_skipahead(self):
547547
self.skip_ahead = self.module.get_function("skip_ahead")
548548
if self.generator_bits == 32:
549-
self.skip_ahead.prepare("PiI")
549+
self.skip_ahead.prepare("PnI")
550550
if self.generator_bits == 64:
551-
self.skip_ahead.prepare("PiQ")
551+
self.skip_ahead.prepare("PnQ")
552552
self.skip_ahead_array = self.module.get_function("skip_ahead_array")
553-
self.skip_ahead_array.prepare("PiP")
553+
self.skip_ahead_array.prepare("PnP")
554554

555555
def _kernels(self):
556556
return list(self.generators.values()) + [
@@ -769,7 +769,7 @@ def __init__(
769769
raise TypeError("seed must be GPUArray of integers of right length")
770770

771771
p = self.module.get_function("prepare")
772-
p.prepare("PiPi")
772+
p.prepare("PnPn")
773773

774774
from pycuda.characterize import has_stack
775775

@@ -799,15 +799,15 @@ def __init__(
799799

800800
def _prepare_skipahead(self):
801801
self.skip_ahead = self.module.get_function("skip_ahead")
802-
self.skip_ahead.prepare("PiQ")
802+
self.skip_ahead.prepare("PnQ")
803803
self.skip_ahead_array = self.module.get_function("skip_ahead_array")
804-
self.skip_ahead_array.prepare("PiP")
804+
self.skip_ahead_array.prepare("PnP")
805805
self.skip_ahead_sequence = self.module.get_function("skip_ahead_sequence")
806-
self.skip_ahead_sequence.prepare("PiQ")
806+
self.skip_ahead_sequence.prepare("PnQ")
807807
self.skip_ahead_sequence_array = self.module.get_function(
808808
"skip_ahead_sequence_array"
809809
)
810-
self.skip_ahead_sequence_array.prepare("PiP")
810+
self.skip_ahead_sequence_array.prepare("PnP")
811811

812812
def call_skip_ahead_sequence(self, i, stream=None):
813813
self.skip_ahead_sequence.prepared_async_call(
@@ -855,10 +855,10 @@ def seed_getter_unique(n):
855855

856856
xorwow_random_source = """
857857
extern "C" {
858-
__global__ void prepare(%(state_type)s *s, const int n,
859-
%(vector_type)s *v, const unsigned int o)
858+
__global__ void prepare(%(state_type)s *s, const size_t n,
859+
%(vector_type)s *v, const size_t o)
860860
{
861-
const int id = blockIdx.x*blockDim.x+threadIdx.x;
861+
const size_t id = blockIdx.x*blockDim.x+threadIdx.x;
862862
if (id < n)
863863
curand_init(v[id], id, o, &s[id]);
864864
}
@@ -867,16 +867,16 @@ def seed_getter_unique(n):
867867

868868
xorwow_skip_ahead_sequence_source = """
869869
extern "C" {
870-
__global__ void skip_ahead_sequence(%(state_type)s *s, const int n, const unsigned long long skip)
870+
__global__ void skip_ahead_sequence(%(state_type)s *s, const size_t n, const unsigned long long skip)
871871
{
872-
const int idx = blockIdx.x*blockDim.x+threadIdx.x;
872+
const size_t idx = blockIdx.x*blockDim.x+threadIdx.x;
873873
if (idx < n)
874874
skipahead_sequence(skip, &s[idx]);
875875
}
876876
877-
__global__ void skip_ahead_sequence_array(%(state_type)s *s, const int n, const unsigned long long *skip)
877+
__global__ void skip_ahead_sequence_array(%(state_type)s *s, const size_t n, const unsigned long long *skip)
878878
{
879-
const int idx = blockIdx.x*blockDim.x+threadIdx.x;
879+
const size_t idx = blockIdx.x*blockDim.x+threadIdx.x;
880880
if (idx < n)
881881
skipahead_sequence(skip[idx], &s[idx]);
882882
}
@@ -912,10 +912,10 @@ def __init__(self, seed_getter=None, offset=0):
912912

913913
mrg32k3a_random_source = """
914914
extern "C" {
915-
__global__ void prepare(%(state_type)s *s, const int n,
916-
%(vector_type)s *v, const unsigned int o)
915+
__global__ void prepare(%(state_type)s *s, const size_t n,
916+
%(vector_type)s *v, const size_t o)
917917
{
918-
const int id = blockIdx.x*blockDim.x+threadIdx.x;
918+
const size_t id = blockIdx.x*blockDim.x+threadIdx.x;
919919
if (id < n)
920920
curand_init(v[id], id, o, &s[id]);
921921
}
@@ -924,30 +924,30 @@ def __init__(self, seed_getter=None, offset=0):
924924

925925
mrg32k3a_skip_ahead_sequence_source = """
926926
extern "C" {
927-
__global__ void skip_ahead_sequence(%(state_type)s *s, const int n, const unsigned long long skip)
927+
__global__ void skip_ahead_sequence(%(state_type)s *s, const size_t n, const unsigned long long skip)
928928
{
929-
const int idx = blockIdx.x*blockDim.x+threadIdx.x;
929+
const size_t idx = blockIdx.x*blockDim.x+threadIdx.x;
930930
if (idx < n)
931931
skipahead_sequence(skip, &s[idx]);
932932
}
933933
934-
__global__ void skip_ahead_sequence_array(%(state_type)s *s, const int n, const unsigned long long *skip)
934+
__global__ void skip_ahead_sequence_array(%(state_type)s *s, const size_t n, const unsigned long long *skip)
935935
{
936-
const int idx = blockIdx.x*blockDim.x+threadIdx.x;
936+
const size_t idx = blockIdx.x*blockDim.x+threadIdx.x;
937937
if (idx < n)
938938
skipahead_sequence(skip[idx], &s[idx]);
939939
}
940940
941-
__global__ void skip_ahead_subsequence(%(state_type)s *s, const int n, const unsigned long long skip)
941+
__global__ void skip_ahead_subsequence(%(state_type)s *s, const size_t n, const unsigned long long skip)
942942
{
943-
const int idx = blockIdx.x*blockDim.x+threadIdx.x;
943+
const size_t idx = blockIdx.x*blockDim.x+threadIdx.x;
944944
if (idx < n)
945945
skipahead_subsequence(skip, &s[idx]);
946946
}
947947
948-
__global__ void skip_ahead_subsequence_array(%(state_type)s *s, const int n, const unsigned long long *skip)
948+
__global__ void skip_ahead_subsequence_array(%(state_type)s *s, const size_t n, const unsigned long long *skip)
949949
{
950-
const int idx = blockIdx.x*blockDim.x+threadIdx.x;
950+
const size_t idx = blockIdx.x*blockDim.x+threadIdx.x;
951951
if (idx < n)
952952
skipahead_subsequence(skip[idx], &s[idx]);
953953
}
@@ -981,11 +981,11 @@ def _prepare_skipahead(self):
981981
self.skip_ahead_subsequence = self.module.get_function(
982982
"skip_ahead_subsequence"
983983
)
984-
self.skip_ahead_subsequence.prepare("PiQ")
984+
self.skip_ahead_subsequence.prepare("PnQ")
985985
self.skip_ahead_subsequence_array = self.module.get_function(
986986
"skip_ahead_subsequence_array"
987987
)
988-
self.skip_ahead_subsequence_array.prepare("PiP")
988+
self.skip_ahead_subsequence_array.prepare("PnP")
989989

990990
def call_skip_ahead_subsequence(self, i, stream=None):
991991
self.skip_ahead_subsequence.prepared_async_call(
@@ -1049,10 +1049,10 @@ def generate_scramble_constants64(count):
10491049

10501050
sobol_random_source = """
10511051
extern "C" {
1052-
__global__ void prepare(%(state_type)s *s, const int n,
1053-
%(vector_type)s *v, const unsigned int o)
1052+
__global__ void prepare(%(state_type)s *s, const size_t n,
1053+
%(vector_type)s *v, const size_t o)
10541054
{
1055-
const int id = blockIdx.x*blockDim.x+threadIdx.x;
1055+
const size_t id = blockIdx.x*blockDim.x+threadIdx.x;
10561056
if (id < n)
10571057
curand_init(v[id], o, &s[id]);
10581058
}
@@ -1099,7 +1099,7 @@ def __init__(
10991099
raise TypeError("seed must be GPUArray of integers of right length")
11001100

11011101
p = self.module.get_function("prepare")
1102-
p.prepare("PiPi")
1102+
p.prepare("PnPn")
11031103

11041104
from pycuda.characterize import has_stack
11051105

@@ -1135,10 +1135,10 @@ def _kernels(self):
11351135

11361136
scrambledsobol_random_source = """
11371137
extern "C" {
1138-
__global__ void prepare( %(state_type)s *s, const int n,
1139-
%(vector_type)s *v, %(scramble_type)s *scramble, const unsigned int o)
1138+
__global__ void prepare( %(state_type)s *s, const size_t n,
1139+
%(vector_type)s *v, %(scramble_type)s *scramble, const size_t o)
11401140
{
1141-
const int id = blockIdx.x*blockDim.x+threadIdx.x;
1141+
const size_t id = blockIdx.x*blockDim.x+threadIdx.x;
11421142
if (id < n)
11431143
curand_init(v[id], scramble[id], o, &s[id]);
11441144
}
@@ -1200,7 +1200,7 @@ def __init__(
12001200
raise TypeError("scramble must be GPUArray of integers of right length")
12011201

12021202
p = self.module.get_function("prepare")
1203-
p.prepare("PiPPi")
1203+
p.prepare("PnPPn")
12041204

12051205
from pycuda.characterize import has_stack
12061206

0 commit comments

Comments
 (0)