Skip to content

Commit 0d15dde

Browse files
committed
black format
1 parent 169731a commit 0d15dde

2 files changed

Lines changed: 15 additions & 9 deletions

File tree

aiter/ops/triton/_triton_kernels/fusions/fused_routing_from_topk.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -67,13 +67,15 @@ def _fused_routing_from_topk_hist_kernel(
6767
# Clamp the offset for masked-out lanes to 0 so the pointer arithmetic
6868
# below stays within the allocated buffers.
6969
safe_item = tl.where(item_mask, item_offs, 0)
70-
global_expt = tl.load(topk_ids_ptr + safe_item, mask=item_mask, other=0).to(tl.int32)
70+
global_expt = tl.load(topk_ids_ptr + safe_item, mask=item_mask, other=0).to(
71+
tl.int32
72+
)
7173
if HAS_EXPERT_MAP:
7274
map_mask = item_mask & (global_expt >= 0) & (global_expt < expert_map_numel)
7375
safe_global_expt = tl.where(map_mask, global_expt, 0)
74-
local_expt = tl.load(expert_map_ptr + safe_global_expt, mask=map_mask, other=-1).to(
75-
tl.int32
76-
)
76+
local_expt = tl.load(
77+
expert_map_ptr + safe_global_expt, mask=map_mask, other=-1
78+
).to(tl.int32)
7779
# Match reference semantics: invalid experts are redirected to bucket 0
7880
# and later zeroed in gate_scal.
7981
expt = tl.where(local_expt >= 0, local_expt, 0)
@@ -141,14 +143,16 @@ def _fused_routing_from_topk_place_kernel(
141143
item_offs = tl.arange(0, BLOCK_NK)
142144
item_mask = item_offs < NK
143145
safe_item = tl.where(item_mask, item_offs, 0)
144-
global_expt = tl.load(topk_ids_ptr + safe_item, mask=item_mask, other=0).to(tl.int32)
146+
global_expt = tl.load(topk_ids_ptr + safe_item, mask=item_mask, other=0).to(
147+
tl.int32
148+
)
145149
weights = tl.load(topk_weights_ptr + safe_item, mask=item_mask, other=0.0)
146150
if HAS_EXPERT_MAP:
147151
map_mask = item_mask & (global_expt >= 0) & (global_expt < expert_map_numel)
148152
safe_global_expt = tl.where(map_mask, global_expt, 0)
149-
local_expt = tl.load(expert_map_ptr + safe_global_expt, mask=map_mask, other=-1).to(
150-
tl.int32
151-
)
153+
local_expt = tl.load(
154+
expert_map_ptr + safe_global_expt, mask=map_mask, other=-1
155+
).to(tl.int32)
152156
invalid = local_expt < 0
153157
expt = tl.where(invalid, 0, local_expt)
154158
weights = tl.where(invalid, 0.0, weights)

op_tests/triton_tests/fusions/test_fused_routing_from_topk.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,9 @@ def test_fused_routing_from_topk_with_expert_map(
308308
bucket_unsorted_layout=False,
309309
)
310310

311-
assert torch.equal(ref_hist, test_hist), f"hist mismatch:\n ref={ref_hist}\n fused={test_hist}"
311+
assert torch.equal(
312+
ref_hist, test_hist
313+
), f"hist mismatch:\n ref={ref_hist}\n fused={test_hist}"
312314

313315
# Intra-expert ordering can differ between fused and reference,
314316
# especially in expert-0 bucket where invalid experts are redirected.

0 commit comments

Comments
 (0)