Skip to content

Add native swiglu-silu-clamp-mul#507

Open
McZyWu wants to merge 2 commits into
sgl-project:mainfrom
McZyWu:swiglu-silu-clamp-mul
Open

Add native swiglu-silu-clamp-mul#507
McZyWu wants to merge 2 commits into
sgl-project:mainfrom
McZyWu:swiglu-silu-clamp-mul

Conversation

@McZyWu
Copy link
Copy Markdown
Contributor

@McZyWu McZyWu commented May 20, 2026

Add native swiglu-silu-clamp-mul

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces the swiglu_silu_clamp_mul activation function and its native PyTorch implementation. Feedback suggests correcting misleading docstrings regarding output behavior, optimizing memory usage through in-place clamping, and improving type safety and parameter naming consistency.

Comment on lines +5 to +16
def swiglu_silu_clamp_mul_native(x: torch.Tensor, limit: float = 7.0) -> torch.Tensor:
"""Out-variant of swiglustep activation.

Writes into `out`:
silu(x[:d]).clamp(max=limit) * x[d:].clamp(-limit, limit)
"""
gate, up = x.chunk(2, dim=-1)
gate = F.silu(gate)
gate = gate.clamp(max=limit)
up = up.clamp(min=-limit, max=limit)
out = gate * up
return out
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The docstring for swiglu_silu_clamp_mul_native is slightly misleading as it refers to an "Out-variant" and "Writes into out", which typically implies a function that accepts an output buffer as an argument. Since this is a standard PyTorch implementation returning a new tensor, the docstring should be updated for clarity. Additionally, we can optimize memory usage by using an in-place clamp on the temporary tensor produced by F.silu, and rename the input parameter to hidden_states for consistency with the rest of the module.

Suggested change
def swiglu_silu_clamp_mul_native(x: torch.Tensor, limit: float = 7.0) -> torch.Tensor:
"""Out-variant of swiglustep activation.
Writes into `out`:
silu(x[:d]).clamp(max=limit) * x[d:].clamp(-limit, limit)
"""
gate, up = x.chunk(2, dim=-1)
gate = F.silu(gate)
gate = gate.clamp(max=limit)
up = up.clamp(min=-limit, max=limit)
out = gate * up
return out
def swiglu_silu_clamp_mul_native(hidden_states: torch.Tensor, limit: float = 7.0) -> torch.Tensor:
"""Reference implementation of swiglu-silu-clamp-mul activation.
Computes:
silu(x[:d]).clamp(max=limit) * x[d:].clamp(-limit, limit)
"""
gate, up = hidden_states.chunk(2, dim=-1)
return F.silu(gate).clamp_(max=limit) * up.clamp(min=-limit, max=limit)

return out


def swiglu_silu_clamp_mul(hidden_states, limit: float = 7.0) -> torch.Tensor:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Add a type hint to the hidden_states parameter for consistency with the native implementation and better type safety.

Suggested change
def swiglu_silu_clamp_mul(hidden_states, limit: float = 7.0) -> torch.Tensor:
def swiglu_silu_clamp_mul(hidden_states: torch.Tensor, limit: float = 7.0) -> torch.Tensor:

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant