Support Qwix quantization on NNX#4040
Conversation
Codecov Report❌ Patch coverage is
📢 Thoughts on this report? Let us know! |
239e912 to
f20b375
Compare
135d9bc to
2993617
Compare
There was a problem hiding this comment.
This Pull Request introduces support for Qwix quantization on NNX models, primarily focusing on FP8 support. It includes changes to the NNX bridge wrappers, quantization providers, and the training loop to handle special variable types for FP8 stats.
🔍 General Feedback
- The move to NNX for quantization is a positive step, as it simplifies the state management and removes previous workarounds for
jax.lax.scan. - However, the dynamic reconstruction of module structures in
ToLinenis risky and could lead to runtime errors with complex model architectures (especially those involving lists or sequences). - The training loop logic for
OverwriteWithGradientvariables relies on custom gradient behaviors that should be clearly documented to avoid confusion with standard parameter updates. - Test coverage for pure NNX quantization is a good addition.
| dummy_tokens = jnp.ones(input_shape, dtype=jnp.int32) | ||
| dummy_positions = jnp.ones(input_shape, dtype=jnp.int32) | ||
| dummy_segment_ids = jnp.ones(input_shape, dtype=jnp.int32) | ||
| model = qwix.quantize_model( |
There was a problem hiding this comment.
|
|
||
| if is_nnx: | ||
| attr_name = f"_qwix_fp8_gpu_{op_id}" | ||
| if not hasattr(parent, attr_name): |
There was a problem hiding this comment.
| (loss, (aux, _, new_rest)), (raw_grads, overwrite_grads) = grad_func( | ||
| curr_params, overwrite_vars, rest, config, data | ||
| ) | ||
| nnx.update(state.model, new_rest) |
There was a problem hiding this comment.
Is there a reason to prefer overwrite_grads over new_overwrite_vars? If the goal is to support distributed all-reduce of these stats, that makes sense, but it should be documented.
| nnx.update(state.model, new_rest) | |
| grad_func = jax.value_and_grad(diff_wrapper, argnums=(0, 1), has_aux=True) | |
| (loss, (aux, _, new_rest)), (raw_grads, overwrite_grads) = grad_func( | |
| curr_params, overwrite_vars, rest, config, data | |
| ) | |
| nnx.update(state.model, new_rest) | |
| nnx.update(state.model, overwrite_grads) |
| if not hasattr(curr, p): | ||
| setattr(curr, p, nnx.Module()) | ||
| curr = getattr(curr, p) | ||
|
|
There was a problem hiding this comment.
Consider a more robust way to ensure the model structure matches the incoming state, or ensure that the model is fully initialized (including quantization wrappers) before nnx.update is called.
| # Dynamically reconstruct the unknown variables | |
| curr = module | |
| for p in path[:-1]: | |
| if not hasattr(curr, p): | |
| setattr(curr, p, nnx.Module()) | |
| curr = getattr(curr, p) |
Description
Start with a short description of what the PR does and how this is a change from
the past.
The rest of the description includes relevant details and context, examples:
If the change fixes a bug or a Github issue, please include a link, e.g.,:
FIXES: b/123456
FIXES: #123456
You can also provide a comma-separated list. If you don't want to close a bug but
simply to reference it, use BUGS, e.g.:
BUGS: b/123456
Notice 1: Once all tests pass, the "pull ready" label will automatically be assigned.
This label is used for administrative purposes. Please do not add it manually.
Notice 2: For external contributions, our settings currently require an approval from a MaxText maintainer to trigger CI tests.
Tests
Please describe how you tested this change, and include any instructions and/or
commands to reproduce.
Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.