Skip to content

Commit 6f4dadb

Browse files
committed
[BIONEMO-2334] Patch TE to fix Evo2 stop and go training
Signed-off-by: Bruno Alvisio <balvisio@nvidia.com>
1 parent b764770 commit 6f4dadb

3 files changed

Lines changed: 45 additions & 7 deletions

File tree

Dockerfile

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -58,11 +58,21 @@ rm -rf /tmp/* /var/tmp/*
5858
EOF
5959

6060

61-
## BUMP TE as a solution to the issue https://github.com/NVIDIA/bionemo-framework/issues/422. Drop this when pytorch images ship the fixed commit.
62-
ARG TE_TAG=9d4e11eaa508383e35b510dc338e58b09c30be73
63-
RUN PIP_CONSTRAINT= NVTE_FRAMEWORK=pytorch NVTE_WITH_USERBUFFERS=1 MPI_HOME=/usr/local/mpi \
64-
pip --disable-pip-version-check --no-cache-dir install \
65-
git+https://github.com/NVIDIA/TransformerEngine.git@${TE_TAG}
61+
## BUMP and patch TE as a solution to the issues:
62+
## 1. https://github.com/NVIDIA/bionemo-framework/issues/422
63+
## 2. https://github.com/NVIDIA/bionemo-framework/issues/973
64+
## Drop this when pytorch images ship the fixed commit.
65+
66+
# Set ARG
67+
ARG TE_TAG=9d4e11eaa508383e35b510dc338e58b09c30be73
68+
69+
COPY ./patches/te.patch /tmp/te.patch
70+
RUN git clone --recurse-submodules https://github.com/NVIDIA/TransformerEngine.git /tmp/TransformerEngine && \
71+
cd /tmp/TransformerEngine && \
72+
git checkout --recurse-submodules ${TE_TAG} && \
73+
patch -p1 < /tmp/te.patch && \
74+
PIP_CONSTRAINT= NVTE_FRAMEWORK=pytorch NVTE_WITH_USERBUFFERS=1 MPI_HOME=/usr/local/mpi \
75+
pip --disable-pip-version-check --no-cache-dir install .
6676

6777
# Install AWS CLI based on architecture
6878
RUN if [ "$TARGETARCH" = "arm64" ]; then \

patches/te.patch

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp
2+
index 2c3ccff1..a00d46cc 100644
3+
--- a/transformer_engine/pytorch/csrc/extensions/cast.cpp
4+
+++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp
5+
@@ -10,6 +10,7 @@
6+
#include "extensions.h"
7+
#include "pybind.h"
8+
#include "transformer_engine/transformer_engine.h"
9+
+#include "util.h"
10+
11+
namespace transformer_engine::pytorch {
12+
13+
@@ -33,6 +34,16 @@ py::object quantize(const at::Tensor& tensor, py::handle quantizer, const py::ob
14+
DType fake_te_type = GetTransformerEngineDType(fake_tensor_type);
15+
std::tie(te_output, out) = my_quantizer->create_tensor(input_shape, fake_te_type);
16+
} else {
17+
+ if (my_quantizer->columnwise_usage && !non_tn_fp8_gemm_supported()) {
18+
+ bool transpose_exists = !output.attr("_transpose_invalid").cast<bool>() && !output.attr("_transpose").is_none();
19+
+ if (!transpose_exists) {
20+
+ DType fake_te_type = GetTransformerEngineDType(fake_tensor_type);
21+
+ py::object new_out;
22+
+ std::tie(std::ignore, new_out) = my_quantizer->create_tensor(input_shape, fake_te_type);
23+
+ output.attr("_transpose_invalid") = py::bool_(false);
24+
+ output.attr("_transpose") = new_out.attr("_transpose");
25+
+ }
26+
+ }
27+
out = output;
28+
te_output = makeTransformerEngineTensor(output, quantizer);
29+
}

sub-packages/bionemo-evo2/tests/bionemo/evo2/run/test_train.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def small_training_cmd(path, max_steps, val_check, devices: int = 1, additional_
5656
"--model-size 1b_nv --num-layers 4 --hybrid-override-pattern SDH* --limit-val-batches 1 "
5757
"--no-activation-checkpointing --add-bias-output --create-tensorboard-logger --create-tflops-callback "
5858
f"--max-steps {max_steps} --warmup-steps 1 --val-check-interval {val_check} --limit-val-batches 1 "
59-
f"--seq-length 8 --hidden-dropout 0.1 --attention-dropout 0.1 {additional_args}"
59+
f"--seq-length 16 --hidden-dropout 0.1 --attention-dropout 0.1 {additional_args}"
6060
)
6161
return cmd
6262

@@ -158,7 +158,6 @@ def test_train_evo2_stops(tmp_path):
158158
"--fp8",
159159
marks=[
160160
pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8),
161-
pytest.mark.xfail(reason="FP8 test currently broken - TODO: fix"),
162161
],
163162
id="fp8",
164163
),

0 commit comments

Comments
 (0)