diff --git a/configs/chat_config_24gb.json b/configs/chat_config_24gb.json index b3811d7..2ab26bf 100644 --- a/configs/chat_config_24gb.json +++ b/configs/chat_config_24gb.json @@ -1,7 +1,7 @@ { - "model": "hugging-quants/Meta-Llama-3.1-70B-Instruct-AWQ-INT4", - "draft_model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", - "offload": true, + "model": "Qwen/Qwen2.5-3B-Instruct", + "draft_model": "Qwen/Qwen2.5-0.5B-Instruct", + "offload": false, "max_length": 8192, "num_cache_layers": 16, "generation_length": 256, @@ -14,5 +14,5 @@ "num_beams": 24, "depth": 24, "engine": "dynamic", - "template": "meta-llama3" + "template": "qwen" } diff --git a/examples/generate.py b/examples/generate.py index d63238b..43c5ef3 100644 --- a/examples/generate.py +++ b/examples/generate.py @@ -1,4 +1,8 @@ import os +import sys +# Find local UMbreLLa first +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) + os.environ["TOKENIZERS_PARALLELISM"] = "false" from umbrella.models.auto_model import AutoModelLM from umbrella.logging_config import setup_logger diff --git a/requirements.txt b/requirements.txt index 85e36d6..f0fcdde 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,6 +3,4 @@ transformers==4.47.0 huggingface-hub==0.27.0 transformers-stream-generator==0.0.5 optimum==1.23.3 -autoawq==0.2.7.post3 -autoawq-kernels==0.0.8 gradio \ No newline at end of file diff --git a/umbrella/models/auto_model.py b/umbrella/models/auto_model.py index dab1a2f..d13cc59 100644 --- a/umbrella/models/auto_model.py +++ b/umbrella/models/auto_model.py @@ -1,4 +1,5 @@ from .llama import Llama, LlamaAwq, LlamaOffload, LlamaAwqOffload, LlamaCudagraph +from .qwen import Qwen, QwenOffload, QwenCudagraph class AutoModelLM: """ @@ -15,6 +16,8 @@ class AutoModelLM: "meta-llama/Llama-3.1-8B-Instruct": LlamaOffload, "meta-llama/Meta-Llama-3-70B-Instruct": LlamaOffload, "meta-llama/Meta-Llama-3-8B-Instruct": LlamaOffload, + "Qwen/Qwen2.5-3B-Instruct": QwenOffload, + "Qwen/Qwen2.5-0.5B-Instruct": QwenOffload } _MODEL_MAPPING = { @@ -37,7 +40,9 @@ class AutoModelLM: "Zhuominc/Coder-400M": Llama, "Zhuominc/Coder-400M-IT": Llama, "Zhuominc/FastCode-500M": Llama, - "InfiniAILab/CodeDrafter-500M": Llama + "InfiniAILab/CodeDrafter-500M": Llama, + "Qwen/Qwen2.5-3B-Instruct": Qwen, + "Qwen/Qwen2.5-0.5B-Instruct": Qwen } _CUDAGRAPH_MODEL_MAPPING = { @@ -53,7 +58,9 @@ class AutoModelLM: "Zhuominc/Coder-400M": LlamaCudagraph, "Zhuominc/Coder-400M-IT": LlamaCudagraph, "Zhuominc/FastCode-500M": LlamaCudagraph, - "InfiniAILab/CodeDrafter-500M": LlamaCudagraph + "InfiniAILab/CodeDrafter-500M": LlamaCudagraph, + "Qwen/Qwen2.5-3B-Instruct": QwenCudagraph, + "Qwen/Qwen2.5-0.5B-Instruct": QwenCudagraph } @classmethod diff --git a/umbrella/models/qwen.py b/umbrella/models/qwen.py new file mode 100644 index 0000000..7f993b2 --- /dev/null +++ b/umbrella/models/qwen.py @@ -0,0 +1,533 @@ +from transformers import Qwen2ForCausalLM, Qwen2Config, AutoModelForCausalLM +import torch +import torch.nn.functional as F +import gc +import flashinfer +from ..attn.cache import KV_Cache, StaticKV_Cache +from .qwen_layer import QwenLayer, QwenAwqLayer, QwenPackedLayer +from .base import LLMBase +from .model_utils import apply_rotary_pos_emb, layer_norm, capture_graph +from tqdm import tqdm + +class Qwen(LLMBase): + def __init__(self, + model_name: str, + batch_size :int = 1, + max_length :int = 256, + device :str = 'cuda:0', + dtype = torch.float16) -> None: + + super().__init__() + self.batch_size = batch_size + self.device = device + self.dtype = dtype + self.config = Qwen2Config.from_pretrained(model_name) + self.model_name = model_name + self.max_length = max_length + self.hidden_size = self.config.hidden_size + self.num_heads = self.config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = self.config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = self.config.max_position_embeddings + self.rope_theta = self.config.rope_theta + self.eos_tokens = self.config.eos_token_id if (isinstance(self.config.eos_token_id, list)) else [self.config.eos_token_id] + + def alloc(self, **kwargs): + + self.kv_cache = KV_Cache(self.config, max_length=self.max_length, device=self.device, dtype=self.dtype, batch_size=self.batch_size) + hf_model = Qwen2ForCausalLM.from_pretrained(self.model_name, torch_dtype=self.dtype) + self.embed_tokens = hf_model.model.embed_tokens.weight.detach().to(self.device) + if self.config.tie_word_embeddings: + self.lm_head = self.embed_tokens + else: + self.lm_head = hf_model.lm_head.weight.detach().to(self.device) + + self.norm_weight = hf_model.model.norm.weight.detach().to(self.device) + self.norm_variance_epsilon = hf_model.model.norm.variance_epsilon + + self.inv_freq = hf_model.model.rotary_emb.inv_freq.detach().to(self.device) + self.attention_scaling = hf_model.model.rotary_emb.attention_scaling + position_ids = torch.arange(0, self.max_length).unsqueeze(0).to(self.device) + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + self.cos_cache = emb.cos()[0] + self.sin_cache = emb.sin()[0] + self.cos_cache = self.cos_cache * self.attention_scaling + self.sin_cache = self.sin_cache * self.attention_scaling + self.cos_cache = self.cos_cache.to(self.dtype) + self.sin_cache = self.sin_cache.to(self.dtype) + + self.layers: list[QwenLayer] = [] + + for idx, hf_layer in enumerate(hf_model.model.layers): + layer = QwenLayer(idx) + layer.init_parameters(hf_layer=hf_layer) + layer.to(self.device) + self.layers.append(layer) + hf_model.model.layers[idx] = None + gc.collect() + + self.num_layers = len(self.layers) + + + @torch.inference_mode() + def layer_compute(self, + buffer: QwenLayer, + layer_idx: int, + hidden_states: torch.FloatTensor, + position_ids: torch.LongTensor, + attention_mask: torch.FloatTensor, + storage_ids: torch.LongTensor): + + residual = hidden_states + bsz, q_len, _ = hidden_states.size() + + hidden_states = layer_norm(hidden_states, buffer.input_layernorm_variance_epsilon, buffer.input_layernorm_weight) + bsz, q_len, _ = hidden_states.size() + query_states = F.linear(hidden_states, buffer.wq) + key_states = F.linear(hidden_states, buffer.wk) + value_states = F.linear(hidden_states, buffer.wv) + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim) + + + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, self.cos_cache, self.sin_cache, position_ids) + hidden_states = self.kv_cache.compute_attention( + query_states, key_states, value_states, layer_idx, storage_ids, attention_mask + ) + hidden_states = hidden_states.reshape(bsz, q_len, self.hidden_size) + + hidden_states = F.linear(hidden_states, buffer.wo) + hidden_states = residual + hidden_states + residual = hidden_states + hidden_states = layer_norm(hidden_states, buffer.post_attention_layernorm_variance_epsilon, buffer.post_attention_layernorm_weight) + up = F.linear(hidden_states, buffer.up_proj) + gate = F.linear(hidden_states, buffer.gate_proj) + gate = F.silu(gate) + hidden_states = gate * up + hidden_states = F.linear(hidden_states, buffer.down_proj) + hidden_states = residual + hidden_states + + return hidden_states + + + @torch.inference_mode() + def inference(self, + input_ids: torch.LongTensor, + position_ids: torch.LongTensor, + attention_mask: torch.FloatTensor, + storage_ids: torch.LongTensor): + + hidden_states = F.embedding(input_ids, self.embed_tokens) + for idx in range(self.num_layers): + hidden_states = self.layer_compute(self.layers[idx], idx, hidden_states, position_ids, attention_mask, storage_ids) + + b, s, h = hidden_states.shape + + hidden_states = hidden_states.reshape(b * s, h) + hidden_states = flashinfer.rmsnorm(hidden_states, self.norm_weight, self.norm_variance_epsilon) + hidden_states = hidden_states.reshape(b, s, h) + logits = F.linear(hidden_states, self.lm_head).float() + return logits + + def gather_kv_incremental(self, indices: torch.LongTensor, offset:int): + + self.kv_cache.gather_kv_incremental(indices=indices, offset=offset) + + def clear(self): + + self.kv_cache.clear() + +class QwenOffload(Qwen): + def __init__(self, model_name, batch_size = 1, max_length = 256, device = 'cuda:0', dtype=torch.float16): + super().__init__(model_name, batch_size, max_length, device, dtype) + self.load_stream = torch.cuda.Stream(device=device) + + def alloc(self, **kwargs): + + + self.num_cache_layers = kwargs["num_cache_layers"] if 'num_cache_layers' in kwargs else 0 + self.kv_cache = KV_Cache(self.config, max_length=self.max_length, device=self.device, dtype=self.dtype, batch_size=self.batch_size) + + hf_model = Qwen2ForCausalLM.from_pretrained(self.model_name, torch_dtype=self.dtype) + self.embed_tokens = hf_model.model.embed_tokens.weight.detach().to(self.device) + if self.config.tie_word_embeddings: + self.lm_head = self.embed_tokens + else: + self.lm_head = hf_model.lm_head.weight.detach().to(self.device) + + self.norm_weight = hf_model.model.norm.weight.detach().to(self.device) + self.norm_variance_epsilon = hf_model.model.norm.variance_epsilon + + self.inv_freq = hf_model.model.rotary_emb.inv_freq.detach().to(self.device) + self.attention_scaling = hf_model.model.rotary_emb.attention_scaling + position_ids = torch.arange(0, self.max_length).unsqueeze(0).to(self.device) + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + self.cos_cache = emb.cos()[0] + self.sin_cache = emb.sin()[0] + self.cos_cache = self.cos_cache * self.attention_scaling + self.sin_cache = self.sin_cache * self.attention_scaling + self.cos_cache = self.cos_cache.to(self.dtype) + self.sin_cache = self.sin_cache.to(self.dtype) + + self.layers: list[QwenLayer] = [] + + for idx, hf_layer in tqdm(enumerate(hf_model.model.layers), desc="initial offloaded model"): + layer = QwenLayer(idx) + layer.init_parameters(hf_layer=hf_layer) + if idx < self.num_cache_layers: + layer.to(self.device) + self.layers.append(layer) + hf_model.model.layers[idx] = None + gc.collect() + + self.num_layers = len(self.layers) + assert self.num_layers % 2 == 0 + self.buffer = [QwenLayer(-1, self.device) for _ in range(2)] + self.buffer[0].alloc_space(self.layers[0], self.device) + self.buffer[1].alloc_space(self.layers[0], self.device) + + @torch.inference_mode() + def inference(self, + input_ids: torch.LongTensor, + position_ids: torch.LongTensor, + attention_mask: torch.FloatTensor, + storage_ids: torch.LongTensor): + + hidden_states = F.embedding(input_ids, self.embed_tokens) + if self.buffer[0].layer_idx != 0: + self.buffer[0].copy(self.layers[0]) + torch.cuda.synchronize() + for idx in range(self.num_layers): + with torch.cuda.stream(self.load_stream): + self.buffer[(idx + 1) % 2].copy(self.layers[(idx + 1)% self.num_layers]) + + hidden_states = self.layer_compute(self.buffer[idx % 2], idx, hidden_states, position_ids, attention_mask, storage_ids) + torch.cuda.synchronize() + b, s, h = hidden_states.shape + + hidden_states = hidden_states.reshape(b * s, h) + hidden_states = flashinfer.rmsnorm(hidden_states, self.norm_weight, self.norm_variance_epsilon) + hidden_states = hidden_states.reshape(b, s, h) + logits = F.linear(hidden_states, self.lm_head).float() + return logits + + +class QwenAwq(Qwen): + def alloc(self, **kwargs): + + self.kv_cache = KV_Cache(self.config, max_length=self.max_length, device=self.device, dtype=self.dtype, batch_size=self.batch_size) + + hf_model = AutoModelForCausalLM.from_pretrained(self.model_name, torch_dtype=self.dtype) + self.embed_tokens = hf_model.model.embed_tokens.weight.detach().to(self.device) + if self.config.tie_word_embeddings: + self.lm_head = self.embed_tokens + else: + self.lm_head = hf_model.lm_head.weight.detach().to(self.device) + self.norm_weight = hf_model.model.norm.weight.detach().to(self.device) + self.norm_variance_epsilon = hf_model.model.norm.variance_epsilon + self.inv_freq = hf_model.model.rotary_emb.inv_freq.detach().to(self.device) + self.attention_scaling = hf_model.model.rotary_emb.attention_scaling + position_ids = torch.arange(0, self.max_length).unsqueeze(0).to(self.device) + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + self.cos_cache = emb.cos()[0] + self.sin_cache = emb.sin()[0] + self.cos_cache = self.cos_cache * self.attention_scaling + self.sin_cache = self.sin_cache * self.attention_scaling + self.cos_cache = self.cos_cache.to(self.dtype) + self.sin_cache = self.sin_cache.to(self.dtype) + + self.layers: list[QwenAwqLayer] = [] + + for idx, hf_layer in enumerate(hf_model.model.layers): + layer = QwenAwqLayer(idx) + layer.init_parameters(hf_layer=hf_layer) + layer.to(self.device) + self.layers.append(layer) + hf_model.model.layers[idx] = None + gc.collect() + self.num_layers = len(self.layers) + + + @torch.inference_mode() + def layer_compute(self, + buffer: QwenAwqLayer, + layer_idx: int, + hidden_states: torch.FloatTensor, + position_ids: torch.LongTensor, + attention_mask: torch.FloatTensor, + storage_ids: torch.LongTensor): + + residual = hidden_states + bsz, q_len, _ = hidden_states.size() + + hidden_states = layer_norm(hidden_states, buffer.input_layernorm_variance_epsilon, buffer.input_layernorm_weight) + bsz, q_len, _ = hidden_states.size() + query_states = buffer.wq.apply(hidden_states) + key_states = buffer.wk.apply(hidden_states) + value_states = buffer.wv.apply(hidden_states) + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim) + + + + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, self.cos_cache, self.sin_cache, position_ids) + + hidden_states = self.kv_cache.compute_attention( + query_states, key_states, value_states, layer_idx, storage_ids, attention_mask + ) + hidden_states = hidden_states.reshape(bsz, q_len, self.hidden_size) + + hidden_states = buffer.wo.apply(hidden_states) + hidden_states = residual + hidden_states + residual = hidden_states + hidden_states = layer_norm(hidden_states, buffer.post_attention_layernorm_variance_epsilon, buffer.post_attention_layernorm_weight) + up = buffer.up_proj.apply(hidden_states) + gate = buffer.gate_proj.apply(hidden_states) + gate = F.silu(gate) + hidden_states = gate * up + hidden_states = buffer.down_proj.apply(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + + @torch.inference_mode() + def inference(self, + input_ids: torch.LongTensor, + position_ids: torch.LongTensor, + attention_mask: torch.FloatTensor, + storage_ids: torch.LongTensor): + + hidden_states = F.embedding(input_ids, self.embed_tokens) + + for idx in range(self.num_layers): + hidden_states = self.layer_compute(self.layers[idx], idx, hidden_states, position_ids, attention_mask, storage_ids) + + b, s, h = hidden_states.shape + hidden_states = hidden_states.reshape(b * s, h) + hidden_states = flashinfer.rmsnorm(hidden_states, self.norm_weight, self.norm_variance_epsilon) + hidden_states = hidden_states.reshape(b, s, h) + logits = F.linear(hidden_states, self.lm_head).float() + return logits + +class QwenAwqOffload(QwenOffload): + + def alloc(self, **kwargs): + + self.num_cache_layers = kwargs["num_cache_layers"] if 'num_cache_layers' in kwargs else 0 + self.kv_cache = KV_Cache(self.config, max_length=self.max_length, device=self.device, dtype=self.dtype, batch_size=self.batch_size) + + hf_model = AutoModelForCausalLM.from_pretrained(self.model_name, torch_dtype=self.dtype) + self.embed_tokens = hf_model.model.embed_tokens.weight.detach().to(self.device) + if self.config.tie_word_embeddings: + self.lm_head = self.embed_tokens + else: + self.lm_head = hf_model.lm_head.weight.detach().to(self.device) + self.norm_weight = hf_model.model.norm.weight.detach().to(self.device) + self.norm_variance_epsilon = hf_model.model.norm.variance_epsilon + self.inv_freq = hf_model.model.rotary_emb.inv_freq.detach().to(self.device) + self.attention_scaling = hf_model.model.rotary_emb.attention_scaling + position_ids = torch.arange(0, self.max_length).unsqueeze(0).to(self.device) + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + self.cos_cache = emb.cos()[0] + self.sin_cache = emb.sin()[0] + self.cos_cache = self.cos_cache * self.attention_scaling + self.sin_cache = self.sin_cache * self.attention_scaling + self.cos_cache = self.cos_cache.to(self.dtype) + self.sin_cache = self.sin_cache.to(self.dtype) + self.layers: list[QwenAwqLayer] = [] + + for idx, hf_layer in tqdm(enumerate(hf_model.model.layers), desc="initial offloaded model"): + layer = QwenAwqLayer(idx) + layer.init_parameters(hf_layer=hf_layer) + if idx < self.num_cache_layers: + layer.to(self.device) + self.layers.append(layer) + hf_model.model.layers[idx] = None + gc.collect() + self.num_layers = len(self.layers) + assert self.num_layers % 2 == 0 + self.buffer = [QwenAwqLayer(-1, self.device) for _ in range(2)] + self.buffer[0].alloc_space(self.layers[0], self.device) + self.buffer[1].alloc_space(self.layers[0], self.device) + + @torch.inference_mode() + def layer_compute(self, + buffer: QwenAwqLayer, + layer_idx: int, + hidden_states: torch.FloatTensor, + position_ids: torch.LongTensor, + attention_mask: torch.FloatTensor, + storage_ids: torch.LongTensor): + + residual = hidden_states + bsz, q_len, _ = hidden_states.size() + + hidden_states = layer_norm(hidden_states, buffer.input_layernorm_variance_epsilon, buffer.input_layernorm_weight) + bsz, q_len, _ = hidden_states.size() + query_states = buffer.wq.apply(hidden_states) + key_states = buffer.wk.apply(hidden_states) + value_states = buffer.wv.apply(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim) + + + + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, self.cos_cache, self.sin_cache, position_ids) + hidden_states = self.kv_cache.compute_attention( + query_states, key_states, value_states, layer_idx, storage_ids, attention_mask + ) + hidden_states = hidden_states.reshape(bsz, q_len, self.hidden_size) + + hidden_states = buffer.wo.apply(hidden_states) + hidden_states = residual + hidden_states + residual = hidden_states + hidden_states = layer_norm(hidden_states, buffer.post_attention_layernorm_variance_epsilon, buffer.post_attention_layernorm_weight) + up = buffer.up_proj.apply(hidden_states) + gate = buffer.gate_proj.apply(hidden_states) + gate = F.silu(gate) + hidden_states = gate * up + hidden_states = buffer.down_proj.apply(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + +class QwenCudagraph(Qwen): + def __init__(self, model_name, batch_size = 1, max_length = 256, device = 'cuda:0', dtype=torch.float16): + super().__init__(model_name, batch_size, max_length, device, dtype) + + self.callables = {} + self.mempool = None + + def alloc(self, **kwargs): + + exit_layer = kwargs.pop("exit_layer", -1) + self.kv_cache = StaticKV_Cache(self.config, max_length=self.max_length, device=self.device, dtype=self.dtype, batch_size=self.batch_size) + hf_model = Qwen2ForCausalLM.from_pretrained(self.model_name, torch_dtype=self.dtype) + self.embed_tokens = hf_model.model.embed_tokens.weight.detach().to(self.device) + if self.config.tie_word_embeddings: + self.lm_head = self.embed_tokens + else: + self.lm_head = hf_model.lm_head.weight.detach().to(self.device) + + self.norm_weight = hf_model.model.norm.weight.detach().to(self.device) + self.norm_variance_epsilon = hf_model.model.norm.variance_epsilon + + self.inv_freq = hf_model.model.rotary_emb.inv_freq.detach().to(self.device) + self.attention_scaling = hf_model.model.rotary_emb.attention_scaling + position_ids = torch.arange(0, self.max_length).unsqueeze(0).to(self.device) + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + self.cos_cache = emb.cos()[0] + self.sin_cache = emb.sin()[0] + self.cos_cache = self.cos_cache * self.attention_scaling + self.sin_cache = self.sin_cache * self.attention_scaling + self.cos_cache = self.cos_cache.to(self.dtype) + self.sin_cache = self.sin_cache.to(self.dtype) + + self.layers: list[QwenPackedLayer] = [] + + for idx, hf_layer in enumerate(hf_model.model.layers): + if exit_layer > 0 and idx >= exit_layer: + break + layer = QwenPackedLayer(idx) + layer.init_parameters(hf_layer=hf_layer) + layer.to(self.device) + self.layers.append(layer) + hf_model.model.layers[idx] = None + gc.collect() + + self.num_layers = len(self.layers) + + @torch.inference_mode() + def layer_compute(self, + buffer: QwenPackedLayer, + layer_idx: int, + hidden_states: torch.FloatTensor, + position_ids: torch.LongTensor, + attention_mask: torch.FloatTensor, + storage_ids: torch.LongTensor): + + residual = hidden_states + bsz, q_len, _ = hidden_states.size() + + hidden_states = layer_norm(hidden_states, buffer.input_layernorm_variance_epsilon, buffer.input_layernorm_weight) + bsz, q_len, _ = hidden_states.size() + qkv = F.linear(hidden_states, buffer.wqkv) + query_states = qkv[...,:self.hidden_size] + key_states = qkv[...,self.hidden_size:self.hidden_size + self.head_dim * self.num_key_value_heads] + value_states = qkv[...,self.hidden_size + self.head_dim * self.num_key_value_heads:] + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1,2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1,2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1,2) + + + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, self.cos_cache, self.sin_cache, position_ids, unsqueeze_dim=1) + hidden_states = self.kv_cache.compute_attention( + query_states, key_states, value_states, layer_idx, storage_ids, attention_mask + ) + + hidden_states = hidden_states.reshape(bsz, q_len, self.hidden_size) + hidden_states = F.linear(hidden_states, buffer.wo) + hidden_states = residual + hidden_states + residual = hidden_states + hidden_states = layer_norm(hidden_states, buffer.post_attention_layernorm_variance_epsilon, buffer.post_attention_layernorm_weight) + up = F.linear(hidden_states, buffer.up_proj) + gate = F.linear(hidden_states, buffer.gate_proj) + gate = F.silu(gate) + hidden_states = gate * up + hidden_states = F.linear(hidden_states, buffer.down_proj) + hidden_states = residual + hidden_states + + return hidden_states + + + @torch.inference_mode() + def initialize_cuda_graph(self, + decoding_seqlens :list[int], + n_warmups=12): + gc.collect() + self.mempool = torch.cuda.graphs.graph_pool_handle() + for decoding_seqlen in decoding_seqlens: + if decoding_seqlen not in self.callables: + self.callables[decoding_seqlen] = capture_graph( + llm=self, + decoding_seqlen=decoding_seqlen, + mempool=self.mempool, + n_warmups=n_warmups + ) + self.clear() + + @torch.inference_mode() + def graph_inference(self, + input_ids: torch.LongTensor, + storage_ids :torch.LongTensor, + position_ids = None, + attention_mask = None, + ): + dec_length = input_ids.shape[1] + if dec_length in self.callables.keys(): + logits = self.callables[dec_length](input_ids, storage_ids, position_ids, attention_mask) + else: + logits = self.inference(input_ids, position_ids, attention_mask, storage_ids) + return logits + diff --git a/umbrella/models/qwen_layer.py b/umbrella/models/qwen_layer.py new file mode 100644 index 0000000..9dfabdd --- /dev/null +++ b/umbrella/models/qwen_layer.py @@ -0,0 +1,261 @@ +from __future__ import annotations +import torch +from transformers.models.qwen2.modeling_qwen2 import Qwen2DecoderLayer +from ..quantization.awq_utils import AwqLinear + +class QwenLayer: + def __init__(self, layer_idx, device = "cpu") -> None: + + self.wq :torch.Tensor = None + self.wk :torch.Tensor = None + self.wv :torch.Tensor = None + self.wo :torch.Tensor = None + + self.gate_proj :torch.Tensor = None + self.up_proj :torch.Tensor = None + self.down_proj :torch.Tensor = None + + self.input_layernorm_weight :torch.Tensor = None + self.input_layernorm_variance_epsilon :float = 0.0 + + self.post_attention_layernorm_weight :torch.Tensor = None + self.post_attention_layernorm_variance_epsilon :float = 0.0 + + self.layer_idx = layer_idx + self.device = device + + def init_parameters(self, hf_layer: Qwen2DecoderLayer): + + self.wq :torch.Tensor= hf_layer.self_attn.q_proj.weight.detach() + self.wk :torch.Tensor= hf_layer.self_attn.k_proj.weight.detach() + self.wv :torch.Tensor= hf_layer.self_attn.v_proj.weight.detach() + self.wo :torch.Tensor= hf_layer.self_attn.o_proj.weight.detach() + + self.gate_proj = hf_layer.mlp.gate_proj.weight.detach() + self.up_proj = hf_layer.mlp.up_proj.weight.detach() + self.down_proj = hf_layer.mlp.down_proj.weight.detach() + + self.input_layernorm_weight = hf_layer.input_layernorm.weight.detach() + self.input_layernorm_variance_epsilon = hf_layer.input_layernorm.variance_epsilon + + self.post_attention_layernorm_weight = hf_layer.post_attention_layernorm.weight.detach() + self.post_attention_layernorm_variance_epsilon = hf_layer.post_attention_layernorm.variance_epsilon + + + def to(self, device:str = 'cuda:0', non_blocking = True): + + self.device = device + self.input_layernorm_weight = self.input_layernorm_weight.to(device, non_blocking=non_blocking) + self.post_attention_layernorm_weight = self.post_attention_layernorm_weight.to(device, non_blocking=non_blocking) + self.wq = self.wq.to(device, non_blocking=non_blocking) + self.wk = self.wk.to(device, non_blocking=non_blocking) + self.wv = self.wv.to(device, non_blocking=non_blocking) + self.wo = self.wo.to(device, non_blocking=non_blocking) + self.gate_proj = self.gate_proj.to(device, non_blocking=non_blocking) + self.up_proj = self.up_proj.to(device, non_blocking=non_blocking) + self.down_proj = self.down_proj.to(device, non_blocking=non_blocking) + + def copy(self, layer: QwenLayer): + + self.wq.copy_(layer.wq, non_blocking=True) + self.wk.copy_(layer.wk, non_blocking=True) + self.wv.copy_(layer.wv, non_blocking=True) + self.wo.copy_(layer.wo, non_blocking=True) + self.gate_proj.copy_(layer.gate_proj, non_blocking=True) + self.up_proj.copy_(layer.up_proj, non_blocking=True) + self.down_proj.copy_(layer.down_proj, non_blocking=True) + + self.input_layernorm_weight.copy_(layer.input_layernorm_weight, non_blocking=True) + self.post_attention_layernorm_weight.copy_(layer.post_attention_layernorm_weight, non_blocking=True) + self.input_layernorm_variance_epsilon= layer.input_layernorm_variance_epsilon + self.post_attention_layernorm_variance_epsilon = layer.post_attention_layernorm_variance_epsilon + self.layer_idx = layer.layer_idx + + def alloc_space(self, layer: QwenLayer, device): + + self.device = device + self.wq = torch.zeros_like(layer.wq).to(device) + self.wk = torch.zeros_like(layer.wk).to(device) + self.wv = torch.zeros_like(layer.wv).to(device) + self.wo = torch.zeros_like(layer.wo).to(device) + + + self.gate_proj = torch.zeros_like(layer.gate_proj).to(device) + self.up_proj = torch.zeros_like(layer.up_proj).to(device) + self.down_proj = torch.zeros_like(layer.down_proj).to(device) + self.input_layernorm_weight = torch.zeros_like(layer.input_layernorm_weight).to(device) + self.post_attention_layernorm_weight = torch.zeros_like(layer.post_attention_layernorm_weight).to(device) + + +class QwenPackedLayer: + def __init__(self, layer_idx, device = "cpu") -> None: + + self.wqkv :torch.Tensor = None + + self.gate_proj :torch.Tensor = None + self.up_proj :torch.Tensor = None + self.down_proj :torch.Tensor = None + + self.input_layernorm_weight :torch.Tensor = None + self.input_layernorm_variance_epsilon :float = 0.0 + + self.post_attention_layernorm_weight :torch.Tensor = None + self.post_attention_layernorm_variance_epsilon :float = 0.0 + + self.layer_idx = layer_idx + self.device = device + + def init_parameters(self, hf_layer: Qwen2DecoderLayer): + + self.wqkv :torch.Tensor= torch.cat( + [ + hf_layer.self_attn.q_proj.weight.detach(), + hf_layer.self_attn.k_proj.weight.detach(), + hf_layer.self_attn.v_proj.weight.detach(), + ], + dim=0 + ) + self.wo :torch.Tensor= hf_layer.self_attn.o_proj.weight.detach() + self.gate_proj = hf_layer.mlp.gate_proj.weight.detach() + self.up_proj = hf_layer.mlp.up_proj.weight.detach() + self.down_proj = hf_layer.mlp.down_proj.weight.detach() + + self.input_layernorm_weight = hf_layer.input_layernorm.weight.detach() + self.input_layernorm_variance_epsilon = hf_layer.input_layernorm.variance_epsilon + + self.post_attention_layernorm_weight = hf_layer.post_attention_layernorm.weight.detach() + self.post_attention_layernorm_variance_epsilon = hf_layer.post_attention_layernorm.variance_epsilon + + + def to(self, device:str = 'cuda:0', non_blocking = True): + + self.device = device + self.input_layernorm_weight = self.input_layernorm_weight.to(device, non_blocking=non_blocking) + self.post_attention_layernorm_weight = self.post_attention_layernorm_weight.to(device, non_blocking=non_blocking) + self.wqkv = self.wqkv.to(device, non_blocking=non_blocking) + self.wo = self.wo.to(device, non_blocking=non_blocking) + self.gate_proj = self.gate_proj.to(device, non_blocking=non_blocking) + self.up_proj = self.up_proj.to(device, non_blocking=non_blocking) + self.down_proj = self.down_proj.to(device, non_blocking=non_blocking) + + def copy(self, layer: QwenPackedLayer): + + self.wqkv.copy_(layer.wqkv, non_blocking=True) + self.wo.copy_(layer.wo, non_blocking=True) + self.gate_proj.copy_(layer.gate_proj, non_blocking=True) + self.up_proj.copy_(layer.up_proj, non_blocking=True) + self.down_proj.copy_(layer.down_proj, non_blocking=True) + + self.input_layernorm_weight.copy_(layer.input_layernorm_weight, non_blocking=True) + self.post_attention_layernorm_weight.copy_(layer.post_attention_layernorm_weight, non_blocking=True) + self.input_layernorm_variance_epsilon= layer.input_layernorm_variance_epsilon + self.post_attention_layernorm_variance_epsilon = layer.post_attention_layernorm_variance_epsilon + self.layer_idx = layer.layer_idx + + def alloc_space(self, layer: QwenPackedLayer, device): + + self.device = device + self.wqkv = torch.zeros_like(layer.wqkv).to(device) + self.wo = torch.zeros_like(layer.wo).to(device) + + + self.gate_proj = torch.zeros_like(layer.gate_proj).to(device) + self.up_proj = torch.zeros_like(layer.up_proj).to(device) + self.down_proj = torch.zeros_like(layer.down_proj).to(device) + self.input_layernorm_weight = torch.zeros_like(layer.input_layernorm_weight).to(device) + self.post_attention_layernorm_weight = torch.zeros_like(layer.post_attention_layernorm_weight).to(device) + +class QwenAwqLayer(): + def __init__(self, layer_idx, device="cpu"): + + self.wq = AwqLinear() + self.wk = AwqLinear() + self.wv = AwqLinear() + self.wo = AwqLinear() + + self.gate_proj = AwqLinear() + self.up_proj = AwqLinear() + self.down_proj = AwqLinear() + + + self.input_layernorm_weight :torch.Tensor = None + self.input_layernorm_variance_epsilon :float = 0.0 + + self.post_attention_layernorm_weight :torch.Tensor = None + self.post_attention_layernorm_variance_epsilon :float = 0.0 + + self.layer_idx = layer_idx + self.device = device + + + def init_parameters(self, hf_layer): + + + self.wq.init_parameters(hf_layer.self_attn.q_proj) + self.wk.init_parameters(hf_layer.self_attn.k_proj) + self.wv.init_parameters(hf_layer.self_attn.v_proj) + self.wo.init_parameters(hf_layer.self_attn.o_proj) + self.gate_proj.init_parameters(hf_layer.mlp.gate_proj) + self.up_proj.init_parameters(hf_layer.mlp.up_proj) + self.down_proj.init_parameters(hf_layer.mlp.down_proj) + + self.input_layernorm_weight = hf_layer.input_layernorm.weight.detach().pin_memory() + self.input_layernorm_variance_epsilon = hf_layer.input_layernorm.variance_epsilon + + self.post_attention_layernorm_weight = hf_layer.post_attention_layernorm.weight.detach().pin_memory() + self.post_attention_layernorm_variance_epsilon = hf_layer.post_attention_layernorm.variance_epsilon + + def to(self, device:str = 'cuda:0', non_blocking = True): + + self.device = device + self.input_layernorm_weight = self.input_layernorm_weight.to(device, non_blocking=non_blocking) + self.post_attention_layernorm_weight = self.post_attention_layernorm_weight.to(device, non_blocking=non_blocking) + + self.wq.to(device=device) + self.wk.to(device=device) + self.wv.to(device=device) + self.wo.to(device=device) + + self.gate_proj.to(device=device) + self.up_proj.to(device=device) + self.down_proj.to(device=device) + + def alloc_space(self, layer: QwenAwqLayer, device): + + self.device = device + self.wq.empty_like(layer.wq) + self.wk.empty_like(layer.wk) + self.wv.empty_like(layer.wv) + self.wo.empty_like(layer.wo) + + self.gate_proj.empty_like(layer.gate_proj) + self.up_proj.empty_like(layer.up_proj) + self.down_proj.empty_like(layer.down_proj) + + self.wq.to(device=device) + self.wk.to(device=device) + self.wv.to(device=device) + self.wo.to(device=device) + + self.gate_proj.to(device=device) + self.up_proj.to(device=device) + self.down_proj.to(device=device) + + self.input_layernorm_weight = torch.zeros_like(layer.input_layernorm_weight).to(device) + self.post_attention_layernorm_weight = torch.zeros_like(layer.post_attention_layernorm_weight).to(device) + + def copy(self, layer: QwenAwqLayer): + + self.wq.copy(layer.wq, non_blocking=True) + self.wk.copy(layer.wk, non_blocking=True) + self.wv.copy(layer.wv, non_blocking=True) + self.wo.copy(layer.wo, non_blocking=True) + self.gate_proj.copy(layer.gate_proj, non_blocking=True) + self.up_proj.copy(layer.up_proj, non_blocking=True) + self.down_proj.copy(layer.down_proj, non_blocking=True) + + self.input_layernorm_weight.copy_(layer.input_layernorm_weight, non_blocking=True) + self.post_attention_layernorm_weight.copy_(layer.post_attention_layernorm_weight, non_blocking=True) + self.input_layernorm_variance_epsilon= layer.input_layernorm_variance_epsilon + self.post_attention_layernorm_variance_epsilon = layer.post_attention_layernorm_variance_epsilon + self.layer_idx = layer.layer_idx \ No newline at end of file diff --git a/umbrella/speculation/dynamic_speculation_engine.py b/umbrella/speculation/dynamic_speculation_engine.py index 4dc0218..06c0565 100644 --- a/umbrella/speculation/dynamic_speculation_engine.py +++ b/umbrella/speculation/dynamic_speculation_engine.py @@ -541,4 +541,4 @@ def generate_stream(self, **api_args): self.reset() - + \ No newline at end of file diff --git a/umbrella/templates.py b/umbrella/templates.py index 7e5463c..9162f15 100644 --- a/umbrella/templates.py +++ b/umbrella/templates.py @@ -10,8 +10,8 @@ {}<|eot_id|><|start_header_id|>assistant<|end_header_id|> -<|python_tag|>""" - +<|python_tag|>""", +'qwen': """<|im_start|>user{}<|im_end|><|im_start|>assistant<|im_end|>""" } SysPrompts = { @@ -21,7 +21,8 @@ 'llama3-code': """<|begin_of_text|><|start_header_id|>system<|end_header_id|> Environment: ipython<|eot_id|>""", - + 'qwen': """<|im_start|>system +You are a helpful assistant.<|im_end|>""" } ExtraPrompts = {