Skip to content

Commit 16b9606

Browse files
committed
fix: rewrite flux swiglu activation to avoid gather op in neuron IR
1 parent cbe8f28 commit 16b9606

1 file changed

Lines changed: 2 additions & 2 deletions

File tree

src/diffusers/models/transformers/transformer_flux2.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -291,8 +291,8 @@ def __init__(self):
291291
self.gate_fn = nn.SiLU()
292292

293293
def forward(self, x: torch.Tensor) -> torch.Tensor:
294-
x1, x2 = x.chunk(2, dim=-1)
295-
x = self.gate_fn(x1) * x2
294+
half = x.shape[-1] // 2
295+
x = self.gate_fn(x[..., :half]) * x[..., half:]
296296
return x
297297

298298

0 commit comments

Comments
 (0)