Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ These approaches allow you to obtain results in seconds. However, selecting opti

This project includes code modified or inspired from the following open-source repositories:

* [https://github.com/huggingface/transformers](https://github.com/huggingface/transformers)
* [https://github.com/triton-lang/triton](https://github.com/triton-lang/triton)
* [https://github.com/ROCm/triton](https://github.com/ROCm/triton)
* [https://github.com/l1351868270/implicit_gemm.triton](https://github.com/l1351868270/implicit_gemm.triton)
Expand Down
109 changes: 99 additions & 10 deletions attention.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,20 @@
import functools
import math

import ninetoothed
import ninetoothed.language as ntl
import torch
import torch.nn as nn
import torch.nn.functional as F
import triton
import triton.language as tl
from ninetoothed import Symbol, Tensor
from transformers.models.llama.modeling_llama import repeat_kv

import rope


def arrangement(q, k, v, o):
def arrangement(q, k, v, scale, o):
BLOCK_SIZE_M = Symbol("BLOCK_SIZE_M", meta=True)
BLOCK_SIZE_N = Symbol("BLOCK_SIZE_N", meta=True)

Expand All @@ -30,11 +37,11 @@ def arrange_k_or_v(input):

q_arranged = arrange_q_or_o(q)

return q_arranged, arrange_k_or_v(k), arrange_k_or_v(v), arrange_q_or_o(o)
return q_arranged, arrange_k_or_v(k), arrange_k_or_v(v), scale, arrange_q_or_o(o)


def application(q, k, v, o):
q_loaded = (q * 1.44269504089).to(ntl.float16)
def application(q, k, v, scale, o):
q_loaded = (q * scale * 1.44269504089).to(ntl.float16)

acc = ntl.zeros((q.shape[-2], q.shape[-1]), dtype=ntl.float32)
l_i = ntl.full((q.shape[-2],), 1, dtype=ntl.float32)
Expand All @@ -60,17 +67,79 @@ def application(q, k, v, o):
Tensor(4, shape_options=(None, None, None, {"constexpr": True, "upper_bound": 128}))
for _ in range(4)
)
attention_kernel = ninetoothed.make(arrangement, application, (q, k, v, o))
attention_kernel = ninetoothed.make(arrangement, application, (q, k, v, Tensor(0), o))


def attention(q, k, v, scale=None):
if scale is None:
scale = 1 / math.sqrt(q.shape[-1])

def attention(q, k, v):
o = torch.empty_like(q, dtype=v.dtype)

attention_kernel(q, k, v, o)
attention_kernel(q, k, v, scale, o)

return o


class Attention(nn.Module):
def __init__(self, other):
super().__init__()

self.__dict__ = other.__dict__

def forward(
self,
hidden_states,
position_embeddings,
attention_mask,
past_key_value,
cache_position,
**kwargs,
):
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)

query_states = self.q_proj(hidden_states).view(hidden_shape)
key_states = self.k_proj(hidden_states).view(hidden_shape)
value_states = self.v_proj(hidden_states).view(hidden_shape)

cos_table, sin_table = position_embeddings

_rope(query_states, sin_table, cos_table)
_rope(key_states, sin_table, cos_table)

query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)

if past_key_value is not None:
cache_kwargs = {
"sin": sin_table,
"cos": cos_table,
"cache_position": cache_position,
}
key_states, value_states = past_key_value.update(
key_states, value_states, self.layer_idx, cache_kwargs
)

key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)

attn_dtype = torch.float16
attn_output = attention(
query_states.to(attn_dtype),
key_states.to(attn_dtype),
value_states.to(attn_dtype),
scale=self.scaling,
).to(query_states.dtype)
attn_output = attn_output.transpose(1, 2)

attn_output = attn_output.reshape(*input_shape, -1).contiguous()
attn_output = self.o_proj(attn_output)

return attn_output, None


@triton.autotune(
configs=[
triton.Config(
Expand Down Expand Up @@ -116,6 +185,7 @@ def triton_attention_kernel(
o_stride_h,
o_stride_m,
o_stride_n,
scale,
SEQ_LEN: tl.constexpr,
EMB_DIM: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr,
Expand Down Expand Up @@ -164,7 +234,7 @@ def triton_attention_kernel(
order=(1, 0),
)

q = (tl.load(q_block_ptr) * 1.44269504089).to(q_block_ptr.type.element_ty)
q = (tl.load(q_block_ptr) * scale * 1.44269504089).to(q_block_ptr.type.element_ty)

acc = tl.zeros((BLOCK_SIZE_M, EMB_DIM), dtype=tl.float32)
l_i = tl.full((BLOCK_SIZE_M,), 1, dtype=tl.float32)
Expand Down Expand Up @@ -194,11 +264,14 @@ def triton_attention_kernel(
tl.store(o_block_ptr, acc.to(o_ptr.type.element_ty))


def triton_attention(q, k, v):
def triton_attention(q, k, v, scale=None):
o = torch.empty_like(q)

batch_size, num_heads, seq_len, emb_dim = q.shape

if scale is None:
scale = 1 / math.sqrt(emb_dim)

def grid(meta):
return (
triton.cdiv(seq_len, meta["BLOCK_SIZE_M"]),
Expand All @@ -215,13 +288,29 @@ def grid(meta):
*k.stride(),
*v.stride(),
*o.stride(),
scale=scale,
SEQ_LEN=seq_len,
EMB_DIM=emb_dim,
)

return o


_rope_kernel = ninetoothed.make(
functools.partial(rope.arrangement, interleaved=False),
rope.application,
rope.tensors,
)


def _rope(x, sin_table, cos_table):
_, _, num_heads, _ = x.shape
sin_table = sin_table.unsqueeze(2).expand(-1, -1, num_heads, -1)
cos_table = cos_table.unsqueeze(2).expand(-1, -1, num_heads, -1)

_rope_kernel(x, sin_table, cos_table)


if __name__ == "__main__":
torch.manual_seed(0)
shape = (2, 4, 1024, 64)
Expand All @@ -231,7 +320,7 @@ def grid(meta):
v = torch.randn(shape, dtype=dtype, device="cuda")

ninetoothed_output = attention(q, k, v)
torch_output = F.scaled_dot_product_attention(q, k, v, scale=1)
torch_output = F.scaled_dot_product_attention(q, k, v)
triton_output = triton_attention(q, k, v)
print(ninetoothed_output)
print(torch_output)
Expand Down
71 changes: 71 additions & 0 deletions bmm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import ninetoothed
import torch
from ninetoothed import Symbol, Tensor

import matmul

BLOCK_SIZE_M = Symbol("BLOCK_SIZE_M", meta=True)
BLOCK_SIZE_N = Symbol("BLOCK_SIZE_N", meta=True)
BLOCK_SIZE_K = Symbol("BLOCK_SIZE_K", meta=True)


def arrangement(
lhs,
rhs,
output,
BLOCK_SIZE_M=BLOCK_SIZE_M,
BLOCK_SIZE_N=BLOCK_SIZE_N,
BLOCK_SIZE_K=BLOCK_SIZE_K,
):
output_arranged = output.tile((1, BLOCK_SIZE_M, BLOCK_SIZE_N))
output_arranged.dtype = output_arranged.dtype.squeeze(0)

lhs_arranged = lhs.tile((1, BLOCK_SIZE_M, BLOCK_SIZE_K))
lhs_arranged = lhs_arranged.tile((1, 1, -1))
lhs_arranged = lhs_arranged.expand((-1, -1, output_arranged.shape[-1]))
lhs_arranged.dtype = lhs_arranged.dtype.squeeze((0, 1))
lhs_arranged.dtype.dtype = lhs_arranged.dtype.dtype.squeeze(0)

rhs_arranged = rhs.tile((1, BLOCK_SIZE_K, BLOCK_SIZE_N))
rhs_arranged = rhs_arranged.tile((1, -1, 1))
rhs_arranged = rhs_arranged.expand((-1, output_arranged.shape[-2], -1))
rhs_arranged.dtype = rhs_arranged.dtype.squeeze((0, 2))
rhs_arranged.dtype.dtype = rhs_arranged.dtype.dtype.squeeze(0)

return lhs_arranged, rhs_arranged, output_arranged


bmm_kernel = ninetoothed.make(
arrangement, matmul.application, (Tensor(3), Tensor(3), Tensor(3))
)


def bmm(lhs, rhs):
output = torch.empty(
(lhs.shape[0], lhs.shape[-2], rhs.shape[-1]), dtype=lhs.dtype, device=lhs.device
)

bmm_kernel(lhs, rhs, output)

return output


if __name__ == "__main__":
torch.manual_seed(0)

batch_size, m, n, k = 4, 512, 2028, 1024
dtype = torch.float16
device = "cuda"
lhs = torch.randn(batch_size, m, k, dtype=dtype, device=device)
rhs = torch.randn(batch_size, k, n, dtype=dtype, device=device)

ninetoothed_output = bmm(lhs, rhs)
torch_output = torch.bmm(lhs, rhs)

print(ninetoothed_output)
print(torch_output)

if torch.allclose(ninetoothed_output, torch_output):
print("✅ NineToothed and PyTorch match.")
else:
print("❌ NineToothed and PyTorch differ.")
64 changes: 64 additions & 0 deletions infer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import argparse

from transformers import AutoModelForCausalLM, AutoTokenizer

from attention import Attention
from linear import Linear
from rms_norm import RMSNorm
from silu import SiLU
from utils import replace_module

if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Generate text using a causal language model."
)

parser.add_argument(
"--model",
type=str,
required=True,
help="Path to the model or model identifier from Hugging Face.",
)
parser.add_argument(
"--prompts",
type=str,
nargs="+",
required=True,
help="List of prompts for text generation.",
)
parser.add_argument(
"--max-new-tokens",
type=int,
default=64,
help="Maximum number of new tokens to generate.",
)
parser.add_argument(
"--device",
type=str,
default="cpu",
help='Device to use for inference (e.g., "cuda", "cpu").',
)

args = parser.parse_args()

model_name_or_path = args.model
prompts = args.prompts
max_new_tokens = args.max_new_tokens
device = args.device

tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
model = AutoModelForCausalLM.from_pretrained(model_name_or_path).to(device)

tokenizer.pad_token = tokenizer.eos_token
model.generation_config.pad_token_id = tokenizer.pad_token_id

replace_module(model, Attention)
replace_module(model, Linear)
replace_module(model, RMSNorm)
replace_module(model, SiLU)

inputs = tokenizer(prompts, padding=True, return_tensors="pt").to(device)
outputs = model.generate(**inputs, max_new_tokens=max_new_tokens)
strings = tokenizer.batch_decode(outputs, skip_special_tokens=True)

print(strings)
13 changes: 13 additions & 0 deletions linear.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import torch.nn as nn

from bmm import bmm


class Linear(nn.Module):
def __init__(self, other):
super().__init__()

self.__dict__ = other.__dict__

def forward(self, input):
return bmm(input, self.weight.T.unsqueeze(0).expand(input.shape[0], -1, -1))
41 changes: 41 additions & 0 deletions rms_norm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import ninetoothed
import ninetoothed.language as ntl
import torch
import torch.nn as nn
from ninetoothed import Symbol, Tensor

BLOCK_SIZE = Symbol("BLOCK_SIZE", constexpr=True)


@ninetoothed.jit
def fused_rms_norm_kernel(
x: Tensor(2).tile((1, BLOCK_SIZE)),
w: Tensor(2).tile((1, BLOCK_SIZE)),
y: Tensor(2).tile((1, BLOCK_SIZE)),
eps: Tensor(0),
):
x_fp32 = ntl.cast(x, ntl.float32)
y = x_fp32 * ntl.rsqrt(ntl.sum(x_fp32 * x_fp32) / x.shape[-1] + eps) * w # noqa: F841


def fused_rms_norm(x, w, eps=None):
if eps is None:
eps = torch.finfo(x.dtype).eps()

x_2d = x.view(-1, x.shape[-1])
w_2d = w.expand_as(x_2d)
y_2d = torch.empty_like(x_2d)

fused_rms_norm_kernel(x_2d, w_2d, y_2d, eps, BLOCK_SIZE=x.shape[-1])

return y_2d.view(x.shape)


class RMSNorm(nn.Module):
def __init__(self, other):
super().__init__()

self.__dict__ = other.__dict__

def forward(self, x):
return fused_rms_norm(x, self.weight, self.variance_epsilon)
Loading