Skip to content

Commit 77941e0

Browse files
Merge branch 'main' of github.com:NVIDIA/TransformerEngine into flash_attn_pad_bw_seqs
2 parents 9389309 + eca05d3 commit 77941e0

81 files changed

Lines changed: 5378 additions & 2261 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

.github/workflows/lint.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ concurrency:
1111
# Group by workflow name + PR number (for PRs) or ref (for branch/tag pushes)
1212
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
1313
cancel-in-progress: true
14+
permissions:
15+
contents: read
1416
jobs:
1517
pytorch_cpplint:
1618
name: 'PyTorch C++'

CODEOWNERS

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
# IMPORTANT:
2+
# This file is ONLY used to subscribe for notifications for PRs
3+
# related to a specific file path. Approvals from people in this
4+
# file are not required for merges.
5+
6+
# C API
7+
/transformer_engine/common/include/ @ptrendx
8+
9+
# TE/JAX
10+
/transformer_engine/jax/ @jberchtold-nvidia
11+
12+
# TE/PyTorch
13+
/transformer_engine/pytorch/ @ksivaman
14+
15+
# te.ops API
16+
/transformer_engine/pytorch/ops/ @timmoon10
17+
18+
# Quantization kernels
19+
/transformer_engine/common/cast/ @Oleg-Goncharov
20+
21+
# Attention
22+
/transformer_engine/pytorch/attention/ @cyanguwa
23+
/transformer_engine/common/fused_attn/ @cyanguwa
24+
/transformer_engine/jax/cpp_extensions/attention.py @KshitijLakhani

build_tools/wheel_utils/Dockerfile.aarch

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ ENV CUDA_MAJOR=${CUDA_MAJOR}
2323

2424
# Cuda toolkit, cudnn, driver.
2525
RUN dnf config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel8/sbsa/cuda-rhel8.repo
26-
RUN dnf -y install epel-release
2726
RUN dnf -y install cuda-compiler-${CUDA_MAJOR}-${CUDA_MINOR}.aarch64 \
2827
cuda-libraries-${CUDA_MAJOR}-${CUDA_MINOR}.aarch64 \
2928
cuda-libraries-devel-${CUDA_MAJOR}-${CUDA_MINOR}.aarch64

build_tools/wheel_utils/Dockerfile.x86

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ ENV CUDA_MAJOR=${CUDA_MAJOR}
2323

2424
# Cuda toolkit, cudnn, driver.
2525
RUN dnf config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel8/x86_64/cuda-rhel8.repo
26-
RUN dnf -y install epel-release
2726
RUN dnf -y install cuda-compiler-${CUDA_MAJOR}-${CUDA_MINOR}.x86_64 \
2827
cuda-libraries-${CUDA_MAJOR}-${CUDA_MINOR}.x86_64 \
2928
cuda-libraries-devel-${CUDA_MAJOR}-${CUDA_MINOR}.x86_64
@@ -44,4 +43,4 @@ ENV CUDA_PATH=/usr/local/cuda
4443
ENV CUDADIR=/usr/local/cuda
4544
ENV NVTE_RELEASE_BUILD=1
4645

47-
CMD ["/bin/bash", "-c", "bash /TransformerEngine/build_tools/wheel_utils/build_wheels.sh manylinux_2_28_x86_64 $BUILD_METAPACKAGE $BUILD_COMMON $BUILD_PYTORCH $BUILD_JAX $CUDA_MAJOR"]
46+
CMD ["/bin/bash", "-c", "bash /TransformerEngine/build_tools/wheel_utils/build_wheels.sh manylinux_2_28_x86_64 $BUILD_METAPACKAGE $BUILD_COMMON $BUILD_PYTORCH $BUILD_JAX $CUDA_MAJOR"]

docs/Doxyfile

Lines changed: 0 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -93,14 +93,6 @@ ALLOW_UNICODE_NAMES = NO
9393

9494
OUTPUT_LANGUAGE = English
9595

96-
# The OUTPUT_TEXT_DIRECTION tag is used to specify the direction in which all
97-
# documentation generated by doxygen is written. Doxygen will use this
98-
# information to generate all generated output in the proper direction.
99-
# Possible values are: None, LTR, RTL and Context.
100-
# The default value is: None.
101-
102-
OUTPUT_TEXT_DIRECTION = None
103-
10496
# If the BRIEF_MEMBER_DESC tag is set to YES, doxygen will include brief member
10597
# descriptions after the members that are listed in the file and class
10698
# documentation (similar to Javadoc). Set to NO to disable this.
@@ -263,12 +255,6 @@ TAB_SIZE = 2
263255

264256
ALIASES =
265257

266-
# This tag can be used to specify a number of word-keyword mappings (TCL only).
267-
# A mapping has the form "name=value". For example adding "class=itcl::class"
268-
# will allow you to use the command class in the itcl::class meaning.
269-
270-
TCL_SUBST =
271-
272258
# Set the OPTIMIZE_OUTPUT_FOR_C tag to YES if your project consists of C sources
273259
# only. Doxygen will then generate output that is more tailored for C. For
274260
# instance, some of the names that are used will be different. The list of all
@@ -1156,13 +1142,6 @@ CLANG_DATABASE_PATH =
11561142

11571143
ALPHABETICAL_INDEX = YES
11581144

1159-
# The COLS_IN_ALPHA_INDEX tag can be used to specify the number of columns in
1160-
# which the alphabetical index list will be split.
1161-
# Minimum value: 1, maximum value: 20, default value: 5.
1162-
# This tag requires that the tag ALPHABETICAL_INDEX is set to YES.
1163-
1164-
COLS_IN_ALPHA_INDEX = 5
1165-
11661145
# In case all classes in a project start with a common prefix, all classes will
11671146
# be put under the same header in the alphabetical index. The IGNORE_PREFIX tag
11681147
# can be used to specify a prefix (or a list of prefixes) that should be ignored
@@ -1290,15 +1269,6 @@ HTML_COLORSTYLE_SAT = 100
12901269

12911270
HTML_COLORSTYLE_GAMMA = 80
12921271

1293-
# If the HTML_TIMESTAMP tag is set to YES then the footer of each generated HTML
1294-
# page will contain the date and time when the page was generated. Setting this
1295-
# to YES can help to show when doxygen was last run and thus if the
1296-
# documentation is up to date.
1297-
# The default value is: NO.
1298-
# This tag requires that the tag GENERATE_HTML is set to YES.
1299-
1300-
HTML_TIMESTAMP = NO
1301-
13021272
# If the HTML_DYNAMIC_MENUS tag is set to YES then the generated HTML
13031273
# documentation will contain a main index with vertical navigation menus that
13041274
# are dynamically created via JavaScript. If disabled, the navigation index will
@@ -1580,17 +1550,6 @@ EXT_LINKS_IN_WINDOW = NO
15801550

15811551
FORMULA_FONTSIZE = 10
15821552

1583-
# Use the FORMULA_TRANSPARENT tag to determine whether or not the images
1584-
# generated for formulas are transparent PNGs. Transparent PNGs are not
1585-
# supported properly for IE 6.0, but are supported on all modern browsers.
1586-
#
1587-
# Note that when changing this option you need to delete any form_*.png files in
1588-
# the HTML output directory before the changes have effect.
1589-
# The default value is: YES.
1590-
# This tag requires that the tag GENERATE_HTML is set to YES.
1591-
1592-
FORMULA_TRANSPARENT = YES
1593-
15941553
# The FORMULA_MACROFILE can contain LaTeX \newcommand and \renewcommand commands
15951554
# to create new LaTeX commands to be used in formulas as building blocks. See
15961555
# the section "Including formulas" for details.
@@ -1889,16 +1848,6 @@ LATEX_BATCHMODE = NO
18891848

18901849
LATEX_HIDE_INDICES = NO
18911850

1892-
# If the LATEX_SOURCE_CODE tag is set to YES then doxygen will include source
1893-
# code with syntax highlighting in the LaTeX output.
1894-
#
1895-
# Note that which sources are shown also depends on other settings such as
1896-
# SOURCE_BROWSER.
1897-
# The default value is: NO.
1898-
# This tag requires that the tag GENERATE_LATEX is set to YES.
1899-
1900-
LATEX_SOURCE_CODE = NO
1901-
19021851
# The LATEX_BIB_STYLE tag can be used to specify the style to use for the
19031852
# bibliography, e.g. plainnat, or ieeetr. See
19041853
# https://en.wikipedia.org/wiki/BibTeX and \cite for more info.
@@ -1907,14 +1856,6 @@ LATEX_SOURCE_CODE = NO
19071856

19081857
LATEX_BIB_STYLE = plain
19091858

1910-
# If the LATEX_TIMESTAMP tag is set to YES then the footer of each generated
1911-
# page will contain the date and time when the page was generated. Setting this
1912-
# to NO can help when comparing the output of multiple runs.
1913-
# The default value is: NO.
1914-
# This tag requires that the tag GENERATE_LATEX is set to YES.
1915-
1916-
LATEX_TIMESTAMP = NO
1917-
19181859
# The LATEX_EMOJI_DIRECTORY tag is used to specify the (relative or absolute)
19191860
# path from which the emoji images will be read. If a relative path is entered,
19201861
# it will be relative to the LATEX_OUTPUT directory. If left blank the
@@ -1979,16 +1920,6 @@ RTF_STYLESHEET_FILE =
19791920

19801921
RTF_EXTENSIONS_FILE =
19811922

1982-
# If the RTF_SOURCE_CODE tag is set to YES then doxygen will include source code
1983-
# with syntax highlighting in the RTF output.
1984-
#
1985-
# Note that which sources are shown also depends on other settings such as
1986-
# SOURCE_BROWSER.
1987-
# The default value is: NO.
1988-
# This tag requires that the tag GENERATE_RTF is set to YES.
1989-
1990-
RTF_SOURCE_CODE = NO
1991-
19921923
#---------------------------------------------------------------------------
19931924
# Configuration options related to the man page output
19941925
#---------------------------------------------------------------------------
@@ -2085,15 +2016,6 @@ GENERATE_DOCBOOK = NO
20852016

20862017
DOCBOOK_OUTPUT = docbook
20872018

2088-
# If the DOCBOOK_PROGRAMLISTING tag is set to YES, doxygen will include the
2089-
# program listings (including syntax highlighting and cross-referencing
2090-
# information) to the DOCBOOK output. Note that enabling this will significantly
2091-
# increase the size of the DOCBOOK output.
2092-
# The default value is: NO.
2093-
# This tag requires that the tag GENERATE_DOCBOOK is set to YES.
2094-
2095-
DOCBOOK_PROGRAMLISTING = NO
2096-
20972019
#---------------------------------------------------------------------------
20982020
# Configuration options for the AutoGen Definitions output
20992021
#---------------------------------------------------------------------------

docs/api/pytorch.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ PyTorch
3838
:members: reset, get_states, set_states, add, fork
3939

4040

41-
.. autoapifunction:: transformer_engine.pytorch.autocast
41+
.. autoapiclass:: transformer_engine.pytorch.autocast(enabled=True, calibrating=False, recipe=None, amax_reduction_group=None)
4242

4343
.. autoapifunction:: transformer_engine.pytorch.quantized_model_init
4444

examples/jax/collective_gemm/run_test_cgemm.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,5 +143,6 @@ wait
143143

144144
# Final cleanup (trap will also call cleanup on exit)
145145
cleanup
146+
wait
146147

147148
exit $HAS_FAILURE

examples/jax/encoder/run_test_multiprocessing_encoder.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,5 +98,6 @@ wait
9898

9999
# Final cleanup (trap will also call cleanup on exit)
100100
cleanup
101+
wait
101102

102103
exit $HAS_FAILURE

qa/L0_pytorch_unittest/test.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ pip3 install pytest==8.2.1 || error_exit "Failed to install pytest"
2626

2727
NVTE_GROUPED_LINEAR_SINGLE_PARAM=1 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_sanity.xml $TE_PATH/tests/pytorch/test_sanity.py || test_fail "test_sanity.py"
2828
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_recipe.xml $TE_PATH/tests/pytorch/test_recipe.py || test_fail "test_recipe.py"
29+
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_custom_recipe.xml $TE_PATH/tests/pytorch/test_custom_recipe.py || test_fail "test_custom_recipe.py"
2930
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_deferred_init.xml $TE_PATH/tests/pytorch/test_deferred_init.py || test_fail "test_deferred_init.py"
3031
PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 NVTE_FUSED_ATTN=0 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_numerics.xml $TE_PATH/tests/pytorch/test_numerics.py || test_fail "test_numerics.py"
3132
PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 NVTE_FUSED_ATTN=0 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_cuda_graphs.xml $TE_PATH/tests/pytorch/test_cuda_graphs.py || test_fail "test_cuda_graphs.py"

tests/cpp/operator/test_act.cu

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,7 @@ void performTest(const size_t N, const size_t H) {
124124
fillUniform(&input);
125125
fillUniform(&ograd);
126126
setRandomScale(&output);
127+
const float ref_scale = isFp8Type(otype) ? output.scale() : 1.0f;
127128

128129
std::unique_ptr<OType[]> ref_output = std::make_unique<OType[]>(N*H);
129130
std::unique_ptr<IType[]> ref_igrad = std::make_unique<IType[]>(N*H);
@@ -132,7 +133,7 @@ void performTest(const size_t N, const size_t H) {
132133

133134
float ref_amax;
134135
compute_ref_act_cast<ref_act>(input.rowwise_cpu_dptr<IType>(), ref_output.get(),
135-
output.scale(), &ref_amax, N, H);
136+
ref_scale, &ref_amax, N, H);
136137

137138
cudaDeviceSynchronize();
138139
auto err = cudaGetLastError();
@@ -179,6 +180,7 @@ void performTestGLU(const size_t N, const size_t H) {
179180
fillUniform(&input);
180181
fillUniform(&ograd);
181182
setRandomScale(&output);
183+
const float ref_scale = isFp8Type(otype) ? output.scale() : 1.0f;
182184

183185
std::unique_ptr<OType[]> ref_output = std::make_unique<OType[]>(N * H);
184186
std::unique_ptr<IType[]> ref_igrad = std::make_unique<IType[]>(2 * N * H);
@@ -187,7 +189,7 @@ void performTestGLU(const size_t N, const size_t H) {
187189

188190
float ref_amax;
189191
compute_ref_glu_act_cast<ref_act>(input.rowwise_cpu_dptr<IType>(), ref_output.get(),
190-
output.scale(), &ref_amax, N, H);
192+
ref_scale, &ref_amax, N, H);
191193

192194
cudaDeviceSynchronize();
193195
auto err = cudaGetLastError();
@@ -197,8 +199,8 @@ void performTestGLU(const size_t N, const size_t H) {
197199
auto [atol, rtol] = getTolerances(DType::kFloat32);
198200
compareResults("amax", output.amax(), ref_amax, atol, rtol);
199201
if (output.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING) {
200-
const float ref_scale = 1.f / output.scale();
201-
compareResults("scale_inv", *output.rowwise_cpu_scale_inv_ptr<float>(), ref_scale, atol, rtol);
202+
const float ref_scale_inv = 1.f / ref_scale;
203+
compareResults("scale_inv", *output.rowwise_cpu_scale_inv_ptr<float>(), ref_scale_inv, atol, rtol);
202204
}
203205
}
204206
auto [atol, rtol] = getTolerances(otype);

0 commit comments

Comments
 (0)