4848from maxtext .utils import max_utils
4949from maxtext .inference import page_manager , kvcache
5050
51- from maxtext .scratch_code import gdn_pallas
51+ from maxtext .scratch_code import gdn_pallas , gdn_pallas2
5252
5353# -----------------------------------------
5454# Qwen3-Next Layer Implementations
@@ -282,6 +282,7 @@ def to_chunk_scalar(x):
282282 else :
283283 h_init = initial_state .astype (compute_dtype )
284284
285+ kernel_to_use = gdn_pallas .gdn_pallas_layer
285286 # Invoke Kernel
286287 if mesh is not None :
287288 # Mesh Partitioning
@@ -298,17 +299,17 @@ def to_chunk_scalar(x):
298299 state_spec = P (batch_spec , head_spec , None , None )
299300
300301 sharded_gdn = shard_map (
301- gdn_pallas . gdn_pallas_layer ,
302+ kernel_to_use ,
302303 mesh = mesh ,
303304 in_specs = (in_specs , in_specs , in_specs , in_specs , in_specs , scalar_specs , scalar_specs , state_spec ),
304- out_specs = (in_specs , state_spec ), # Returns (out, final_state)
305+ out_specs = (in_specs , state_spec ),
305306 check_rep = False
306307 )
307308
308309 o_pallas , h_final = sharded_gdn (w_p , u_p , q_p , k_p , v_p , g_p , beta_p , h_init )
309310 else :
310311 # Single Device
311- o_pallas , h_final = gdn_pallas . gdn_pallas_layer (w_p , u_p , q_p , k_p , v_p , g_p , beta_p , h_init )
312+ o_pallas , h_final = kernel_to_use (w_p , u_p , q_p , k_p , v_p , g_p , beta_p , h_init )
312313
313314 o_chunks = o_pallas .transpose (0 , 2 , 1 , 3 , 4 )
314315
0 commit comments