Skip to content

Commit 71f6897

Browse files
authored
Merge pull request #33 from InfiniTensor/develop-dropout
Add `dropout` operator
2 parents 3ac48fd + 61f6589 commit 71f6897

3 files changed

Lines changed: 82 additions & 0 deletions

File tree

src/ntops/kernels/dropout.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
import functools
2+
3+
import ninetoothed
4+
import ninetoothed.language as ntl
5+
from ninetoothed import Tensor
6+
7+
from ntops.kernels.element_wise import arrangement
8+
9+
10+
def application(input, p, seed, output):
11+
output = ntl.where(ntl.rand(seed, input.offsets()) > p, input / (1 - p), 0) # noqa: F841
12+
13+
14+
@functools.cache
15+
def make(ndim):
16+
tensors = (Tensor(ndim), Tensor(0), Tensor(0), Tensor(ndim))
17+
18+
return ninetoothed.make(arrangement, application, tensors)

src/ntops/torch.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import random
2+
13
import torch
24

35
import ntops.kernels.abs
@@ -10,6 +12,7 @@
1012
import ntops.kernels.clamp
1113
import ntops.kernels.cos
1214
import ntops.kernels.div
15+
import ntops.kernels.dropout
1316
import ntops.kernels.eq
1417
import ntops.kernels.exp
1518
import ntops.kernels.ge
@@ -147,6 +150,27 @@ def div(input, other, *, rounding_mode=None, out=None):
147150
return out
148151

149152

153+
def dropout(input, p=0.5, training=True, inplace=False):
154+
if not training or p == 0:
155+
if inplace:
156+
return input
157+
else:
158+
return input.clone()
159+
160+
seed = random.randrange(0, 2**31)
161+
162+
if inplace:
163+
output = input
164+
else:
165+
output = torch.empty_like(input)
166+
167+
kernel = ntops.kernels.dropout.make(input.ndim)
168+
169+
kernel(input, p, seed, output)
170+
171+
return output
172+
173+
150174
def exp(input, *, out=None):
151175
if out is None:
152176
out = torch.empty_like(input)

tests/test_dropout.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
import random
2+
3+
import pytest
4+
import torch
5+
import torch.nn.functional as F
6+
7+
import ntops.torch
8+
from tests.skippers import skip_if_cuda_not_available
9+
from tests.utils import generate_arguments
10+
11+
12+
@skip_if_cuda_not_available
13+
@pytest.mark.parametrize(*generate_arguments())
14+
def test_cuda(shape, dtype, atol, rtol):
15+
device = "cuda"
16+
17+
input = torch.randn(shape, dtype=dtype, device=device)
18+
p = random.uniform(0, 1)
19+
20+
# TODO: Add `training` and `inplace` tests later.
21+
ninetoothed_output = ntops.torch.dropout(input, p=p)
22+
reference_output = F.dropout(input, p=p)
23+
24+
assert ninetoothed_output.shape == reference_output.shape
25+
26+
ninetoothed_non_zero_ratio = (
27+
ninetoothed_output.nonzero().numel() / ninetoothed_output.ndim / input.numel()
28+
)
29+
reference_non_zero_ratio = (
30+
reference_output.nonzero().numel() / reference_output.ndim / input.numel()
31+
)
32+
33+
print(abs(ninetoothed_non_zero_ratio - reference_non_zero_ratio))
34+
35+
assert abs(ninetoothed_non_zero_ratio - reference_non_zero_ratio) < 0.1
36+
37+
assert torch.allclose(
38+
ninetoothed_output[ninetoothed_output != 0],
39+
input[ninetoothed_output != 0] / (1 - p),
40+
)

0 commit comments

Comments
 (0)