Skip to content

Commit 2c0bcf8

Browse files
author
zhangyue
committed
test(ascend): add comprehensive tests for all Ascend operators
Add new tests: Cast, Cat, E2E Layer, FlashAttention, Linear, Matmul, Mul, PagedAttention, ReshapeAndCache, RotaryEmbedding, SiluAndMul. Update existing tests with NPU stream handling and Ascend-specific parametrization.
1 parent dc253c9 commit 2c0bcf8

16 files changed

Lines changed: 2281 additions & 12 deletions

tests/test_add.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,13 @@
22
import pytest
33
import torch
44

5-
from tests.utils import Payload, empty_strided, randint_strided, randn_strided
5+
from tests.utils import (
6+
Payload,
7+
empty_strided,
8+
get_npu_stream,
9+
randint_strided,
10+
randn_strided,
11+
)
612

713
_INT_DTYPES = (torch.int16, torch.int32, torch.int64)
814

@@ -63,7 +69,10 @@ def test_add(
6369

6470

6571
def _add(input, other, out):
66-
infini.ops.add(input, other, out)
72+
if input.device.type == "npu":
73+
infini.ops.add(input, other, out, stream=get_npu_stream(input))
74+
else:
75+
infini.ops.add(input, other, out)
6776

6877
return out
6978

tests/test_add_rms_norm.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
import infini.ops
2+
import pytest
3+
import torch
4+
5+
from tests.utils import Payload, empty_strided, get_npu_stream, randn_strided
6+
7+
8+
@pytest.mark.auto_act_and_assert
9+
@pytest.mark.parametrize(
10+
"shape, strides",
11+
(
12+
((1, 64), None),
13+
((2, 128), None),
14+
((4, 48, 64), None),
15+
((2, 4, 2048), None),
16+
((1, 64), (64, 1)),
17+
((4, 48, 64), (3072, 64, 1)),
18+
),
19+
)
20+
@pytest.mark.parametrize("eps", (1e-6, 1e-5))
21+
@pytest.mark.parametrize("implementation_index", (0, 1))
22+
@pytest.mark.parametrize(
23+
("dtype", "rtol", "atol"),
24+
(
25+
(torch.float32, 1e-4, 1e-4),
26+
(torch.float16, 1e-2, 1e-2),
27+
(torch.bfloat16, 2e-2, 1e-2),
28+
),
29+
)
30+
def test_add_rms_norm(
31+
shape,
32+
strides,
33+
eps,
34+
implementation_index,
35+
dtype,
36+
device,
37+
rtol,
38+
atol,
39+
):
40+
active_indices = infini.ops.AddRmsNorm.active_implementation_indices(device)
41+
42+
if implementation_index not in active_indices:
43+
pytest.skip(f"implementation `{implementation_index}` not active on `{device}`")
44+
45+
weight_shape = (shape[-1],)
46+
x1 = randn_strided(shape, strides, dtype=dtype, device=device)
47+
x2 = randn_strided(shape, strides, dtype=dtype, device=device)
48+
gamma = randn_strided(weight_shape, None, dtype=dtype, device=device)
49+
y_out = empty_strided(shape, strides, dtype=dtype, device=device)
50+
x_out = empty_strided(shape, strides, dtype=dtype, device=device)
51+
52+
return Payload(
53+
lambda *args, **kwargs: _add_rms_norm(
54+
*args, **kwargs, implementation_index=implementation_index
55+
),
56+
_torch_add_rms_norm,
57+
(x1, x2, gamma),
58+
{"eps": eps, "y_out": y_out, "x_out": x_out},
59+
rtol=rtol,
60+
atol=atol,
61+
)
62+
63+
64+
def _add_rms_norm(x1, x2, gamma, *, eps=1e-6, y_out=None, x_out=None,
65+
implementation_index=0):
66+
if x1.device.type == "npu":
67+
infini.ops.add_rms_norm(
68+
x1, x2, gamma, eps, y_out, x_out,
69+
implementation_index=implementation_index,
70+
stream=get_npu_stream(x1),
71+
)
72+
else:
73+
infini.ops.add_rms_norm(
74+
x1, x2, gamma, eps, y_out, x_out,
75+
implementation_index=implementation_index,
76+
)
77+
78+
# Concatenate both outputs into a single flat tensor for allclose comparison.
79+
return torch.cat([y_out.contiguous().flatten(), x_out.contiguous().flatten()])
80+
81+
82+
def _torch_add_rms_norm(x1, x2, gamma, *, eps=1e-6, y_out=None, x_out=None):
83+
x_sum = x1 + x2
84+
85+
if x_out is not None:
86+
x_out.copy_(x_sum)
87+
88+
rms = torch.sqrt(torch.mean(x_sum.float() * x_sum.float(), dim=-1,
89+
keepdim=True) + eps)
90+
y = (x_sum.float() / rms * gamma.float()).to(x1.dtype)
91+
92+
if y_out is not None:
93+
y_out.copy_(y)
94+
95+
return torch.cat([y_out.contiguous().flatten(), x_out.contiguous().flatten()])

tests/test_cast.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
import infini.ops
2+
import pytest
3+
import torch
4+
5+
from tests.utils import Payload, empty_strided, get_npu_stream, randn_strided
6+
7+
8+
@pytest.mark.auto_act_and_assert
9+
@pytest.mark.parametrize(
10+
"shape, input_strides, out_strides",
11+
(
12+
((13, 4), None, None),
13+
((13, 4), (10, 1), (10, 1)),
14+
((13, 4, 4), None, None),
15+
((16, 5632), None, None),
16+
((4, 4, 5632), None, None),
17+
),
18+
)
19+
@pytest.mark.parametrize(
20+
("input_dtype", "out_dtype", "rtol", "atol"),
21+
(
22+
(torch.float16, torch.float32, 1e-3, 1e-3),
23+
(torch.float32, torch.float16, 1e-3, 1e-3),
24+
(torch.bfloat16, torch.float32, 1e-2, 5e-3),
25+
(torch.float32, torch.bfloat16, 1e-2, 5e-3),
26+
(torch.float16, torch.bfloat16, 1e-2, 5e-3),
27+
(torch.bfloat16, torch.float16, 1e-2, 5e-3),
28+
),
29+
)
30+
def test_cast(
31+
shape,
32+
input_strides,
33+
out_strides,
34+
input_dtype,
35+
out_dtype,
36+
device,
37+
rtol,
38+
atol,
39+
):
40+
input = randn_strided(shape, input_strides, dtype=input_dtype, device=device)
41+
out = empty_strided(shape, out_strides, dtype=out_dtype, device=device)
42+
43+
return Payload(
44+
_cast,
45+
_torch_cast,
46+
(input, out),
47+
{},
48+
rtol=rtol,
49+
atol=atol,
50+
)
51+
52+
53+
def _cast(input, out):
54+
if input.device.type == "npu":
55+
infini.ops.cast(input, out, stream=get_npu_stream(input))
56+
else:
57+
infini.ops.cast(input, out)
58+
59+
return out
60+
61+
62+
def _torch_cast(input, out):
63+
out.copy_(input.to(out.dtype))
64+
65+
return out

tests/test_cat.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
import infini.ops
2+
import pytest
3+
import torch
4+
5+
from tests.utils import Payload, empty_strided, get_npu_stream, randn_strided
6+
7+
8+
@pytest.mark.auto_act_and_assert
9+
@pytest.mark.parametrize(
10+
"shapes, dim, out_shape",
11+
(
12+
# 2 inputs, dim=0
13+
(((4, 64), (4, 64)), 0, (8, 64)),
14+
# 2 inputs, dim=1
15+
(((4, 32), (4, 64)), 1, (4, 96)),
16+
# 2 inputs, dim=-1 (negative dim)
17+
(((4, 32), (4, 64)), -1, (4, 96)),
18+
# 3 inputs, dim=1
19+
(((4, 16), (4, 32), (4, 16)), 1, (4, 64)),
20+
# 2 inputs, dim=0, 3D
21+
(((2, 4, 64), (2, 4, 64)), 0, (4, 4, 64)),
22+
# 2 inputs, dim=2, 3D
23+
(((2, 4, 32), (2, 4, 64)), 2, (2, 4, 96)),
24+
# 4 inputs, dim=1
25+
(((1, 1024), (1, 1024), (1, 1024), (1, 1024)), 1, (1, 4096)),
26+
),
27+
)
28+
@pytest.mark.parametrize(
29+
("dtype", "rtol", "atol"),
30+
(
31+
(torch.float32, 1e-7, 1e-7),
32+
(torch.float16, 1e-3, 1e-3),
33+
(torch.bfloat16, 1e-2, 5e-3),
34+
),
35+
)
36+
def test_cat(shapes, dim, out_shape, dtype, device, rtol, atol):
37+
inputs = [
38+
randn_strided(s, None, dtype=dtype, device=device) for s in shapes
39+
]
40+
out = empty_strided(out_shape, None, dtype=dtype, device=device)
41+
42+
return Payload(
43+
lambda *args: _cat(*args, dim=dim),
44+
lambda *args: _torch_cat(*args, dim=dim),
45+
(*inputs, out),
46+
{},
47+
rtol=rtol,
48+
atol=atol,
49+
)
50+
51+
52+
def _cat(*args, dim):
53+
inputs = list(args[:-1])
54+
out = args[-1]
55+
56+
first = inputs[0]
57+
rest = inputs[1:]
58+
59+
if first.device.type == "npu":
60+
infini.ops.cat(first, rest, dim, out, stream=get_npu_stream(first))
61+
else:
62+
infini.ops.cat(first, rest, dim, out)
63+
64+
return out
65+
66+
67+
def _torch_cat(*args, dim):
68+
inputs = list(args[:-1])
69+
out = args[-1]
70+
71+
result = torch.cat(inputs, dim=dim)
72+
out.copy_(result)
73+
74+
return out

tests/test_causal_softmax.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import pytest
33
import torch
44

5-
from tests.utils import Payload, empty_strided, randn_strided
5+
from tests.utils import Payload, empty_strided, get_npu_stream, randn_strided
66

77

88
@pytest.mark.auto_act_and_assert
@@ -40,15 +40,18 @@ def test_causal_softmax(shape, input_strides, out_strides, dtype, device, rtol,
4040

4141

4242
def _causal_softmax(input, out):
43-
infini.ops.causal_softmax(input, out)
43+
if input.device.type == "npu":
44+
infini.ops.causal_softmax(input, out, stream=get_npu_stream(input))
45+
else:
46+
infini.ops.causal_softmax(input, out)
4447

4548
return out
4649

4750

4851
def _torch_causal_softmax(input, out):
4952
mask = torch.tril(torch.ones_like(input), diagonal=-1).flip(dims=[-2, -1])
5053
masked = torch.where(mask == 1, -torch.inf, input.to(torch.float32))
51-
result = torch.nn.functional.softmax(masked, dim=-1, dtype=input.dtype)
54+
result = torch.nn.functional.softmax(masked, dim=-1)
5255
out.copy_(result)
5356

5457
return out

0 commit comments

Comments
 (0)