@@ -1057,8 +1057,8 @@ def _sample_and_accept_draft_tokens_base(
10571057 device = logits .device )
10581058
10591059 # Sample tokens using per-request sampling parameters
1060- target_tokens = self ._sample_tokens_for_batch (
1061- logits , spec_metadata , num_contexts , batch_size )
1060+ target_tokens = self ._sample_tokens_for_batch (logits , spec_metadata ,
1061+ num_contexts , batch_size )
10621062
10631063 # Context requests: only accept the sampled token (no draft tokens yet)
10641064 accepted_tokens [:num_contexts , 0 ] = target_tokens [:num_contexts ]
@@ -1072,8 +1072,7 @@ def _sample_and_accept_draft_tokens_base(
10721072 # Compare draft tokens with target tokens using cumulative product
10731073 # Counts consecutive matches from the start
10741074 num_accepted_tokens [num_contexts :] += torch .cumprod (
1075- (draft_tokens
1076- == gen_target_tokens [:, :runtime_draft_len ]).int (),
1075+ (draft_tokens == gen_target_tokens [:, :runtime_draft_len ]).int (),
10771076 dim = - 1 ).sum (1 )
10781077
10791078 # Apply force override if set
@@ -1182,8 +1181,9 @@ def _sample_and_accept_draft_tokens_rejection(
11821181
11831182 target_probs_flat = compute_probs_from_logits (
11841183 gen_logits , temperatures , top_ks , top_ps )
1185- target_probs = target_probs_flat .reshape (
1186- num_gens , runtime_draft_len + 1 , vocab_size )
1184+ target_probs = target_probs_flat .reshape (num_gens ,
1185+ runtime_draft_len + 1 ,
1186+ vocab_size )
11871187
11881188 draft_vocab_size = draft_probs .shape [- 1 ]
11891189 assert draft_probs .shape [0 ] == num_gens , (
@@ -1200,8 +1200,7 @@ def _sample_and_accept_draft_tokens_rejection(
12001200 # configured, e.g. when use_rejection_sampling was off at
12011201 # prepare() time.
12021202 if spec_metadata .full_draft_probs is not None :
1203- full_draft_probs = spec_metadata .full_draft_probs [:
1204- num_gens ]
1203+ full_draft_probs = spec_metadata .full_draft_probs [:num_gens ]
12051204 else :
12061205 full_draft_probs = torch .zeros (
12071206 (num_gens , runtime_draft_len , vocab_size ),
0 commit comments