Skip to content

Commit e0e3d12

Browse files
vasunvidiatimmoon10pre-commit-ci[bot]
authored
Dropout with 8-bit RNG (NVIDIA#2014)
* Add dropout kernel with 8-bit RNG Co-authored-by: Vasudevan Rengasamy <vrengasamy@nvidia.com> Co-authored-by: Tim Moon <tmoon@nvidia.com> Signed-off-by: Tim Moon <tmoon@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix license Signed-off-by: Tim Moon <tmoon@nvidia.com> * Avoid ambiguous types Signed-off-by: Tim Moon <tmoon@nvidia.com> * Do not enforce dropout prob is representable in 8 bits Signed-off-by: Tim Moon <tmoon@nvidia.com> * Expand error message Signed-off-by: Tim Moon <tmoon@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix small statistical bug from using less-equal instead of less-than Refactor kernel implementations and add comments. Interpret masks as bytes rather than 16-bit uints. Signed-off-by: Tim Moon <tmoon@nvidia.com> * Fix linter warning Signed-off-by: Tim Moon <tmoon@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Remove unnecessary helper function in PyTorch extensions Signed-off-by: Tim Moon <tmoon@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Tim Moon <tmoon@nvidia.com> Co-authored-by: Tim Moon <tmoon@nvidia.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 607fcc4 commit e0e3d12

9 files changed

Lines changed: 639 additions & 33 deletions

File tree

tests/pytorch/test_fusible_ops.py

Lines changed: 40 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1749,43 +1749,65 @@ def test_constant_scale(
17491749
torch.testing.assert_close(y_test, y_ref, **tols)
17501750
torch.testing.assert_close(dx_test, x_ref.grad, **tols)
17511751

1752-
@pytest.mark.parametrize("prob", (0.1, 0.5, 0.75))
1752+
@pytest.mark.parametrize("prob", (0.0625, 0.5, 0.75))
17531753
@pytest.mark.parametrize("is_training", (True, False))
1754-
@pytest.mark.parametrize("shape", ((101,), (2, 4, 16)))
1754+
@pytest.mark.parametrize("quantization", (None, "fp8_current_scaling"))
1755+
@pytest.mark.parametrize("shape", ((101,), (2, 4, 16), (128, 128)))
17551756
@pytest.mark.parametrize("dtype", _dtypes)
17561757
def test_dropout(
17571758
self,
17581759
*,
17591760
prob: float,
17601761
is_training: bool,
1762+
quantization: Optional[str],
17611763
shape: Iterable[int],
17621764
dtype: torch.dtype,
17631765
device: torch.device = "cuda",
17641766
):
17651767

1768+
# Skip invalid configurations
1769+
quantized_input = quantization is not None
1770+
maybe_skip_quantization(quantization, dims=shape, device=device)
1771+
17661772
# Random data
1767-
x_ref = torch.rand(shape, dtype=dtype, device=device) + 0.5
1768-
x_test = x_ref.clone().requires_grad_()
1769-
dy_ref = torch.rand(shape, dtype=dtype, device=device) + 0.5
1770-
dy_test = dy_ref.clone()
1773+
# Note: Shift values to make sure inputs are non-zero
1774+
x_ref, x_test = make_reference_and_test_tensors(
1775+
shape,
1776+
quantization=quantization,
1777+
test_dtype=dtype,
1778+
test_device=device,
1779+
test_is_quantized=quantized_input,
1780+
)
1781+
with torch.no_grad():
1782+
x_test += 1
1783+
x_ref.copy_(x_test)
1784+
dy_ref, dy_test = make_reference_and_test_tensors(
1785+
shape,
1786+
test_dtype=dtype,
1787+
test_device=device,
1788+
requires_grad=False,
1789+
)
17711790

17721791
# Apply dropout
17731792
op = te_ops.Dropout(prob)
17741793
if is_training:
17751794
op.train()
17761795
else:
17771796
op.eval()
1778-
y = op(x_test)
1779-
y.backward(dy_test)
1797+
y_test = op(x_test)
1798+
y_test.backward(dy_test)
17801799

17811800
# Check values
1801+
y_test = y_test.to(dtype=torch.float64, device="cpu")
1802+
dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
17821803
if is_training:
1783-
mask = ((y != 0) / (1 - prob)).to(dtype=dtype)
1784-
torch.testing.assert_close(y, x_ref * mask)
1785-
torch.testing.assert_close(x_test.grad, dy_ref * mask)
1804+
tols = dtype_tols(dtype)
1805+
mask = ((y_test != 0) / (1 - prob)).to(dtype=dtype)
1806+
torch.testing.assert_close(y_test, x_ref * mask, **tols)
1807+
torch.testing.assert_close(dx_test, dy_ref * mask, **tols)
17861808
else:
1787-
torch.testing.assert_close(y, x_ref, rtol=0, atol=0)
1788-
torch.testing.assert_close(x_test.grad, dy_ref, rtol=0, atol=0)
1809+
torch.testing.assert_close(y_test, x_ref, rtol=0, atol=0)
1810+
torch.testing.assert_close(dx_test, dy_ref, rtol=0, atol=0)
17891811

17901812
# Hypothesis testing for number of zeros
17911813
# Note: A Bernoulli random variable with probability p has
@@ -1797,9 +1819,11 @@ def test_dropout(
17971819
# p-value is less than 1% and we assume that the dropout
17981820
# distribution is incorrect.
17991821
if is_training:
1800-
prob_observed = 1 - torch.count_nonzero(y).item() / y.numel()
1801-
z_score = (prob_observed - prob) / math.sqrt(prob * (1 - prob) / y.numel())
1802-
assert abs(z_score) < 2.5758, "Number of zeros is outside 99% confidence interval"
1822+
prob_observed = 1 - torch.count_nonzero(y_test).item() / y_test.numel()
1823+
z_score = (prob_observed - prob) / math.sqrt(prob * (1 - prob) / y_test.numel())
1824+
assert (
1825+
abs(z_score) < 2.5758
1826+
), f"Number of zeros is outside 99% confidence interval ({prob=}, {prob_observed=})"
18031827

18041828

18051829
class TestFusedOps:

transformer_engine/common/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ list(APPEND transformer_engine_SOURCES
6969
transpose/quantize_transpose_vector_blockwise.cu
7070
transpose/swap_first_dims.cu
7171
activation/gelu.cu
72+
dropout/dropout.cu
7273
fused_attn/flash_attn.cu
7374
fused_attn/context_parallel.cu
7475
fused_attn/kv_cache.cu

transformer_engine/common/__init__.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,38 @@ def _load_nvrtc():
294294
return ctypes.CDLL(f"libnvrtc{_get_sys_extension()}", mode=ctypes.RTLD_GLOBAL)
295295

296296

297+
@functools.lru_cache(maxsize=None)
298+
def _load_curand():
299+
"""Load cuRAND shared library."""
300+
# Attempt to locate cuRAND in CUDA_HOME, CUDA_PATH or /usr/local/cuda
301+
cuda_home = os.environ.get("CUDA_HOME") or os.environ.get("CUDA_PATH") or "/usr/local/cuda"
302+
libs = glob.glob(f"{cuda_home}/**/libcurand{_get_sys_extension()}*", recursive=True)
303+
libs = list(filter(lambda x: not ("stub" in x), libs))
304+
libs.sort(reverse=True, key=os.path.basename)
305+
if libs:
306+
return ctypes.CDLL(libs[0], mode=ctypes.RTLD_GLOBAL)
307+
308+
# Attempt to locate cuRAND in Python dist-packages
309+
found, handle = _load_nvidia_cuda_library("curand")
310+
if found:
311+
return handle
312+
313+
# Attempt to locate cuRAND via ldconfig
314+
libs = subprocess.check_output(
315+
f"ldconfig -p | grep 'libcurand{_get_sys_extension()}'", shell=True
316+
)
317+
libs = libs.decode("utf-8").split("\n")
318+
sos = []
319+
for lib in libs:
320+
if "libcurand" in lib and "=>" in lib:
321+
sos.append(lib.split(">")[1].strip())
322+
if sos:
323+
return ctypes.CDLL(sos[0], mode=ctypes.RTLD_GLOBAL)
324+
325+
# If all else fails, assume that it is in LD_LIBRARY_PATH and error out otherwise
326+
return ctypes.CDLL(f"libcurand{_get_sys_extension()}", mode=ctypes.RTLD_GLOBAL)
327+
328+
297329
@functools.lru_cache(maxsize=None)
298330
def _load_core_library():
299331
"""Load shared library with Transformer Engine C extensions"""
@@ -303,6 +335,7 @@ def _load_core_library():
303335
if "NVTE_PROJECT_BUILDING" not in os.environ or bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))):
304336
_CUDNN_LIB_CTYPES = _load_cudnn()
305337
_NVRTC_LIB_CTYPES = _load_nvrtc()
338+
_CURAND_LIB_CTYPES = _load_curand()
306339
_CUBLAS_LIB_CTYPES = _load_nvidia_cuda_library("cublas")
307340
_CUDART_LIB_CTYPES = _load_nvidia_cuda_library("cuda_runtime")
308341
_TE_LIB_CTYPES = _load_core_library()

0 commit comments

Comments
 (0)