|
| 1 | +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
| 2 | +# |
| 3 | +# See LICENSE for license information. |
| 4 | + |
| 5 | +"""JAX: Dense GEMMs with TransformerEngine. |
| 6 | +
|
| 7 | +Companion source for ``dense.rst``. Code blocks between ``# DENSE_*_START`` / |
| 8 | +``# DENSE_*_END`` markers are pulled into the RST via ``literalinclude``. |
| 9 | +
|
| 10 | +Run as a script to exercise the example end-to-end: |
| 11 | +
|
| 12 | + python docs/examples/jax/dense.py |
| 13 | +
|
| 14 | +Pytest tests live in ``test_dense.py``; the multi-GPU section auto-skips when |
| 15 | +fewer than 4 GPUs are visible. |
| 16 | +""" |
| 17 | + |
| 18 | +# DENSE_IMPORTS_START |
| 19 | +import jax |
| 20 | +import jax.numpy as jnp |
| 21 | +from flax import linen as nn |
| 22 | + |
| 23 | +import quickstart_jax_utils as utils |
| 24 | + |
| 25 | +# DENSE_IMPORTS_END |
| 26 | + |
| 27 | + |
| 28 | +# DENSE_BASELINE_MODEL_START |
| 29 | +class FlaxDenseBlock(nn.Module): |
| 30 | + """One linear layer. ``dot_general_cls`` lets us swap the GEMM impl.""" |
| 31 | + |
| 32 | + features: int |
| 33 | + dtype: jnp.dtype = jnp.bfloat16 |
| 34 | + dot_general_cls: callable = lambda: None |
| 35 | + |
| 36 | + @nn.compact |
| 37 | + def __call__(self, x): |
| 38 | + return nn.Dense( |
| 39 | + features=self.features, |
| 40 | + use_bias=False, |
| 41 | + dtype=self.dtype, |
| 42 | + dot_general=self.dot_general_cls(), |
| 43 | + )(x) |
| 44 | + |
| 45 | + |
| 46 | +# DENSE_BASELINE_MODEL_END |
| 47 | + |
| 48 | + |
| 49 | +# DENSE_INPUTS_SETUP_START |
| 50 | +batch, seq, hidden, out_features = 8, 2048, 8192, 32768 |
| 51 | +dtype = jnp.bfloat16 |
| 52 | + |
| 53 | +key = jax.random.PRNGKey(0) |
| 54 | +k_init, k_x, k_dy = jax.random.split(key, 3) |
| 55 | +x = jax.random.normal(k_x, (batch, seq, hidden)).astype(dtype) |
| 56 | +dy = jax.random.normal(k_dy, (batch, seq, out_features)).astype(dtype) |
| 57 | + |
| 58 | +baseline = FlaxDenseBlock(features=out_features) |
| 59 | +baseline_vars = baseline.init(k_init, x) |
| 60 | +# DENSE_INPUTS_SETUP_END |
| 61 | + |
| 62 | + |
| 63 | +# DENSE_TE_SETUP_START |
| 64 | +from transformer_engine.jax import flax as te_flax |
| 65 | +from transformer_engine.common.recipe import MXFP8BlockScaling |
| 66 | + |
| 67 | +recipe = MXFP8BlockScaling() |
| 68 | +te_dot_general_cls = te_flax.make_dot_general_cls(recipe) |
| 69 | + |
| 70 | +te_model = FlaxDenseBlock(features=out_features, dot_general_cls=te_dot_general_cls) |
| 71 | +te_vars = te_model.init(k_init, x) |
| 72 | + |
| 73 | +print("Variable collections:", list(te_vars.keys())) |
| 74 | +print(jax.tree_util.tree_map(lambda a: (a.shape, a.dtype), te_vars)) |
| 75 | +# DENSE_TE_SETUP_END |
| 76 | + |
| 77 | + |
| 78 | +# DENSE_SINGLE_GPU_BENCH_START |
| 79 | +def run_single_gpu_bench(): |
| 80 | + print("bf16 baseline:") |
| 81 | + utils.speedometer( |
| 82 | + model_apply_fn=baseline.apply, |
| 83 | + variables=baseline_vars, |
| 84 | + input=x, |
| 85 | + output_grad=dy, |
| 86 | + ) |
| 87 | + |
| 88 | + print(f"\nTE {type(recipe).__name__}:") |
| 89 | + utils.speedometer( |
| 90 | + model_apply_fn=te_model.apply, |
| 91 | + variables=te_vars, |
| 92 | + input=x, |
| 93 | + output_grad=dy, |
| 94 | + ) |
| 95 | + |
| 96 | + |
| 97 | +# DENSE_SINGLE_GPU_BENCH_END |
| 98 | + |
| 99 | + |
| 100 | +# DENSE_MULTI_GPU_MESH_SETUP_START |
| 101 | +from jax.sharding import Mesh, NamedSharding, PartitionSpec as P |
| 102 | +from jax.experimental import mesh_utils |
| 103 | +from transformer_engine.jax.sharding import MeshResource, global_shard_guard |
| 104 | + |
| 105 | + |
| 106 | +def build_dp_tp_mesh(): |
| 107 | + # 2x2 mesh: DP on one axis, TP on the other. |
| 108 | + devices = mesh_utils.create_device_mesh((2, 2)) |
| 109 | + mesh = Mesh(devices, axis_names=("dp", "tp")) |
| 110 | + |
| 111 | + # Tell TE which mesh axis is which. This is a *global* setting, established |
| 112 | + # outside JIT, so TE's GEMM primitives can plan comms accordingly. |
| 113 | + mesh_resource = MeshResource(dp_resource="dp", tp_resource="tp") |
| 114 | + return mesh, mesh_resource |
| 115 | + |
| 116 | + |
| 117 | +# DENSE_MULTI_GPU_MESH_SETUP_END |
| 118 | + |
| 119 | + |
| 120 | +# DENSE_MULTI_GPU_SHARD_SETUP_START |
| 121 | +def shard_variables(mesh, variables_dict): |
| 122 | + kernel_sharding = NamedSharding(mesh, P(None, "tp")) |
| 123 | + |
| 124 | + def _shard(variables): |
| 125 | + params = variables["params"] |
| 126 | + sharded = jax.device_put(params["Dense_0"]["kernel"], kernel_sharding) |
| 127 | + return { |
| 128 | + **variables, |
| 129 | + "params": { |
| 130 | + **params, |
| 131 | + "Dense_0": {**params["Dense_0"], "kernel": sharded}, |
| 132 | + }, |
| 133 | + } |
| 134 | + |
| 135 | + input_sharding = NamedSharding(mesh, P("dp", None, None)) |
| 136 | + output_grad_sharding = NamedSharding(mesh, P("dp", None, "tp")) |
| 137 | + |
| 138 | + return { |
| 139 | + "x": jax.device_put(x, input_sharding), |
| 140 | + "dy": jax.device_put(dy, output_grad_sharding), |
| 141 | + **{name: _shard(vars_) for name, vars_ in variables_dict.items()}, |
| 142 | + } |
| 143 | + |
| 144 | + |
| 145 | +# DENSE_MULTI_GPU_SHARD_SETUP_END |
| 146 | + |
| 147 | + |
| 148 | +# DENSE_MULTI_GPU_BENCH_START |
| 149 | +def run_multi_gpu_bench(): |
| 150 | + mesh, mesh_resource = build_dp_tp_mesh() |
| 151 | + sharded = shard_variables(mesh, {"baseline": baseline_vars, "te": te_vars}) |
| 152 | + |
| 153 | + with jax.set_mesh(mesh), global_shard_guard(mesh_resource): |
| 154 | + print("bf16 DP=2/TP=2:") |
| 155 | + utils.speedometer( |
| 156 | + model_apply_fn=baseline.apply, |
| 157 | + variables=sharded["baseline"], |
| 158 | + input=sharded["x"], |
| 159 | + output_grad=sharded["dy"], |
| 160 | + ) |
| 161 | + |
| 162 | + print(f"\nTE {type(recipe).__name__} DP=2/TP=2:") |
| 163 | + utils.speedometer( |
| 164 | + model_apply_fn=te_model.apply, |
| 165 | + variables=sharded["te"], |
| 166 | + input=sharded["x"], |
| 167 | + output_grad=sharded["dy"], |
| 168 | + ) |
| 169 | + |
| 170 | + |
| 171 | +# DENSE_MULTI_GPU_BENCH_END |
| 172 | + |
| 173 | + |
| 174 | +if __name__ == "__main__": |
| 175 | + run_single_gpu_bench() |
| 176 | + if len(jax.devices()) >= 4: |
| 177 | + print() |
| 178 | + run_multi_gpu_bench() |
| 179 | + else: |
| 180 | + print("\n[skipped multi-GPU section: <4 devices visible]") |
0 commit comments