Skip to content

Commit 8bd18e5

Browse files
mesakhcienetecnal-cienet
authored andcommitted
feat: implement nnx-based pipeline
1 parent 2961f8d commit 8bd18e5

8 files changed

Lines changed: 1935 additions & 892 deletions

File tree

src/maxtext/layers/embeddings.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -151,10 +151,11 @@ def __call__(self, inputs: Array, model_mode: str = MODEL_MODE_TRAIN) -> Array:
151151
if not jnp.issubdtype(inputs.dtype, jnp.integer):
152152
raise ValueError("Input type must be an integer or unsigned integer.")
153153

154-
embedding = jnp.asarray(
155-
_maybe_move_embedding_to_device(self.embedding.value, self.config),
156-
self.dtype,
157-
)
154+
embedding_val = _maybe_move_embedding_to_device(self.embedding.value, self.config)
155+
if isinstance(embedding_val, jax.ShapeDtypeStruct):
156+
embedding = embedding_val
157+
else:
158+
embedding = jnp.asarray(embedding_val, self.dtype)
158159

159160
output_axis_names = (
160161
(

src/maxtext/layers/linears.py

Lines changed: 36 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -220,32 +220,48 @@ def __call__(self, inputs: Array, _initializing: bool = False, out_sharding: Nam
220220
kernel_shape = self.in_features_shape + self.out_features_shape
221221
kernel = jnp.zeros(kernel_shape, dtype=self.dtype)
222222
else:
223-
kernel = self.kernel[...]
224-
# Move logit_dense kernel to device if parameter offloading is enabled
225-
if self.parameter_memory_host_offload:
226-
max_logging.log("linear.py: Moving parameter logits_dense kernel to device")
227-
kernel = jax.device_put(kernel, max_utils.device_space())
228-
kernel = jnp.asarray(kernel, self.dtype)
223+
kernel_val = self.kernel.value
224+
if kernel_val is not None:
225+
if isinstance(kernel_val, jax.ShapeDtypeStruct):
226+
# Bypass concrete indexing for abstract tracers
227+
kernel = kernel_val
228+
else:
229+
kernel = self.kernel[...]
230+
# Move logit_dense kernel to device if parameter offloading is enabled
231+
if self.parameter_memory_host_offload:
232+
max_logging.log("linear.py: Moving parameter logits_dense kernel to device")
233+
kernel = jax.device_put(kernel, max_utils.device_space())
234+
kernel = jnp.asarray(kernel, self.dtype)
235+
else:
236+
kernel = None
229237

230238
# out_sharding should be None for auto mesh axis
231239
if self.shard_mode != ShardMode.EXPLICIT:
232240
out_sharding = None
233241

234-
contract_ind = tuple(range(0, len(self.axis)))
235-
output = _compute_dot_general_nnx(
236-
inputs,
237-
kernel,
238-
norm_axis,
239-
contract_ind,
240-
self.matmul_precision,
241-
self.quant_dot_general,
242-
_initializing,
243-
out_sharding,
244-
)
242+
if kernel is not None:
243+
contract_ind = tuple(range(0, len(self.axis)))
244+
output = _compute_dot_general_nnx(
245+
inputs,
246+
kernel,
247+
norm_axis,
248+
contract_ind,
249+
self.matmul_precision,
250+
self.quant_dot_general,
251+
_initializing,
252+
out_sharding,
253+
)
254+
255+
if self.bias is not None:
256+
bias_val = self.bias.value
257+
if bias_val is not None:
258+
bias = jnp.asarray(self.bias[...], self.dtype)
259+
output += bias
260+
else:
261+
# If kernel is missing (e.g. masked in pipeline), return zeros.
262+
out_shape = inputs.shape[: -len(self.axis)] + self.out_features_shape
263+
output = jnp.zeros(out_shape, dtype=self.dtype)
245264

246-
if self.bias is not None:
247-
bias = jnp.asarray(self.bias[...], self.dtype)
248-
output += bias
249265
return output
250266

251267

src/maxtext/layers/moe.py

Lines changed: 60 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -273,25 +273,35 @@ def __call__(self, inputs: jax.Array, _initializing: bool = False) -> Tuple[jax.
273273
kernel_shape = self.in_features_shape + self.out_features_shape
274274
kernel = jnp.zeros(kernel_shape, dtype=self.dtype)
275275
else:
276-
kernel = self.kernel[...]
277-
kernel = jnp.asarray(kernel, self.dtype)
276+
kernel_val = self.kernel.value
277+
if kernel_val is not None:
278+
kernel = self.kernel[...]
279+
kernel = jnp.asarray(kernel, self.dtype)
280+
else:
281+
kernel = None
282+
283+
if kernel is not None:
284+
contract_ind = tuple(range(0, len(norm_axis)))
285+
output_sharding = (
286+
create_sharding(self.mesh, ("activation_batch_no_exp_moe", "activation_length_no_exp_moe", None))
287+
if self.shard_mode == ShardMode.EXPLICIT
288+
else None
289+
)
290+
output = linears._compute_dot_general_nnx(
291+
inputs,
292+
kernel,
293+
norm_axis,
294+
contract_ind,
295+
self.matmul_precision,
296+
self.quant_dot_general,
297+
_initializing,
298+
out_sharding=output_sharding,
299+
)
300+
else:
301+
# If kernel is missing (e.g. masked in pipeline), return zeros.
302+
out_shape = inputs.shape[:-1] + self.out_features_shape
303+
output = jnp.zeros(out_shape, dtype=self.dtype)
278304

279-
contract_ind = tuple(range(0, len(norm_axis)))
280-
output_sharding = (
281-
create_sharding(self.mesh, ("activation_batch_no_exp_moe", "activation_length_moe", None))
282-
if self.shard_mode == ShardMode.EXPLICIT
283-
else None
284-
)
285-
output = linears._compute_dot_general_nnx(
286-
inputs,
287-
kernel,
288-
norm_axis,
289-
contract_ind,
290-
self.matmul_precision,
291-
self.quant_dot_general,
292-
_initializing,
293-
out_sharding=output_sharding,
294-
)
295305
pre_bias_logits = None
296306

297307
if self.score_func:
@@ -300,8 +310,10 @@ def __call__(self, inputs: jax.Array, _initializing: bool = False) -> Tuple[jax.
300310
pre_bias_logits = output
301311

302312
if self.use_bias:
303-
bias = jnp.asarray(self.bias[...], self.dtype)
304-
output += bias
313+
bias_val = self.bias.value
314+
if bias_val is not None:
315+
bias = jnp.asarray(self.bias[...], self.dtype)
316+
output += bias
305317
return output, pre_bias_logits
306318

307319

@@ -2024,9 +2036,10 @@ def __call__(
20242036
routing_inputs = inputs if gate_inputs is None else gate_inputs.astype(gate_dtype)
20252037
gate_logits, pre_bias_logits = self.gate(routing_inputs)
20262038

2027-
w0_kernel = jnp.asarray(self.wi_0[...], self.dtype)
2028-
w1_kernel = jnp.asarray(self.wi_1[...], self.dtype)
2029-
wo_kernel = jnp.asarray(self.wo[...], self.dtype)
2039+
if self.wi_0.value is not None:
2040+
w0_kernel = jnp.asarray(self.wi_0[...], self.dtype)
2041+
w1_kernel = jnp.asarray(self.wi_1[...], self.dtype)
2042+
wo_kernel = jnp.asarray(self.wo[...], self.dtype)
20302043

20312044
if self.per_expert_scale is not None:
20322045
wo_kernel = wo_kernel * jnp.asarray(self.per_expert_scale[...], self.dtype)[:, None, None]
@@ -2038,26 +2051,32 @@ def __call__(
20382051
else:
20392052
w0_bias, w1_bias, wo_bias = None, None, None
20402053

2041-
if cfg.sparse_matmul:
2042-
if quantizations.in_serve_mode(self.quant):
2043-
w0_kernel, w1_kernel, wo_kernel = self.retrieve_quantized_weight(
2044-
inputs,
2045-
gate_logits,
2046-
pre_bias_logits,
2047-
w0_kernel,
2048-
w1_kernel,
2049-
wo_kernel,
2050-
w0_bias,
2051-
w1_bias,
2052-
wo_bias,
2054+
if cfg.sparse_matmul:
2055+
if quantizations.in_serve_mode(self.quant):
2056+
w0_kernel, w1_kernel, wo_kernel = self.retrieve_quantized_weight(
2057+
inputs,
2058+
gate_logits,
2059+
pre_bias_logits,
2060+
w0_kernel,
2061+
w1_kernel,
2062+
wo_kernel,
2063+
w0_bias,
2064+
w1_bias,
2065+
wo_bias,
2066+
)
2067+
output, lb_loss, bias_updates = self.sparse_matmul(
2068+
inputs, gate_logits, pre_bias_logits, w0_kernel, w1_kernel, wo_kernel, w0_bias, w1_bias, wo_bias
2069+
)
2070+
else:
2071+
output, lb_loss, bias_updates = self.dense_matmul(
2072+
inputs, gate_logits, pre_bias_logits, w0_kernel, w1_kernel, wo_kernel, w0_bias, w1_bias, wo_bias
20532073
)
2054-
output, lb_loss, bias_updates = self.sparse_matmul(
2055-
inputs, gate_logits, pre_bias_logits, w0_kernel, w1_kernel, wo_kernel, w0_bias, w1_bias, wo_bias
2056-
)
20572074
else:
2058-
output, lb_loss, bias_updates = self.dense_matmul(
2059-
inputs, gate_logits, pre_bias_logits, w0_kernel, w1_kernel, wo_kernel, w0_bias, w1_bias, wo_bias
2060-
)
2075+
# If kernels are missing (e.g. masked in pipeline), return zeros.
2076+
output = jnp.zeros_like(inputs)
2077+
lb_loss = None
2078+
bias_updates = None
2079+
20612080
return output, lb_loss, bias_updates
20622081

20632082

0 commit comments

Comments
 (0)