Skip to content

Support Qwix quantization on NNX#4040

Draft
hsuan-lun-chiang wants to merge 2 commits into
mainfrom
feat/Support-Qwix-quantization-on-NNX
Draft

Support Qwix quantization on NNX#4040
hsuan-lun-chiang wants to merge 2 commits into
mainfrom
feat/Support-Qwix-quantization-on-NNX

Conversation

@hsuan-lun-chiang
Copy link
Copy Markdown
Collaborator

@hsuan-lun-chiang hsuan-lun-chiang commented Jun 2, 2026

Description

Start with a short description of what the PR does and how this is a change from
the past.

The rest of the description includes relevant details and context, examples:

  • why is this change being made,
  • the problem being solved and any relevant context,
  • why this is a good solution,
  • some information about the specific implementation,
  • shortcomings of the solution and possible future improvements.

If the change fixes a bug or a Github issue, please include a link, e.g.,:
FIXES: b/123456
FIXES: #123456

You can also provide a comma-separated list. If you don't want to close a bug but
simply to reference it, use BUGS, e.g.:
BUGS: b/123456

Notice 1: Once all tests pass, the "pull ready" label will automatically be assigned.
This label is used for administrative purposes. Please do not add it manually.

Notice 2: For external contributions, our settings currently require an approval from a MaxText maintainer to trigger CI tests.

Tests

Please describe how you tested this change, and include any instructions and/or
commands to reproduce.

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@codecov
Copy link
Copy Markdown

codecov Bot commented Jun 2, 2026

Codecov Report

❌ Patch coverage is 91.30435% with 4 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/maxtext/layers/quantizations.py 86.66% 3 Missing and 1 partial ⚠️

📢 Thoughts on this report? Let us know!

@hsuan-lun-chiang hsuan-lun-chiang force-pushed the feat/Support-Qwix-quantization-on-NNX branch 3 times, most recently from 239e912 to f20b375 Compare June 3, 2026 09:04
@hsuan-lun-chiang hsuan-lun-chiang force-pushed the feat/Support-Qwix-quantization-on-NNX branch from 135d9bc to 2993617 Compare June 3, 2026 10:44
@AI-Hypercomputer AI-Hypercomputer deleted a comment from github-actions Bot Jun 3, 2026
Copy link
Copy Markdown

@github-actions github-actions Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

## 📋 Review Summary

This Pull Request introduces support for Qwix quantization on NNX models, primarily focusing on FP8 support. It includes changes to the NNX bridge wrappers, quantization providers, and the training loop to handle special variable types for FP8 stats.

🔍 General Feedback

  • The move to NNX for quantization is a positive step, as it simplifies the state management and removes previous workarounds for jax.lax.scan.
  • However, the dynamic reconstruction of module structures in ToLinen is risky and could lead to runtime errors with complex model architectures (especially those involving lists or sequences).
  • The training loop logic for OverwriteWithGradient variables relies on custom gradient behaviors that should be clearly documented to avoid confusion with standard parameter updates.
  • Test coverage for pure NNX quantization is a good addition.

dummy_tokens = jnp.ones(input_shape, dtype=jnp.int32)
dummy_positions = jnp.ones(input_shape, dtype=jnp.int32)
dummy_segment_ids = jnp.ones(input_shape, dtype=jnp.int32)
model = qwix.quantize_model(
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 Dummy inputs for `qwix.quantize_model` are hardcoded to `jnp.ones`. This might not be suitable for all models or all input types. Consider making this more flexible or deriving it from the model configuration if possible.


if is_nnx:
attr_name = f"_qwix_fp8_gpu_{op_id}"
if not hasattr(parent, attr_name):
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 Hardcoded RNG seed `nnx.Rngs(0)`. While likely acceptable for initialization of quantization stats, it's generally safer to fork an existing RNG or use a properly seeded one to avoid potential collisions if multiple modules are initialized this way.

(loss, (aux, _, new_rest)), (raw_grads, overwrite_grads) = grad_func(
curr_params, overwrite_vars, rest, config, data
)
nnx.update(state.model, new_rest)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 The logic for handling `overwrite_grads` and `new_rest` seems partially redundant. - You are extracting `new_overwrite_vars` in the `diff_wrapper` aux but then ignoring it in the caller: `(loss, (aux, _, new_rest))`. - Instead, you are using `overwrite_grads` to update the model. - If `overwrite_vars` (variables of type `OverwriteWithGradient`) use the "custom gradient trick" to return their updated values as gradients, then `overwrite_grads` and `new_overwrite_vars` should be identical. - However, if any variable of this type does *not* use this trick, its gradient will be 0, and `nnx.update(state.model, overwrite_grads)` will incorrectly zero it out, whereas `new_overwrite_vars` would have preserved the updated (or original) value.

Is there a reason to prefer overwrite_grads over new_overwrite_vars? If the goal is to support distributed all-reduce of these stats, that makes sense, but it should be documented.

Suggested change
nnx.update(state.model, new_rest)
grad_func = jax.value_and_grad(diff_wrapper, argnums=(0, 1), has_aux=True)
(loss, (aux, _, new_rest)), (raw_grads, overwrite_grads) = grad_func(
curr_params, overwrite_vars, rest, config, data
)
nnx.update(state.model, new_rest)
nnx.update(state.model, overwrite_grads)

if not hasattr(curr, p):
setattr(curr, p, nnx.Module())
curr = getattr(curr, p)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟠 The dynamic reconstruction of modules here is fragile and will likely fail in several scenarios: 1. **Integer paths**: If any part of `path[:-1]` is an integer (e.g., when traversing a list of layers or an `nnx.Sequential`), `setattr(curr, p, ...)` will raise a `TypeError` because attribute names must be strings. 2. **Container Types**: It always uses `nnx.Module()`. If the path expects a list or another container type, the structure will be incorrect. 3. **Callable Wrappers**: For Qwix quantization, these attributes are often intended to be `ToNNX` wrappers. If they are pre-created as plain `nnx.Module` instances, they will fail when called during the forward pass (specifically in `_apply_linen_module_in_nnx` which expects the attribute to be callable).

Consider a more robust way to ensure the model structure matches the incoming state, or ensure that the model is fully initialized (including quantization wrappers) before nnx.update is called.

Suggested change
# Dynamically reconstruct the unknown variables
curr = module
for p in path[:-1]:
if not hasattr(curr, p):
setattr(curr, p, nnx.Module())
curr = getattr(curr, p)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants