Skip to content

Commit 7628b2f

Browse files
author
zhangyue
committed
style: fix lint issues in feat/ascend-framework
- Fix `ruff format` violations in `generate_wrappers.py` and `test_gemm.py`. - Fix `ruff isort` violation: move `import re` into stdlib group. - Add backticks around identifiers in comments (`numel()`, `operator()`, `make()`, `torch_npu`, `uint16`/`uint32`/`uint64`). - Add missing blank line after `if` block in `skip_unsupported_dtype`. - Remove `.worktrees/` from project `.gitignore` (belongs in global gitignore).
1 parent 91689d5 commit 7628b2f

File tree

6 files changed

+17
-16
lines changed

6 files changed

+17
-16
lines changed

.gitignore

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
# Generated files
22
build/
33
generated/
4-
.worktrees/
54

65
# Prerequisites
76
*.d

scripts/generate_wrappers.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
import argparse
22
import json
33
import pathlib
4+
import re
45
import shutil
56
import subprocess
67
import textwrap
78

8-
import re
9-
109
import clang.cindex
1110
from clang.cindex import CursorKind
1211

@@ -120,10 +119,8 @@ def _generate_params(node):
120119
if _is_optional_tensor(arg):
121120
parts.append(f"std::optional<py::object> {arg.spelling}")
122121
else:
123-
param = (
124-
arg.type.spelling
125-
.replace("const Tensor", "py::object")
126-
.replace("Tensor", "py::object")
122+
param = arg.type.spelling.replace("const Tensor", "py::object").replace(
123+
"Tensor", "py::object"
127124
)
128125
parts.append(f"{param} {arg.spelling}")
129126

@@ -136,9 +133,7 @@ def _generate_arguments(node):
136133
if arg.spelling == "stream":
137134
continue
138135
if _is_optional_tensor(arg):
139-
args.append(
140-
f"OptionalTensorFromPybind11Handle({arg.spelling})"
141-
)
136+
args.append(f"OptionalTensorFromPybind11Handle({arg.spelling})")
142137
elif "Tensor" in arg.type.spelling:
143138
args.append(f"TensorFromPybind11Handle({arg.spelling})")
144139
else:
@@ -184,7 +179,7 @@ def _generate_call(op_name, call, method=True):
184179
f" handle.set_stream(reinterpret_cast<void*>(stream));\n"
185180
f" }}\n"
186181
f" return Self::call(handle, config, {call_args});\n"
187-
f" }}, {py_args_str}py::kw_only(), py::arg(\"implementation_index\") = 0, py::arg(\"stream\") = 0);"
182+
f' }}, {py_args_str}py::kw_only(), py::arg("implementation_index") = 0, py::arg("stream") = 0);'
188183
)
189184

190185
return f""" .def("__call__", [](const Self& self, {call_params}) {{

src/ascend/common.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ inline aclTensor* buildAclTensor(const Tensor& t,
2828
}
2929

3030
// Compute the minimum physical storage needed for this strided view.
31-
// For contiguous tensors this equals numel(); for non-contiguous (gapped)
31+
// For contiguous tensors this equals `numel()`; for non-contiguous (gapped)
3232
// tensors it may be larger; for broadcast (stride-0) tensors it may be
3333
// smaller. Passing the view shape as the storage shape causes
3434
// "ViewShape overlap" errors in ACLNN for non-contiguous inputs.

src/operator.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -177,8 +177,8 @@ class Operator : public OperatorBase {
177177
auto it{cache.find(key)};
178178

179179
if (it == cache.end()) {
180-
// Pass args as lvalue refs so they remain valid for the operator() call
181-
// below. Forwarding rvalue temporaries into make() would leave the args
180+
// Pass args as lvalue refs so they remain valid for the `operator()` call
181+
// below. Forwarding rvalue temporaries into `make()` would leave the args
182182
// in a moved-from (empty) state before operator() can use them.
183183
it = cache.emplace(std::move(key), make(config, args...)).first;
184184
}

tests/conftest.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def set_seed_per_test(request):
4646

4747
_NPU_UNSUPPORTED_DTYPES = {torch.float64}
4848

49-
# torch_npu does not implement random number generation for uint16/uint32/uint64.
49+
# `torch_npu` does not implement random number generation for `uint16`/`uint32`/`uint64`.
5050
for _bits in (16, 32, 64):
5151
_t = getattr(torch, f"uint{_bits}", None)
5252
if _t is not None:
@@ -57,6 +57,7 @@ def set_seed_per_test(request):
5757
def skip_unsupported_dtype(request):
5858
if not hasattr(request.node, "callspec"):
5959
return
60+
6061
params = request.node.callspec.params
6162

6263
if params.get("device") == "npu" and params.get("dtype") in _NPU_UNSUPPORTED_DTYPES:

tests/test_gemm.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,13 @@ def test_gemm(
8686
def _gemm(a, b, alpha, beta, trans_a, trans_b, c, implementation_index=0):
8787
if a.device.type == "npu":
8888
infini.ops.gemm(
89-
a, b, alpha, beta, trans_a, trans_b, c,
89+
a,
90+
b,
91+
alpha,
92+
beta,
93+
trans_a,
94+
trans_b,
95+
c,
9096
stream=get_npu_stream(a),
9197
)
9298
else:

0 commit comments

Comments
 (0)