Skip to content

Commit 60a8724

Browse files
author
zhangyue
committed
style: apply ruff format to test and utility files
1 parent a9523de commit 60a8724

File tree

4 files changed

+19
-24
lines changed

4 files changed

+19
-24
lines changed

tests/test_flash_attention.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -307,9 +307,7 @@ def test_flash_attention_decode_cpu_cuseqlens(
307307
dtype=dtype,
308308
device=device,
309309
)
310-
output = torch.empty(
311-
(num_reqs, num_heads, head_size), dtype=dtype, device=device
312-
)
310+
output = torch.empty((num_reqs, num_heads, head_size), dtype=dtype, device=device)
313311

314312
block_table = torch.zeros(
315313
(num_reqs, num_blocks_per_req), dtype=torch.int32, device=device
@@ -319,9 +317,7 @@ def test_flash_attention_decode_cpu_cuseqlens(
319317
for j in range(num_blocks_per_req):
320318
block_table[i, j] = i * num_blocks_per_req + j
321319

322-
cu_seqlens_q = torch.arange(
323-
0, num_reqs + 1, dtype=torch.int64, device=device
324-
)
320+
cu_seqlens_q = torch.arange(0, num_reqs + 1, dtype=torch.int64, device=device)
325321

326322
# CPU cu_seqlens_kv — exercises `detail::extractSeqLengths` host path
327323
# (direct pointer read, no D2H copy).

tests/test_paged_attention.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -384,9 +384,7 @@ def test_paged_attention_host_tensors(
384384
dtype=dtype,
385385
device=device,
386386
)
387-
output = torch.empty(
388-
(num_reqs, num_heads, head_size), dtype=dtype, device=device
389-
)
387+
output = torch.empty((num_reqs, num_heads, head_size), dtype=dtype, device=device)
390388

391389
block_table = torch.zeros(
392390
(num_reqs, num_blocks_per_req), dtype=torch.int32, device=device
@@ -396,9 +394,7 @@ def test_paged_attention_host_tensors(
396394
for j in range(num_blocks_per_req):
397395
block_table[i, j] = i * num_blocks_per_req + j
398396

399-
seq_lens = torch.full(
400-
(num_reqs,), kv_len, dtype=torch.int32, device=device
401-
)
397+
seq_lens = torch.full((num_reqs,), kv_len, dtype=torch.int32, device=device)
402398

403399
# CPU copies for the D2H-free path.
404400
seq_lens_cpu = seq_lens.cpu().contiguous()

tests/test_reshape_and_cache.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,7 @@ def test_reshape_and_cache_contiguous(
5555
active_indices = infini.ops.ReshapeAndCache.active_implementation_indices(device)
5656

5757
if implementation_index not in active_indices:
58-
pytest.skip(
59-
f"implementation `{implementation_index}` not active on `{device}`"
60-
)
58+
pytest.skip(f"implementation `{implementation_index}` not active on `{device}`")
6159

6260
key = randn_strided(
6361
(num_tokens, num_kv_heads, head_size), None, dtype=dtype, device=device
@@ -125,9 +123,7 @@ def test_reshape_and_cache_noncontiguous_slots(
125123
active_indices = infini.ops.ReshapeAndCache.active_implementation_indices(device)
126124

127125
if implementation_index not in active_indices:
128-
pytest.skip(
129-
f"implementation `{implementation_index}` not active on `{device}`"
130-
)
126+
pytest.skip(f"implementation `{implementation_index}` not active on `{device}`")
131127

132128
key = randn_strided(
133129
(num_tokens, num_kv_heads, head_size), None, dtype=dtype, device=device
@@ -157,17 +153,26 @@ def test_reshape_and_cache_noncontiguous_slots(
157153
)
158154

159155

160-
def _reshape_and_cache(key, value, kv_cache, slot_mapping, kv_cache_out,
161-
implementation_index=0):
156+
def _reshape_and_cache(
157+
key, value, kv_cache, slot_mapping, kv_cache_out, implementation_index=0
158+
):
162159
if key.device.type == "npu":
163160
infini.ops.reshape_and_cache(
164-
key, value, kv_cache, slot_mapping, kv_cache_out,
161+
key,
162+
value,
163+
kv_cache,
164+
slot_mapping,
165+
kv_cache_out,
165166
implementation_index=implementation_index,
166167
stream=get_npu_stream(key),
167168
)
168169
else:
169170
infini.ops.reshape_and_cache(
170-
key, value, kv_cache, slot_mapping, kv_cache_out,
171+
key,
172+
value,
173+
kv_cache,
174+
slot_mapping,
175+
kv_cache_out,
171176
implementation_index=implementation_index,
172177
)
173178

tests/utils.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,5 +100,3 @@ def clone_strided(input):
100100
output.as_strided(*as_strided_args).copy_(input.as_strided(*as_strided_args))
101101

102102
return output
103-
104-

0 commit comments

Comments
 (0)