@@ -273,25 +273,35 @@ def __call__(self, inputs: jax.Array, _initializing: bool = False) -> Tuple[jax.
273273 kernel_shape = self .in_features_shape + self .out_features_shape
274274 kernel = jnp .zeros (kernel_shape , dtype = self .dtype )
275275 else :
276- kernel = self .kernel [...]
277- kernel = jnp .asarray (kernel , self .dtype )
276+ kernel_val = self .kernel .value
277+ if kernel_val is not None :
278+ kernel = self .kernel [...]
279+ kernel = jnp .asarray (kernel , self .dtype )
280+ else :
281+ kernel = None
282+
283+ if kernel is not None :
284+ contract_ind = tuple (range (0 , len (norm_axis )))
285+ output_sharding = (
286+ create_sharding (self .mesh , ("activation_batch_no_exp_moe" , "activation_length_no_exp_moe" , None ))
287+ if self .shard_mode == ShardMode .EXPLICIT
288+ else None
289+ )
290+ output = linears ._compute_dot_general_nnx (
291+ inputs ,
292+ kernel ,
293+ norm_axis ,
294+ contract_ind ,
295+ self .matmul_precision ,
296+ self .quant_dot_general ,
297+ _initializing ,
298+ out_sharding = output_sharding ,
299+ )
300+ else :
301+ # If kernel is missing (e.g. masked in pipeline), return zeros.
302+ out_shape = inputs .shape [:- 1 ] + self .out_features_shape
303+ output = jnp .zeros (out_shape , dtype = self .dtype )
278304
279- contract_ind = tuple (range (0 , len (norm_axis )))
280- output_sharding = (
281- create_sharding (self .mesh , ("activation_batch_no_exp_moe" , "activation_length_moe" , None ))
282- if self .shard_mode == ShardMode .EXPLICIT
283- else None
284- )
285- output = linears ._compute_dot_general_nnx (
286- inputs ,
287- kernel ,
288- norm_axis ,
289- contract_ind ,
290- self .matmul_precision ,
291- self .quant_dot_general ,
292- _initializing ,
293- out_sharding = output_sharding ,
294- )
295305 pre_bias_logits = None
296306
297307 if self .score_func :
@@ -300,8 +310,10 @@ def __call__(self, inputs: jax.Array, _initializing: bool = False) -> Tuple[jax.
300310 pre_bias_logits = output
301311
302312 if self .use_bias :
303- bias = jnp .asarray (self .bias [...], self .dtype )
304- output += bias
313+ bias_val = self .bias .value
314+ if bias_val is not None :
315+ bias = jnp .asarray (self .bias [...], self .dtype )
316+ output += bias
305317 return output , pre_bias_logits
306318
307319
@@ -2024,9 +2036,10 @@ def __call__(
20242036 routing_inputs = inputs if gate_inputs is None else gate_inputs .astype (gate_dtype )
20252037 gate_logits , pre_bias_logits = self .gate (routing_inputs )
20262038
2027- w0_kernel = jnp .asarray (self .wi_0 [...], self .dtype )
2028- w1_kernel = jnp .asarray (self .wi_1 [...], self .dtype )
2029- wo_kernel = jnp .asarray (self .wo [...], self .dtype )
2039+ if self .wi_0 .value is not None :
2040+ w0_kernel = jnp .asarray (self .wi_0 [...], self .dtype )
2041+ w1_kernel = jnp .asarray (self .wi_1 [...], self .dtype )
2042+ wo_kernel = jnp .asarray (self .wo [...], self .dtype )
20302043
20312044 if self .per_expert_scale is not None :
20322045 wo_kernel = wo_kernel * jnp .asarray (self .per_expert_scale [...], self .dtype )[:, None , None ]
@@ -2038,26 +2051,32 @@ def __call__(
20382051 else :
20392052 w0_bias , w1_bias , wo_bias = None , None , None
20402053
2041- if cfg .sparse_matmul :
2042- if quantizations .in_serve_mode (self .quant ):
2043- w0_kernel , w1_kernel , wo_kernel = self .retrieve_quantized_weight (
2044- inputs ,
2045- gate_logits ,
2046- pre_bias_logits ,
2047- w0_kernel ,
2048- w1_kernel ,
2049- wo_kernel ,
2050- w0_bias ,
2051- w1_bias ,
2052- wo_bias ,
2054+ if cfg .sparse_matmul :
2055+ if quantizations .in_serve_mode (self .quant ):
2056+ w0_kernel , w1_kernel , wo_kernel = self .retrieve_quantized_weight (
2057+ inputs ,
2058+ gate_logits ,
2059+ pre_bias_logits ,
2060+ w0_kernel ,
2061+ w1_kernel ,
2062+ wo_kernel ,
2063+ w0_bias ,
2064+ w1_bias ,
2065+ wo_bias ,
2066+ )
2067+ output , lb_loss , bias_updates = self .sparse_matmul (
2068+ inputs , gate_logits , pre_bias_logits , w0_kernel , w1_kernel , wo_kernel , w0_bias , w1_bias , wo_bias
2069+ )
2070+ else :
2071+ output , lb_loss , bias_updates = self .dense_matmul (
2072+ inputs , gate_logits , pre_bias_logits , w0_kernel , w1_kernel , wo_kernel , w0_bias , w1_bias , wo_bias
20532073 )
2054- output , lb_loss , bias_updates = self .sparse_matmul (
2055- inputs , gate_logits , pre_bias_logits , w0_kernel , w1_kernel , wo_kernel , w0_bias , w1_bias , wo_bias
2056- )
20572074 else :
2058- output , lb_loss , bias_updates = self .dense_matmul (
2059- inputs , gate_logits , pre_bias_logits , w0_kernel , w1_kernel , wo_kernel , w0_bias , w1_bias , wo_bias
2060- )
2075+ # If kernels are missing (e.g. masked in pipeline), return zeros.
2076+ output = jnp .zeros_like (inputs )
2077+ lb_loss = None
2078+ bias_updates = None
2079+
20612080 return output , lb_loss , bias_updates
20622081
20632082
0 commit comments