Skip to content

Commit ef6d635

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 996f24c commit ef6d635

2 files changed

Lines changed: 2 additions & 6 deletions

File tree

transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -317,9 +317,7 @@ def __init__(
317317
validate_grouped_mlp_dims(fc1, swiglu, fc2)
318318
is_geglu = isinstance(swiglu, ScaledClampedQGeGLU)
319319
self._cudnn_dact_func: str = "dgeglu" if is_geglu else "dswiglu"
320-
self._pass_geglu_runtime_params: bool = (
321-
is_geglu and _cudnn_frontend_geglu_runtime_params()
322-
)
320+
self._pass_geglu_runtime_params: bool = is_geglu and _cudnn_frontend_geglu_runtime_params()
323321
if self._pass_geglu_runtime_params:
324322
self._cudnn_linear_offset: float = swiglu._clamped.glu_linear_offset
325323
self._cudnn_geglu_alpha: float = swiglu._clamped.alpha

transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -100,9 +100,7 @@ def __init__(
100100
validate_grouped_mlp_dims(fc1, swiglu, fc2)
101101
is_geglu = isinstance(swiglu, ScaledClampedQGeGLU)
102102
self._cudnn_act_func: str = "geglu" if is_geglu else "swiglu"
103-
self._pass_geglu_runtime_params: bool = (
104-
is_geglu and _cudnn_frontend_geglu_runtime_params()
105-
)
103+
self._pass_geglu_runtime_params: bool = is_geglu and _cudnn_frontend_geglu_runtime_params()
106104
if self._pass_geglu_runtime_params:
107105
self._cudnn_linear_offset: float = swiglu._clamped.glu_linear_offset
108106
self._cudnn_geglu_alpha: float = swiglu._clamped.alpha

0 commit comments

Comments
 (0)