Skip to content

Commit 740b7d3

Browse files
committed
Code format
1 parent 7aebec7 commit 740b7d3

3 files changed

Lines changed: 37 additions & 43 deletions

File tree

src/base/random_sample.h

Lines changed: 21 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,8 @@ class RandomSample : public Operator<RandomSample> {
3535
std::optional<Tensor> temperature, float temperature_val,
3636
std::optional<Tensor> top_k, int top_k_val,
3737
std::optional<Tensor> top_p, float top_p_val,
38-
std::optional<Tensor> min_p, float min_p_val,
39-
std::uint64_t seed, std::uint64_t offset,
40-
bool deterministic)
38+
std::optional<Tensor> min_p, float min_p_val, std::uint64_t seed,
39+
std::uint64_t offset, bool deterministic)
4140
: logits_dtype_{logits.dtype()},
4241
out_dtype_{out.dtype()},
4342
ndim_{logits.ndim()},
@@ -70,34 +69,29 @@ class RandomSample : public Operator<RandomSample> {
7069
// Simplified constructor: no filtering, default temperature.
7170
RandomSample(const Tensor logits, Tensor out, Tensor valid,
7271
std::uint64_t seed, std::uint64_t offset)
73-
: RandomSample{logits, out, valid,
74-
std::nullopt, 1.0f,
75-
std::nullopt, 0,
76-
std::nullopt, 1.0f,
77-
std::nullopt, 0.0f,
78-
seed, offset, false} {}
72+
: RandomSample{logits, out, valid, std::nullopt,
73+
1.0f, std::nullopt, 0, std::nullopt,
74+
1.0f, std::nullopt, 0.0f, seed,
75+
offset, false} {}
7976

8077
virtual void operator()(const Tensor logits, Tensor out, Tensor valid,
8178
std::optional<Tensor> temperature,
82-
float temperature_val,
83-
std::optional<Tensor> top_k, int top_k_val,
84-
std::optional<Tensor> top_p, float top_p_val,
85-
std::optional<Tensor> min_p, float min_p_val,
86-
std::uint64_t seed, std::uint64_t offset,
87-
bool deterministic) const = 0;
79+
float temperature_val, std::optional<Tensor> top_k,
80+
int top_k_val, std::optional<Tensor> top_p,
81+
float top_p_val, std::optional<Tensor> min_p,
82+
float min_p_val, std::uint64_t seed,
83+
std::uint64_t offset, bool deterministic) const = 0;
8884

8985
virtual void operator()(const Tensor logits, Tensor out, Tensor valid,
9086
std::uint64_t seed, std::uint64_t offset) const {
91-
return operator()(logits, out, valid,
92-
temperature_, temperature_val_,
93-
top_k_, top_k_val_,
94-
top_p_, top_p_val_,
95-
min_p_, min_p_val_,
96-
seed, offset, deterministic_);
87+
return operator()(logits, out, valid, temperature_, temperature_val_,
88+
top_k_, top_k_val_, top_p_, top_p_val_, min_p_,
89+
min_p_val_, seed, offset, deterministic_);
9790
}
9891

9992
protected:
100-
static void ValidateIntParam(std::optional<Tensor> t, Tensor::Size batch_size) {
93+
static void ValidateIntParam(std::optional<Tensor> t,
94+
Tensor::Size batch_size) {
10195
if (!t.has_value()) return;
10296
const auto& tensor = *t;
10397
assert(tensor.ndim() == 1 && tensor.size(0) == batch_size &&
@@ -107,7 +101,8 @@ class RandomSample : public Operator<RandomSample> {
107101
"per-batch int param must be int32 or int64");
108102
}
109103

110-
static void ValidateFloatParam(std::optional<Tensor> t, Tensor::Size batch_size) {
104+
static void ValidateFloatParam(std::optional<Tensor> t,
105+
Tensor::Size batch_size) {
111106
if (!t.has_value()) return;
112107
const auto& tensor = *t;
113108
assert(tensor.ndim() == 1 && tensor.size(0) == batch_size &&
@@ -119,8 +114,9 @@ class RandomSample : public Operator<RandomSample> {
119114
"per-batch float param must be float16/bfloat16/float32/float64");
120115
}
121116

122-
void ValidateParams(std::optional<Tensor> temperature, std::optional<Tensor> top_k,
123-
std::optional<Tensor> top_p, std::optional<Tensor> min_p) const {
117+
void ValidateParams(std::optional<Tensor> temperature,
118+
std::optional<Tensor> top_k, std::optional<Tensor> top_p,
119+
std::optional<Tensor> min_p) const {
124120
ValidateFloatParam(temperature, batch_size_);
125121
ValidateIntParam(top_k, batch_size_);
126122
ValidateFloatParam(top_p, batch_size_);

src/cpu/random_sample/random_sample.h

Lines changed: 10 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,8 @@
2020
namespace infini::ops {
2121

2222
template <>
23-
class Operator<RandomSample, Device::Type::kCpu>
24-
: public RandomSample,
25-
Caster<Device::Type::kCpu> {
23+
class Operator<RandomSample, Device::Type::kCpu> : public RandomSample,
24+
Caster<Device::Type::kCpu> {
2625
public:
2726
using RandomSample::RandomSample;
2827

@@ -33,8 +32,7 @@ class Operator<RandomSample, Device::Type::kCpu>
3332
std::optional<Tensor> min_p, float min_p_val,
3433
std::uint64_t seed, std::uint64_t offset,
3534
bool deterministic) const override {
36-
DispatchFunc<Device::Type::kCpu,
37-
ConcatType<FloatTypes, ReducedFloatTypes>,
35+
DispatchFunc<Device::Type::kCpu, ConcatType<FloatTypes, ReducedFloatTypes>,
3836
List<DataType::kInt32, DataType::kInt64>>(
3937
{logits_dtype_, out_dtype_},
4038
[&](auto tag1, auto tag2) {
@@ -49,12 +47,9 @@ class Operator<RandomSample, Device::Type::kCpu>
4947

5048
void operator()(const Tensor logits, Tensor out, Tensor valid,
5149
std::uint64_t seed, std::uint64_t offset) const override {
52-
return operator()(logits, out, valid,
53-
std::nullopt, temperature_val_,
54-
std::nullopt, top_k_val_,
55-
std::nullopt, top_p_val_,
56-
std::nullopt, min_p_val_,
57-
seed, offset, deterministic_);
50+
return operator()(logits, out, valid, std::nullopt, temperature_val_,
51+
std::nullopt, top_k_val_, std::nullopt, top_p_val_,
52+
std::nullopt, min_p_val_, seed, offset, deterministic_);
5853
}
5954

6055
private:
@@ -112,9 +107,8 @@ class Operator<RandomSample, Device::Type::kCpu>
112107
std::optional<Tensor> temperature, float temperature_val,
113108
std::optional<Tensor> top_k, int top_k_val,
114109
std::optional<Tensor> top_p, float top_p_val,
115-
std::optional<Tensor> min_p, float min_p_val,
116-
std::uint64_t seed, std::uint64_t offset,
117-
bool deterministic) const {
110+
std::optional<Tensor> min_p, float min_p_val, std::uint64_t seed,
111+
std::uint64_t offset, bool deterministic) const {
118112
assert(valid.dtype() == DataType::kUInt8 &&
119113
"`RandomSample` requires uint8 valid tensor");
120114

@@ -143,8 +137,8 @@ class Operator<RandomSample, Device::Type::kCpu>
143137

144138
float sum = 0.f;
145139
for (Tensor::Size j = 0; j < vocab_size; ++j) {
146-
float v = std::exp(
147-
Cast<float>(logits_row[j * col_stride]) * inv_temp - max_val);
140+
float v = std::exp(Cast<float>(logits_row[j * col_stride]) * inv_temp -
141+
max_val);
148142
probs[j] = v;
149143
sum += v;
150144
}

tests/test_random_sample.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,9 @@ def test_greedy_topk1(batch_size, vocab_size, dtype, device):
6464
_random_sample(logits, out, valid, top_k_val=1, seed=42)
6565

6666
expected = _torch_argmax_sample(logits)
67-
assert torch.equal(out, expected), f"top_k=1 should give argmax, got {out}, expected {expected}"
67+
assert torch.equal(out, expected), (
68+
f"top_k=1 should give argmax, got {out}, expected {expected}"
69+
)
6870
assert valid.all(), "all samples should be valid"
6971

7072

@@ -166,7 +168,9 @@ def test_seed_offset_reproducibility(dtype, device):
166168

167169
# Different offset → must be different (different RNG state)
168170
_random_sample(logits, out3, valid3, seed=1, offset=999999)
169-
assert not torch.equal(out1, out3), "different offset should produce different results"
171+
assert not torch.equal(out1, out3), (
172+
"different offset should produce different results"
173+
)
170174

171175

172176
@pytest.mark.parametrize("dtype", (torch.float32,))

0 commit comments

Comments
 (0)