@@ -1585,13 +1585,40 @@ class at the server level, which is too granular for ModelRunner.
15851585 sampler_output = None
15861586 if not self .speculative_decoding :
15871587 sampler_output = self .sampler (logits , self .sampling_metadata )
1588+ if self .parallel_config .tensor_parallel_size > 1 :
1589+ paddle .distributed .broadcast (
1590+ sampler_output .sampled_token_ids ,
1591+ self .parallel_config .data_parallel_rank * self .parallel_config .tensor_parallel_size ,
1592+ group = self .parallel_config .tp_group ,
1593+ )
15881594 else :
15891595 sampler_output = self .sampler (
15901596 logits ,
15911597 self .sampling_metadata ,
15921598 self .model_config .max_model_len ,
15931599 self .share_inputs ,
15941600 )
1601+ if self .parallel_config .tensor_parallel_size > 1 :
1602+ paddle .distributed .broadcast (
1603+ self .share_inputs ["accept_tokens" ],
1604+ self .parallel_config .data_parallel_rank * self .parallel_config .tensor_parallel_size ,
1605+ group = self .parallel_config .tp_group ,
1606+ )
1607+ paddle .distributed .broadcast (
1608+ self .share_inputs ["accept_num" ],
1609+ self .parallel_config .data_parallel_rank * self .parallel_config .tensor_parallel_size ,
1610+ group = self .parallel_config .tp_group ,
1611+ )
1612+ paddle .distributed .broadcast (
1613+ self .share_inputs ["step_idx" ],
1614+ self .parallel_config .data_parallel_rank * self .parallel_config .tensor_parallel_size ,
1615+ group = self .parallel_config .tp_group ,
1616+ )
1617+ paddle .distributed .broadcast (
1618+ self .share_inputs ["stop_flags" ],
1619+ self .parallel_config .data_parallel_rank * self .parallel_config .tensor_parallel_size ,
1620+ group = self .parallel_config .tp_group ,
1621+ )
15951622
15961623 prompt_logprobs_list = None
15971624 if not self .speculative_decoding :
0 commit comments