Skip to content

Commit 1654d1d

Browse files
author
zhangyue
committed
fix: address PR #46 review feedback
- Rename `toAclDtype` → `ToAclDtype`, `isIntegerDtype` → `IsIntegerDtype` (Google C++ Style Guide PascalCase). - Reorder `switch` cases in `ToAclDtype` to match `DataType` enum definition. - Simplify `device_.h` include to `#include "device.h"`. - Add Markdown backticks to code references in comments and help messages. - Add blank lines before `return`/`if` per CONTRIBUTING.md Python style rules. - Reorder pybind11 generated params: `Handle` (`stream`) before `Config` (`implementation_index`), matching `Operator::call` signature. - Rename `Matmul` → `MatMul` (ONNX convention), params → `input`/`other`/`out`, remove `trans_a`/`trans_b` (use `Gemm` for transposed matmul). - Rename `AddRmsNorm` params: `x1`/`x2`/`gamma` → `input`/`other`/`weight`, `y_out`/`x_out` → `out`/`rstd_out` (PyTorch conventions). - Rename `skip_unsupported_dtype` → `skip_unsupported_dtypes`. - Replace `get_npu_stream` with generic `get_stream(device)` using `torch.accelerator.current_stream` with device-specific fallbacks. - Reorder `_PLATFORM_TO_TORCH_DEVICE` with `nvidia` first.
1 parent 7628b2f commit 1654d1d

File tree

10 files changed

+110
-92
lines changed

10 files changed

+110
-92
lines changed

scripts/generate_wrappers.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -94,11 +94,12 @@ def __init__(self, name, constructors, calls):
9494

9595
def _find_optional_tensor_params(op_name):
9696
"""Return a set of parameter names declared as `std::optional<Tensor>` in
97-
the base header. libclang resolves the type to ``int`` when the STL
97+
the base header. `libclang` resolves the type to `int` when the STL
9898
headers are not fully available, so we fall back to a regex scan of the
9999
source text.
100100
"""
101101
source = (_BASE_DIR / f"{op_name}.h").read_text()
102+
102103
return set(re.findall(r"std::optional<Tensor>\s+(\w+)", source))
103104

104105

@@ -108,6 +109,7 @@ def _generate_pybind11(operator):
108109
def _is_optional_tensor(arg):
109110
if arg.spelling in optional_tensor_params:
110111
return True
112+
111113
return "std::optional" in arg.type.spelling and "Tensor" in arg.type.spelling
112114

113115
def _generate_params(node):
@@ -116,6 +118,7 @@ def _generate_params(node):
116118
for arg in node.get_arguments():
117119
if arg.spelling == "stream":
118120
continue
121+
119122
if _is_optional_tensor(arg):
120123
parts.append(f"std::optional<py::object> {arg.spelling}")
121124
else:
@@ -132,6 +135,7 @@ def _generate_arguments(node):
132135
for arg in node.get_arguments():
133136
if arg.spelling == "stream":
134137
continue
138+
135139
if _is_optional_tensor(arg):
136140
args.append(f"OptionalTensorFromPybind11Handle({arg.spelling})")
137141
elif "Tensor" in arg.type.spelling:
@@ -163,23 +167,23 @@ def _generate_call(op_name, call, method=True):
163167

164168
if not method:
165169
params = (
166-
f"{call_params}, std::size_t implementation_index, std::uintptr_t stream"
170+
f"{call_params}, std::uintptr_t stream, std::size_t implementation_index"
167171
if call_params
168-
else "std::size_t implementation_index, std::uintptr_t stream"
172+
else "std::uintptr_t stream, std::size_t implementation_index"
169173
)
170174
py_args = _generate_py_args(call)
171175
py_args_str = f"{py_args}, " if py_args else ""
172176

173177
return (
174178
f' m.def("{op_name}", []({params}) {{\n'
175-
f" Config config;\n"
176-
f" config.set_implementation_index(implementation_index);\n"
177179
f" Handle handle;\n"
178180
f" if (stream) {{\n"
179181
f" handle.set_stream(reinterpret_cast<void*>(stream));\n"
180182
f" }}\n"
183+
f" Config config;\n"
184+
f" config.set_implementation_index(implementation_index);\n"
181185
f" return Self::call(handle, config, {call_args});\n"
182-
f' }}, {py_args_str}py::kw_only(), py::arg("implementation_index") = 0, py::arg("stream") = 0);'
186+
f' }}, {py_args_str}py::kw_only(), py::arg("stream") = 0, py::arg("implementation_index") = 0);'
183187
)
184188

185189
return f""" .def("__call__", [](const Self& self, {call_params}) {{
@@ -438,7 +442,7 @@ def _get_all_ops(devices):
438442
nargs="+",
439443
default="cpu",
440444
type=str,
441-
help="Devices to use. Please pick from cpu, nvidia, cambricon, ascend, metax, moore, iluvatar, kunlun, hygon, and qy. (default: cpu)",
445+
help="Devices to use. Please pick from `cpu`, `nvidia`, `cambricon`, `ascend`, `metax`, `moore`, `iluvatar`, `kunlun`, `hygon`, and `qy`. (default: `cpu`)",
442446
)
443447

444448
args = parser.parse_args()

src/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ if(WITH_ASCEND)
178178
"ascend/*.cc"
179179
"ascend/*.cpp"
180180
)
181-
# Exclude kernel_impl.cpp — AscendC device code, not compiled by the host C++ compiler.
181+
# Exclude `kernel_impl.cpp` — AscendC device code, not compiled by the host C++ compiler.
182182
list(FILTER ASCEND_SOURCES EXCLUDE REGEX ".*kernel_impl\\.cpp$")
183183

184184
target_compile_definitions(infiniops PUBLIC WITH_ASCEND=1)

src/ascend/common.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,11 @@
1111

1212
namespace infini::ops::ascend {
1313

14-
// Build an aclTensor descriptor from an InfiniOps Tensor.
14+
// Build an `aclTensor` descriptor from an InfiniOps `Tensor`.
1515
//
1616
// When `transpose_last2` is true the last two dimensions are swapped in the
17-
// descriptor (shape and strides) without copying data. This is used by GEMM
18-
// and Matmul to express a transpose via the view.
17+
// descriptor (shape and strides) without copying data. This is used by `Gemm`
18+
// and `MatMul` to express a transpose via the view.
1919
inline aclTensor* buildAclTensor(const Tensor& t,
2020
bool transpose_last2 = false) {
2121
std::vector<int64_t> shape(t.shape().begin(), t.shape().end());
@@ -45,7 +45,7 @@ inline aclTensor* buildAclTensor(const Tensor& t,
4545
std::vector<int64_t> storage_shape = {storage_elems};
4646

4747
return aclCreateTensor(
48-
shape.data(), static_cast<int64_t>(shape.size()), toAclDtype(t.dtype()),
48+
shape.data(), static_cast<int64_t>(shape.size()), ToAclDtype(t.dtype()),
4949
strides.data(),
5050
/*storageOffset=*/0, ACL_FORMAT_ND, storage_shape.data(),
5151
static_cast<int64_t>(storage_shape.size()), const_cast<void*>(t.data()));

src/ascend/data_type_.h

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,8 @@
99

1010
namespace infini::ops::ascend {
1111

12-
inline aclDataType toAclDtype(DataType dt) {
12+
inline aclDataType ToAclDtype(DataType dt) {
1313
switch (dt) {
14-
case DataType::kFloat16:
15-
return ACL_FLOAT16;
16-
case DataType::kBFloat16:
17-
return ACL_BF16;
18-
case DataType::kFloat32:
19-
return ACL_FLOAT;
2014
case DataType::kInt8:
2115
return ACL_INT8;
2216
case DataType::kInt16:
@@ -33,14 +27,20 @@ inline aclDataType toAclDtype(DataType dt) {
3327
return ACL_UINT32;
3428
case DataType::kUInt64:
3529
return ACL_UINT64;
30+
case DataType::kFloat16:
31+
return ACL_FLOAT16;
32+
case DataType::kBFloat16:
33+
return ACL_BF16;
34+
case DataType::kFloat32:
35+
return ACL_FLOAT;
3636
default:
37-
assert(false && "unsupported dtype for Ascend backend");
37+
assert(false && "Unsupported dtype for Ascend backend.");
3838
return ACL_DT_UNDEFINED;
3939
}
4040
}
4141

42-
// Returns true for integer (signed or unsigned) DataType values.
43-
inline bool isIntegerDtype(DataType dt) {
42+
// Returns true for integer (signed or unsigned) `DataType` values.
43+
inline bool IsIntegerDtype(DataType dt) {
4444
switch (dt) {
4545
case DataType::kInt8:
4646
case DataType::kInt16:

src/ascend/device_.h

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,7 @@
11
#ifndef INFINI_OPS_ASCEND_DEVICE__H_
22
#define INFINI_OPS_ASCEND_DEVICE__H_
33

4-
// NOTE: Cannot use `#include "device.h"` here — GCC resolves quoted includes
5-
// relative to the current file first, and `src/ascend/` used to contain a
6-
// `device.h`. Use `data_type.h` which transitively pulls in `src/device.h`.
7-
#include "data_type.h"
4+
#include "device.h"
85

96
namespace infini::ops {
107

src/base/add_rms_norm.h

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,23 +11,24 @@ namespace infini::ops {
1111

1212
class AddRmsNorm : public Operator<AddRmsNorm> {
1313
public:
14-
AddRmsNorm(const Tensor x1, const Tensor x2, const Tensor gamma, float eps,
15-
Tensor y_out, Tensor x_out)
16-
: input_shape_{x1.shape()},
14+
AddRmsNorm(const Tensor input, const Tensor other, const Tensor weight,
15+
float eps, Tensor out, Tensor rstd_out)
16+
: input_shape_{input.shape()},
1717
eps_{eps},
18-
dim_{x1.size(-1)},
19-
ndim_{x1.ndim()},
20-
batch_size_{ndim_ == 2 ? x1.size(-2) : x1.size(-3)},
21-
nhead_{ndim_ == 2 ? 1 : x1.size(-2)},
18+
dim_{input.size(-1)},
19+
ndim_{input.ndim()},
20+
batch_size_{ndim_ == 2 ? input.size(-2) : input.size(-3)},
21+
nhead_{ndim_ == 2 ? 1 : input.size(-2)},
2222
rstd_shape_{static_cast<int64_t>(batch_size_),
2323
static_cast<int64_t>(nhead_)} {
24-
assert(x1.dtype() == x2.dtype());
25-
assert(x1.dtype() == y_out.dtype());
26-
assert(x1.dtype() == x_out.dtype());
24+
assert(input.dtype() == other.dtype());
25+
assert(input.dtype() == out.dtype());
26+
assert(input.dtype() == rstd_out.dtype());
2727
}
2828

29-
virtual void operator()(const Tensor x1, const Tensor x2, const Tensor gamma,
30-
float eps, Tensor y_out, Tensor x_out) const = 0;
29+
virtual void operator()(const Tensor input, const Tensor other,
30+
const Tensor weight, float eps, Tensor out,
31+
Tensor rstd_out) const = 0;
3132

3233
protected:
3334
Tensor::Shape input_shape_;

src/base/matmul.h

Lines changed: 11 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -6,34 +6,24 @@
66

77
namespace infini::ops {
88

9-
class Matmul : public Operator<Matmul> {
9+
class MatMul : public Operator<MatMul> {
1010
public:
11-
// `trans_a` / `trans_b`: If true, transpose the last two dims of `a` / `b`
12-
// before multiplying. These are constructor parameters so the `CacheKey`
13-
// encodes the transposition and distinct descriptors are cached for each
14-
// combination.
15-
Matmul(const Tensor a, const Tensor b, Tensor c, bool trans_a, bool trans_b)
16-
: a_shape_{a.shape()},
17-
b_shape_{b.shape()},
18-
c_shape_{c.shape()},
19-
trans_a_{trans_a},
20-
trans_b_{trans_b} {
21-
assert(a.dtype() == b.dtype());
11+
MatMul(const Tensor input, const Tensor other, Tensor out)
12+
: input_shape_{input.shape()},
13+
other_shape_{other.shape()},
14+
out_shape_{out.shape()} {
15+
assert(input.dtype() == other.dtype());
2216
}
2317

24-
virtual void operator()(const Tensor a, const Tensor b, Tensor c,
25-
bool trans_a, bool trans_b) const = 0;
18+
virtual void operator()(const Tensor input, const Tensor other,
19+
Tensor out) const = 0;
2620

2721
protected:
28-
Tensor::Shape a_shape_;
22+
Tensor::Shape input_shape_;
2923

30-
Tensor::Shape b_shape_;
24+
Tensor::Shape other_shape_;
3125

32-
Tensor::Shape c_shape_;
33-
34-
bool trans_a_{false};
35-
36-
bool trans_b_{false};
26+
Tensor::Shape out_shape_;
3727
};
3828

3929
} // namespace infini::ops

tests/conftest.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ def pytest_addoption(parser):
1616
"--devices",
1717
nargs="+",
1818
default=None,
19-
help="Device(s) to test on (e.g., --devices ascend cpu). Accepts platform names (ascend, nvidia, cambricon, metax, moore, iluvatar) or PyTorch device types (npu, cuda, mlu, musa). Defaults to all available devices.",
19+
help="Device(s) to test on (e.g., `--devices ascend cpu`). Accepts platform names (`nvidia`, `metax`, `iluvatar`, `moore`, `cambricon`, `ascend`) or PyTorch device types (`cuda`, `mlu`, `musa`, `npu`). Defaults to all available devices.",
2020
)
2121

2222

@@ -46,15 +46,17 @@ def set_seed_per_test(request):
4646

4747
_NPU_UNSUPPORTED_DTYPES = {torch.float64}
4848

49-
# `torch_npu` does not implement random number generation for `uint16`/`uint32`/`uint64`.
49+
# `torch_npu` does not implement random number generation for
50+
# `uint16`/`uint32`/`uint64`.
5051
for _bits in (16, 32, 64):
5152
_t = getattr(torch, f"uint{_bits}", None)
5253
if _t is not None:
5354
_NPU_UNSUPPORTED_DTYPES.add(_t)
5455

5556

5657
@pytest.fixture(autouse=True)
57-
def skip_unsupported_dtype(request):
58+
def skip_unsupported_dtypes(request):
59+
5860
if not hasattr(request.node, "callspec"):
5961
return
6062

@@ -71,16 +73,16 @@ def _set_random_seed(seed):
7173

7274
_PLATFORM_TO_TORCH_DEVICE = {
7375
"nvidia": "cuda",
74-
"iluvatar": "cuda",
7576
"metax": "cuda",
76-
"cambricon": "mlu",
77+
"iluvatar": "cuda",
7778
"moore": "musa",
79+
"cambricon": "mlu",
7880
"ascend": "npu",
7981
}
8082

8183

8284
def _resolve_device(name):
83-
"""Map a platform name (e.g., ``ascend``) to a PyTorch device type (e.g., ``npu``)."""
85+
"""Map a platform name (e.g., `ascend`) to a PyTorch device type (e.g., `npu`)."""
8486
return _PLATFORM_TO_TORCH_DEVICE.get(name, name)
8587

8688

tests/test_gemm.py

Lines changed: 12 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import pytest
33
import torch
44

5-
from tests.utils import Payload, get_npu_stream, randn_strided
5+
from tests.utils import Payload, get_stream, randn_strided
66

77

88
@pytest.mark.auto_act_and_assert
@@ -84,28 +84,17 @@ def test_gemm(
8484

8585

8686
def _gemm(a, b, alpha, beta, trans_a, trans_b, c, implementation_index=0):
87-
if a.device.type == "npu":
88-
infini.ops.gemm(
89-
a,
90-
b,
91-
alpha,
92-
beta,
93-
trans_a,
94-
trans_b,
95-
c,
96-
stream=get_npu_stream(a),
97-
)
98-
else:
99-
infini.ops.gemm(
100-
a,
101-
b,
102-
alpha,
103-
beta,
104-
trans_a,
105-
trans_b,
106-
c,
107-
implementation_index=implementation_index,
108-
)
87+
infini.ops.gemm(
88+
a,
89+
b,
90+
alpha,
91+
beta,
92+
trans_a,
93+
trans_b,
94+
c,
95+
stream=get_stream(a.device),
96+
implementation_index=implementation_index,
97+
)
10998

11099
return c
111100

tests/utils.py

Lines changed: 39 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -82,12 +82,47 @@ def randint_strided(low, high, shape, strides, *, dtype=None, device=None):
8282
return output
8383

8484

85-
def get_npu_stream(tensor):
86-
"""Return the current NPU stream handle for `tensor`, or 0 on other devices."""
87-
if tensor.device.type != "npu":
85+
def get_stream(device):
86+
"""Return the raw stream handle for `device`, or 0 for CPU.
87+
88+
Uses `torch.accelerator.current_stream` when available, falling back to
89+
device-specific APIs for older PyTorch versions.
90+
"""
91+
if isinstance(device, torch.device):
92+
device = device.type
93+
94+
if isinstance(device, str) and ":" in device:
95+
device = device.split(":")[0]
96+
97+
if device == "cpu":
98+
return 0
99+
100+
if hasattr(torch, "accelerator") and hasattr(torch.accelerator, "current_stream"):
101+
stream = torch.accelerator.current_stream()
102+
103+
# Each backend exposes the raw handle under a different attribute name.
104+
for attr in ("npu_stream", "cuda_stream", "mlu_stream", "musa_stream"):
105+
if hasattr(stream, attr):
106+
return getattr(stream, attr)
107+
88108
return 0
89109

90-
return torch.npu.current_stream().npu_stream
110+
# Fallback for older PyTorch builds without `torch.accelerator`.
111+
_STREAM_ACCESSORS = {
112+
"npu": ("npu", "npu_stream"),
113+
"cuda": ("cuda", "cuda_stream"),
114+
"mlu": ("mlu", "mlu_stream"),
115+
"musa": ("musa", "musa_stream"),
116+
}
117+
118+
if device in _STREAM_ACCESSORS:
119+
mod_name, attr = _STREAM_ACCESSORS[device]
120+
mod = getattr(torch, mod_name, None)
121+
122+
if mod is not None and hasattr(mod, "current_stream"):
123+
return getattr(mod.current_stream(), attr)
124+
125+
return 0
91126

92127

93128
def clone_strided(input):

0 commit comments

Comments
 (0)