@@ -35,9 +35,8 @@ class RandomSample : public Operator<RandomSample> {
3535 std::optional<Tensor> temperature, float temperature_val,
3636 std::optional<Tensor> top_k, int top_k_val,
3737 std::optional<Tensor> top_p, float top_p_val,
38- std::optional<Tensor> min_p, float min_p_val,
39- std::uint64_t seed, std::uint64_t offset,
40- bool deterministic)
38+ std::optional<Tensor> min_p, float min_p_val, std::uint64_t seed,
39+ std::uint64_t offset, bool deterministic)
4140 : logits_dtype_{logits.dtype ()},
4241 out_dtype_{out.dtype ()},
4342 ndim_{logits.ndim ()},
@@ -70,34 +69,29 @@ class RandomSample : public Operator<RandomSample> {
7069 // Simplified constructor: no filtering, default temperature.
7170 RandomSample (const Tensor logits, Tensor out, Tensor valid,
7271 std::uint64_t seed, std::uint64_t offset)
73- : RandomSample{logits, out, valid,
74- std::nullopt , 1 .0f ,
75- std::nullopt , 0 ,
76- std::nullopt , 1 .0f ,
77- std::nullopt , 0 .0f ,
78- seed, offset, false } {}
72+ : RandomSample{logits, out, valid, std::nullopt ,
73+ 1 .0f , std::nullopt , 0 , std::nullopt ,
74+ 1 .0f , std::nullopt , 0 .0f , seed,
75+ offset, false } {}
7976
8077 virtual void operator ()(const Tensor logits, Tensor out, Tensor valid,
8178 std::optional<Tensor> temperature,
82- float temperature_val,
83- std::optional<Tensor> top_k, int top_k_val,
84- std::optional<Tensor> top_p, float top_p_val,
85- std::optional<Tensor> min_p, float min_p_val,
86- std::uint64_t seed, std::uint64_t offset,
87- bool deterministic) const = 0;
79+ float temperature_val, std::optional<Tensor> top_k,
80+ int top_k_val, std::optional<Tensor> top_p,
81+ float top_p_val, std::optional<Tensor> min_p,
82+ float min_p_val, std::uint64_t seed,
83+ std::uint64_t offset, bool deterministic) const = 0;
8884
8985 virtual void operator ()(const Tensor logits, Tensor out, Tensor valid,
9086 std::uint64_t seed, std::uint64_t offset) const {
91- return operator ()(logits, out, valid,
92- temperature_, temperature_val_,
93- top_k_, top_k_val_,
94- top_p_, top_p_val_,
95- min_p_, min_p_val_,
96- seed, offset, deterministic_);
87+ return operator ()(logits, out, valid, temperature_, temperature_val_,
88+ top_k_, top_k_val_, top_p_, top_p_val_, min_p_,
89+ min_p_val_, seed, offset, deterministic_);
9790 }
9891
9992 protected:
100- static void ValidateIntParam (std::optional<Tensor> t, Tensor::Size batch_size) {
93+ static void ValidateIntParam (std::optional<Tensor> t,
94+ Tensor::Size batch_size) {
10195 if (!t.has_value ()) return ;
10296 const auto & tensor = *t;
10397 assert (tensor.ndim () == 1 && tensor.size (0 ) == batch_size &&
@@ -107,7 +101,8 @@ class RandomSample : public Operator<RandomSample> {
107101 " per-batch int param must be int32 or int64" );
108102 }
109103
110- static void ValidateFloatParam (std::optional<Tensor> t, Tensor::Size batch_size) {
104+ static void ValidateFloatParam (std::optional<Tensor> t,
105+ Tensor::Size batch_size) {
111106 if (!t.has_value ()) return ;
112107 const auto & tensor = *t;
113108 assert (tensor.ndim () == 1 && tensor.size (0 ) == batch_size &&
@@ -119,8 +114,9 @@ class RandomSample : public Operator<RandomSample> {
119114 " per-batch float param must be float16/bfloat16/float32/float64" );
120115 }
121116
122- void ValidateParams (std::optional<Tensor> temperature, std::optional<Tensor> top_k,
123- std::optional<Tensor> top_p, std::optional<Tensor> min_p) const {
117+ void ValidateParams (std::optional<Tensor> temperature,
118+ std::optional<Tensor> top_k, std::optional<Tensor> top_p,
119+ std::optional<Tensor> min_p) const {
124120 ValidateFloatParam (temperature, batch_size_);
125121 ValidateIntParam (top_k, batch_size_);
126122 ValidateFloatParam (top_p, batch_size_);
0 commit comments