Skip to content

Commit 6743b36

Browse files
committed
feat: init gemma3
1 parent 20fcd7f commit 6743b36

15 files changed

Lines changed: 362 additions & 27 deletions

File tree

python/minisgl/attention/base.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,15 @@ def get_last_indices(self, bs: int) -> torch.Tensor: ...
1818
class BaseAttnBackend(ABC):
1919
@abstractmethod
2020
def forward(
21-
self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, layer_id: int, batch: Batch
21+
self,
22+
q: torch.Tensor,
23+
k: torch.Tensor,
24+
v: torch.Tensor,
25+
layer_id: int,
26+
batch: Batch,
27+
*,
28+
window_size: tuple[int, int] = (-1, -1),
29+
softmax_scale: float | None = None,
2230
) -> torch.Tensor: ...
2331

2432
@abstractmethod
@@ -44,10 +52,26 @@ def __init__(
4452
self.decode_backend = decode_backend
4553

4654
def forward(
47-
self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, layer_id: int, batch: Batch
55+
self,
56+
q: torch.Tensor,
57+
k: torch.Tensor,
58+
v: torch.Tensor,
59+
layer_id: int,
60+
batch: Batch,
61+
*,
62+
window_size: tuple[int, int] = (-1, -1),
63+
softmax_scale: float | None = None,
4864
) -> torch.Tensor:
4965
backend = self.prefill_backend if batch.is_prefill else self.decode_backend
50-
return backend.forward(q, k, v, layer_id, batch)
66+
return backend.forward(
67+
q,
68+
k,
69+
v,
70+
layer_id,
71+
batch,
72+
window_size=window_size,
73+
softmax_scale=softmax_scale,
74+
)
5175

5276
def prepare_metadata(self, batch: Batch) -> None:
5377
backend = self.prefill_backend if batch.is_prefill else self.decode_backend

python/minisgl/attention/fa.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,15 @@ def __init__(self, config: ModelConfig):
4646
self.version = 4 if is_sm100_supported() else 3
4747

4848
def forward(
49-
self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, layer_id: int, batch: Batch
49+
self,
50+
q: torch.Tensor,
51+
k: torch.Tensor,
52+
v: torch.Tensor,
53+
layer_id: int,
54+
batch: Batch,
55+
*,
56+
window_size: tuple[int, int] = (-1, -1),
57+
softmax_scale: float | None = None,
5058
) -> torch.Tensor:
5159
metadata = batch.attn_metadata
5260
assert isinstance(metadata, FAMetadata)
@@ -60,8 +68,9 @@ def forward(
6068
cu_seqlens_q=metadata.cu_seqlens_q,
6169
cu_seqlens_k=metadata.cu_seqlens_k,
6270
max_seqlen_q=metadata.max_seqlen_q,
63-
softmax_scale=self.scale,
71+
softmax_scale=self.scale if softmax_scale is None else softmax_scale,
6472
version=self.version,
73+
window_size=window_size,
6574
)
6675

6776
def prepare_metadata(self, batch: Batch) -> None:

python/minisgl/attention/fi.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,11 +174,22 @@ def _get_ones_cpu(self, bs: int) -> torch.Tensor:
174174
return self.cached_ones_cpu[:bs]
175175

176176
def forward(
177-
self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, layer_id: int, batch: Batch
177+
self,
178+
q: torch.Tensor,
179+
k: torch.Tensor,
180+
v: torch.Tensor,
181+
layer_id: int,
182+
batch: Batch,
183+
*,
184+
window_size: tuple[int, int] = (-1, -1),
185+
softmax_scale: float | None = None,
178186
) -> torch.Tensor:
179187
def _flatten_cache(cache: torch.Tensor) -> torch.Tensor: # treat page = 1
180188
return cache.view(-1, 1, cache.shape[2], cache.shape[3])
181189

190+
if window_size != (-1, -1) or softmax_scale is not None:
191+
raise NotImplementedError
192+
182193
metadata = batch.attn_metadata
183194
assert isinstance(metadata, FIMetadata)
184195
self._initialize_metadata_once(metadata)

python/minisgl/attention/trtllm.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,11 +47,22 @@ def __init__(self, config: ModelConfig):
4747
)
4848

4949
def forward(
50-
self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, layer_id: int, batch: Batch
50+
self,
51+
q: torch.Tensor,
52+
k: torch.Tensor,
53+
v: torch.Tensor,
54+
layer_id: int,
55+
batch: Batch,
56+
*,
57+
window_size: tuple[int, int] = (-1, -1),
58+
softmax_scale: float | None = None,
5159
) -> torch.Tensor:
5260
from flashinfer.decode import trtllm_batch_decode_with_kv_cache
5361
from flashinfer.prefill import trtllm_batch_context_with_kv_cache
5462

63+
if window_size != (-1, -1) or softmax_scale is not None:
64+
raise NotImplementedError
65+
5566
metadata = batch.attn_metadata
5667
assert isinstance(metadata, TRTLLMMetadata)
5768
self.kvcache.store_kv(k, v, batch.out_loc, layer_id)

python/minisgl/layers/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from .activation import gelu_and_mul, silu_and_mul
1+
from .activation import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul
22
from .attention import AttentionLayer
33
from .base import BaseOP, OPList, StateLessOP
44
from .embedding import ParallelLMHead, VocabParallelEmbedding
@@ -10,12 +10,13 @@
1010
LinearRowParallel,
1111
)
1212
from .moe import MoELayer
13-
from .norm import RMSNorm, RMSNormFused
13+
from .norm import Gemma3RMSNorm, RMSNorm, RMSNormFused
1414
from .rotary import get_rope, set_rope_device
1515

1616
__all__ = [
1717
"silu_and_mul",
1818
"gelu_and_mul",
19+
"gelu_tanh_and_mul",
1920
"AttentionLayer",
2021
"BaseOP",
2122
"StateLessOP",
@@ -26,6 +27,7 @@
2627
"LinearRowParallel",
2728
"LinearOProj",
2829
"LinearQKVMerged",
30+
"Gemma3RMSNorm",
2931
"RMSNorm",
3032
"RMSNormFused",
3133
"get_rope",

python/minisgl/layers/activation.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,4 +18,10 @@ def gelu_and_mul(x: torch.Tensor, out: torch.Tensor | None = None):
1818
return gelu_and_mul(x, out=out)
1919

2020

21-
__all__ = ["silu_and_mul", "gelu_and_mul"]
21+
def gelu_tanh_and_mul(x: torch.Tensor, out: torch.Tensor | None = None):
22+
from flashinfer import gelu_tanh_and_mul
23+
24+
return gelu_tanh_and_mul(x, out=out)
25+
26+
27+
__all__ = ["silu_and_mul", "gelu_and_mul", "gelu_tanh_and_mul"]

python/minisgl/layers/attention.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ def __init__(
2525
rotary_config: RotaryConfig,
2626
q_norm: RMSNorm | None = None,
2727
k_norm: RMSNorm | None = None,
28+
sliding_window_size: int | None = None,
29+
softmax_scale: float | None = None,
2830
):
2931
assert num_qo_heads % num_kv_heads == 0
3032
self.layer_id = layer_id
@@ -43,6 +45,11 @@ def __init__(
4345
)
4446
self.q_norm = q_norm
4547
self.k_norm = k_norm
48+
# sliding_window_size: HF-convention (inclusive). Converted to FA (left, right) here.
49+
self._window_size = (
50+
(sliding_window_size - 1, 0) if sliding_window_size is not None else (-1, -1)
51+
)
52+
self._softmax_scale = softmax_scale
4653

4754
def forward(self, qkv: torch.Tensor) -> torch.Tensor:
4855
ctx = get_global_ctx()
@@ -53,5 +60,13 @@ def forward(self, qkv: torch.Tensor) -> torch.Tensor:
5360
self.k_norm.forward_inplace(k.view(-1, self.num_kv_heads, self.head_dim))
5461
q, k = self.rotary.forward(ctx.batch.positions, q, k)
5562
q = q.view(-1, self.num_qo_heads, self.head_dim)
56-
o = ctx.attn_backend.forward(q, k, v, self.layer_id, ctx.batch)
63+
o = ctx.attn_backend.forward(
64+
q,
65+
k,
66+
v,
67+
self.layer_id,
68+
ctx.batch,
69+
window_size=self._window_size,
70+
softmax_scale=self._softmax_scale,
71+
)
5772
return o.view(-1, self.qo_attn_dim)

python/minisgl/layers/linear.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,9 +100,11 @@ def __init__(self, input_size: int, output_size: int, has_bias: bool):
100100
super().__init__(full_isize, full_osize, local_isize, local_osize, has_bias)
101101

102102
def forward(self, x: torch.Tensor) -> torch.Tensor:
103-
y = F.linear(x, self.weight, self.bias)
103+
y = F.linear(x, self.weight, None)
104104
if self._tp_size > 1:
105105
y = self._comm.all_reduce(y)
106+
if self.bias is not None:
107+
y = y + self.bias
106108
return y
107109

108110

@@ -121,7 +123,9 @@ def __init__(
121123
super().__init__(input_size, output_size, local_input_size, local_output_size, has_bias)
122124

123125
def forward(self, x: torch.Tensor) -> torch.Tensor:
124-
y = F.linear(x, self.weight, self.bias)
126+
y = F.linear(x, self.weight, None)
125127
if self._tp_size > 1:
126128
y = self._comm.all_reduce(y)
129+
if self.bias is not None:
130+
y = y + self.bias
127131
return y

python/minisgl/layers/norm.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from __future__ import annotations
2+
13
from typing import Tuple
24

35
import torch
@@ -36,3 +38,20 @@ def forward(
3638
return self.rmsnorm(x, self.weight, self.eps), x
3739
self.fused_add_rmsnorm(x, residual, self.weight, self.eps)
3840
return x, residual
41+
42+
43+
class Gemma3RMSNorm(BaseOP):
44+
45+
def __init__(self, size: int, eps: float) -> None:
46+
from flashinfer import gemma_rmsnorm
47+
48+
self.eps = eps
49+
self.weight = torch.zeros(size)
50+
self.gemma_rmsnorm = gemma_rmsnorm
51+
52+
def forward(self, x: torch.Tensor) -> torch.Tensor:
53+
return self.gemma_rmsnorm(x, self.weight, self.eps)
54+
55+
def forward_inplace(self, x: torch.Tensor) -> None:
56+
shape = x.shape # [t, h, d]
57+
x.copy_(self.gemma_rmsnorm(x.view(-1, shape[-1]), self.weight, self.eps).view(shape))

python/minisgl/layers/rotary.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@ def __init__(
2020
) -> None:
2121
super().__init__()
2222
self.head_size = head_size
23-
assert rotary_dim == head_size
2423
inv_freq = 1.0 / (base ** (torch.arange(0, rotary_dim, 2, dtype=torch.float) / rotary_dim))
2524
if post_process is not None:
2625
inv_freq = post_process(inv_freq)
@@ -30,8 +29,8 @@ def __init__(
3029
sin = freqs.sin()
3130
# buffer, so don't load/save
3231
self._cos_sin_cache = torch.cat((cos, sin), dim=-1)
33-
assert self.head_size in [64, 128, 256, 512]
3432

33+
assert self.head_size in [64, 128, 256, 512]
3534
from flashinfer import apply_rope_with_cos_sin_cache_inplace
3635

3736
self.apply_rope_with_cos_sin_cache_inplace = apply_rope_with_cos_sin_cache_inplace
@@ -97,15 +96,20 @@ def post_process(inv_freq: torch.Tensor) -> torch.Tensor:
9796
orig_max_pos: int = rope_scaling["original_max_position_embeddings"]
9897

9998
def _find_correction_dim(num_rotations: float) -> float:
100-
return rotary_dim * math.log(orig_max_pos / (num_rotations * 2 * math.pi)) / (2 * math.log(base))
99+
return (
100+
rotary_dim
101+
* math.log(orig_max_pos / (num_rotations * 2 * math.pi))
102+
/ (2 * math.log(base))
103+
)
101104

102105
low = max(math.floor(_find_correction_dim(beta_fast)), 0)
103106
high = min(math.ceil(_find_correction_dim(beta_slow)), rotary_dim // 2 - 1)
104107

105108
def post_process(inv_freq: torch.Tensor) -> torch.Tensor:
106109
ramp = torch.clamp(
107110
(torch.arange(rotary_dim // 2, dtype=torch.float32) - low) / max(high - low, 1),
108-
0, 1,
111+
0,
112+
1,
109113
)
110114
return (inv_freq / factor) * ramp + inv_freq * (1 - ramp)
111115

@@ -143,4 +147,4 @@ def get_rope(
143147
return _get_rope(head_dim, rotary_dim, max_position, base, rope_map)
144148

145149

146-
__all__ = ["get_rope", "RotaryEmbedding", "set_rope_device"]
150+
__all__ = ["get_rope", "RotaryEmbedding", "set_rope_device"]

0 commit comments

Comments
 (0)