Add native swiglu-silu-clamp-mul#507
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces the swiglu_silu_clamp_mul activation function and its native PyTorch implementation. Feedback suggests correcting misleading docstrings regarding output behavior, optimizing memory usage through in-place clamping, and improving type safety and parameter naming consistency.
| def swiglu_silu_clamp_mul_native(x: torch.Tensor, limit: float = 7.0) -> torch.Tensor: | ||
| """Out-variant of swiglustep activation. | ||
|
|
||
| Writes into `out`: | ||
| silu(x[:d]).clamp(max=limit) * x[d:].clamp(-limit, limit) | ||
| """ | ||
| gate, up = x.chunk(2, dim=-1) | ||
| gate = F.silu(gate) | ||
| gate = gate.clamp(max=limit) | ||
| up = up.clamp(min=-limit, max=limit) | ||
| out = gate * up | ||
| return out |
There was a problem hiding this comment.
The docstring for swiglu_silu_clamp_mul_native is slightly misleading as it refers to an "Out-variant" and "Writes into out", which typically implies a function that accepts an output buffer as an argument. Since this is a standard PyTorch implementation returning a new tensor, the docstring should be updated for clarity. Additionally, we can optimize memory usage by using an in-place clamp on the temporary tensor produced by F.silu, and rename the input parameter to hidden_states for consistency with the rest of the module.
| def swiglu_silu_clamp_mul_native(x: torch.Tensor, limit: float = 7.0) -> torch.Tensor: | |
| """Out-variant of swiglustep activation. | |
| Writes into `out`: | |
| silu(x[:d]).clamp(max=limit) * x[d:].clamp(-limit, limit) | |
| """ | |
| gate, up = x.chunk(2, dim=-1) | |
| gate = F.silu(gate) | |
| gate = gate.clamp(max=limit) | |
| up = up.clamp(min=-limit, max=limit) | |
| out = gate * up | |
| return out | |
| def swiglu_silu_clamp_mul_native(hidden_states: torch.Tensor, limit: float = 7.0) -> torch.Tensor: | |
| """Reference implementation of swiglu-silu-clamp-mul activation. | |
| Computes: | |
| silu(x[:d]).clamp(max=limit) * x[d:].clamp(-limit, limit) | |
| """ | |
| gate, up = hidden_states.chunk(2, dim=-1) | |
| return F.silu(gate).clamp_(max=limit) * up.clamp(min=-limit, max=limit) |
| return out | ||
|
|
||
|
|
||
| def swiglu_silu_clamp_mul(hidden_states, limit: float = 7.0) -> torch.Tensor: |
There was a problem hiding this comment.
Add a type hint to the hidden_states parameter for consistency with the native implementation and better type safety.
| def swiglu_silu_clamp_mul(hidden_states, limit: float = 7.0) -> torch.Tensor: | |
| def swiglu_silu_clamp_mul(hidden_states: torch.Tensor, limit: float = 7.0) -> torch.Tensor: |
Add native swiglu-silu-clamp-mul