Skip to content

Commit ca65f69

Browse files
committed
feat: implement nnx-based pipeline
1 parent 37ded59 commit ca65f69

6 files changed

Lines changed: 1551 additions & 822 deletions

File tree

src/maxtext/layers/attentions.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -533,14 +533,14 @@ def __init__(
533533
elif self.is_qwen3_next:
534534
self.query_norm = Qwen3NextRMSNorm(
535535
num_features=self.config.head_dim,
536-
eps=self.config.normalization_layer_epsilon,
536+
epsilon=self.config.normalization_layer_epsilon,
537537
dtype=self.config.dtype,
538538
weight_dtype=self.config.weight_dtype,
539539
rngs=self.rngs,
540540
)
541541
self.key_norm = Qwen3NextRMSNorm(
542542
num_features=self.config.head_dim,
543-
eps=self.config.normalization_layer_epsilon,
543+
epsilon=self.config.normalization_layer_epsilon,
544544
dtype=self.config.dtype,
545545
weight_dtype=self.config.weight_dtype,
546546
rngs=self.rngs,

0 commit comments

Comments
 (0)