Skip to content

Commit d8e8ba4

Browse files
Merge branch 'main' of https://github.com/NVIDIA/TransformerEngine into flash_attn_pad_bw_seqs
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com> # Conflicts: # tests/pytorch/attention/test_attention_with_cp.py
2 parents 77941e0 + 86ade9e commit d8e8ba4

51 files changed

Lines changed: 2366 additions & 1038 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
#
3+
# See LICENSE for license information.
4+
5+
# A workflow to automatically label the contributions as community/org
6+
name: Label community contributions
7+
8+
on:
9+
pull_request_target:
10+
types: [opened, reopened, ready_for_review, synchronize]
11+
12+
permissions:
13+
contents: read
14+
issues: write
15+
pull-requests: write
16+
17+
jobs:
18+
label:
19+
runs-on: ubuntu-latest
20+
steps:
21+
- uses: actions/github-script@3a2844b7e9c422d3c10d287c895573f7108da1b3
22+
with:
23+
script: |
24+
const pr = context.payload.pull_request;
25+
const user = pr.user.login;
26+
const association = pr.author_association;
27+
28+
const communityLabel = "community-contribution";
29+
const orgLabel = "org-contribution";
30+
31+
let targetLabel = null;
32+
33+
const isOrgMember =
34+
association === "MEMBER" || association === "OWNER";
35+
36+
let permission = "none";
37+
38+
try {
39+
const res = await github.rest.repos.getCollaboratorPermissionLevel({
40+
owner: context.repo.owner,
41+
repo: context.repo.repo,
42+
username: user,
43+
});
44+
permission = res.data.permission;
45+
} catch (e) {
46+
if (e.status !== 404) throw e;
47+
}
48+
49+
const isCore = permission === "write" || permission === "admin";
50+
if (!isOrgMember) {
51+
targetLabel = communityLabel;
52+
} else {
53+
targetLabel = orgLabel;
54+
}
55+
56+
if (!isCore) {
57+
await github.rest.issues.addLabels({
58+
owner: context.repo.owner,
59+
repo: context.repo.repo,
60+
issue_number: pr.number,
61+
labels: [targetLabel],
62+
});
63+
}

3rdparty/cudnn-frontend

Submodule cudnn-frontend updated 301 files

build_tools/VERSION.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
2.16.0.dev0
1+
2.17.0.dev0

docs/examples/jax/attention.rst

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
..
2+
Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3+
4+
See LICENSE for license information.
5+
6+
JAX: Attention with TransformerEngine
7+
=====================================
8+
9+
**TODO — Coming soon.**
10+
11+
`← Back to the JAX integration overview <../te_jax_integration.html>`_
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
..
2+
Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3+
4+
See LICENSE for license information.
5+
6+
JAX: Collective GEMMs with TransformerEngine
7+
=============================================
8+
9+
**TODO — Coming soon.**
10+
11+
`← Back to the JAX integration overview <../te_jax_integration.html>`_

docs/examples/jax/dense.out

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# Numbers below are illustrative (captured on a GB200). Regenerate with:
2+
# python3 docs/examples/jax/dense.py > dense.out
3+
4+
# SINGLE_GPU_OUTPUT_START
5+
Variable collections: ['params']
6+
{'params': {'Dense_0': {'kernel': ((8192, 32768), dtype('float32'))}}}
7+
8+
bf16 baseline:
9+
Mean time: 18.056 ms
10+
11+
TE MXFP8BlockScaling:
12+
Mean time: 11.260 ms
13+
# SINGLE_GPU_OUTPUT_END
14+
15+
# MULTI_GPU_OUTPUT_START
16+
bf16 DP=2/TP=2:
17+
Mean time: 5.516 ms
18+
19+
TE MXFP8BlockScaling DP=2/TP=2:
20+
Mean time: 3.712 ms
21+
# MULTI_GPU_OUTPUT_END

docs/examples/jax/dense.py

Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,180 @@
1+
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
#
3+
# See LICENSE for license information.
4+
5+
"""JAX: Dense GEMMs with TransformerEngine.
6+
7+
Companion source for ``dense.rst``. Code blocks between ``# DENSE_*_START`` /
8+
``# DENSE_*_END`` markers are pulled into the RST via ``literalinclude``.
9+
10+
Run as a script to exercise the example end-to-end:
11+
12+
python docs/examples/jax/dense.py
13+
14+
Pytest tests live in ``test_dense.py``; the multi-GPU section auto-skips when
15+
fewer than 4 GPUs are visible.
16+
"""
17+
18+
# DENSE_IMPORTS_START
19+
import jax
20+
import jax.numpy as jnp
21+
from flax import linen as nn
22+
23+
import quickstart_jax_utils as utils
24+
25+
# DENSE_IMPORTS_END
26+
27+
28+
# DENSE_BASELINE_MODEL_START
29+
class FlaxDenseBlock(nn.Module):
30+
"""One linear layer. ``dot_general_cls`` lets us swap the GEMM impl."""
31+
32+
features: int
33+
dtype: jnp.dtype = jnp.bfloat16
34+
dot_general_cls: callable = lambda: None
35+
36+
@nn.compact
37+
def __call__(self, x):
38+
return nn.Dense(
39+
features=self.features,
40+
use_bias=False,
41+
dtype=self.dtype,
42+
dot_general=self.dot_general_cls(),
43+
)(x)
44+
45+
46+
# DENSE_BASELINE_MODEL_END
47+
48+
49+
# DENSE_INPUTS_SETUP_START
50+
batch, seq, hidden, out_features = 8, 2048, 8192, 32768
51+
dtype = jnp.bfloat16
52+
53+
key = jax.random.PRNGKey(0)
54+
k_init, k_x, k_dy = jax.random.split(key, 3)
55+
x = jax.random.normal(k_x, (batch, seq, hidden)).astype(dtype)
56+
dy = jax.random.normal(k_dy, (batch, seq, out_features)).astype(dtype)
57+
58+
baseline = FlaxDenseBlock(features=out_features)
59+
baseline_vars = baseline.init(k_init, x)
60+
# DENSE_INPUTS_SETUP_END
61+
62+
63+
# DENSE_TE_SETUP_START
64+
from transformer_engine.jax import flax as te_flax
65+
from transformer_engine.common.recipe import MXFP8BlockScaling
66+
67+
recipe = MXFP8BlockScaling()
68+
te_dot_general_cls = te_flax.make_dot_general_cls(recipe)
69+
70+
te_model = FlaxDenseBlock(features=out_features, dot_general_cls=te_dot_general_cls)
71+
te_vars = te_model.init(k_init, x)
72+
73+
print("Variable collections:", list(te_vars.keys()))
74+
print(jax.tree_util.tree_map(lambda a: (a.shape, a.dtype), te_vars))
75+
# DENSE_TE_SETUP_END
76+
77+
78+
# DENSE_SINGLE_GPU_BENCH_START
79+
def run_single_gpu_bench():
80+
print("bf16 baseline:")
81+
utils.speedometer(
82+
model_apply_fn=baseline.apply,
83+
variables=baseline_vars,
84+
input=x,
85+
output_grad=dy,
86+
)
87+
88+
print(f"\nTE {type(recipe).__name__}:")
89+
utils.speedometer(
90+
model_apply_fn=te_model.apply,
91+
variables=te_vars,
92+
input=x,
93+
output_grad=dy,
94+
)
95+
96+
97+
# DENSE_SINGLE_GPU_BENCH_END
98+
99+
100+
# DENSE_MULTI_GPU_MESH_SETUP_START
101+
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
102+
from jax.experimental import mesh_utils
103+
from transformer_engine.jax.sharding import MeshResource, global_shard_guard
104+
105+
106+
def build_dp_tp_mesh():
107+
# 2x2 mesh: DP on one axis, TP on the other.
108+
devices = mesh_utils.create_device_mesh((2, 2))
109+
mesh = Mesh(devices, axis_names=("dp", "tp"))
110+
111+
# Tell TE which mesh axis is which. This is a *global* setting, established
112+
# outside JIT, so TE's GEMM primitives can plan comms accordingly.
113+
mesh_resource = MeshResource(dp_resource="dp", tp_resource="tp")
114+
return mesh, mesh_resource
115+
116+
117+
# DENSE_MULTI_GPU_MESH_SETUP_END
118+
119+
120+
# DENSE_MULTI_GPU_SHARD_SETUP_START
121+
def shard_variables(mesh, variables_dict):
122+
kernel_sharding = NamedSharding(mesh, P(None, "tp"))
123+
124+
def _shard(variables):
125+
params = variables["params"]
126+
sharded = jax.device_put(params["Dense_0"]["kernel"], kernel_sharding)
127+
return {
128+
**variables,
129+
"params": {
130+
**params,
131+
"Dense_0": {**params["Dense_0"], "kernel": sharded},
132+
},
133+
}
134+
135+
input_sharding = NamedSharding(mesh, P("dp", None, None))
136+
output_grad_sharding = NamedSharding(mesh, P("dp", None, "tp"))
137+
138+
return {
139+
"x": jax.device_put(x, input_sharding),
140+
"dy": jax.device_put(dy, output_grad_sharding),
141+
**{name: _shard(vars_) for name, vars_ in variables_dict.items()},
142+
}
143+
144+
145+
# DENSE_MULTI_GPU_SHARD_SETUP_END
146+
147+
148+
# DENSE_MULTI_GPU_BENCH_START
149+
def run_multi_gpu_bench():
150+
mesh, mesh_resource = build_dp_tp_mesh()
151+
sharded = shard_variables(mesh, {"baseline": baseline_vars, "te": te_vars})
152+
153+
with jax.set_mesh(mesh), global_shard_guard(mesh_resource):
154+
print("bf16 DP=2/TP=2:")
155+
utils.speedometer(
156+
model_apply_fn=baseline.apply,
157+
variables=sharded["baseline"],
158+
input=sharded["x"],
159+
output_grad=sharded["dy"],
160+
)
161+
162+
print(f"\nTE {type(recipe).__name__} DP=2/TP=2:")
163+
utils.speedometer(
164+
model_apply_fn=te_model.apply,
165+
variables=sharded["te"],
166+
input=sharded["x"],
167+
output_grad=sharded["dy"],
168+
)
169+
170+
171+
# DENSE_MULTI_GPU_BENCH_END
172+
173+
174+
if __name__ == "__main__":
175+
run_single_gpu_bench()
176+
if len(jax.devices()) >= 4:
177+
print()
178+
run_multi_gpu_bench()
179+
else:
180+
print("\n[skipped multi-GPU section: <4 devices visible]")

0 commit comments

Comments
 (0)