Skip to content

Commit 0de5aa1

Browse files
jan-wassenbergcopybara-github
authored andcommitted
Fix random-inl for the deprecated HWY_SCALAR target
AESRound is not supported there. Also lint fix for size_t and random_device. PiperOrigin-RevId: 907659857
1 parent 9bbb7fb commit 0de5aa1

2 files changed

Lines changed: 31 additions & 4 deletions

File tree

hwy/contrib/random/random-inl.h

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
#define HIGHWAY_HWY_CONTRIB_RANDOM_RANDOM_H_
2222
#endif
2323

24+
#include <stddef.h>
25+
2426
#include <array>
2527
#include <cstdint>
2628
#include <limits>
@@ -201,7 +203,7 @@ class VectorXoshiro {
201203

202204
HWY_INLINE VU64 operator()() noexcept { return Next(); }
203205

204-
AlignedVector<std::uint64_t> operator()(const std::size_t n) {
206+
AlignedVector<std::uint64_t> operator()(const size_t n) {
205207
AlignedVector<std::uint64_t> result(n);
206208
const ScalableTag<std::uint64_t> tag{};
207209
auto s0 = Load(tag, state_[{0}].data());
@@ -254,7 +256,7 @@ class VectorXoshiro {
254256
return Mul(real, MUL_VALUE);
255257
}
256258

257-
AlignedVector<double> Uniform(const std::size_t n) {
259+
AlignedVector<double> Uniform(const size_t n) {
258260
AlignedVector<double> result(n);
259261
const ScalableTag<std::uint64_t> tag{};
260262
const ScalableTag<double> real_tag{};
@@ -371,7 +373,7 @@ class CachedXoshiro {
371373
private:
372374
VectorXoshiro generator_;
373375
alignas(HWY_ALIGNMENT) std::array<result_type, size> cache_;
374-
std::size_t index_;
376+
size_t index_;
375377

376378
static_assert((size & (size - 1)) == 0 && size != 0,
377379
"only power of 2 are supported");
@@ -420,6 +422,7 @@ class alignas(16) AesCtrEngine {
420422
// users generally call once at a time, this requires buffering, which is not
421423
// worth the complexity in this application.
422424
uint64_t operator()(uint64_t stream, uint64_t counter) const {
425+
#if HWY_TARGET != HWY_SCALAR
423426
using D = Full128<uint8_t>; // 128 bits for AES
424427
using V = Vec<D>;
425428
const Repartition<uint64_t, D> d64;
@@ -441,6 +444,12 @@ class alignas(16) AesCtrEngine {
441444

442445
// Return lower 64 bits of the u8 vector.
443446
return GetLane(BitCast(d64, state));
447+
#else
448+
HWY_DASSERT(0); // Not supported.
449+
(void)stream;
450+
(void)counter;
451+
return 0;
452+
#endif // HWY_TARGET != HWY_SCALAR
444453
}
445454

446455
private:

hwy/contrib/random/random_test.cc

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ void TestSeeding() {
7373

7474
void TestMultiThreadSeeding() {
7575
const std::uint64_t seed = GetSeed();
76-
const std::uint64_t threadId = std::random_device()() % 1000;
76+
const std::uint64_t threadId = GetSeed() % 1000;
7777
VectorXoshiro generator{seed, threadId};
7878
internal::Xoshiro reference{seed};
7979

@@ -293,6 +293,8 @@ void TestUniformCachedXorshiro() {
293293

294294
// ----- AesCtrEngine / RngStream / RandomNormalizedFloat tests -----
295295

296+
#if HWY_TARGET != HWY_SCALAR
297+
296298
void TestAesCtrDeterministic() {
297299
const AesCtrEngine engine1(/*deterministic=*/true);
298300
const AesCtrEngine engine2(/*deterministic=*/true);
@@ -406,6 +408,22 @@ void TestRandomNormalizedFloat() {
406408
HWY_ASSERT(-0.01 < mean && mean < 0.01);
407409
}
408410

411+
#else
412+
413+
void TestAesCtrDeterministic() {}
414+
415+
void TestAesCtrSeeded() {}
416+
417+
void TestAesCtrStreamsDiffer() {}
418+
419+
void TestAesCtrBitDistribution() {}
420+
421+
void TestAesCtrChiSquared() {}
422+
423+
void TestRandomNormalizedFloat() {}
424+
425+
#endif // HWY_TARGET != HWY_SCALAR
426+
409427
} // namespace
410428
// NOLINTNEXTLINE(google-readability-namespace-comments)
411429
} // namespace HWY_NAMESPACE

0 commit comments

Comments
 (0)