@@ -254,6 +254,55 @@ TEST(RejectionSamplerTest, LogProbs) {
254254 EXPECT_TRUE (torch::equal (output.top_tokens , top_k_indices));
255255}
256256
257+ TEST (RejectionSamplerTest, ConstructorDoesNotMutateDoSampleShape) {
258+ const auto device = get_test_device ();
259+ auto do_sample = torch::tensor ({false , true }, torch::device (device));
260+
261+ ASSERT_EQ (do_sample.dim (), 1 );
262+ ASSERT_EQ (do_sample.sizes (), torch::IntArrayRef ({2 }));
263+
264+ RejectionSampler sampler (do_sample,
265+ do_sample.all ().item <bool >(),
266+ !do_sample.any ().item <bool >(),
267+ /* logprobs=*/ false ,
268+ /* max_top_logprobs=*/ 0 );
269+
270+ EXPECT_EQ (do_sample.dim (), 1 );
271+ EXPECT_EQ (do_sample.sizes (), torch::IntArrayRef ({2 }));
272+ }
273+
274+ TEST (RejectionSamplerTest,
275+ ReusingDoSampleAfterRejectionSamplerKeepsSamplerOutput1D) {
276+ const auto options = get_test_options (torch::kFloat32 );
277+ const auto device = get_test_device ();
278+ auto do_sample = torch::tensor ({false , true }, torch::device (device));
279+
280+ RejectionSampler rejection_sampler (do_sample,
281+ do_sample.all ().item <bool >(),
282+ !do_sample.any ().item <bool >(),
283+ /* logprobs=*/ false ,
284+ /* max_top_logprobs=*/ 0 );
285+ (void )rejection_sampler;
286+
287+ SamplingParameters params;
288+ params.selected_token_idxes =
289+ torch::tensor ({0 , 1 }, torch::dtype (torch::kInt64 ).device (device));
290+ params.sample_idxes =
291+ torch::tensor ({0 , 1 }, torch::dtype (torch::kInt64 ).device (device));
292+ params.do_sample = do_sample;
293+ params.all_random_sample = false ;
294+ params.all_greedy_sample = false ;
295+
296+ auto logits =
297+ torch::tensor ({{3 .0f , 1 .0f , 0 .5f }, {0 .1f , 0 .2f , 4 .0f }}, options);
298+ auto output = Sampler ().forward (logits, params);
299+
300+ EXPECT_EQ (output.probs .dim (), 2 );
301+ EXPECT_EQ (output.probs .size (0 ), 2 );
302+ EXPECT_EQ (output.next_tokens .dim (), 1 );
303+ EXPECT_EQ (output.next_tokens .size (0 ), 2 );
304+ }
305+
257306TEST (RejectionSamplerTest, Random) {
258307 const auto options = get_test_options (torch::kFloat32 );
259308
0 commit comments