-
Notifications
You must be signed in to change notification settings - Fork 527
Support Qwix quantization on NNX #4040
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -38,9 +38,13 @@ | |
| from flax.linen import fp8_ops | ||
| from flax.linen import initializers as flax_initializers | ||
| import flax.linen as nn | ||
| from flax import nnx | ||
|
|
||
| from qwix._src import flax_util | ||
|
|
||
| from maxtext.common.common_types import DType, Config | ||
| from maxtext.inference.kvcache import KVQuant | ||
| from maxtext.layers import nnx_wrappers | ||
|
|
||
| # Params used to define mixed precision quantization configs | ||
| DEFAULT = "__default__" # default config | ||
|
|
@@ -707,6 +711,28 @@ def configure_kv_quant(config): | |
| return None if not config.quantize_kvcache else KVQuant(config) | ||
|
|
||
|
|
||
| def _apply_linen_module_in_nnx(linen_module_cls, op_id, *args, **kwargs): | ||
| """Applies a Linen module within an NNX context.""" | ||
| try: | ||
| parent = flax_util.get_current_module() | ||
| is_nnx = isinstance(parent, nnx.Module) | ||
| except Exception: # pylint: disable=broad-exception-caught | ||
| is_nnx = False | ||
|
|
||
| if is_nnx: | ||
| attr_name = f"_qwix_fp8_gpu_{op_id}" | ||
| if not hasattr(parent, attr_name): | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
🟡 Hardcoded RNG seed `nnx.Rngs(0)`. While likely acceptable for initialization of quantization stats, it's generally safer to fork an existing RNG or use a properly seeded one to avoid potential collisions if multiple modules are initialized this way.
|
||
| rngs = getattr(parent, "qwix_rngs", None) | ||
| if rngs is None: | ||
| rngs = nnx.Rngs(0) | ||
| wrapper = nnx_wrappers.ToNNX(linen_module_cls(name=op_id), rngs=rngs) | ||
| wrapper.lazy_init(*args, **kwargs) | ||
| setattr(parent, attr_name, wrapper) | ||
| return getattr(parent, attr_name)(*args, mutable=["_overwrite_with_gradient"], **kwargs) | ||
| else: | ||
| return linen_module_cls(name=op_id)(*args, **kwargs) | ||
|
|
||
|
|
||
| class NvidaFp8Provider(qwix.QtProvider): | ||
| """Wraps nn.Fp8DirectDotGeneralOp with Qwix's provider interface.""" | ||
|
|
||
|
|
@@ -715,13 +741,13 @@ def dot_general(self, *args, **kwargs): | |
| rule, op_id = self._get_current_rule_and_op_id("dot_general") | ||
| if rule is None: | ||
| return jax.lax.dot_general(*args, **kwargs) | ||
| return nn.Fp8DirectDotGeneralOp(name=op_id)(*args, **kwargs) | ||
| return _apply_linen_module_in_nnx(nn.Fp8DirectDotGeneralOp, op_id, *args, **kwargs) | ||
|
|
||
| def einsum(self, *args, **kwargs): | ||
| rule, op_id = self._get_current_rule_and_op_id("einsum") | ||
| if rule is None: | ||
| return jnp.einsum(*args, **kwargs) | ||
| return nn.Fp8Einsum(name=op_id)(*args, **kwargs) | ||
| return _apply_linen_module_in_nnx(nn.Fp8Einsum, op_id, *args, **kwargs) | ||
|
|
||
|
|
||
| class NANOOFp8Provider(qwix.QtProvider): | ||
|
|
@@ -731,7 +757,7 @@ def dot_general(self, *args, **kwargs): | |
| rule, op_id = self._get_current_rule_and_op_id("dot_general") | ||
| if rule is None: | ||
| return jax.lax.dot_general(*args, **kwargs) | ||
| return nn.NANOOFp8DotGeneralOp(name=op_id)(*args, **kwargs) | ||
| return _apply_linen_module_in_nnx(nn.NANOOFp8DotGeneralOp, op_id, *args, **kwargs) | ||
|
|
||
|
|
||
| def get_fp8_full_qwix_rule_w_sparsity(config: Config): | ||
|
|
@@ -812,7 +838,21 @@ def maybe_quantize_model(model, config): | |
| if config.use_qwix_quantization and not config.use_batch_split_schedule: | ||
| quantization_provider = get_qt_provider(config) | ||
| if quantization_provider: | ||
| model = qwix.quantize_model(model, quantization_provider) | ||
| if config.pure_nnx: | ||
| input_shape = (config.micro_batch_size_to_train_on, config.max_target_length) | ||
| dummy_tokens = jnp.ones(input_shape, dtype=jnp.int32) | ||
| dummy_positions = jnp.ones(input_shape, dtype=jnp.int32) | ||
| dummy_segment_ids = jnp.ones(input_shape, dtype=jnp.int32) | ||
| model = qwix.quantize_model( | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
🟡 Dummy inputs for `qwix.quantize_model` are hardcoded to `jnp.ones`. This might not be suitable for all models or all input types. Consider making this more flexible or deriving it from the model configuration if possible.
|
||
| model, | ||
| quantization_provider, | ||
| dummy_tokens, | ||
| dummy_positions, | ||
| dummy_segment_ids, | ||
| enable_dropout=False, | ||
| ) | ||
| else: | ||
| model = qwix.quantize_model(model, quantization_provider) | ||
| return model | ||
|
|
||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -349,7 +349,10 @@ def train_step(model, config, state_mesh_shardings, params_shardings, state, dat | |||||||||||||||
| is_train=True, | ||||||||||||||||
| ) | ||||||||||||||||
| else: | ||||||||||||||||
| model_graphdef, curr_params, rest = nnx.split(state.model, nnx.Param, ...) | ||||||||||||||||
| OverwriteWithGradient = nnx.variablelib.variable_type_from_name( | ||||||||||||||||
| maxtext_utils.OVERWRITE_WITH_GRADIENT, allow_register=True | ||||||||||||||||
| ) | ||||||||||||||||
| model_graphdef, curr_params, overwrite_vars, rest = nnx.split(state.model, nnx.Param, OverwriteWithGradient, ...) | ||||||||||||||||
| if config.parameter_memory_host_offload: | ||||||||||||||||
| # Params are kept on host (pinned_host) in in_shardings. Move only Param | ||||||||||||||||
| # variables to device before the forward/backward pass so that all dot_general | ||||||||||||||||
|
|
@@ -371,15 +374,18 @@ def train_step(model, config, state_mesh_shardings, params_shardings, state, dat | |||||||||||||||
| ) | ||||||||||||||||
| nnx.update(state.model, curr_params) | ||||||||||||||||
|
|
||||||||||||||||
| def diff_wrapper(param, rest, config, data): | ||||||||||||||||
| local_model = nnx.merge(model_graphdef, param, rest, copy=True) | ||||||||||||||||
| def diff_wrapper(param, overwrite_vars, rest, config, data): | ||||||||||||||||
| local_model = nnx.merge(model_graphdef, param, overwrite_vars, rest, copy=True) | ||||||||||||||||
| loss, aux = loss_fn(local_model, config, data, None, None, is_train=True) | ||||||||||||||||
| _, _, new_rest = nnx.split(local_model, nnx.Param, ...) | ||||||||||||||||
| return loss, (aux, new_rest) | ||||||||||||||||
| _, _, new_overwrite_vars, new_rest = nnx.split(local_model, nnx.Param, OverwriteWithGradient, ...) | ||||||||||||||||
| return loss, (aux, new_overwrite_vars, new_rest) | ||||||||||||||||
|
|
||||||||||||||||
| grad_func = jax.value_and_grad(diff_wrapper, argnums=0, has_aux=True) | ||||||||||||||||
| (loss, (aux, new_rest)), raw_grads = grad_func(curr_params, rest, config, data) | ||||||||||||||||
| grad_func = jax.value_and_grad(diff_wrapper, argnums=(0, 1), has_aux=True) | ||||||||||||||||
| (loss, (aux, _, new_rest)), (raw_grads, overwrite_grads) = grad_func( | ||||||||||||||||
| curr_params, overwrite_vars, rest, config, data | ||||||||||||||||
| ) | ||||||||||||||||
| nnx.update(state.model, new_rest) | ||||||||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
🟡 The logic for handling `overwrite_grads` and `new_rest` seems partially redundant.
- You are extracting `new_overwrite_vars` in the `diff_wrapper` aux but then ignoring it in the caller: `(loss, (aux, _, new_rest))`.
- Instead, you are using `overwrite_grads` to update the model.
- If `overwrite_vars` (variables of type `OverwriteWithGradient`) use the "custom gradient trick" to return their updated values as gradients, then `overwrite_grads` and `new_overwrite_vars` should be identical.
- However, if any variable of this type does *not* use this trick, its gradient will be 0, and `nnx.update(state.model, overwrite_grads)` will incorrectly zero it out, whereas `new_overwrite_vars` would have preserved the updated (or original) value.
Is there a reason to prefer
Suggested change
|
||||||||||||||||
| nnx.update(state.model, overwrite_grads) | ||||||||||||||||
|
|
||||||||||||||||
| raw_grads = jax.tree_util.tree_map( | ||||||||||||||||
| lambda x: x.astype(config.grad_dtype) if x.dtype == jnp.float32 else x, | ||||||||||||||||
|
|
||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider a more robust way to ensure the model structure matches the incoming state, or ensure that the model is fully initialized (including quantization wrappers) before
nnx.updateis called.