diff --git a/src/maxtext/layers/nnx_decoders.py b/src/maxtext/layers/nnx_decoders.py index a988c6c9df..383da52537 100644 --- a/src/maxtext/layers/nnx_decoders.py +++ b/src/maxtext/layers/nnx_decoders.py @@ -25,7 +25,7 @@ import jax.numpy as jnp from flax import linen as nn from flax import nnx -from flax.nnx import wrappers as nnx_wrappers +from maxtext.layers import nnx_wrappers from jax.ad_checkpoint import checkpoint_name from jax.sharding import Mesh @@ -543,14 +543,8 @@ def pure_layer_fn(state_in, y_in): out = merged_layer(y_in, **kwargs) return out, nnx.state(merged_layer) - # Linen FP8 ops keep amax_history in mutable Linen scope; jax.checkpoint - # re-traces and hits UnexpectedTracerError. Skip remat for FP8. - uses_linen_fp8_mutable_state = self.config.quantization in ("fp8_nanoo", "fp8_gpu") - if uses_linen_fp8_mutable_state: - out, new_state = pure_layer_fn(state, y) - else: - checkpointed_fn = jax.checkpoint(pure_layer_fn, policy=policy, prevent_cse=prevent_cse) - out, new_state = checkpointed_fn(state, y) + checkpointed_fn = jax.checkpoint(pure_layer_fn, policy=policy, prevent_cse=prevent_cse) + out, new_state = checkpointed_fn(state, y) nnx.update(layer, new_state) return out @@ -599,7 +593,6 @@ def _extract_matching_state(template, full): use_kv = kv_caches_stacked is not None def layer_fn(carry, scanned_vars): - # Unpack the sliced variables for THIS layer if use_kv: current_params, current_state, kv_cache_layer = scanned_vars @@ -671,22 +664,7 @@ def layer_fn(carry, scanned_vars): params = nnx_ensure_scan_leading_axis(params, length) state = nnx_ensure_scan_leading_axis(state, length) - # Linen FP8 ops keep amax_history in mutable Linen scope; jax.lax.scan - # leaks the tracer and hits UnexpectedTracerError. Use a Python for-loop - # for FP8 instead. - uses_linen_fp8_mutable_state = self.config.quantization in ("fp8_nanoo", "fp8_gpu") - if uses_linen_fp8_mutable_state: - carry = x_in - per_layer_states = [] - for i in range(length): - current_params = jax.tree.map(lambda x, i=i: x[i], params) - current_state = jax.tree.map(lambda x, i=i: x[i], state) - carry, new_state_i = layer_fn(carry, (current_params, current_state)) - per_layer_states.append(new_state_i) - final_carry = carry - scanned_state = jax.tree.map(lambda *xs: jnp.stack(list(xs)), *per_layer_states) - else: - final_carry, scanned_state = jax.lax.scan(layer_fn_wrapped, x_in, (params, state)) + final_carry, scanned_state = jax.lax.scan(layer_fn_wrapped, x_in, (params, state)) returned_kv_stacked = None if scan_axis != 0: @@ -982,7 +960,7 @@ def apply_output_head(self, shared_embedding, y, deterministic, model_mode): if cfg.logits_via_embedding: # Use the transpose of embedding matrix for logit transform. if isinstance(shared_embedding, nnx.Module): - embedding_table = shared_embedding.embedding[...] + embedding_table = shared_embedding.embedding.value else: embedding_table = shared_embedding.variables["params"]["embedding"] if isinstance(embedding_table, nn.spmd.LogicallyPartitioned): @@ -1246,7 +1224,6 @@ def __call__( # Hoisted function to preserve XLA cache ID def pure_layer_fn(graphdef, state_in, y_in, kv_in): - if cfg.parameter_memory_host_offload: state_in = jax.tree.map(lambda x: jax.device_put(x, max_utils.device_space()), state_in) diff --git a/src/maxtext/layers/nnx_wrappers.py b/src/maxtext/layers/nnx_wrappers.py index d29edd6e8e..3668313f55 100644 --- a/src/maxtext/layers/nnx_wrappers.py +++ b/src/maxtext/layers/nnx_wrappers.py @@ -161,6 +161,33 @@ def lazy_init(fn: Module | tp.Callable[..., tp.Any], *args, **kwargs): return fn +def isolate_linen_stacks(fn: tp.Callable[..., tp.Any], *args, **kwargs): + """Temporarily shadows/clears the thread-local Linen and Bridge module stacks during fn execution.""" + # Retrieve the thread-local stack lists + linen_stack = getattr(linen.module._context, "module_stack", None) # pylint: disable=protected-access + bridge_stack = getattr(bdg_module.MODULE_CONTEXT, "module_stack", None) + + # Performance optimization: Filter out None or already-empty stacks + # to avoid redundant copying and clearing operations. + active_stacks = [s for s in (linen_stack, bridge_stack) if s] + + if active_stacks: + # Use fastest list-slicing s[:] copy instead of list(s) + saved_stacks = [s[:] for s in active_stacks] + for s in active_stacks: + s.clear() + else: + saved_stacks = None + + try: + return fn(*args, **kwargs) + finally: + if saved_stacks: + for s, saved in zip(active_stacks, saved_stacks): + s.clear() + s.extend(saved) + + def current_linen_module() -> linen.Module | None: """Get the current Linen module from the Linen context.""" if linen.module._context.module_stack: # pylint: disable=W0212 @@ -273,7 +300,7 @@ def __call__( if "params" not in _rngs and "default" in _rngs: _rngs["params"] = _rngs.pop("default") if self._pytree__state.initializing: - out, updates = self.to_nnx__module.init_with_output(_rngs, *args, method=method, **kwargs) + out, updates = isolate_linen_stacks(self.to_nnx__module.init_with_output, _rngs, *args, method=method, **kwargs) else: nnx_attrs = { k: v @@ -285,16 +312,22 @@ def __call__( # Get `mutable` from top level bridge.Module context if any if mutable is not None: pass - elif (m := bdg_module.current_module()) is not None: + elif getattr(bdg_module.MODULE_CONTEXT, "module_stack", None) and (m := bdg_module.current_module()) is not None: assert m.scope is not None mutable = m.scope.mutable elif (m := current_linen_module()) is not None: assert m.scope is not None mutable = m.scope.mutable else: - mutable = False - - out = self.to_nnx__module.apply(variables, *args, rngs=_rngs, method=method, mutable=mutable, **kwargs) + # Safe fallback mutability: when running functionally isolated inside standard JAX transforms, + # we determine which collections (such as "stats" or "amax_history") are present and mark them mutable. + mutable = [k for k in variables.keys() if k != "params"] + if not mutable: + mutable = False + + out = isolate_linen_stacks( + self.to_nnx__module.apply, variables, *args, rngs=_rngs, method=method, mutable=mutable, **kwargs + ) # Split out the updates if `mutable` is passed into the Flax module if mutable is not False: @@ -499,7 +532,31 @@ def maybe_unbox(x): for path, _ in unknown_state_flat.items(): paths_str += f"\n - {'/'.join(map(str, path))}" - warnings.warn(f"Found unknown module paths in incoming state:{paths_str}") + # Dynamically reconstruct the unknown variables + curr = module + for p in path[:-1]: + if isinstance(curr, dict): + if p not in curr: + curr[p] = nnx.Module() + curr = curr[p] + elif isinstance(curr, list): + if not isinstance(p, int): + raise TypeError(f"Expected int index for list, got {type(p)}: {p}") + while len(curr) <= p: + curr.append(nnx.Module()) + curr = curr[p] + elif isinstance(curr, tuple): + raise ValueError(f"Cannot dynamically reconstruct elements within a tuple at path {path}.") + else: + if not isinstance(p, str): + p = str(p) + if not hasattr(curr, p): + setattr(curr, p, nnx.Module()) + curr = getattr(curr, p) + + warnings.warn( + f"Found unknown module paths in incoming state:{paths_str}. Intermediate modules have been reconstructed." + ) nnx.update(module, new_state) _refresh_variable_trace_state(module) diff --git a/src/maxtext/layers/quantizations.py b/src/maxtext/layers/quantizations.py index d4688abb80..9cb09c8118 100644 --- a/src/maxtext/layers/quantizations.py +++ b/src/maxtext/layers/quantizations.py @@ -38,9 +38,13 @@ from flax.linen import fp8_ops from flax.linen import initializers as flax_initializers import flax.linen as nn +from flax import nnx + +from qwix._src import flax_util from maxtext.common.common_types import DType, Config from maxtext.inference.kvcache import KVQuant +from maxtext.layers import nnx_wrappers # Params used to define mixed precision quantization configs DEFAULT = "__default__" # default config @@ -245,6 +249,7 @@ def __call__( *, out_sharding=None, ) -> jax.Array: + return dot_general_qt.dot_general_qt(lhs, rhs, dimension_numbers, self.config) @@ -263,6 +268,7 @@ def __call__( _dot_general: Callable[..., jax.Array] | None = None, out_sharding=None, ) -> jax.Array: + def custom_dot_general(*args, **kwargs): return dot_general_qt.dot_general_qt(*args[:3], self.config) @@ -507,14 +513,9 @@ def _get_aqt_fp8_default_config(config): constant_bound_config = None if len(config.constant_bound_config) == 6: - ( - fwd_lhs_bound, - fwd_rhs_bound, - dlhs_lhs_bound, - dlhs_rhs_bound, - drhs_lhs_bound, - drhs_rhs_bound, - ) = config.constant_bound_config + fwd_lhs_bound, fwd_rhs_bound, dlhs_lhs_bound, dlhs_rhs_bound, drhs_lhs_bound, drhs_rhs_bound = ( + config.constant_bound_config + ) constant_bound_config = ConstantBoundConfig( fwd_lhs_bound=fwd_lhs_bound, fwd_rhs_bound=fwd_rhs_bound, @@ -710,6 +711,32 @@ def configure_kv_quant(config): return None if not config.quantize_kvcache else KVQuant(config) +def _apply_linen_module_in_nnx(linen_module_cls, op_id, *args, **kwargs): + """Applies a Linen module within an NNX context.""" + try: + parent = flax_util.get_current_module() + is_nnx = isinstance(parent, nnx.Module) + except Exception: # pylint: disable=broad-exception-caught + is_nnx = False + + if is_nnx: + attr_name = f"_qwix_fp8_gpu_{op_id}" + if not hasattr(parent, attr_name): + rngs = getattr(parent, "qwix_rngs", None) + if rngs is None: + parent_rngs = getattr(parent, "rngs", None) + if parent_rngs is not None and hasattr(parent_rngs, "fork"): + rngs = parent_rngs.fork() + else: + rngs = nnx.Rngs(0) + wrapper = nnx_wrappers.ToNNX(linen_module_cls(name=op_id), rngs=rngs) + wrapper.lazy_init(*args, **kwargs) + setattr(parent, attr_name, wrapper) + return getattr(parent, attr_name)(*args, mutable=["_overwrite_with_gradient"], **kwargs) + else: + return linen_module_cls(name=op_id)(*args, **kwargs) + + class NvidaFp8Provider(qwix.QtProvider): """Wraps nn.Fp8DirectDotGeneralOp with Qwix's provider interface.""" @@ -718,13 +745,13 @@ def dot_general(self, *args, **kwargs): rule, op_id = self._get_current_rule_and_op_id("dot_general") if rule is None: return jax.lax.dot_general(*args, **kwargs) - return nn.Fp8DirectDotGeneralOp(name=op_id)(*args, **kwargs) + return _apply_linen_module_in_nnx(nn.Fp8DirectDotGeneralOp, op_id, *args, **kwargs) def einsum(self, *args, **kwargs): rule, op_id = self._get_current_rule_and_op_id("einsum") if rule is None: return jnp.einsum(*args, **kwargs) - return nn.Fp8Einsum(name=op_id)(*args, **kwargs) + return _apply_linen_module_in_nnx(nn.Fp8Einsum, op_id, *args, **kwargs) class NANOOFp8Provider(qwix.QtProvider): @@ -734,7 +761,7 @@ def dot_general(self, *args, **kwargs): rule, op_id = self._get_current_rule_and_op_id("dot_general") if rule is None: return jax.lax.dot_general(*args, **kwargs) - return nn.NANOOFp8DotGeneralOp(name=op_id)(*args, **kwargs) + return _apply_linen_module_in_nnx(nn.NANOOFp8DotGeneralOp, op_id, *args, **kwargs) def get_fp8_full_qwix_rule_w_sparsity(config: Config): @@ -815,7 +842,21 @@ def maybe_quantize_model(model, config): if config.use_qwix_quantization and not config.use_batch_split_schedule: quantization_provider = get_qt_provider(config) if quantization_provider: - model = qwix.quantize_model(model, quantization_provider) + if config.pure_nnx: + input_shape = (config.micro_batch_size_to_train_on, config.max_target_length) + dummy_tokens = jnp.ones(input_shape, dtype=jnp.int32) + dummy_positions = jnp.ones(input_shape, dtype=jnp.int32) + dummy_segment_ids = jnp.ones(input_shape, dtype=jnp.int32) + model = qwix.quantize_model( + model, + quantization_provider, + dummy_tokens, + dummy_positions, + dummy_segment_ids, + enable_dropout=False, + ) + else: + model = qwix.quantize_model(model, quantization_provider) return model @@ -842,14 +883,13 @@ def _get_max_min(target_dtype): return jnp.finfo(target_dtype).max.astype(jnp.bfloat16), jnp.finfo(target_dtype).min.astype(jnp.bfloat16) -def manual_quantize(tensor: jax.Array, dtype: jax.typing.DTypeLike, calibration_method: str) -> qwix.QArray: +def manual_quantize(tensor, calibration_method, dtype=jnp.float8_e4m3fn): """Manually quantizes a tensor based on a fixed calibration method. Args: tensor: The tensor to quantize. - dtype: The logical type of the quantized value, e.g. jnp.float8_e4m3fn calibration_method: A string specifying the calibration method. Expected - format is "fixed,{scale},{max_val}". e.g., "fixed,-224,224" + format is "fixed,{scale},{max_val}". Returns: A qwix.QArray containing the quantized value and the scale. @@ -857,13 +897,12 @@ def manual_quantize(tensor: jax.Array, dtype: jax.typing.DTypeLike, calibration_ Raises: ValueError: If calibration_method is None or has an unexpected format. """ - # validate calibration method and parse calib_method = calibration_method if calib_method is None: raise ValueError("calibration_method cannot be None for manual quantization") if not calib_method.startswith("fixed"): - # we can use static scale for weight/activation, but grad usually needs dynamic - raise ValueError("Only static scale quantization is supported, but got" f" {calib_method}") + raise ValueError("Only static weight/activation quantization is supported, but got" f" {calib_method}") + parts = calib_method.split(",") if len(parts) != 3: raise ValueError(f"Unexpected format for weight calibration method: {calib_method}") diff --git a/tests/unit/quantizations_test.py b/tests/unit/quantizations_test.py index b0af64d9fc..fff430e646 100644 --- a/tests/unit/quantizations_test.py +++ b/tests/unit/quantizations_test.py @@ -22,6 +22,7 @@ from aqt.jax.v2 import aqt_tensor from aqt.jax.v2.flax import aqt_flax from flax import nnx +from flax.nnx import traversals import jax from jax import lax from jax import numpy as jnp @@ -48,7 +49,7 @@ def __init__( self, quantization: quantizations.AqtQuantization, data_type: Any, - rngs: nnx.Rngs, + rngs: nnx.Rngs, # pylint: disable=unused-argument ): self.quantization = quantization self.identity = jnp.identity(2, dtype=data_type) @@ -387,17 +388,116 @@ def compare_fn(path, x, y): jax.tree_util.tree_map_with_path(compare_fn, a, b) - def quantization_config(self, quant, logits_tolerance=2e-1, grad_tolerance=5e-1): + def quantization_config(self, quant, logits_tolerance=2e-1, grad_tolerance=5e-1, **kwargs): """Run forward pass and backward pass for quantized model and compare with base model.""" # pylint: disable=protected-access - cfg = self.init_pyconfig(quantization=quant) - qt_model = model_creation_utils.create_model(cfg, self.mesh) - + cfg = self.init_pyconfig(quantization=quant, **kwargs) ids, decoder_segment_ids, decoder_positions = self.get_data() - if not hasattr(self.__class__, "_cached_base_results"): - model = model_creation_utils.create_model(self.cfg, self.mesh) - var = model.init( + if cfg.pure_nnx: + qt_model = model_creation_utils.create_model(cfg, self.mesh, rngs=nnx.Rngs(0)) + if getattr(self.__class__, "_cached_base_results_nnx", None) is None: + base_cfg = self.init_pyconfig(quantization="", **kwargs) + base_model = model_creation_utils.create_model(base_cfg, self.mesh, rngs=nnx.Rngs(0)) + + def loss_base(model): + logits = model( + decoder_input_tokens=ids, + decoder_positions=decoder_positions, + decoder_segment_ids=decoder_segment_ids, + enable_dropout=False, + ) + return jnp.mean((logits) ** 2) + + grads_base = nnx.grad(loss_base)(base_model) + logits_base = base_model( + decoder_input_tokens=ids, + decoder_positions=decoder_positions, + decoder_segment_ids=decoder_segment_ids, + enable_dropout=False, + ) + self.__class__._cached_base_results_nnx = (grads_base, logits_base) + + grads_base, logits = self.__class__._cached_base_results_nnx + + def loss_quant(model): + logits_q = model( + decoder_input_tokens=ids, + decoder_positions=decoder_positions, + decoder_segment_ids=decoder_segment_ids, + enable_dropout=False, + ) + return jnp.mean((logits_q) ** 2) + + grads_quant = nnx.grad(loss_quant)(qt_model) + quant_logits = qt_model( + decoder_input_tokens=ids, + decoder_positions=decoder_positions, + decoder_segment_ids=decoder_segment_ids, + enable_dropout=False, + ) + + print("relative error in logits:" f" {jnp.abs(quant_logits - logits).mean() / jnp.abs(logits).mean()}") + assert jnp.abs(quant_logits - logits).mean() / jnp.abs(logits).mean() < logits_tolerance + + # nnx.grad returns a State object which is a mapping of paths to gradients. + # Flatten them to check for tolerance. + grads_base_flat = traversals.flatten_mapping(grads_base) + grads_quant_flat = traversals.flatten_mapping(grads_quant) + + # Filter for param collections to compare only parameters and not stats/buffers if any + # Note: NNX grads structure might contain variables like 'kernel', 'bias'. + # For simplicity we compare all matching keys. + def flatten_and_filter(grads_flat): + return {k: v for k, v in grads_flat.items() if hasattr(v, "shape") and "quant_stats" not in str(k)} + + gb_f = flatten_and_filter(grads_base_flat) + gq_f = flatten_and_filter(grads_quant_flat) + + for k in gb_f: + if k in gq_f: + diff = jnp.abs(gb_f[k] - gq_f[k]).mean() / (jnp.abs(gb_f[k]).mean() + 1e-8) + if diff > grad_tolerance: + print(f"Gradient mismatch for {k}: rel_error = {diff}") + assert diff <= grad_tolerance + else: + qt_model = model_creation_utils.create_model(cfg, self.mesh) + if not hasattr(self.__class__, "_cached_base_results"): + model = model_creation_utils.create_model(self.cfg, self.mesh) + var = model.init( + {"params": self.rng, "aqt": self.rng, "dropout": self.rng}, + ids, + decoder_positions, + decoder_segment_ids, + enable_dropout=False, + mutable=True, + ) + + def loss_base_linen(all_vars, inputs): + logits_b, _ = model.apply( + all_vars, + *inputs, + enable_dropout=False, + rngs={"params": self.rng}, + mutable=True, + ) + return jnp.mean((logits_b) ** 2) + + grads_base_linen = jax.grad(loss_base_linen)(var, (ids, decoder_positions, decoder_segment_ids)) + logits_b, _ = model.apply( + var, + ids, + decoder_positions, + decoder_segment_ids, + enable_dropout=False, + rngs={"params": self.rng}, + mutable=True, + ) + self.__class__._cached_base_results = (grads_base_linen, logits_b) + + grads_base_linen, logits = self.__class__._cached_base_results + + quantized_vars = qt_model.init( {"params": self.rng, "aqt": self.rng, "dropout": self.rng}, ids, decoder_positions, @@ -406,19 +506,20 @@ def quantization_config(self, quant, logits_tolerance=2e-1, grad_tolerance=5e-1) mutable=True, ) - def loss_base(all_vars, inputs): - logits, _ = model.apply( + def loss_quant_linen(all_vars, inputs): + logits_q, _ = qt_model.apply( all_vars, *inputs, enable_dropout=False, rngs={"params": self.rng}, mutable=True, ) - return jnp.mean((logits) ** 2) + return jnp.mean((logits_q) ** 2) + + grads_quant_linen = jax.grad(loss_quant_linen)(quantized_vars, (ids, decoder_positions, decoder_segment_ids)) - grads_base = jax.grad(loss_base)(var, (ids, decoder_positions, decoder_segment_ids)) - logits, _ = model.apply( - var, + quant_logits, _ = qt_model.apply( + quantized_vars, ids, decoder_positions, decoder_segment_ids, @@ -426,74 +527,61 @@ def loss_base(all_vars, inputs): rngs={"params": self.rng}, mutable=True, ) - self.__class__._cached_base_results = (grads_base, logits) - - grads_base, logits = self.__class__._cached_base_results - - quantized_vars = qt_model.init( - {"params": self.rng, "aqt": self.rng, "dropout": self.rng}, - ids, - decoder_positions, - decoder_segment_ids, - enable_dropout=False, - mutable=True, - ) - - def loss_quant(all_vars, inputs): - logits, _ = qt_model.apply( - all_vars, - *inputs, - enable_dropout=False, - rngs={"params": self.rng}, - mutable=True, + print("relative error in logits:" f" {jnp.abs(quant_logits - logits).mean() / jnp.abs(logits).mean()}") + assert jnp.abs(quant_logits - logits).mean() / jnp.abs(logits).mean() < logits_tolerance + self.print_grad_diff(grads_base_linen["params"], grads_quant_linen["params"]) + self.assertTrue( + self.pytree_allclose( + grads_base_linen["params"], + grads_quant_linen["params"], + tolerance=grad_tolerance, + ) ) - return jnp.mean((logits) ** 2) - - # Compute gradients w.r.t. both models - grads_quant = jax.grad(loss_quant)(quantized_vars, (ids, decoder_positions, decoder_segment_ids)) - - quant_logits, _ = qt_model.apply( - quantized_vars, - ids, - decoder_positions, - decoder_segment_ids, - enable_dropout=False, - rngs={"params": self.rng}, - mutable=True, - ) - print("relative error in logits:" f" {jnp.abs(quant_logits - logits).mean() / jnp.abs(logits).mean()}") - assert jnp.abs(quant_logits - logits).mean() / jnp.abs(logits).mean() < logits_tolerance - self.print_grad_diff(grads_base["params"], grads_quant["params"]) - self.assertTrue( - self.pytree_allclose( - grads_base["params"], - grads_quant["params"], - tolerance=grad_tolerance, - ) - ) @pytest.mark.tpu_only def test_int8_quantization(self): self.quantization_config("int8") + @pytest.mark.tpu_only + def test_int8_quantization_nnx(self): + self.quantization_config("int8", enable_nnx=True, pure_nnx_decoder=True, pure_nnx=True) + @pytest.mark.tpu_only def test_fp8_quantization(self): self.quantization_config("fp8") + @pytest.mark.tpu_only + def test_fp8_quantization_nnx(self): + self.quantization_config("fp8", enable_nnx=True, pure_nnx_decoder=True, pure_nnx=True) + @pytest.mark.tpu_only def test_fp8_full_quantization(self): self.quantization_config("fp8_full") + @pytest.mark.tpu_only + def test_fp8_full_quantization_nnx(self): + self.quantization_config("fp8_full", enable_nnx=True, pure_nnx_decoder=True, pure_nnx=True) + @pytest.mark.gpu_only @pytest.mark.external_serving def test_fp8_gpu_quantization(self): self.quantization_config("fp8_gpu", grad_tolerance=1.5) + @pytest.mark.gpu_only + @pytest.mark.external_serving + def test_fp8_gpu_quantization_nnx(self): + self.quantization_config("fp8_gpu", grad_tolerance=1.5, enable_nnx=True, pure_nnx_decoder=True, pure_nnx=True) + @pytest.mark.gpu_only @pytest.mark.external_serving def test_fp8_nanoo_quantization(self): self.quantization_config("fp8_nanoo", grad_tolerance=1.5) + @pytest.mark.gpu_only + @pytest.mark.external_serving + def test_fp8_nanoo_quantization_nnx(self): + self.quantization_config("fp8_nanoo", grad_tolerance=1.5, enable_nnx=True, pure_nnx_decoder=True, pure_nnx=True) + @pytest.mark.skip(reason="No runner with GPU arch >= 89 is available") @pytest.mark.gpu_only def test_fp8_te_fp8_delayedscaling_quantization(self):