Skip to content

Commit 8f01149

Browse files
metax666duqimengStareAtYou
authored
[Metax] Fix add flags (#227) (#2527)
* Set MXCC_OVERRIDE_OPTIONS in compile script Add MXCC_OVERRIDE_OPTIONS for metax GPU compilation. * Add MXCC_OVERRIDE_OPTIONS for Metax GPU * Update flash_attn_grad_kernel.cu * Update compile.sh * [Metax][feat] add top_p_sampling.patch. (#225) * [Metax] Fix add flags --------- Co-authored-by: duqimeng <77875733+duqimeng@users.noreply.github.com> Co-authored-by: MingkunZhang <39252862+StareAtYou@users.noreply.github.com>
1 parent 8f3743f commit 8f01149

2 files changed

Lines changed: 30 additions & 1 deletion

File tree

backends/metax_gpu/common/flags_declare.cc

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,35 @@ PHI_DEFINE_EXPORTED_bool(use_fast_math,
116116
false,
117117
"Whether to use fast math GPU functions.");
118118

119+
/**
120+
* GPU RNG related FLAG
121+
* Name: FLAGS_deterministic_rng
122+
* Since Version: 3.4
123+
* Value Range: bool, default=false
124+
* Example: paddle.set_flags({'FLAGS_deterministic_rng': True})
125+
* Note: Fix RNG kernel launch config so same seed gives same results
126+
* across GPU types.
127+
*/
128+
PHI_DEFINE_EXPORTED_bool(
129+
deterministic_rng,
130+
false,
131+
"Enable cross-device RNG consistency by fixing GPU kernel launch "
132+
"configuration. When true, RNG kernels use a fixed grid/block size "
133+
"so that the same seed produces identical results across GPU types.");
134+
/**
135+
* GPU RNG related FLAG
136+
* Name: FLAGS_deterministic_rng_grid
137+
* Since Version: 3.4
138+
* Value Range: int32, default=1024
139+
* Example: paddle.set_flags({'FLAGS_deterministic_rng_grid': 4096})
140+
* Note: Grid size cap used when FLAGS_deterministic_rng is enabled.
141+
* Cross-device consistency requires the same value on all devices.
142+
*/
143+
PHI_DEFINE_EXPORTED_int32(
144+
deterministic_rng_grid,
145+
1024,
146+
"Grid size cap when FLAGS_deterministic_rng is enabled.");
147+
119148
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
120149
/**
121150
* FlashAttention related FLAG

0 commit comments

Comments
 (0)