Skip to content

Commit d3d98d2

Browse files
authored
Random operator rework (#6100)
Random number generators now take a `_rng_state` argument, which, when provided, overrides the seed and makes the upcoming iteration totally deterministic and repeatable. To achieve that, the generators use a counter-based Philox RNG and the random state consists of the key and counter. This PR reworks the RNGBase, which no longer needs to maintain a state array (as states can be simply regenerated each time). ----- Signed-off-by: Michal Zientkiewicz <michalz@nvidia.com>
1 parent 2767a97 commit d3d98d2

38 files changed

Lines changed: 610 additions & 824 deletions

dali/operators/generic/roi_random_crop.cc

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright (c) 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1+
// Copyright (c) 2021-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
//
33
// Licensed under the Apache License, Version 2.0 (the "License");
44
// you may not use this file except in compliance with the License.
@@ -15,7 +15,6 @@
1515
#include <random>
1616
#include "dali/pipeline/operator/operator.h"
1717
#include "dali/operators/random/rng_base_cpu.h"
18-
#include "dali/pipeline/util/batch_rng.h"
1918
#include "dali/pipeline/operator/arg_helper.h"
2019

2120
namespace dali {
@@ -70,7 +69,7 @@ bounds of the input.
7069
.NumInput(0, 1)
7170
.NumOutput(1);
7271

73-
class ROIRandomCropCPU : public rng::OperatorWithRng<CPUBackend> {
72+
class ROIRandomCropCPU : public rng::OperatorWithRng<Operator<CPUBackend>> {
7473
public:
7574
explicit ROIRandomCropCPU(const OpSpec &spec);
7675
bool SetupImpl(std::vector<OutputDesc> &output_desc, const Workspace &ws) override;
@@ -89,7 +88,7 @@ class ROIRandomCropCPU : public rng::OperatorWithRng<CPUBackend> {
8988
};
9089

9190
ROIRandomCropCPU::ROIRandomCropCPU(const OpSpec &spec)
92-
: rng::OperatorWithRng<CPUBackend>(spec),
91+
: OperatorWithRng<Operator<CPUBackend>>(spec),
9392
roi_start_("roi_start", spec),
9493
roi_end_("roi_end", spec),
9594
roi_shape_("roi_shape", spec),
@@ -181,6 +180,7 @@ void ROIRandomCropCPU::RunImpl(Workspace &ws) {
181180
int ndim = crop_start[0].shape[0];
182181

183182
for (int sample_idx = 0; sample_idx < nsamples; sample_idx++) {
183+
auto rng = GetSampleRNG(sample_idx);
184184
int64_t* sample_sh = nullptr;
185185
if (!in_shape_.empty())
186186
sample_sh = in_shape_.tensor_shape_span(sample_idx).data();
@@ -208,7 +208,7 @@ void ROIRandomCropCPU::RunImpl(Workspace &ws) {
208208
}
209209

210210
auto dist = std::uniform_int_distribution<int64_t>(start_range[0], start_range[1]);
211-
crop_start[sample_idx].data[d] = dist(rng_[sample_idx]);
211+
crop_start[sample_idx].data[d] = dist(rng);
212212
}
213213
}
214214
}

dali/operators/image/crop/bbox_crop.cc

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
#include "dali/core/geom/box.h"
2626
#include "dali/core/static_switch.h"
2727
#include "dali/pipeline/data/views.h"
28-
#include "dali/pipeline/util/batch_rng.h"
2928
#include "dali/pipeline/util/bounding_box_utils.h"
3029

3130
namespace dali {
@@ -387,16 +386,18 @@ class RandomBBoxCropImpl : public OpImplBase<CPUBackend> {
387386
/**
388387
* @param spec Pointer to a persistent OpSpec object,
389388
* which is guaranteed to be alive for the entire lifetime of this object
389+
* @param rng_op Pointer to the (base class) of the enclosing operator, needed for
390+
* accessing the random number generator infrastructure.
390391
*/
391-
RandomBBoxCropImpl(const OpSpec *spec, BatchRNG<std::mt19937_64> &rng)
392+
RandomBBoxCropImpl(const OpSpec *spec, rng::OperatorWithRng<Operator<CPUBackend>> *rng_op)
392393
: spec_(*spec),
394+
rng_op_(rng_op),
393395
num_attempts_{spec_.GetArgument<int>("num_attempts")},
394396
has_labels_(spec_.NumRegularInput() > 1),
395397
has_crop_shape_(spec_.ArgumentDefined("crop_shape")),
396398
has_input_shape_(spec_.ArgumentDefined("input_shape")),
397399
all_boxes_above_threshold_(spec_.GetArgument<bool>("all_boxes_above_threshold")),
398-
output_bbox_indices_(spec_.GetArgument<bool>("output_bbox_indices")),
399-
rngs_(rng) {
400+
output_bbox_indices_(spec_.GetArgument<bool>("output_bbox_indices")) {
400401
has_bbox_layout_ = spec_.TryGetArgument(bbox_layout_, "bbox_layout");
401402
has_shape_layout_ = spec_.TryGetArgument(shape_layout_, "shape_layout");
402403

@@ -717,8 +718,8 @@ class RandomBBoxCropImpl : public OpImplBase<CPUBackend> {
717718
float best_metric = -1.0;
718719

719720
crop.clear();
721+
auto rng = rng_op_->GetSampleRNG(sample);
720722
while (!crop.success && (total_num_attempts_ < 0 || count < total_num_attempts_)) {
721-
auto &rng = rngs_[sample];
722723
std::uniform_int_distribution<> idx_dist(0, sample_options_.size() - 1);
723724
SampleOption option = sample_options_[idx_dist(rng)];
724725
bool absolute_crop_dims = has_crop_shape_;
@@ -911,6 +912,7 @@ class RandomBBoxCropImpl : public OpImplBase<CPUBackend> {
911912

912913
private:
913914
const OpSpec &spec_;
915+
rng::OperatorWithRng<Operator<CPUBackend>> *rng_op_;
914916
int num_attempts_;
915917
int total_num_attempts_;
916918
bool has_labels_;
@@ -928,8 +930,6 @@ class RandomBBoxCropImpl : public OpImplBase<CPUBackend> {
928930
bool output_bbox_indices_ = false;
929931
float bbox_prune_threshold_ = 0.0f;
930932

931-
BatchRNG<std::mt19937_64> &rngs_;
932-
933933
std::vector<SampleOption> sample_options_;
934934

935935
std::vector<i64vec<ndim>> crop_shape_;
@@ -951,7 +951,7 @@ RandomBBoxCrop<CPUBackend>::~RandomBBoxCrop() = default;
951951

952952
template <>
953953
RandomBBoxCrop<CPUBackend>::RandomBBoxCrop(const OpSpec &spec)
954-
: OperatorWithRng<CPUBackend>(spec) {}
954+
: OperatorWithRng<Operator<CPUBackend>>(spec) {}
955955

956956
template <>
957957
bool RandomBBoxCrop<CPUBackend>::SetupImpl(std::vector<OutputDesc> &output_desc,
@@ -977,7 +977,7 @@ bool RandomBBoxCrop<CPUBackend>::SetupImpl(std::vector<OutputDesc> &output_desc,
977977

978978
if (impl_ == nullptr || impl_ndim_ != num_dims) {
979979
VALUE_SWITCH(num_dims, ndim, (2, 3),
980-
(impl_ = std::make_unique<RandomBBoxCropImpl<ndim>>(&spec_, rng_);),
980+
(impl_ = std::make_unique<RandomBBoxCropImpl<ndim>>(&spec_, this);),
981981
(DALI_FAIL(make_string("Not supported number of dimensions", num_dims));));
982982
impl_ndim_ = num_dims;
983983
}

dali/operators/image/crop/bbox_crop.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright (c) 2017-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1+
// Copyright (c) 2017-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
//
33
// Licensed under the Apache License, Version 2.0 (the "License");
44
// you may not use this file except in compliance with the License.
@@ -25,7 +25,7 @@
2525
namespace dali {
2626

2727
template <typename Backend>
28-
class RandomBBoxCrop : public rng::OperatorWithRng<Backend> {
28+
class RandomBBoxCrop : public rng::OperatorWithRng<Operator<Backend>> {
2929
public:
3030
explicit inline RandomBBoxCrop(const OpSpec &spec);
3131
~RandomBBoxCrop() override;

dali/operators/image/remap/displacement_filter.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright (c) 2017-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1+
// Copyright (c) 2017-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
//
33
// Licensed under the Apache License, Version 2.0 (the "License");
44
// you may not use this file except in compliance with the License.
@@ -40,8 +40,7 @@ class DisplacementIdentity {
4040
explicit DisplacementIdentity(const OpSpec& spec) {}
4141

4242
DALI_HOST_DEV
43-
ivec2 operator()(const int h, const int w, const int c,
44-
const int H, const int W, const int C) {
43+
ivec2 operator()(int sample_idx, int h, int w, int c, int H, int W, int C) {
4544
// identity
4645
return { w, h };
4746
}

dali/operators/image/remap/displacement_filter_impl_cpu.h

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright (c) 2017-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1+
// Copyright (c) 2017-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
//
33
// Licensed under the Apache License, Version 2.0 (the "License");
44
// you may not use this file except in compliance with the License.
@@ -33,6 +33,7 @@ namespace dali {
3333
template <DALIInterpType interp_type, bool per_channel,
3434
typename Out, typename In, typename Displacement, typename Border>
3535
void Warp(
36+
int sample_idx,
3637
const kernels::OutTensorCPU<Out, 3> &out,
3738
const kernels::InTensorCPU<In, 3> &in,
3839
Displacement &displacement,
@@ -51,11 +52,11 @@ void Warp(
5152
for (int x = 0; x < outW; x++) {
5253
if (per_channel) {
5354
for (int c = 0; c < C; c++) {
54-
auto p = displacement(y, x, c, inH, inW, C);
55+
auto p = displacement(sample_idx, y, x, c, inH, inW, C);
5556
sampler(&out_row[C*x], p, c, border);
5657
}
5758
} else {
58-
auto p = displacement(y, x, 0, inH, inW, C);
59+
auto p = displacement(sample_idx, y, x, 0, inH, inW, C);
5960
sampler(&out_row[C*x], p, border);
6061
}
6162
}
@@ -91,7 +92,11 @@ class DisplacementFilter<CPUBackend, Displacement, per_channel_transform>
9192
}
9293

9394
template <typename Out, typename In, DALIInterpType interp>
94-
void RunWarp(SampleView<CPUBackend> output, ConstSampleView<CPUBackend> input, int thread_idx) {
95+
void RunWarp(
96+
int sample_idx,
97+
SampleView<CPUBackend> output,
98+
ConstSampleView<CPUBackend> input,
99+
int thread_idx) {
95100
auto &displace = displace_[thread_idx];
96101
In fill[1024];
97102
auto in = view<const Out, 3>(input);
@@ -101,7 +106,7 @@ class DisplacementFilter<CPUBackend, Displacement, per_channel_transform>
101106
fill[i] = fill_value_;
102107
}
103108

104-
Warp<interp, per_channel_transform>(out, in, displace, fill);
109+
Warp<interp, per_channel_transform>(sample_idx, out, in, displace, fill);
105110
}
106111

107112
void RunSample(Workspace &ws, int sample_idx, int thread_idx) {
@@ -117,18 +122,22 @@ class DisplacementFilter<CPUBackend, Displacement, per_channel_transform>
117122
switch (interp_type_) {
118123
case DALI_INTERP_NN:
119124
if (IsType<float>(input.type())) {
120-
RunWarp<float, float, DALI_INTERP_NN>(out_tensor, in_tensor, thread_idx);
125+
RunWarp<float, float, DALI_INTERP_NN>(
126+
sample_idx, out_tensor, in_tensor, thread_idx);
121127
} else if (IsType<uint8_t>(input.type())) {
122-
RunWarp<uint8_t, uint8_t, DALI_INTERP_NN>(out_tensor, in_tensor, thread_idx);
128+
RunWarp<uint8_t, uint8_t, DALI_INTERP_NN>(
129+
sample_idx, out_tensor, in_tensor, thread_idx);
123130
} else {
124131
DALI_FAIL(make_string("Unexpected input type ", input.type()));
125132
}
126133
break;
127134
case DALI_INTERP_LINEAR:
128135
if (IsType<float>(input.type())) {
129-
RunWarp<float, float, DALI_INTERP_LINEAR>(out_tensor, in_tensor, thread_idx);
136+
RunWarp<float, float, DALI_INTERP_LINEAR>(
137+
sample_idx, out_tensor, in_tensor, thread_idx);
130138
} else if (IsType<uint8_t>(input.type())) {
131-
RunWarp<uint8_t, uint8_t, DALI_INTERP_LINEAR>(out_tensor, in_tensor, thread_idx);
139+
RunWarp<uint8_t, uint8_t, DALI_INTERP_LINEAR>(
140+
sample_idx, out_tensor, in_tensor, thread_idx);
132141
} else {
133142
DALI_FAIL(make_string("Unexpected input type ", input.type()));
134143
}

dali/operators/image/remap/displacement_filter_impl_gpu.cuh

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright (c) 2017-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1+
// Copyright (c) 2017-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
//
33
// Licensed under the Apache License, Version 2.0 (the "License");
44
// you may not use this file except in compliance with the License.
@@ -36,24 +36,26 @@ struct DisplacementSampleDesc {
3636
};
3737

3838
template <typename T, class Displacement, DALIInterpType interp_type>
39-
__device__ inline T GetPixelValueSingleC(int h, int w, int c,
39+
__device__ inline T GetPixelValueSingleC(int sample_idx,
40+
int h, int w, int c,
4041
int H, int W, int C,
4142
const T * input,
4243
Displacement& displace, const T fill_value) {
4344
kernels::Surface2D<const T> in_surface = { input, W, H, C, C, C*W, 1 };
4445
auto sampler = kernels::make_sampler<interp_type>(in_surface);
45-
auto p = displace(h, w, c, H, W, C);
46+
auto p = displace(sample_idx, h, w, c, H, W, C);
4647
return sampler.template at<T>(p, c, fill_value);
4748
}
4849

4950
template <typename T, class Displacement, DALIInterpType interp_type>
50-
__device__ inline void GetPixelValueMultiC(int h, int w,
51+
__device__ inline void GetPixelValueMultiC(int sample_idx,
52+
int h, int w,
5153
int H, int W, int C,
5254
const T * input, T * output,
5355
Displacement& displace, const T fill_value) {
5456
kernels::Surface2D<const T> in_surface = { input, W, H, C, C, C*W, 1 };
5557
auto sampler = kernels::make_sampler<interp_type>(in_surface);
56-
auto p = displace(h, w, 0, H, W, C);
58+
auto p = displace(sample_idx, h, w, 0, H, W, C);
5759
sampler(output, p, fill_value);
5860
}
5961

@@ -75,7 +77,8 @@ __global__ void DisplacementKernel(const DisplacementSampleDesc *samples,
7577
const kernels::BlockDesc<1> *blocks, const T fill_value,
7678
Displacement displace) {
7779
const auto &block = blocks[blockIdx.x];
78-
const auto &sample = samples[block.sample_idx];
80+
const int sample_idx = block.sample_idx;
81+
const auto &sample = samples[sample_idx];
7982

8083
auto *image_out = static_cast<T *>(sample.output);
8184
const auto *image_in = static_cast<const T *>(sample.input);
@@ -100,7 +103,7 @@ __global__ void DisplacementKernel(const DisplacementSampleDesc *samples,
100103
const int h = idx;
101104

102105
image_out[out_idx] = GetPixelValueSingleC<T, Displacement, interp_type>(
103-
h, w, c, H, W, C, image_in, displace, fill_value);
106+
sample_idx, h, w, c, H, W, C, image_in, displace, fill_value);
104107
} else {
105108
image_out[out_idx] = image_in[out_idx];
106109
}
@@ -111,15 +114,16 @@ template <typename T, int C, bool per_channel_transform,
111114
int nThreads, class Displacement, DALIInterpType interp_type>
112115
__global__
113116
void DisplacementKernel_aligned32bit(
114-
const DisplacementSampleDesc *samples,
115-
const kernels::BlockDesc<1> *blocks,
116-
const T fill_value,
117-
Displacement displace) {
117+
const DisplacementSampleDesc *samples,
118+
const kernels::BlockDesc<1> *blocks,
119+
const T fill_value,
120+
Displacement displace) {
118121
constexpr int nPixelsPerThread = sizeof(uint32_t)/sizeof(T);
119122
__shared__ T scratch[nThreads * C * nPixelsPerThread];
120123

121124
const auto &block = blocks[blockIdx.x];
122-
const auto &sample = samples[block.sample_idx];
125+
const int sample_idx = block.sample_idx;
126+
const auto &sample = samples[sample_idx];
123127

124128
auto *image_out = reinterpret_cast<uint32_t *>(sample.output);
125129
const auto *image_in = reinterpret_cast<const T *>(sample.input);
@@ -153,7 +157,7 @@ void DisplacementKernel_aligned32bit(
153157
#pragma unroll
154158
for (int c = 0; c < C; ++c) {
155159
my_scratch[j * C + c] = GetPixelValueSingleC<T, Displacement, interp_type>(
156-
h, w, c, H, W, C, image_in, displace, fill_value);
160+
sample_idx, h, w, c, H, W, C, image_in, displace, fill_value);
157161
}
158162
}
159163
} else {
@@ -163,7 +167,7 @@ void DisplacementKernel_aligned32bit(
163167
const int w = hw % W;
164168
const int h = hw / W;
165169
GetPixelValueMultiC<T, Displacement, interp_type>(
166-
h, w, H, W, C, image_in, my_scratch + j * C, displace, fill_value);
170+
sample_idx, h, w, H, W, C, image_in, my_scratch + j * C, displace, fill_value);
167171
}
168172
}
169173
__syncthreads();
@@ -188,8 +192,8 @@ void DisplacementKernel_aligned32bit(
188192
#pragma unroll
189193
for (int c = 0; c < C; ++c) {
190194
out[h * W * C + w * C + c] =
191-
GetPixelValueSingleC<T, Displacement, interp_type>(h, w, c, H, W, C, image_in,
192-
displace, fill_value);
195+
GetPixelValueSingleC<T, Displacement, interp_type>(
196+
sample_idx, h, w, c, H, W, C, image_in, displace, fill_value);
193197
}
194198
}
195199
} else {
@@ -199,7 +203,7 @@ void DisplacementKernel_aligned32bit(
199203
const int w = hw % W;
200204
const int h = hw / W;
201205
GetPixelValueMultiC<T, Displacement, interp_type>(
202-
h, w, H, W, C, image_in, out + h * W * C + w * C, displace, fill_value);
206+
sample_idx, h, w, H, W, C, image_in, out + h * W * C + w * C, displace, fill_value);
203207
}
204208
}
205209
}

0 commit comments

Comments
 (0)