Skip to content

Commit 9b970de

Browse files
authored
[XPU] Add TP broadcast after sampling in XPU model runner to ensure consistent results across ranks. (#7096)
1 parent 3749457 commit 9b970de

File tree

1 file changed

+27
-0
lines changed

1 file changed

+27
-0
lines changed

fastdeploy/worker/xpu_model_runner.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1585,13 +1585,40 @@ class at the server level, which is too granular for ModelRunner.
15851585
sampler_output = None
15861586
if not self.speculative_decoding:
15871587
sampler_output = self.sampler(logits, self.sampling_metadata)
1588+
if self.parallel_config.tensor_parallel_size > 1:
1589+
paddle.distributed.broadcast(
1590+
sampler_output.sampled_token_ids,
1591+
self.parallel_config.data_parallel_rank * self.parallel_config.tensor_parallel_size,
1592+
group=self.parallel_config.tp_group,
1593+
)
15881594
else:
15891595
sampler_output = self.sampler(
15901596
logits,
15911597
self.sampling_metadata,
15921598
self.model_config.max_model_len,
15931599
self.share_inputs,
15941600
)
1601+
if self.parallel_config.tensor_parallel_size > 1:
1602+
paddle.distributed.broadcast(
1603+
self.share_inputs["accept_tokens"],
1604+
self.parallel_config.data_parallel_rank * self.parallel_config.tensor_parallel_size,
1605+
group=self.parallel_config.tp_group,
1606+
)
1607+
paddle.distributed.broadcast(
1608+
self.share_inputs["accept_num"],
1609+
self.parallel_config.data_parallel_rank * self.parallel_config.tensor_parallel_size,
1610+
group=self.parallel_config.tp_group,
1611+
)
1612+
paddle.distributed.broadcast(
1613+
self.share_inputs["step_idx"],
1614+
self.parallel_config.data_parallel_rank * self.parallel_config.tensor_parallel_size,
1615+
group=self.parallel_config.tp_group,
1616+
)
1617+
paddle.distributed.broadcast(
1618+
self.share_inputs["stop_flags"],
1619+
self.parallel_config.data_parallel_rank * self.parallel_config.tensor_parallel_size,
1620+
group=self.parallel_config.tp_group,
1621+
)
15951622

15961623
prompt_logprobs_list = None
15971624
if not self.speculative_decoding:

0 commit comments

Comments
 (0)