Skip to content

Commit c34771d

Browse files
committed
jax/ep: apply clang-format and silence pylint unused-arg in lowering
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
1 parent 07f928c commit c34771d

2 files changed

Lines changed: 3 additions & 4 deletions

File tree

  • transformer_engine/jax

transformer_engine/jax/cpp_extensions/ep.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -514,7 +514,7 @@ def lowering(
514514
out_leading_shape,
515515
out_partition_spec,
516516
):
517-
del out_partition_spec
517+
del out_leading_shape, out_partition_spec
518518
return ffi.ffi_lowering(EpCombinePrimitive.name)(
519519
ctx,
520520
handle_mem,
@@ -650,7 +650,7 @@ def lowering(
650650
out_leading_shape,
651651
out_partition_spec,
652652
):
653-
del out_partition_spec
653+
del out_leading_shape, out_partition_spec
654654
return ffi.ffi_lowering(EpDispatchBwdPrimitive.name)(
655655
ctx,
656656
handle_mem,

transformer_engine/jax/csrc/extensions/ep.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -446,8 +446,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(EpDispatchBwdHandler, EpDispatchBwdFFI,
446446
// ── ep_combine_bwd ────────────────────────────────────────────────────────────
447447

448448
Error_Type EpCombineBwdFFI(cudaStream_t stream, EpInstanceState* ep_state, Buffer_Type handle_mem,
449-
Buffer_Type grad, Result_Type grad_expert_out,
450-
EpConfig config) {
449+
Buffer_Type grad, Result_Type grad_expert_out, EpConfig config) {
451450
(void)ep_state;
452451
auto grad_dims = grad.dimensions();
453452
NVTE_CHECK(grad_dims.size() >= 2,

0 commit comments

Comments
 (0)