Skip to content

Commit c702a96

Browse files
Add entire pallas kernel for testing
1 parent e25d9bc commit c702a96

2 files changed

Lines changed: 482 additions & 4 deletions

File tree

src/maxtext/models/qwen3.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@
4848
from maxtext.utils import max_utils
4949
from maxtext.inference import page_manager, kvcache
5050

51-
from maxtext.scratch_code import gdn_pallas
51+
from maxtext.scratch_code import gdn_pallas, gdn_pallas2
5252

5353
# -----------------------------------------
5454
# Qwen3-Next Layer Implementations
@@ -282,6 +282,7 @@ def to_chunk_scalar(x):
282282
else:
283283
h_init = initial_state.astype(compute_dtype)
284284

285+
kernel_to_use = gdn_pallas.gdn_pallas_layer
285286
# Invoke Kernel
286287
if mesh is not None:
287288
# Mesh Partitioning
@@ -298,17 +299,17 @@ def to_chunk_scalar(x):
298299
state_spec = P(batch_spec, head_spec, None, None)
299300

300301
sharded_gdn = shard_map(
301-
gdn_pallas.gdn_pallas_layer,
302+
kernel_to_use,
302303
mesh=mesh,
303304
in_specs=(in_specs, in_specs, in_specs, in_specs, in_specs, scalar_specs, scalar_specs, state_spec),
304-
out_specs=(in_specs, state_spec), # Returns (out, final_state)
305+
out_specs=(in_specs, state_spec),
305306
check_rep=False
306307
)
307308

308309
o_pallas, h_final = sharded_gdn(w_p, u_p, q_p, k_p, v_p, g_p, beta_p, h_init)
309310
else:
310311
# Single Device
311-
o_pallas, h_final = gdn_pallas.gdn_pallas_layer(w_p, u_p, q_p, k_p, v_p, g_p, beta_p, h_init)
312+
o_pallas, h_final = kernel_to_use(w_p, u_p, q_p, k_p, v_p, g_p, beta_p, h_init)
312313

313314
o_chunks = o_pallas.transpose(0, 2, 1, 3, 4)
314315

0 commit comments

Comments
 (0)