From 9f42c325bccff8cee2584aa4ddf89754bd73ca6f Mon Sep 17 00:00:00 2001 From: Ikko Eltociear Ashimine Date: Thu, 16 Jan 2025 12:52:22 +0900 Subject: [PATCH 01/23] chore: update question.jsonl inital -> initial --- examples/data/question.jsonl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/data/question.jsonl b/examples/data/question.jsonl index 464e2c2..1b63e8a 100644 --- a/examples/data/question.jsonl +++ b/examples/data/question.jsonl @@ -39,7 +39,7 @@ {"question_id": 119, "category": "math", "turns": ["Benjamin went to a bookstore and purchased a variety of books. He bought 5 copies of a sci-fi novel, each priced at $20, 3 copies of a history book priced at $30 each, and 2 copies of a philosophy book for $45 each.\nWhat was the total cost of his purchases?", "Suppose Benjamin decides to sell each of these books at a 25% markup from the price he purchased them. What would be his total revenue if he sold all the books he bought?"], "reference": ["280", "350"]} {"question_id": 120, "category": "math", "turns": ["Given that f(x) = 4x^3 - 9x - 14, find the value of f(2).", "Find x such that f(x) = 0."], "reference": ["f(2) = 0", "x = 2"]} {"question_id": 121, "category": "coding", "turns": ["Develop a Python program that reads all the text files under a directory and returns top-5 words with the most number of occurrences.", "Can you parallelize it?"], "reference": ["Can be simple solutions like using Counter\n\nSample answer:\n```\nimport os\nimport re\nfrom collections import Counter\ndef get_files_in_directory(directory):\n return [os.path.join(directory, f) for f in os.listdir(directory) if os.path.isfile(os.path.join(directory, f)) and f.endswith('.txt')]\ndef read_file(file_path):\n with open(file_path, 'r', encoding='utf-8') as file:\n return file.read()\ndef count_words(text):\n words = re.findall(r'\\w+', text.lower())\n return Counter(words)\ndef main():\n directory = input(\"Enter the directory path: \")\n files = get_files_in_directory(directory)\n word_counts = Counter()\n for file in files:\n text = read_file(file)\n word_counts += count_words(text)\n top_5_words = word_counts.most_common(5)\n print(\"Top 5 words with the most number of occurrences:\")\n for word, count in top_5_words:\n print(f\"{word}: {count}\")\nif __name__ == \"__main__\":\n main()\n```", "You should carefully check whether the parallelization logic is correct and choose the faster implementation.\n\nSample answer:\n```\nimport os\nimport re\nfrom collections import Counter\nimport concurrent.futures\ndef get_files_in_directory(directory):\n return [os.path.join(directory, f) for f in os.listdir(directory) if os.path.isfile(os.path.join(directory, f)) and f.endswith('.txt')]\ndef read_file(file_path):\n with open(file_path, 'r', encoding='utf-8') as file:\n return file.read()\ndef count_words(text):\n words = re.findall(r'\\w+', text.lower())\n return Counter(words)\ndef process_file(file):\n text = read_file(file)\n return count_words(text)\ndef main():\n directory = input(\"Enter the directory path: \")\n files = get_files_in_directory(directory)\n word_counts = Counter()\n with concurrent.futures.ThreadPoolExecutor() as executor:\n future_word_counts = {executor.submit(process_file, file): file for file in files}\n for future in concurrent.futures.as_completed(future_word_counts):\n word_counts += future.result()\n top_5_words = word_counts.most_common(5)\n print(\"Top 5 words with the most number of occurrences:\")\n for word, count in top_5_words:\n print(f\"{word}: {count}\")\nif __name__ == \"__main__\":\n main()\n```"]} -{"question_id": 122, "category": "coding", "turns": ["Write a C++ program to find the nth Fibonacci number using recursion.", "Now we define a sequence of numbers in which each number is the sum of the three preceding ones. The first three numbers are 0, -1, -1. Write a program to find the nth number."], "reference": ["Straightforward\n\n```\nint fibonacci(int n) {\n if (n <= 1) {\n return n;\n } else {\n return fibonacci(n - 1) + fibonacci(n - 2);\n }\n}\n```", "You should carefully check the inital cases for n < 3\n\n```\nint find_nth_number(int n) {\n std::vector sequence = {0, -1, -1};\n for (int i = 3; i <= n; ++i) {\n int next_number = sequence[i - 1] + sequence[i - 2] + sequence[i - 3];\n sequence.push_back(next_number);\n }\n return sequence[n];\n}\n```"]} +{"question_id": 122, "category": "coding", "turns": ["Write a C++ program to find the nth Fibonacci number using recursion.", "Now we define a sequence of numbers in which each number is the sum of the three preceding ones. The first three numbers are 0, -1, -1. Write a program to find the nth number."], "reference": ["Straightforward\n\n```\nint fibonacci(int n) {\n if (n <= 1) {\n return n;\n } else {\n return fibonacci(n - 1) + fibonacci(n - 2);\n }\n}\n```", "You should carefully check the initial cases for n < 3\n\n```\nint find_nth_number(int n) {\n std::vector sequence = {0, -1, -1};\n for (int i = 3; i <= n; ++i) {\n int next_number = sequence[i - 1] + sequence[i - 2] + sequence[i - 3];\n sequence.push_back(next_number);\n }\n return sequence[n];\n}\n```"]} {"question_id": 123, "category": "coding", "turns": ["Write a simple website in HTML. When a user clicks the button, it shows a random joke from a list of 4 jokes.", "How to use CSS to change the color of jokes to red?"]} {"question_id": 124, "category": "coding", "turns": ["Here is a Python function to find the length of the longest common subsequence of two input strings. Can you identify any bug in this function?\n\n```\ndef longest_common_subsequence_length(str1, str2):\n m = len(str1)\n n = len(str2)\n\n dp = [[0] * (n + 1) for _ in range(m + 1)]\n\n for i in range(1, m + 1):\n for j in range(1, n + 1):\n if str1[i - 1] == str2[j - 1]:\n dp[i][j] = dp[i - 1][j - 1] + 1\n else:\n dp[i][j] = max(dp[i - 1][j], dp[i][j - 1])\n\n return dp[m][n]\n```", "what about this one?\n\n```\ndef longest_common_subsequence(X , Y): \n # Find lengths of two strings \n m = len(X) \n n = len(Y) \n \n # Create a table to store results of sub-problems \n dp = [[None]*(n+1) for i in range(m+1)] \n \n # Fill dp[][] in bottom up manner \n for i in range(1, m+1): \n for j in range(1, n+1): \n if X[i-1] == Y[j-1]: \n dp[i][j] = dp[i-1][j-1]+1\n else: \n dp[i][j] = max(dp[i-1][j], dp[i][j-1]) \n \n return dp[m][n]\n```"], "reference": ["There is no bug in this implementation", "There is a bug for the initialization of dp array. Should use 0 rather than None"]} {"question_id": 125, "category": "coding", "turns": ["Write a function to find the highest common ancestor (not LCA) of two nodes in a binary tree.", "What if it is not a binary tree?"], "reference": ["Very simple. The function should just return the root of the tree.", "Same answer. It's still the root of the tree."]} From 8357238586d93aba91991fd8731d9e647644cc60 Mon Sep 17 00:00:00 2001 From: kushaann Date: Tue, 28 Jan 2025 15:58:15 -0500 Subject: [PATCH 02/23] Updated reqiurements.txt transformers -> 4.48 (includes gemma updates that match llama format) Add gemma.py, gemma_layer.py to models/* for gemma2 support Update model_utils to support gemma-specific norm and attention calculation Add model registration to auto_models Add prompt formats for instruction tuned and regular gemma2 --- examples/generate.py | 2 +- requirements.txt | 4 +- umbrella/attn/cache.py | 14 +-- umbrella/models/auto_model.py | 7 +- umbrella/models/gemma.py | 180 +++++++++++++++++++++++++++++++++ umbrella/models/gemma_layer.py | 113 +++++++++++++++++++++ umbrella/models/model_utils.py | 13 ++- umbrella/templates.py | 15 ++- 8 files changed, 333 insertions(+), 15 deletions(-) create mode 100644 umbrella/models/gemma.py create mode 100644 umbrella/models/gemma_layer.py diff --git a/examples/generate.py b/examples/generate.py index 065bf31..f130030 100644 --- a/examples/generate.py +++ b/examples/generate.py @@ -91,4 +91,4 @@ t2 = time.time() dec_len = len(generated_tokens) -logger.info(TextColors.colorize("Avg Accept Tokens {:.2f} | TPOT {:.2f} ms ".format(1, 1000 * (t2-t1)/dec_len), "magenta")) \ No newline at end of file +logger.info(TextColors.colorize("Avg Accept Tokens {:.2f} | TPOT {:.2f} ms ".format(1, 1000 * (t2-t1)/dec_len), "magenta")) diff --git a/requirements.txt b/requirements.txt index 85e36d6..12c4f8a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,8 +1,8 @@ torch==2.4.1 -transformers==4.47.0 +transformers==4.48.0 huggingface-hub==0.27.0 transformers-stream-generator==0.0.5 optimum==1.23.3 autoawq==0.2.7.post3 autoawq-kernels==0.0.8 -gradio \ No newline at end of file +gradio diff --git a/umbrella/attn/cache.py b/umbrella/attn/cache.py index 533206d..1347080 100644 --- a/umbrella/attn/cache.py +++ b/umbrella/attn/cache.py @@ -14,11 +14,12 @@ def __init__(self, self.max_length = max_length self.device = device self.dtype = dtype + self.head_dim = getattr(config, 'head_dim', config.hidden_size // config.num_attention_heads) self.k_cache = torch.zeros( config.num_hidden_layers, max_length, config.num_key_value_heads, - config.hidden_size // config.num_attention_heads, + self.head_dim, device=self.device, dtype=self.dtype ) @@ -27,7 +28,7 @@ def __init__(self, config.num_hidden_layers, max_length, config.num_key_value_heads, - config.hidden_size // config.num_attention_heads, + self.head_dim, device=self.device, dtype=self.dtype ) @@ -35,7 +36,6 @@ def __init__(self, self.kv_offset = 0 self.num_key_value_heads = config.num_key_value_heads self.num_attention_heads = config.num_attention_heads - self.head_dim = config.hidden_size // config.num_attention_heads self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads def gather_kv_incremental(self, indices: torch.LongTensor, offset:int): @@ -70,7 +70,8 @@ def compute_attention(self, value_states :torch.Tensor, layer_idx, storage_ids :torch.Tensor, - attention_mask :torch.Tensor): + attention_mask :torch.Tensor, + logits_soft_cap = 0): key_states, value_states = self.update_kv_cache(key_states[0], value_states[0], layer_idx, storage_ids) hidden_states = flashinfer.single_prefill_with_kv_cache( @@ -79,7 +80,8 @@ def compute_attention(self, v = value_states, kv_layout="NHD", custom_mask=attention_mask[:,:self.kv_offset], - allow_fp16_qk_reduction=True + allow_fp16_qk_reduction=True, + logits_soft_cap = logits_soft_cap ) return hidden_states @@ -186,4 +188,4 @@ def compute_attention(self, hidden_states = hidden_states.reshape(bsz, self.num_attention_heads, q_len, -1) hidden_states = hidden_states.transpose(1, 2).contiguous() - return hidden_states \ No newline at end of file + return hidden_states diff --git a/umbrella/models/auto_model.py b/umbrella/models/auto_model.py index 20f9483..5c60c1f 100644 --- a/umbrella/models/auto_model.py +++ b/umbrella/models/auto_model.py @@ -1,5 +1,6 @@ from .llama import Llama, LlamaAwq, LlamaOffload, LlamaAwqOffload, LlamaCudagraph from .qwen import Qwen, QwenOffload, QwenAwq, QwenAwqOffload, QwenCudagraph +from .gemma import Gemma2 class AutoModelLM: """ 自动模型加载器,根据模型类型动态加载对应的类。 @@ -102,7 +103,9 @@ class AutoModelLM: "Qwen/Qwen2.5-32B-Instruct-AWQ": QwenAwq, "Qwen/Qwen2.5-72B-Instruct-AWQ": QwenAwq, "KirillR/QwQ-32B-Preview-AWQ": QwenAwq, - "casperhansen/deepseek-r1-distill-qwen-32b-awq":QwenAwq + "casperhansen/deepseek-r1-distill-qwen-32b-awq":QwenAwq, + "google/gemma-2-2b-it": Gemma2, + "google/gemma-2-2b": Gemma2 } _CUDAGRAPH_MODEL_MAPPING = { @@ -162,4 +165,4 @@ def from_pretrained(cls, model_name, offload=False, cuda_graph=False, **kwargs): raise ValueError(f"Model type '{model_name}' is not supported (offload). " f"Supported (offload) types: {list(cls._OFFLOAD_MODEL_MAPPING.keys())}") model_class = cls._OFFLOAD_MODEL_MAPPING[model_name] - return model_class(model_name = model_name, **kwargs) \ No newline at end of file + return model_class(model_name = model_name, **kwargs) diff --git a/umbrella/models/gemma.py b/umbrella/models/gemma.py new file mode 100644 index 0000000..91ae5ce --- /dev/null +++ b/umbrella/models/gemma.py @@ -0,0 +1,180 @@ +from transformers import Gemma2ForCausalLM, Gemma2Config +from transformers.models.gemma2.modeling_gemma2 import Gemma2RotaryEmbedding +import torch +import torch.nn.functional as F +import gc +import flashinfer +from ..attn.cache import KV_Cache, StaticKV_Cache +from .gemma_layer import Gemma2Layer +from .base import LLMBase +from .model_utils import apply_rotary_pos_emb, layer_norm, capture_graph, layer_norm_gemma +from tqdm import tqdm + +class Gemma2(LLMBase): + def __init__(self, + model_name: str, + batch_size :int = 1, + max_length :int = 256, + device :str = 'cuda:0', + dtype = torch.float16) -> None: + super().__init__() + self.batch_size = batch_size + self.device = device + self.dtype = dtype + self.config = Gemma2Config.from_pretrained(model_name) + self.model_name = model_name + self.max_length = max_length + self.hidden_size = self.config.hidden_size + self.num_heads = self.config.num_attention_heads + # self.head_dim = self.hidden_size // self.num_heads + self.head_dim = self.config.head_dim + self.num_key_value_heads = self.config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = self.config.max_position_embeddings + self.rope_theta = self.config.rope_theta + self.eos_tokens = self.config.eos_token_id if (isinstance(self.config.eos_token_id, list)) else [self.config.eos_token_id] + self.sliding_window = self.config.sliding_window + self.attn_logit_softcapping = self.config.attn_logit_softcapping + self.final_logit_softcapping = self.config.final_logit_softcapping + + def alloc(self, **kwargs): + + # TODO: can you use a KV cache for Gemma? (probably !?) + self.kv_cache = KV_Cache(self.config, max_length=self.max_length, device=self.device, dtype=self.dtype, batch_size=self.batch_size) + hf_model = Gemma2ForCausalLM.from_pretrained(self.model_name, torch_dtype = self.dtype) + self.embed_tokens = hf_model.model.embed_tokens.weight.detach().to(self.device) + if self.config.tie_word_embeddings: + self.lm_head = self.embed_tokens + else: + self.lm_head = hf_model.lm_head.weight.detach().to(self.device) + + self.norm_weight = hf_model.model.norm.weight.detach().to(self.device) + self.norm_variance_epsilon = hf_model.model.norm.eps + + self.inv_freq = hf_model.model.rotary_emb.inv_freq.detach().to(self.device) + self.attention_scaling = hf_model.model.rotary_emb.attention_scaling + + position_ids = torch.arange(0, self.max_length).unsqueeze(0).to(self.device) + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + self.cos_cache = emb.cos()[0] + self.sin_cache = emb.sin()[0] + self.cos_cache = self.cos_cache * self.attention_scaling + self.sin_cache = self.sin_cache * self.attention_scaling + self.cos_cache = self.cos_cache.to(self.dtype) + self.sin_cache = self.sin_cache.to(self.dtype) + + self.layers :list[Gemma2Layer] = [] + + for idx, hf_layer in enumerate(hf_model.model.layers): + layer = Gemma2Layer(idx) + layer.init_parameters(hf_layer=hf_layer) + layer.to(self.device) + self.layers.append(layer) + hf_model.model.layers[idx] = None + gc.collect() + + self.num_layers = len(self.layers) + + @torch.inference_mode() + def layer_compute(self, + buffer: Gemma2Layer, + layer_idx :int, + hidden_states: torch.FloatTensor, + position_ids: torch.LongTensor, + attention_mask: torch.FloatTensor, + storage_ids: torch.LongTensor): + + ##TODO: fix the attention mask, Gemma uses a sliding window self attention + + if buffer.is_sliding and attention_mask is not None: + min_dtype = torch.finfo(hidden_states.dtype).min + sliding_window_mask = torch.tril( + torch.ones_like(attention_mask, dtype=torch.bool), diagonal=-self.sliding_window + ) + attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask) + if attention_mask.shape[-1] <= 1: # when decoding + attention_mask = attention_mask[:, :, :, -self.sliding_window :] + + + residual = hidden_states + bsz, q_len, _ = hidden_states.size() + + hidden_states = layer_norm_gemma(hidden_states, buffer.input_layernorm_variance_epsilon, buffer.input_layernorm_weight) + bsz, q_len, _ = hidden_states.size() + query_states = F.linear(hidden_states, buffer.wq) + #print("sliding window?", buffer.is_sliding, buffer.sliding_window) + #print("query states size:", hidden_states.shape, buffer.wq.shape, query_states.shape) + #print("query_shape:", self.num_heads, self.head_dim) + #print("Weight shapes:", buffer.wq.shape, buffer.wk.shape, buffer.wv.shape, buffer.wo.shape) + key_states = F.linear(hidden_states, buffer.wk) + value_states = F.linear(hidden_states, buffer.wv) + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim) + + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, self.cos_cache, self.sin_cache, position_ids) + hidden_states = self.kv_cache.compute_attention( + query_states, key_states, value_states, layer_idx, storage_ids, attention_mask, self.attn_logit_softcapping + ) + hidden_states = hidden_states.reshape(bsz, q_len, self.num_heads * self.head_dim) + + #logit soft_capping + + + hidden_states = F.linear(hidden_states, buffer.wo) + hidden_states = layer_norm_gemma(hidden_states, buffer.post_attention_layernorm_variance_epsilon, buffer.post_attention_layernorm_weight) + hidden_states = residual + hidden_states + + residual = hidden_states + + #MLP + hidden_states = layer_norm_gemma(hidden_states, buffer.pre_feedforward_layernorm_variance_epsilon, buffer.pre_feedforward_layernorm_weight) + up = F.linear(hidden_states, buffer.up_proj) + gate = F.linear(hidden_states, buffer.gate_proj) + gate = F.gelu(gate, approximate='tanh') #hidden activation is gelu (tanh approx.) + hidden_states = gate * up + hidden_states = F.linear(hidden_states, buffer.down_proj) + + hidden_states = layer_norm_gemma(hidden_states, buffer.post_feedforward_layernorm_variance_epsilon, buffer.post_feedforward_layernorm_weight) + hidden_states = residual + hidden_states + + return hidden_states + + @torch.inference_mode() + def inference(self, + input_ids: torch.LongTensor, + position_ids: torch.LongTensor, + attention_mask: torch.FloatTensor, + storage_ids: torch.LongTensor): + + hidden_states = F.embedding(input_ids, self.embed_tokens) + normalizer = torch.tensor(self.hidden_size**.5, dtype = hidden_states.dtype) + hidden_states = hidden_states * normalizer + + for idx in range(self.num_layers): + hidden_states = self.layer_compute(self.layers[idx], idx, hidden_states, position_ids, attention_mask, storage_ids) + + b, s, h = hidden_states.shape + hidden_states = hidden_states.reshape(b * s, h) + hidden_states = flashinfer.gemma_rmsnorm(hidden_states, self.norm_weight, self.norm_variance_epsilon) + hidden_states = hidden_states.reshape(b, s, h) + logits = F.linear(hidden_states, self.lm_head).float() + if(self.final_logit_softcapping is not None): + logits = logits / self.final_logit_softcapping + logits = F.tanh(logits) + logits = logits * self.final_logit_softcapping + + #print(logits.shape, logits[0,0,:10]) + return logits + + def gather_kv_incremental(self, indices: torch.LongTensor, offset:int): + + self.kv_cache.gather_kv_incremental(indices=indices, offset=offset) + + def clear(self): + + self.kv_cache.clear() + diff --git a/umbrella/models/gemma_layer.py b/umbrella/models/gemma_layer.py new file mode 100644 index 0000000..08eb737 --- /dev/null +++ b/umbrella/models/gemma_layer.py @@ -0,0 +1,113 @@ +from __future__ import annotations +import torch +from transformers.models.gemma2.modeling_gemma2 import Gemma2DecoderLayer +from ..quantization.awq_utils import AwqLinear + +class Gemma2Layer: + def __init__(self, layer_idx, device = "cpu") -> None: + self.wq :torch.Tensor = None + self.wk :torch.Tensor = None + self.wv :torch.Tensor = None + self.wo :torch.Tensor = None + + self.gate_proj :torch.Tensor = None + self.up_proj :torch.Tensor = None + self.down_proj :torch.Tensor = None + + self.input_layernorm_weight :torch.Tensor = None + self.input_layernorm_variance_epsilon :float = 0.0 + + self.post_attention_layernorm_weight :torch.Tensor = None + self.post_attention_layernorm_variance_epsilon :float = 0.0 + + self.pre_feedforward_layernorm_weight :torch.Tensor = None + self.pre_feedforward_layernorm_variance_epsilon: float = 0.0 + + self.post_feedforward_layernorm_weight :torch.Tensor = None + self.post_feedforward_layernorm_variance_epsilon: float = 0.0 + + self.layer_idx = layer_idx + self.device = device + + self.is_sliding = False + self.sliding_window = 0 + + def init_parameters(self, hf_layer: Gemma2DecoderLayer): + + self.wq :torch.Tensor= hf_layer.self_attn.q_proj.weight.detach() + self.wk :torch.Tensor= hf_layer.self_attn.k_proj.weight.detach() + self.wv :torch.Tensor= hf_layer.self_attn.v_proj.weight.detach() + self.wo :torch.Tensor= hf_layer.self_attn.o_proj.weight.detach() + + self.gate_proj = hf_layer.mlp.gate_proj.weight.detach() + self.up_proj = hf_layer.mlp.up_proj.weight.detach() + self.down_proj = hf_layer.mlp.down_proj.weight.detach() + + self.input_layernorm_weight = hf_layer.input_layernorm.weight.detach() + self.input_layernorm_variance_epsilon = hf_layer.input_layernorm.eps + + self.post_attention_layernorm_weight = hf_layer.post_attention_layernorm.weight.detach() + self.post_attention_layernorm_variance_epsilon = hf_layer.post_attention_layernorm.eps + + self.pre_feedforward_layernorm_weight :torch.Tensor = hf_layer.pre_feedforward_layernorm.weight.detach() + self.pre_feedforward_layernorm_variance_epsilon: float = hf_layer.pre_feedforward_layernorm.eps + + self.post_feedforward_layernorm_weight :torch.Tensor = hf_layer.post_feedforward_layernorm.weight.detach() + self.post_feedforward_layernorm_variance_epsilon: float = hf_layer.post_feedforward_layernorm.eps + + self.is_sliding = not bool(self.layer_idx % 2) + self.sliding_window = hf_layer.sliding_window + + def to(self, device:str = 'cuda:0', non_blocking = True): + + self.device = device + self.input_layernorm_weight = self.input_layernorm_weight.to(device, non_blocking=non_blocking) + self.post_attention_layernorm_weight = self.post_attention_layernorm_weight.to(device, non_blocking=non_blocking) + self.pre_feedforward_layernorm_weight = self.pre_feedforward_layernorm_weight.to(device, non_blocking=non_blocking) + self.post_feedforward_layernorm_weight = self.post_feedforward_layernorm_weight.to(device, non_blocking=non_blocking) + + self.wq = self.wq.to(device, non_blocking=non_blocking) + self.wk = self.wk.to(device, non_blocking=non_blocking) + self.wv = self.wv.to(device, non_blocking=non_blocking) + self.wo = self.wo.to(device, non_blocking=non_blocking) + self.gate_proj = self.gate_proj.to(device, non_blocking=non_blocking) + self.up_proj = self.up_proj.to(device, non_blocking=non_blocking) + self.down_proj = self.down_proj.to(device, non_blocking=non_blocking) + + def copy(self, layer: Gemma2Layer): + self.wq.copy_(layer.wq, non_blocking=True) + self.wk.copy_(layer.wk, non_blocking=True) + self.wv.copy_(layer.wv, non_blocking=True) + self.wo.copy_(layer.wo, non_blocking=True) + self.gate_proj.copy_(layer.gate_proj, non_blocking=True) + self.up_proj.copy_(layer.up_proj, non_blocking=True) + self.down_proj.copy_(layer.down_proj, non_blocking=True) + + self.input_layernorm_weight.copy_(layer.input_layernorm_weight, non_blocking=True) + self.post_attention_layernorm_weight.copy_(layer.post_attention_layernorm_weight, non_blocking=True) + self.pre_feedforward_layernorm_weight.copy_(layer.pre_feedforward_layernorm_weight, non_blocking=True) + self.post_feedforward_layernorm_weight.copy_(layer.post_feedforward_layernorm_weight, non_blocking=True) + + self.input_layernorm_variance_epsilon= layer.input_layernorm_variance_epsilon + self.post_attention_layernorm_variance_epsilon = layer.post_attention_layernorm_variance_epsilon + self.pre_feedforward_layernorm_variance_epsilon = layer.pre_feedforward_layernorm_variance_epsilon + self.post_feedforward_layernorm_variance_epsilon = layer.post_feedforward_layernorm_variance_epsilon + + self.layer_idx = layer.layer_idx + self.is_sliding = layer.is_sliding + + def alloc_space(self, layer: Gemma2Layer, device): + self.device = device + self.wq = torch.zeros_like(layer.wq).to(device) + self.wk = torch.zeros_like(layer.wk).to(device) + self.wv = torch.zeros_like(layer.wv).to(device) + self.wo = torch.zeros_like(layer.wo).to(device) + + + self.gate_proj = torch.zeros_like(layer.gate_proj).to(device) + self.up_proj = torch.zeros_like(layer.up_proj).to(device) + self.down_proj = torch.zeros_like(layer.down_proj).to(device) + self.input_layernorm_weight = torch.zeros_like(layer.input_layernorm_weight).to(device) + self.post_attention_layernorm_weight = torch.zeros_like(layer.post_attention_layernorm_weight).to(device) + self.pre_feedforward_layernorm_weight = torch.zeros_like(layer.pre_feedforward_layernorm_weight).to(device) + self.post_feedforward_layernorm_weight = torch.zeros_like(layer.post_attention_layernorm_weight).to(device) diff --git a/umbrella/models/model_utils.py b/umbrella/models/model_utils.py index 50695f1..6c68e6e 100644 --- a/umbrella/models/model_utils.py +++ b/umbrella/models/model_utils.py @@ -63,6 +63,17 @@ def layer_norm( hidden_states = hidden_states.reshape(b, s, h) return hidden_states +def layer_norm_gemma( + hidden_states: torch.Tensor, + layernorm_variance_epsilon: float, + layernorm_weight: torch.Tensor, +): + b, s, h = hidden_states.shape + + hidden_states = hidden_states.reshape(b * s, h) + hidden_states = flashinfer.gemma_rmsnorm(hidden_states, layernorm_weight, layernorm_variance_epsilon) + hidden_states = hidden_states.reshape(b, s, h) + return hidden_states def capture_graph( llm, decoding_seqlen :int =1, mempool=None, n_warmups :int=3 @@ -102,4 +113,4 @@ def run(input_ids, storage_ids, position_ids, attention_mask): graph.replay() return static_logits.clone() - return run \ No newline at end of file + return run diff --git a/umbrella/templates.py b/umbrella/templates.py index 5730deb..c71dc5f 100644 --- a/umbrella/templates.py +++ b/umbrella/templates.py @@ -14,7 +14,14 @@ 'qwen': """<|im_start|>user {}<|im_end|> <|im_start|>assistant -""" +""", + +'gemma2-it': """user +{} +model +""", + +'gemma2': "{}" } SysPrompts = { @@ -26,10 +33,12 @@ Environment: ipython<|eot_id|>""", 'qwen': """<|im_start|>system You are a helpful assistant.<|im_end|> -""" +""", + 'gemma2': "", + 'gemma2-it': "" } ExtraPrompts = { 'llama3-code': """\nAlways try to wrap what you write in a function.""" -} \ No newline at end of file +} From 4dfe5517bdf6080f4297b28d9697471582a2ac9a Mon Sep 17 00:00:00 2001 From: kushaann Date: Tue, 28 Jan 2025 15:58:15 -0500 Subject: [PATCH 03/23] Clean up comments --- umbrella/models/gemma.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/umbrella/models/gemma.py b/umbrella/models/gemma.py index 91ae5ce..2b0714b 100644 --- a/umbrella/models/gemma.py +++ b/umbrella/models/gemma.py @@ -26,7 +26,6 @@ def __init__(self, self.max_length = max_length self.hidden_size = self.config.hidden_size self.num_heads = self.config.num_attention_heads - # self.head_dim = self.hidden_size // self.num_heads self.head_dim = self.config.head_dim self.num_key_value_heads = self.config.num_key_value_heads self.num_key_value_groups = self.num_heads // self.num_key_value_heads @@ -38,8 +37,6 @@ def __init__(self, self.final_logit_softcapping = self.config.final_logit_softcapping def alloc(self, **kwargs): - - # TODO: can you use a KV cache for Gemma? (probably !?) self.kv_cache = KV_Cache(self.config, max_length=self.max_length, device=self.device, dtype=self.dtype, batch_size=self.batch_size) hf_model = Gemma2ForCausalLM.from_pretrained(self.model_name, torch_dtype = self.dtype) self.embed_tokens = hf_model.model.embed_tokens.weight.detach().to(self.device) @@ -87,8 +84,6 @@ def layer_compute(self, attention_mask: torch.FloatTensor, storage_ids: torch.LongTensor): - ##TODO: fix the attention mask, Gemma uses a sliding window self attention - if buffer.is_sliding and attention_mask is not None: min_dtype = torch.finfo(hidden_states.dtype).min sliding_window_mask = torch.tril( @@ -105,10 +100,6 @@ def layer_compute(self, hidden_states = layer_norm_gemma(hidden_states, buffer.input_layernorm_variance_epsilon, buffer.input_layernorm_weight) bsz, q_len, _ = hidden_states.size() query_states = F.linear(hidden_states, buffer.wq) - #print("sliding window?", buffer.is_sliding, buffer.sliding_window) - #print("query states size:", hidden_states.shape, buffer.wq.shape, query_states.shape) - #print("query_shape:", self.num_heads, self.head_dim) - #print("Weight shapes:", buffer.wq.shape, buffer.wk.shape, buffer.wv.shape, buffer.wo.shape) key_states = F.linear(hidden_states, buffer.wk) value_states = F.linear(hidden_states, buffer.wv) query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim) @@ -167,7 +158,6 @@ def inference(self, logits = F.tanh(logits) logits = logits * self.final_logit_softcapping - #print(logits.shape, logits[0,0,:10]) return logits def gather_kv_incremental(self, indices: torch.LongTensor, offset:int): From 5b20ff67206c71fd7fff33bd9556fef33195805c Mon Sep 17 00:00:00 2001 From: dreaming-panda Date: Thu, 13 Feb 2025 00:54:49 -0500 Subject: [PATCH 04/23] mistral --- examples/hf_generate.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) create mode 100644 examples/hf_generate.py diff --git a/examples/hf_generate.py b/examples/hf_generate.py new file mode 100644 index 0000000..a2261d6 --- /dev/null +++ b/examples/hf_generate.py @@ -0,0 +1,16 @@ +# Load model directly +import torch +from transformers import AutoTokenizer, AutoModelForCausalLM +import argparse +parser = argparse.ArgumentParser() +parser.add_argument('--model', type=str, default="meta-llama/Llama-3.1-8B-Instruct",help='model') +parser.add_argument('--G', type=int, default=512, help='generation length') +args = parser.parse_args() + +tokenizer = AutoTokenizer.from_pretrained(args.model) +model = AutoModelForCausalLM.from_pretrained(args.model, torch_dtype=torch.float16, _attn_implementation="eager").to("cuda:0") +text = "Tell me what you know about Reinforcement Learning in 100 words." +input_ids = tokenizer.encode(text=text, return_tensors="pt").to("cuda:0") + +output = model.generate(input_ids, do_sample=False, max_new_tokens=args.G) +print(tokenizer.decode(output[0], skip_special_tokens=True)) \ No newline at end of file From a8da89955b6913344ff6db6dec01f1309647183e Mon Sep 17 00:00:00 2001 From: dreaming-panda Date: Thu, 13 Feb 2025 00:55:32 -0500 Subject: [PATCH 05/23] mistral --- requirements.txt | 2 + umbrella/attn/cache.py | 2 +- umbrella/models/auto_model.py | 19 +- umbrella/models/mistral.py | 531 +++++++++++++++++++++++++++++++ umbrella/models/mistral_layer.py | 258 +++++++++++++++ umbrella/templates.py | 7 +- 6 files changed, 812 insertions(+), 7 deletions(-) create mode 100644 umbrella/models/mistral.py create mode 100644 umbrella/models/mistral_layer.py diff --git a/requirements.txt b/requirements.txt index 12c4f8a..371e1a7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,3 +6,5 @@ optimum==1.23.3 autoawq==0.2.7.post3 autoawq-kernels==0.0.8 gradio +protobuf +sentencepiece \ No newline at end of file diff --git a/umbrella/attn/cache.py b/umbrella/attn/cache.py index 1347080..06af93e 100644 --- a/umbrella/attn/cache.py +++ b/umbrella/attn/cache.py @@ -128,7 +128,7 @@ def __init__(self, self.kv_offset = 0 self.num_key_value_heads = config.num_key_value_heads self.num_attention_heads = config.num_attention_heads - self.head_dim = config.hidden_size // config.num_attention_heads + self.head_dim = getattr(config, 'head_dim', config.hidden_size // config.num_attention_heads) self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads diff --git a/umbrella/models/auto_model.py b/umbrella/models/auto_model.py index 5c60c1f..9155b0b 100644 --- a/umbrella/models/auto_model.py +++ b/umbrella/models/auto_model.py @@ -1,6 +1,7 @@ from .llama import Llama, LlamaAwq, LlamaOffload, LlamaAwqOffload, LlamaCudagraph from .qwen import Qwen, QwenOffload, QwenAwq, QwenAwqOffload, QwenCudagraph from .gemma import Gemma2 +from .mistral import Mistral, MistralAwqOffload, MistralOffload, MistralCudagraph, MistralAwq class AutoModelLM: """ 自动模型加载器,根据模型类型动态加载对应的类。 @@ -47,8 +48,11 @@ class AutoModelLM: "Qwen/Qwen2.5-32B-Instruct-AWQ": QwenAwqOffload, "Qwen/Qwen2.5-72B-Instruct-AWQ": QwenAwqOffload, "KirillR/QwQ-32B-Preview-AWQ": QwenAwqOffload, - "casperhansen/deepseek-r1-distill-qwen-32b-awq":QwenAwqOffload - + "casperhansen/deepseek-r1-distill-qwen-32b-awq":QwenAwqOffload, + "mistralai/Mistral-7B-Instruct-v0.3": MistralOffload, + "solidrust/Mistral-7B-Instruct-v0.3-AWQ": MistralAwqOffload, + "mistralai/Mistral-Small-24B-Instruct-2501": MistralOffload, + "stelterlab/Mistral-Small-24B-Instruct-2501-AWQ": MistralAwqOffload } _MODEL_MAPPING = { @@ -105,7 +109,13 @@ class AutoModelLM: "KirillR/QwQ-32B-Preview-AWQ": QwenAwq, "casperhansen/deepseek-r1-distill-qwen-32b-awq":QwenAwq, "google/gemma-2-2b-it": Gemma2, - "google/gemma-2-2b": Gemma2 + "google/gemma-2-2b": Gemma2, + "mistralai/Mistral-7B-Instruct-v0.3": Mistral, + "solidrust/Mistral-7B-Instruct-v0.3-AWQ": MistralAwq, + "mistralai/Mistral-Small-24B-Instruct-2501": Mistral, + "stelterlab/Mistral-Small-24B-Instruct-2501-AWQ": MistralAwq, + "PyrTools/Ministral-8B-Instruct-2410-AWQ": MistralAwq, + "mistralai/Ministral-8B-Instruct-2410": Mistral } _CUDAGRAPH_MODEL_MAPPING = { @@ -136,7 +146,8 @@ class AutoModelLM: "Qwen/Qwen2.5-14B-Instruct": QwenCudagraph, "Qwen/Qwen2.5-32B-Instruct": QwenCudagraph, "Qwen/Qwen2.5-72B-Instruct": QwenCudagraph, - "Qwen/QwQ-32B-Preview": QwenCudagraph + "Qwen/QwQ-32B-Preview": QwenCudagraph, + "mistralai/Mistral-7B-Instruct-v0.3": MistralCudagraph, } @classmethod diff --git a/umbrella/models/mistral.py b/umbrella/models/mistral.py new file mode 100644 index 0000000..1b5dc10 --- /dev/null +++ b/umbrella/models/mistral.py @@ -0,0 +1,531 @@ +from transformers import MistralForCausalLM, MistralConfig, AutoModelForCausalLM +import torch +import torch.nn.functional as F +import gc +import flashinfer +from ..attn.cache import KV_Cache, StaticKV_Cache +from .mistral_layer import MistralLayer, MistralAwqLayer, MistralPackedLayer +from .base import LLMBase +from .model_utils import apply_rotary_pos_emb, layer_norm, capture_graph +from tqdm import tqdm + +class Mistral(LLMBase): + def __init__(self, + model_name: str, + batch_size: int=1, + max_length: int=256, + device: str = "cuda:0", + dtype = torch.float16) -> None: + + super().__init__() + self.batch_size = batch_size + self.device = device + self.dtype = dtype + self.config = MistralConfig.from_pretrained(model_name) + self.model_name = model_name + self.max_length = max_length + self.hidden_size = self.config.hidden_size + self.num_heads = self.config.num_attention_heads + self.head_dim = getattr(self.config, 'head_dim', self.config.hidden_size // self.config.num_attention_heads) + self.num_key_value_heads = self.config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = self.config.max_position_embeddings + self.rope_theta = self.config.rope_theta + self.eos_tokens = self.config.eos_token_id if (isinstance(self.config.eos_token_id, list)) else [self.config.eos_token_id] + + def alloc(self, **kwargs): + + self.kv_cache = KV_Cache(self.config, max_length=self.max_length, device=self.device, dtype=self.dtype, batch_size=self.batch_size) + hf_model = MistralForCausalLM.from_pretrained(self.model_name, torch_dtype=self.dtype) + self.embed_tokens = hf_model.model.embed_tokens.weight.detach().to(self.device) + if self.config.tie_word_embeddings: + self.lm_head = self.embed_tokens + else: + self.lm_head = hf_model.lm_head.weight.detach().to(self.device) + + self.norm_weight = hf_model.model.norm.weight.detach().to(self.device) + self.norm_variance_epsilon = hf_model.model.norm.variance_epsilon + + self.inv_freq = hf_model.model.rotary_emb.inv_freq.detach().to(self.device) + self.attention_scaling = hf_model.model.rotary_emb.attention_scaling + position_ids = torch.arange(0, self.max_length).unsqueeze(0).to(self.device) + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + self.cos_cache = emb.cos()[0] + self.sin_cache = emb.sin()[0] + self.cos_cache = self.cos_cache * self.attention_scaling + self.sin_cache = self.sin_cache * self.attention_scaling + self.cos_cache = self.cos_cache.to(self.dtype) + self.sin_cache = self.sin_cache.to(self.dtype) + + self.layers :list[MistralLayer] = [] + + for idx, hf_layer in enumerate(hf_model.model.layers): + layer = MistralLayer(idx) + layer.init_parameters(hf_layer=hf_layer) + layer.to(self.device) + self.layers.append(layer) + hf_model.model.layers[idx] = None + gc.collect() + + self.num_layers = len(self.layers) + + @torch.inference_mode() + def layer_compute(self, + buffer: MistralLayer, + layer_idx :int, + hidden_states: torch.FloatTensor, + position_ids: torch.LongTensor, + attention_mask: torch.FloatTensor, + storage_ids: torch.LongTensor): + + residual = hidden_states + bsz, q_len, _ = hidden_states.size() + + hidden_states = layer_norm(hidden_states, buffer.input_layernorm_variance_epsilon, buffer.input_layernorm_weight) + bsz, q_len, _ = hidden_states.size() + query_states = F.linear(hidden_states, buffer.wq) + key_states = F.linear(hidden_states, buffer.wk) + value_states = F.linear(hidden_states, buffer.wv) + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim) + + + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, self.cos_cache, self.sin_cache, position_ids) + hidden_states = self.kv_cache.compute_attention( + query_states, key_states, value_states, layer_idx, storage_ids, attention_mask + ) + hidden_states = hidden_states.reshape(bsz, q_len, self.head_dim * self.num_heads) + + hidden_states = F.linear(hidden_states, buffer.wo) + hidden_states = residual + hidden_states + residual = hidden_states + hidden_states = layer_norm(hidden_states, buffer.post_attention_layernorm_variance_epsilon, buffer.post_attention_layernorm_weight) + up = F.linear(hidden_states, buffer.up_proj) + gate = F.linear(hidden_states, buffer.gate_proj) + gate = F.silu(gate) + hidden_states = gate * up + hidden_states = F.linear(hidden_states, buffer.down_proj) + hidden_states = residual + hidden_states + + return hidden_states + + @torch.inference_mode() + def inference(self, + input_ids: torch.LongTensor, + position_ids: torch.LongTensor, + attention_mask: torch.FloatTensor, + storage_ids: torch.LongTensor): + + hidden_states = F.embedding(input_ids, self.embed_tokens) + for idx in range(self.num_layers): + hidden_states = self.layer_compute(self.layers[idx], idx, hidden_states, position_ids, attention_mask, storage_ids) + + b, s, h = hidden_states.shape + + hidden_states = hidden_states.reshape(b * s, h) + hidden_states = flashinfer.rmsnorm(hidden_states, self.norm_weight, self.norm_variance_epsilon) + hidden_states = hidden_states.reshape(b, s, h) + logits = F.linear(hidden_states, self.lm_head).float() + return logits + + def gather_kv_incremental(self, indices: torch.LongTensor, offset:int): + + self.kv_cache.gather_kv_incremental(indices=indices, offset=offset) + + def clear(self): + + self.kv_cache.clear() + + +class MistralOffload(Mistral): + def __init__(self, model_name, batch_size = 1, max_length = 256, device = 'cuda:0', dtype=torch.float16): + super().__init__(model_name, batch_size, max_length, device, dtype) + self.load_stream = torch.cuda.Stream(device=device) + + def alloc(self, **kwargs): + + + self.num_cache_layers = kwargs["num_cache_layers"] if 'num_cache_layers' in kwargs else 0 + self.kv_cache = KV_Cache(self.config, max_length=self.max_length, device=self.device, dtype=self.dtype, batch_size=self.batch_size) + hf_model = MistralForCausalLM.from_pretrained(self.model_name, torch_dtype=self.dtype) + self.embed_tokens = hf_model.model.embed_tokens.weight.detach().to(self.device) + if self.config.tie_word_embeddings: + self.lm_head = self.embed_tokens + else: + self.lm_head = hf_model.lm_head.weight.detach().to(self.device) + + self.norm_weight = hf_model.model.norm.weight.detach().to(self.device) + self.norm_variance_epsilon = hf_model.model.norm.variance_epsilon + + self.inv_freq = hf_model.model.rotary_emb.inv_freq.detach().to(self.device) + self.attention_scaling = hf_model.model.rotary_emb.attention_scaling + position_ids = torch.arange(0, self.max_length).unsqueeze(0).to(self.device) + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + self.cos_cache = emb.cos()[0] + self.sin_cache = emb.sin()[0] + self.cos_cache = self.cos_cache * self.attention_scaling + self.sin_cache = self.sin_cache * self.attention_scaling + self.cos_cache = self.cos_cache.to(self.dtype) + self.sin_cache = self.sin_cache.to(self.dtype) + + self.layers :list[MistralLayer] = [] + + for idx, hf_layer in tqdm(enumerate(hf_model.model.layers), desc="initial offloaded model"): + layer = MistralLayer(idx) + layer.init_parameters(hf_layer=hf_layer) + if idx < self.num_cache_layers: + layer.to(self.device) + self.layers.append(layer) + hf_model.model.layers[idx] = None + gc.collect() + + self.num_layers = len(self.layers) + assert self.num_layers % 2 == 0 + self.buffer = [MistralLayer(-1, self.device) for _ in range(2)] + self.buffer[0].alloc_space(self.layers[0], self.device) + self.buffer[1].alloc_space(self.layers[0], self.device) + + @torch.inference_mode() + def inference(self, + input_ids: torch.LongTensor, + position_ids: torch.LongTensor, + attention_mask: torch.FloatTensor, + storage_ids: torch.LongTensor): + + hidden_states = F.embedding(input_ids, self.embed_tokens) + if self.buffer[0].layer_idx != 0: + self.buffer[0].copy(self.layers[0]) + torch.cuda.synchronize() + for idx in range(self.num_layers): + with torch.cuda.stream(self.load_stream): + self.buffer[(idx + 1) % 2].copy(self.layers[(idx + 1)% self.num_layers]) + + hidden_states = self.layer_compute(self.buffer[idx % 2], idx, hidden_states, position_ids, attention_mask, storage_ids) + torch.cuda.synchronize() + b, s, h = hidden_states.shape + + hidden_states = hidden_states.reshape(b * s, h) + hidden_states = flashinfer.rmsnorm(hidden_states, self.norm_weight, self.norm_variance_epsilon) + hidden_states = hidden_states.reshape(b, s, h) + logits = F.linear(hidden_states, self.lm_head).float() + return logits + + +class MistralAwq(Mistral): + def alloc(self, **kwargs): + + self.kv_cache = KV_Cache(self.config, max_length=self.max_length, device=self.device, dtype=self.dtype, batch_size=self.batch_size) + + hf_model = AutoModelForCausalLM.from_pretrained(self.model_name, torch_dtype=self.dtype) + self.embed_tokens = hf_model.model.embed_tokens.weight.detach().to(self.device) + if self.config.tie_word_embeddings: + self.lm_head = self.embed_tokens + else: + self.lm_head = hf_model.lm_head.weight.detach().to(self.device) + self.norm_weight = hf_model.model.norm.weight.detach().to(self.device) + self.norm_variance_epsilon = hf_model.model.norm.variance_epsilon + self.inv_freq = hf_model.model.rotary_emb.inv_freq.detach().to(self.device) + self.attention_scaling = hf_model.model.rotary_emb.attention_scaling + position_ids = torch.arange(0, self.max_length).unsqueeze(0).to(self.device) + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + self.cos_cache = emb.cos()[0] + self.sin_cache = emb.sin()[0] + self.cos_cache = self.cos_cache * self.attention_scaling + self.sin_cache = self.sin_cache * self.attention_scaling + self.cos_cache = self.cos_cache.to(self.dtype) + self.sin_cache = self.sin_cache.to(self.dtype) + self.layers :list[MistralAwqLayer] = [] + + for idx, hf_layer in enumerate(hf_model.model.layers): + layer = MistralAwqLayer(idx) + layer.init_parameters(hf_layer=hf_layer) + layer.to(self.device) + self.layers.append(layer) + hf_model.model.layers[idx] = None + gc.collect() + self.num_layers = len(self.layers) + + + @torch.inference_mode() + def layer_compute(self, + buffer: MistralAwqLayer, + layer_idx :int, + hidden_states: torch.FloatTensor, + position_ids: torch.LongTensor, + attention_mask: torch.FloatTensor, + storage_ids: torch.LongTensor): + + residual = hidden_states + bsz, q_len, _ = hidden_states.size() + + hidden_states = layer_norm(hidden_states, buffer.input_layernorm_variance_epsilon, buffer.input_layernorm_weight) + bsz, q_len, _ = hidden_states.size() + query_states = buffer.wq.apply(hidden_states) + key_states = buffer.wk.apply(hidden_states) + value_states = buffer.wv.apply(hidden_states) + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim) + + + + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, self.cos_cache, self.sin_cache, position_ids) + + hidden_states = self.kv_cache.compute_attention( + query_states, key_states, value_states, layer_idx, storage_ids, attention_mask + ) + + hidden_states = hidden_states.reshape(bsz, q_len, self.head_dim * self.num_heads) + + hidden_states = buffer.wo.apply(hidden_states) + hidden_states = residual + hidden_states + residual = hidden_states + hidden_states = layer_norm(hidden_states, buffer.post_attention_layernorm_variance_epsilon, buffer.post_attention_layernorm_weight) + up = buffer.up_proj.apply(hidden_states) + gate = buffer.gate_proj.apply(hidden_states) + gate = F.silu(gate) + hidden_states = gate * up + hidden_states = buffer.down_proj.apply(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + + @torch.inference_mode() + def inference(self, + input_ids: torch.LongTensor, + position_ids: torch.LongTensor, + attention_mask: torch.FloatTensor, + storage_ids: torch.LongTensor): + + hidden_states = F.embedding(input_ids, self.embed_tokens) + + for idx in range(self.num_layers): + hidden_states = self.layer_compute(self.layers[idx], idx, hidden_states, position_ids, attention_mask, storage_ids) + + b, s, h = hidden_states.shape + hidden_states = hidden_states.reshape(b * s, h) + hidden_states = flashinfer.rmsnorm(hidden_states, self.norm_weight, self.norm_variance_epsilon) + hidden_states = hidden_states.reshape(b, s, h) + logits = F.linear(hidden_states, self.lm_head).float() + return logits + + +class MistralAwqOffload(MistralOffload): + def alloc(self, **kwargs): + + self.num_cache_layers = kwargs["num_cache_layers"] if 'num_cache_layers' in kwargs else 0 + self.kv_cache = KV_Cache(self.config, max_length=self.max_length, device=self.device, dtype=self.dtype, batch_size=self.batch_size) + + hf_model = AutoModelForCausalLM.from_pretrained(self.model_name, torch_dtype=self.dtype) + self.embed_tokens = hf_model.model.embed_tokens.weight.detach().to(self.device) + if self.config.tie_word_embeddings: + self.lm_head = self.embed_tokens + else: + self.lm_head = hf_model.lm_head.weight.detach().to(self.device) + self.norm_weight = hf_model.model.norm.weight.detach().to(self.device) + self.norm_variance_epsilon = hf_model.model.norm.variance_epsilon + self.inv_freq = hf_model.model.rotary_emb.inv_freq.detach().to(self.device) + self.attention_scaling = hf_model.model.rotary_emb.attention_scaling + position_ids = torch.arange(0, self.max_length).unsqueeze(0).to(self.device) + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + self.cos_cache = emb.cos()[0] + self.sin_cache = emb.sin()[0] + self.cos_cache = self.cos_cache * self.attention_scaling + self.sin_cache = self.sin_cache * self.attention_scaling + self.cos_cache = self.cos_cache.to(self.dtype) + self.sin_cache = self.sin_cache.to(self.dtype) + self.layers :list[MistralAwqLayer] = [] + + for idx, hf_layer in tqdm(enumerate(hf_model.model.layers), desc="initial offloaded model"): + layer = MistralAwqLayer(idx) + layer.init_parameters(hf_layer=hf_layer) + if idx < self.num_cache_layers: + layer.to(self.device) + self.layers.append(layer) + hf_model.model.layers[idx] = None + gc.collect() + self.num_layers = len(self.layers) + assert self.num_layers % 2 == 0 + self.buffer = [MistralAwqLayer(-1, self.device) for _ in range(2)] + self.buffer[0].alloc_space(self.layers[0], self.device) + self.buffer[1].alloc_space(self.layers[0], self.device) + + @torch.inference_mode() + def layer_compute(self, + buffer: MistralAwqLayer, + layer_idx :int, + hidden_states: torch.FloatTensor, + position_ids: torch.LongTensor, + attention_mask: torch.FloatTensor, + storage_ids: torch.LongTensor): + + residual = hidden_states + bsz, q_len, _ = hidden_states.size() + + hidden_states = layer_norm(hidden_states, buffer.input_layernorm_variance_epsilon, buffer.input_layernorm_weight) + bsz, q_len, _ = hidden_states.size() + query_states = buffer.wq.apply(hidden_states) + key_states = buffer.wk.apply(hidden_states) + value_states = buffer.wv.apply(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim) + + + + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, self.cos_cache, self.sin_cache, position_ids) + hidden_states = self.kv_cache.compute_attention( + query_states, key_states, value_states, layer_idx, storage_ids, attention_mask + ) + hidden_states = hidden_states.reshape(bsz, q_len, self.head_dim * self.num_heads) + + hidden_states = buffer.wo.apply(hidden_states) + hidden_states = residual + hidden_states + residual = hidden_states + hidden_states = layer_norm(hidden_states, buffer.post_attention_layernorm_variance_epsilon, buffer.post_attention_layernorm_weight) + up = buffer.up_proj.apply(hidden_states) + gate = buffer.gate_proj.apply(hidden_states) + gate = F.silu(gate) + hidden_states = gate * up + hidden_states = buffer.down_proj.apply(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + + +class MistralCudagraph(Mistral): + def __init__(self, model_name, batch_size = 1, max_length = 256, device = 'cuda:0', dtype=torch.float16): + super().__init__(model_name, batch_size, max_length, device, dtype) + + self.callables = {} + self.mempool = None + + def alloc(self, **kwargs): + + exit_layer = kwargs.pop("exit_layer", -1) + self.kv_cache = StaticKV_Cache(self.config, max_length=self.max_length, device=self.device, dtype=self.dtype, batch_size=self.batch_size) + hf_model = MistralForCausalLM.from_pretrained(self.model_name, torch_dtype=self.dtype) + self.embed_tokens = hf_model.model.embed_tokens.weight.detach().to(self.device) + if self.config.tie_word_embeddings: + self.lm_head = self.embed_tokens + else: + self.lm_head = hf_model.lm_head.weight.detach().to(self.device) + + self.norm_weight = hf_model.model.norm.weight.detach().to(self.device) + self.norm_variance_epsilon = hf_model.model.norm.variance_epsilon + + self.inv_freq = hf_model.model.rotary_emb.inv_freq.detach().to(self.device) + self.attention_scaling = hf_model.model.rotary_emb.attention_scaling + position_ids = torch.arange(0, self.max_length).unsqueeze(0).to(self.device) + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + self.cos_cache = emb.cos()[0] + self.sin_cache = emb.sin()[0] + self.cos_cache = self.cos_cache * self.attention_scaling + self.sin_cache = self.sin_cache * self.attention_scaling + self.cos_cache = self.cos_cache.to(self.dtype) + self.sin_cache = self.sin_cache.to(self.dtype) + + self.layers :list[MistralPackedLayer] = [] + + for idx, hf_layer in enumerate(hf_model.model.layers): + if exit_layer > 0 and idx >= exit_layer: + break + layer = MistralPackedLayer(idx) + layer.init_parameters(hf_layer=hf_layer) + layer.to(self.device) + self.layers.append(layer) + hf_model.model.layers[idx] = None + gc.collect() + + self.num_layers = len(self.layers) + + @torch.inference_mode() + def layer_compute(self, + buffer: MistralPackedLayer, + layer_idx :int, + hidden_states: torch.FloatTensor, + position_ids: torch.LongTensor, + attention_mask: torch.FloatTensor, + storage_ids: torch.LongTensor): + + residual = hidden_states + bsz, q_len, _ = hidden_states.size() + + hidden_states = layer_norm(hidden_states, buffer.input_layernorm_variance_epsilon, buffer.input_layernorm_weight) + bsz, q_len, _ = hidden_states.size() + qkv = F.linear(hidden_states, buffer.wqkv) + query_states = qkv[...,:self.hidden_size] + key_states = qkv[...,self.hidden_size:self.hidden_size + self.head_dim * self.num_key_value_heads] + value_states = qkv[...,self.hidden_size + self.head_dim * self.num_key_value_heads:] + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1,2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1,2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1,2) + + + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, self.cos_cache, self.sin_cache, position_ids, unsqueeze_dim=1) + hidden_states = self.kv_cache.compute_attention( + query_states, key_states, value_states, layer_idx, storage_ids, attention_mask + ) + + hidden_states = hidden_states.reshape(bsz, q_len, self.head_dim * self.num_heads) + hidden_states = F.linear(hidden_states, buffer.wo) + hidden_states = residual + hidden_states + residual = hidden_states + hidden_states = layer_norm(hidden_states, buffer.post_attention_layernorm_variance_epsilon, buffer.post_attention_layernorm_weight) + up = F.linear(hidden_states, buffer.up_proj) + gate = F.linear(hidden_states, buffer.gate_proj) + gate = F.silu(gate) + hidden_states = gate * up + hidden_states = F.linear(hidden_states, buffer.down_proj) + hidden_states = residual + hidden_states + + return hidden_states + + + @torch.inference_mode() + def initialize_cuda_graph(self, + decoding_seqlens :list[int], + n_warmups=12): + gc.collect() + self.mempool = torch.cuda.graphs.graph_pool_handle() + for decoding_seqlen in decoding_seqlens: + if decoding_seqlen not in self.callables: + self.callables[decoding_seqlen] = capture_graph( + llm=self, + decoding_seqlen=decoding_seqlen, + mempool=self.mempool, + n_warmups=n_warmups + ) + self.clear() + + @torch.inference_mode() + def graph_inference(self, + input_ids: torch.LongTensor, + storage_ids :torch.LongTensor, + position_ids = None, + attention_mask = None, + ): + dec_length = input_ids.shape[1] + if dec_length in self.callables.keys(): + logits = self.callables[dec_length](input_ids, storage_ids, position_ids, attention_mask) + else: + logits = self.inference(input_ids, position_ids, attention_mask, storage_ids) + return logits \ No newline at end of file diff --git a/umbrella/models/mistral_layer.py b/umbrella/models/mistral_layer.py new file mode 100644 index 0000000..be05520 --- /dev/null +++ b/umbrella/models/mistral_layer.py @@ -0,0 +1,258 @@ +from __future__ import annotations +import torch +from transformers.models.mistral.modeling_mistral import MistralDecoderLayer +from ..quantization.awq_utils import AwqLinear + +class MistralLayer: + def __init__(self, layer_idx, device = "cpu") -> None: + + self.wq :torch.Tensor = None + self.wk :torch.Tensor = None + self.wv :torch.Tensor = None + self.wo :torch.Tensor = None + + self.gate_proj :torch.Tensor = None + self.up_proj :torch.Tensor = None + self.down_proj :torch.Tensor = None + + self.input_layernorm_weight :torch.Tensor = None + self.input_layernorm_variance_epsilon :float = 0.0 + + self.post_attention_layernorm_weight :torch.Tensor = None + self.post_attention_layernorm_variance_epsilon :float = 0.0 + + self.layer_idx = layer_idx + self.device = device + + def init_parameters(self, hf_layer: MistralDecoderLayer): + + self.wq :torch.Tensor= hf_layer.self_attn.q_proj.weight.detach() + self.wk :torch.Tensor= hf_layer.self_attn.k_proj.weight.detach() + self.wv :torch.Tensor= hf_layer.self_attn.v_proj.weight.detach() + self.wo :torch.Tensor= hf_layer.self_attn.o_proj.weight.detach() + + self.gate_proj = hf_layer.mlp.gate_proj.weight.detach() + self.up_proj = hf_layer.mlp.up_proj.weight.detach() + self.down_proj = hf_layer.mlp.down_proj.weight.detach() + + self.input_layernorm_weight = hf_layer.input_layernorm.weight.detach() + self.input_layernorm_variance_epsilon = hf_layer.input_layernorm.variance_epsilon + + self.post_attention_layernorm_weight = hf_layer.post_attention_layernorm.weight.detach() + self.post_attention_layernorm_variance_epsilon = hf_layer.post_attention_layernorm.variance_epsilon + + def to(self, device:str = 'cuda:0', non_blocking = True): + + self.device = device + self.input_layernorm_weight = self.input_layernorm_weight.to(device, non_blocking=non_blocking) + self.post_attention_layernorm_weight = self.post_attention_layernorm_weight.to(device, non_blocking=non_blocking) + self.wq = self.wq.to(device, non_blocking=non_blocking) + self.wk = self.wk.to(device, non_blocking=non_blocking) + self.wv = self.wv.to(device, non_blocking=non_blocking) + self.wo = self.wo.to(device, non_blocking=non_blocking) + self.gate_proj = self.gate_proj.to(device, non_blocking=non_blocking) + self.up_proj = self.up_proj.to(device, non_blocking=non_blocking) + self.down_proj = self.down_proj.to(device, non_blocking=non_blocking) + + def copy(self, layer: MistralLayer): + + self.wq.copy_(layer.wq, non_blocking=True) + self.wk.copy_(layer.wk, non_blocking=True) + self.wv.copy_(layer.wv, non_blocking=True) + self.wo.copy_(layer.wo, non_blocking=True) + self.gate_proj.copy_(layer.gate_proj, non_blocking=True) + self.up_proj.copy_(layer.up_proj, non_blocking=True) + self.down_proj.copy_(layer.down_proj, non_blocking=True) + + self.input_layernorm_weight.copy_(layer.input_layernorm_weight, non_blocking=True) + self.post_attention_layernorm_weight.copy_(layer.post_attention_layernorm_weight, non_blocking=True) + self.input_layernorm_variance_epsilon= layer.input_layernorm_variance_epsilon + self.post_attention_layernorm_variance_epsilon = layer.post_attention_layernorm_variance_epsilon + self.layer_idx = layer.layer_idx + + def alloc_space(self, layer: MistralLayer, device): + + self.device = device + self.wq = torch.zeros_like(layer.wq).to(device) + self.wk = torch.zeros_like(layer.wk).to(device) + self.wv = torch.zeros_like(layer.wv).to(device) + self.wo = torch.zeros_like(layer.wo).to(device) + + + self.gate_proj = torch.zeros_like(layer.gate_proj).to(device) + self.up_proj = torch.zeros_like(layer.up_proj).to(device) + self.down_proj = torch.zeros_like(layer.down_proj).to(device) + self.input_layernorm_weight = torch.zeros_like(layer.input_layernorm_weight).to(device) + self.post_attention_layernorm_weight = torch.zeros_like(layer.post_attention_layernorm_weight).to(device) + + +class MistralPackedLayer: + def __init__(self, layer_idx, device = "cpu") -> None: + + self.wqkv :torch.Tensor = None + + self.gate_proj :torch.Tensor = None + self.up_proj :torch.Tensor = None + self.down_proj :torch.Tensor = None + + self.input_layernorm_weight :torch.Tensor = None + self.input_layernorm_variance_epsilon :float = 0.0 + + self.post_attention_layernorm_weight :torch.Tensor = None + self.post_attention_layernorm_variance_epsilon :float = 0.0 + + self.layer_idx = layer_idx + self.device = device + + def init_parameters(self, hf_layer: MistralDecoderLayer): + + self.wqkv :torch.Tensor= torch.cat( + [ + hf_layer.self_attn.q_proj.weight.detach(), + hf_layer.self_attn.k_proj.weight.detach(), + hf_layer.self_attn.v_proj.weight.detach(), + ], + dim=0 + ) + self.wo :torch.Tensor= hf_layer.self_attn.o_proj.weight.detach() + self.gate_proj = hf_layer.mlp.gate_proj.weight.detach() + self.up_proj = hf_layer.mlp.up_proj.weight.detach() + self.down_proj = hf_layer.mlp.down_proj.weight.detach() + + self.input_layernorm_weight = hf_layer.input_layernorm.weight.detach() + self.input_layernorm_variance_epsilon = hf_layer.input_layernorm.variance_epsilon + + self.post_attention_layernorm_weight = hf_layer.post_attention_layernorm.weight.detach() + self.post_attention_layernorm_variance_epsilon = hf_layer.post_attention_layernorm.variance_epsilon + + def to(self, device:str = 'cuda:0', non_blocking = True): + + self.device = device + self.input_layernorm_weight = self.input_layernorm_weight.to(device, non_blocking=non_blocking) + self.post_attention_layernorm_weight = self.post_attention_layernorm_weight.to(device, non_blocking=non_blocking) + self.wqkv = self.wqkv.to(device, non_blocking=non_blocking) + self.wo = self.wo.to(device, non_blocking=non_blocking) + self.gate_proj = self.gate_proj.to(device, non_blocking=non_blocking) + self.up_proj = self.up_proj.to(device, non_blocking=non_blocking) + self.down_proj = self.down_proj.to(device, non_blocking=non_blocking) + + def copy(self, layer: MistralPackedLayer): + + self.wqkv.copy_(layer.wqkv, non_blocking=True) + self.wo.copy_(layer.wo, non_blocking=True) + self.gate_proj.copy_(layer.gate_proj, non_blocking=True) + self.up_proj.copy_(layer.up_proj, non_blocking=True) + self.down_proj.copy_(layer.down_proj, non_blocking=True) + + self.input_layernorm_weight.copy_(layer.input_layernorm_weight, non_blocking=True) + self.post_attention_layernorm_weight.copy_(layer.post_attention_layernorm_weight, non_blocking=True) + self.input_layernorm_variance_epsilon= layer.input_layernorm_variance_epsilon + self.post_attention_layernorm_variance_epsilon = layer.post_attention_layernorm_variance_epsilon + self.layer_idx = layer.layer_idx + + def alloc_space(self, layer: MistralPackedLayer, device): + + self.device = device + self.wqkv = torch.zeros_like(layer.wqkv).to(device) + self.wo = torch.zeros_like(layer.wo).to(device) + + + self.gate_proj = torch.zeros_like(layer.gate_proj).to(device) + self.up_proj = torch.zeros_like(layer.up_proj).to(device) + self.down_proj = torch.zeros_like(layer.down_proj).to(device) + self.input_layernorm_weight = torch.zeros_like(layer.input_layernorm_weight).to(device) + self.post_attention_layernorm_weight = torch.zeros_like(layer.post_attention_layernorm_weight).to(device) + + +class MistralAwqLayer(): + def __init__(self, layer_idx, device="cpu"): + + self.wq = AwqLinear() + self.wk = AwqLinear() + self.wv = AwqLinear() + self.wo = AwqLinear() + + self.gate_proj = AwqLinear() + self.up_proj = AwqLinear() + self.down_proj = AwqLinear() + + + self.input_layernorm_weight :torch.Tensor = None + self.input_layernorm_variance_epsilon :float = 0.0 + + self.post_attention_layernorm_weight :torch.Tensor = None + self.post_attention_layernorm_variance_epsilon :float = 0.0 + + self.layer_idx = layer_idx + self.device = device + + def init_parameters(self, hf_layer): + + self.wq.init_parameters(hf_layer.self_attn.q_proj) + self.wk.init_parameters(hf_layer.self_attn.k_proj) + self.wv.init_parameters(hf_layer.self_attn.v_proj) + self.wo.init_parameters(hf_layer.self_attn.o_proj) + self.gate_proj.init_parameters(hf_layer.mlp.gate_proj) + self.up_proj.init_parameters(hf_layer.mlp.up_proj) + self.down_proj.init_parameters(hf_layer.mlp.down_proj) + + self.input_layernorm_weight = hf_layer.input_layernorm.weight.detach().pin_memory() + self.input_layernorm_variance_epsilon = hf_layer.input_layernorm.variance_epsilon + + self.post_attention_layernorm_weight = hf_layer.post_attention_layernorm.weight.detach().pin_memory() + self.post_attention_layernorm_variance_epsilon = hf_layer.post_attention_layernorm.variance_epsilon + + def to(self, device:str = 'cuda:0', non_blocking = True): + + self.device = device + self.input_layernorm_weight = self.input_layernorm_weight.to(device, non_blocking=non_blocking) + self.post_attention_layernorm_weight = self.post_attention_layernorm_weight.to(device, non_blocking=non_blocking) + + self.wq.to(device=device) + self.wk.to(device=device) + self.wv.to(device=device) + self.wo.to(device=device) + + self.gate_proj.to(device=device) + self.up_proj.to(device=device) + self.down_proj.to(device=device) + + def alloc_space(self, layer: MistralAwqLayer, device): + + self.device = device + self.wq.empty_like(layer.wq) + self.wk.empty_like(layer.wk) + self.wv.empty_like(layer.wv) + self.wo.empty_like(layer.wo) + + self.gate_proj.empty_like(layer.gate_proj) + self.up_proj.empty_like(layer.up_proj) + self.down_proj.empty_like(layer.down_proj) + + self.wq.to(device=device) + self.wk.to(device=device) + self.wv.to(device=device) + self.wo.to(device=device) + + self.gate_proj.to(device=device) + self.up_proj.to(device=device) + self.down_proj.to(device=device) + + self.input_layernorm_weight = torch.zeros_like(layer.input_layernorm_weight).to(device) + self.post_attention_layernorm_weight = torch.zeros_like(layer.post_attention_layernorm_weight).to(device) + + def copy(self, layer: MistralAwqLayer): + + self.wq.copy(layer.wq, non_blocking=True) + self.wk.copy(layer.wk, non_blocking=True) + self.wv.copy(layer.wv, non_blocking=True) + self.wo.copy(layer.wo, non_blocking=True) + self.gate_proj.copy(layer.gate_proj, non_blocking=True) + self.up_proj.copy(layer.up_proj, non_blocking=True) + self.down_proj.copy(layer.down_proj, non_blocking=True) + + self.input_layernorm_weight.copy_(layer.input_layernorm_weight, non_blocking=True) + self.post_attention_layernorm_weight.copy_(layer.post_attention_layernorm_weight, non_blocking=True) + self.input_layernorm_variance_epsilon= layer.input_layernorm_variance_epsilon + self.post_attention_layernorm_variance_epsilon = layer.post_attention_layernorm_variance_epsilon + self.layer_idx = layer.layer_idx \ No newline at end of file diff --git a/umbrella/templates.py b/umbrella/templates.py index c71dc5f..ac30a80 100644 --- a/umbrella/templates.py +++ b/umbrella/templates.py @@ -21,7 +21,9 @@ model """, -'gemma2': "{}" +'gemma2': "{}", +'mistral': "[INST] {} [/INST]" + } SysPrompts = { @@ -35,7 +37,8 @@ You are a helpful assistant.<|im_end|> """, 'gemma2': "", - 'gemma2-it': "" + 'gemma2-it': "", + 'mistral': "", } From 5c7f080034dd9af34014d742aeedb0ebf9b7110e Mon Sep 17 00:00:00 2001 From: dreaming-panda Date: Thu, 13 Feb 2025 12:40:49 -0500 Subject: [PATCH 06/23] ministral --- umbrella/attn/cache.py | 7 ++++--- umbrella/models/auto_model.py | 1 + 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/umbrella/attn/cache.py b/umbrella/attn/cache.py index 06af93e..81a0800 100644 --- a/umbrella/attn/cache.py +++ b/umbrella/attn/cache.py @@ -107,11 +107,13 @@ def __init__(self, self.max_length = max_length self.device = device self.dtype = dtype + self.head_dim = getattr(config, 'head_dim', config.hidden_size // config.num_attention_heads) + self.k_cache = torch.zeros( config.num_hidden_layers, config.num_key_value_heads, max_length, - config.hidden_size // config.num_attention_heads, + self.head_dim, device=self.device, dtype=self.dtype ) @@ -120,7 +122,7 @@ def __init__(self, config.num_hidden_layers, config.num_key_value_heads, max_length, - config.hidden_size // config.num_attention_heads, + self.head_dim, device=self.device, dtype=self.dtype ) @@ -128,7 +130,6 @@ def __init__(self, self.kv_offset = 0 self.num_key_value_heads = config.num_key_value_heads self.num_attention_heads = config.num_attention_heads - self.head_dim = getattr(config, 'head_dim', config.hidden_size // config.num_attention_heads) self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads diff --git a/umbrella/models/auto_model.py b/umbrella/models/auto_model.py index 9155b0b..17cbbe6 100644 --- a/umbrella/models/auto_model.py +++ b/umbrella/models/auto_model.py @@ -148,6 +148,7 @@ class AutoModelLM: "Qwen/Qwen2.5-72B-Instruct": QwenCudagraph, "Qwen/QwQ-32B-Preview": QwenCudagraph, "mistralai/Mistral-7B-Instruct-v0.3": MistralCudagraph, + "mistralai/Ministral-8B-Instruct-2410": MistralCudagraph } @classmethod From 2df60742a86f3be2318866fa9e0095efa553ee95 Mon Sep 17 00:00:00 2001 From: dreaming-panda Date: Thu, 13 Feb 2025 16:48:56 -0500 Subject: [PATCH 07/23] gemma --- umbrella/models/auto_model.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/umbrella/models/auto_model.py b/umbrella/models/auto_model.py index 17cbbe6..f7c4105 100644 --- a/umbrella/models/auto_model.py +++ b/umbrella/models/auto_model.py @@ -109,6 +109,8 @@ class AutoModelLM: "KirillR/QwQ-32B-Preview-AWQ": QwenAwq, "casperhansen/deepseek-r1-distill-qwen-32b-awq":QwenAwq, "google/gemma-2-2b-it": Gemma2, + "google/gemma-2-9b-it": Gemma2, + "google/gemma-2-27b-it": Gemma2, "google/gemma-2-2b": Gemma2, "mistralai/Mistral-7B-Instruct-v0.3": Mistral, "solidrust/Mistral-7B-Instruct-v0.3-AWQ": MistralAwq, From bd11c54f7c788e90b8fd5d815eba177434c1eff7 Mon Sep 17 00:00:00 2001 From: dreaming-panda Date: Thu, 13 Feb 2025 23:23:46 -0500 Subject: [PATCH 08/23] draft --- draft/config.json | 25 +++++++++++++++ draft/train_draft.py | 76 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 101 insertions(+) create mode 100644 draft/config.json create mode 100644 draft/train_draft.py diff --git a/draft/config.json b/draft/config.json new file mode 100644 index 0000000..55a9136 --- /dev/null +++ b/draft/config.json @@ -0,0 +1,25 @@ +{ + "architectures": [ + "MistralForCausalLM" + ], + "attention_dropout": 0.0, + "bos_token_id": 1, + "eos_token_id": 2, + "head_dim": 64, + "hidden_act": "silu", + "hidden_size": 2048, + "initializer_range": 0.02, + "intermediate_size": 6144, + "max_position_embeddings": 32768, + "model_type": "mistral", + "num_attention_heads": 16, + "num_hidden_layers": 4, + "num_key_value_heads": 4, + "rms_norm_eps": 1e-05, + "rope_theta": 100000000.0, + "sliding_window": null, + "tie_word_embeddings": true, + "torch_dtype": "bfloat16", + "use_cache": true, + "vocab_size": 131072 + } \ No newline at end of file diff --git a/draft/train_draft.py b/draft/train_draft.py new file mode 100644 index 0000000..c0beea1 --- /dev/null +++ b/draft/train_draft.py @@ -0,0 +1,76 @@ +import torch +from transformers import AutoTokenizer, DataCollatorForLanguageModeling, TrainingArguments, Trainer, AutoModelForCausalLM, Qwen2ForCausalLM +from datasets import load_dataset +import json +import argparse +parser = argparse.ArgumentParser() +parser.add_argument('--tokenizer', type=str, default="mistralai/Mistral-Small-24B-Instruct-2501",help='tokenizer') +parser.add_argument('--config', type=str, default="./config.json",help='model config') +parser.add_argument('--bsz', type=int, default=4, help='generation length') +args = parser.parse_args() + +with open(args.config, "r") as f: + config = json.load(f) + +model_name = args.model +tokenizer = AutoTokenizer.from_pretrained(model_name) + +model = AutoModelForCausalLM.from_config(config) + +train_data_files= [ + "train/chunk1/example_train_10*.jsonl.zst", + "train/chunk1/example_train_11*.jsonl.zst", + "train/chunk1/example_train_12*.jsonl.zst", + "train/chunk1/example_train_13*.jsonl.zst" +] +train_raw_datasets = load_dataset("cerebras/SlimPajama-627B", data_files=train_data_files, split="train") + +eval_data_files= ["validation/chunk1/example_holdout_*.jsonl.zst"] +eval_raw_datasets = load_dataset("cerebras/SlimPajama-627B", data_files=eval_data_files, split="train") + +# 定义预处理函数:对句子对进行编码 +def preprocess_function(examples): + + + output = tokenizer( + examples["problem"] + "\n\n" + examples["solution"], + truncation=True, + max_length=2048, + padding="max_length" + ) + + return output + +train_tokenized_datasets = train_raw_datasets.map(preprocess_function, batched=False) +eval_tokenized_datasets = eval_raw_datasets.map(preprocess_function, batched=False) + +data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False) + +training_args = TrainingArguments( + output_dir="./results", + learning_rate=1e-4, + per_device_train_batch_size=args.bsz, + per_device_eval_batch_size=args.bsz, + weight_decay=0.01, + lr_scheduler_type="cosine", + load_best_model_at_end=True, + logging_dir='./logs', + logging_steps=10, + bf16=True, + save_only_model=True, + save_steps=5000, + save_total_limit=2 +) + +# 初始化 Trainer +trainer = Trainer( + model=model, + args=training_args, + train_dataset=train_tokenized_datasets, + eval_dataset=eval_tokenized_datasets, + tokenizer=tokenizer, + data_collator=data_collator +) + +# 开始训练 +trainer.train() From f8116e7e2b1a4375d38696f6634bbf62c0c460c8 Mon Sep 17 00:00:00 2001 From: dreaming-panda Date: Fri, 14 Feb 2025 03:12:19 -0500 Subject: [PATCH 09/23] draft --- draft/train_draft.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/draft/train_draft.py b/draft/train_draft.py index c0beea1..e97d3d6 100644 --- a/draft/train_draft.py +++ b/draft/train_draft.py @@ -1,5 +1,5 @@ import torch -from transformers import AutoTokenizer, DataCollatorForLanguageModeling, TrainingArguments, Trainer, AutoModelForCausalLM, Qwen2ForCausalLM +from transformers import AutoTokenizer, DataCollatorForLanguageModeling, TrainingArguments, Trainer, AutoModelForCausalLM, AutoConfig from datasets import load_dataset import json import argparse @@ -9,10 +9,8 @@ parser.add_argument('--bsz', type=int, default=4, help='generation length') args = parser.parse_args() -with open(args.config, "r") as f: - config = json.load(f) - -model_name = args.model +config = AutoConfig.from_pretrained(args.config) +model_name = args.tokenizer tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForCausalLM.from_config(config) From 7dcbf0d9c1888db8410966f6105497c42a8139d8 Mon Sep 17 00:00:00 2001 From: dreaming-panda Date: Fri, 14 Feb 2025 03:50:01 -0500 Subject: [PATCH 10/23] draft --- draft/train_draft.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/draft/train_draft.py b/draft/train_draft.py index e97d3d6..1b8648e 100644 --- a/draft/train_draft.py +++ b/draft/train_draft.py @@ -31,9 +31,9 @@ def preprocess_function(examples): output = tokenizer( - examples["problem"] + "\n\n" + examples["solution"], + examples["text"], truncation=True, - max_length=2048, + max_length=1024, padding="max_length" ) From 8b7c7ea4c73dc4e650096968596fa0fdf4eeb56d Mon Sep 17 00:00:00 2001 From: dreaming-panda Date: Fri, 14 Feb 2025 03:51:39 -0500 Subject: [PATCH 11/23] draft --- draft/train_draft.py | 1 + 1 file changed, 1 insertion(+) diff --git a/draft/train_draft.py b/draft/train_draft.py index 1b8648e..56a4610 100644 --- a/draft/train_draft.py +++ b/draft/train_draft.py @@ -12,6 +12,7 @@ config = AutoConfig.from_pretrained(args.config) model_name = args.tokenizer tokenizer = AutoTokenizer.from_pretrained(model_name) +tokenizer.pad_token = tokenizer.eos_token model = AutoModelForCausalLM.from_config(config) From a872e02fb61ac39d56b9afdfdc288752615bb627 Mon Sep 17 00:00:00 2001 From: dreaming-panda Date: Fri, 14 Feb 2025 03:53:47 -0500 Subject: [PATCH 12/23] draft --- draft/train_draft.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/draft/train_draft.py b/draft/train_draft.py index 56a4610..457a258 100644 --- a/draft/train_draft.py +++ b/draft/train_draft.py @@ -12,7 +12,8 @@ config = AutoConfig.from_pretrained(args.config) model_name = args.tokenizer tokenizer = AutoTokenizer.from_pretrained(model_name) -tokenizer.pad_token = tokenizer.eos_token +if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token model = AutoModelForCausalLM.from_config(config) @@ -40,8 +41,8 @@ def preprocess_function(examples): return output -train_tokenized_datasets = train_raw_datasets.map(preprocess_function, batched=False) -eval_tokenized_datasets = eval_raw_datasets.map(preprocess_function, batched=False) +train_tokenized_datasets = train_raw_datasets.map(preprocess_function, batched=True) +eval_tokenized_datasets = eval_raw_datasets.map(preprocess_function, batched=True) data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False) From 1c146611fd371791f03ebf8140a148d537287d02 Mon Sep 17 00:00:00 2001 From: dreaming-panda Date: Fri, 14 Feb 2025 11:22:22 -0500 Subject: [PATCH 13/23] draft --- draft/train_draft.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/draft/train_draft.py b/draft/train_draft.py index 457a258..ed00fe5 100644 --- a/draft/train_draft.py +++ b/draft/train_draft.py @@ -21,7 +21,8 @@ "train/chunk1/example_train_10*.jsonl.zst", "train/chunk1/example_train_11*.jsonl.zst", "train/chunk1/example_train_12*.jsonl.zst", - "train/chunk1/example_train_13*.jsonl.zst" + "train/chunk1/example_train_13*.jsonl.zst", + "train/chunk1/example_train_14*.jsonl.zst" ] train_raw_datasets = load_dataset("cerebras/SlimPajama-627B", data_files=train_data_files, split="train") @@ -59,7 +60,9 @@ def preprocess_function(examples): bf16=True, save_only_model=True, save_steps=5000, - save_total_limit=2 + save_total_limit=2, + eval_strategy="steps", + save_strategy="steps" ) # 初始化 Trainer From 71aff1b6a0074a8fde7ef7bacc66653c18716d67 Mon Sep 17 00:00:00 2001 From: dreaming-panda Date: Fri, 14 Feb 2025 12:32:05 -0500 Subject: [PATCH 14/23] draft --- draft/train_draft.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/draft/train_draft.py b/draft/train_draft.py index ed00fe5..a4c49b6 100644 --- a/draft/train_draft.py +++ b/draft/train_draft.py @@ -42,8 +42,8 @@ def preprocess_function(examples): return output -train_tokenized_datasets = train_raw_datasets.map(preprocess_function, batched=True) -eval_tokenized_datasets = eval_raw_datasets.map(preprocess_function, batched=True) +train_tokenized_datasets = train_raw_datasets.map(preprocess_function, batched=True, num_proc=8) +eval_tokenized_datasets = eval_raw_datasets.map(preprocess_function, batched=True, num_proc=8) data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False) From a706ae1b76cf4b6cd2715c4feb9cf4cfc7449eee Mon Sep 17 00:00:00 2001 From: dreaming-panda Date: Fri, 14 Feb 2025 13:11:26 -0500 Subject: [PATCH 15/23] draft --- draft/train_draft.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/draft/train_draft.py b/draft/train_draft.py index a4c49b6..7e121a7 100644 --- a/draft/train_draft.py +++ b/draft/train_draft.py @@ -6,6 +6,7 @@ parser = argparse.ArgumentParser() parser.add_argument('--tokenizer', type=str, default="mistralai/Mistral-Small-24B-Instruct-2501",help='tokenizer') parser.add_argument('--config', type=str, default="./config.json",help='model config') +parser.add_argument('--output_dir', type=str, default="mistral",help='output directory') parser.add_argument('--bsz', type=int, default=4, help='generation length') args = parser.parse_args() @@ -62,7 +63,8 @@ def preprocess_function(examples): save_steps=5000, save_total_limit=2, eval_strategy="steps", - save_strategy="steps" + save_strategy="steps", + eval_steps=5000 ) # 初始化 Trainer From 835dabd09671725e1e784e21fa86002b406fff8c Mon Sep 17 00:00:00 2001 From: dreaming-panda Date: Fri, 14 Feb 2025 13:14:24 -0500 Subject: [PATCH 16/23] draft --- draft/train_draft.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/draft/train_draft.py b/draft/train_draft.py index 7e121a7..02cf455 100644 --- a/draft/train_draft.py +++ b/draft/train_draft.py @@ -17,6 +17,8 @@ tokenizer.pad_token = tokenizer.eos_token model = AutoModelForCausalLM.from_config(config) +total_params = sum(p.numel() for p in model.parameters()) +print(f"total_params: {total_params:,}") train_data_files= [ "train/chunk1/example_train_10*.jsonl.zst", From ff1ac1751cf3073fd7f3bd42927658d98de61657 Mon Sep 17 00:00:00 2001 From: dreaming-panda Date: Fri, 14 Feb 2025 13:15:53 -0500 Subject: [PATCH 17/23] draft --- draft/train_draft.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/draft/train_draft.py b/draft/train_draft.py index 02cf455..525e767 100644 --- a/draft/train_draft.py +++ b/draft/train_draft.py @@ -51,7 +51,7 @@ def preprocess_function(examples): data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False) training_args = TrainingArguments( - output_dir="./results", + output_dir=args.output_dir, learning_rate=1e-4, per_device_train_batch_size=args.bsz, per_device_eval_batch_size=args.bsz, From 427a318419011e68d5491c1392ce83a2e6a12134 Mon Sep 17 00:00:00 2001 From: yulinw2 Date: Fri, 21 Feb 2025 05:02:23 -0500 Subject: [PATCH 18/23] mistral 24b --- .DS_Store | Bin 0 -> 6148 bytes .idea/.gitignore | 8 + .idea/UMbreLLa.iml | 12 + .idea/inspectionProfiles/Project_Default.xml | 38 ++ .../inspectionProfiles/profiles_settings.xml | 6 + .idea/misc.xml | 4 + .idea/modules.xml | 8 + .idea/vcs.xml | 6 + draft/config.json | 25 - draft/train_draft.py | 83 --- examples/generate.py | 76 ++- examples/hf_generate.py | 16 - requirements.txt | 6 +- umbrella/attn/cache.py | 343 ++++++++---- umbrella/models/auto_model.py | 28 +- umbrella/models/gemma.py | 170 ------ umbrella/models/gemma_layer.py | 113 ---- umbrella/models/llama.py | 18 +- umbrella/models/mistral.py | 505 +++++++----------- umbrella/models/mistral_layer.py | 201 ++----- umbrella/models/model_utils.py | 13 +- .../speculation/static_speculation_engine.py | 6 +- umbrella/templates.py | 19 +- 23 files changed, 647 insertions(+), 1057 deletions(-) create mode 100644 .DS_Store create mode 100644 .idea/.gitignore create mode 100644 .idea/UMbreLLa.iml create mode 100644 .idea/inspectionProfiles/Project_Default.xml create mode 100644 .idea/inspectionProfiles/profiles_settings.xml create mode 100644 .idea/misc.xml create mode 100644 .idea/modules.xml create mode 100644 .idea/vcs.xml delete mode 100644 draft/config.json delete mode 100644 draft/train_draft.py delete mode 100644 examples/hf_generate.py delete mode 100644 umbrella/models/gemma.py delete mode 100644 umbrella/models/gemma_layer.py diff --git a/.DS_Store b/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..b5d9f8165cd02056a6c76be9adf60eb4886b9264 GIT binary patch literal 6148 zcmeHKF;2ul47A~jNEBR2xnJN1(dj6t`2ZK73pxZki839}h8f$SivpAsXdKy-^?I_& z6!ET^`Q~tcX|^`A4oxkh=N4`~EUpOTu-8@#tb82O; z6N<;`$hRmr?-LcJfE1W3aGJw4`~MaEhx7lOq>&Vm0{=<@Uu|!;8+=mr*2TxM*EaYq o{L@ey + + + + + + + + + \ No newline at end of file diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml new file mode 100644 index 0000000..f3032d8 --- /dev/null +++ b/.idea/inspectionProfiles/Project_Default.xml @@ -0,0 +1,38 @@ + + + + \ No newline at end of file diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml new file mode 100644 index 0000000..105ce2d --- /dev/null +++ b/.idea/inspectionProfiles/profiles_settings.xml @@ -0,0 +1,6 @@ + + + + \ No newline at end of file diff --git a/.idea/misc.xml b/.idea/misc.xml new file mode 100644 index 0000000..c95407c --- /dev/null +++ b/.idea/misc.xml @@ -0,0 +1,4 @@ + + + + \ No newline at end of file diff --git a/.idea/modules.xml b/.idea/modules.xml new file mode 100644 index 0000000..3892f88 --- /dev/null +++ b/.idea/modules.xml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/vcs.xml b/.idea/vcs.xml new file mode 100644 index 0000000..94a25f7 --- /dev/null +++ b/.idea/vcs.xml @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/draft/config.json b/draft/config.json deleted file mode 100644 index 55a9136..0000000 --- a/draft/config.json +++ /dev/null @@ -1,25 +0,0 @@ -{ - "architectures": [ - "MistralForCausalLM" - ], - "attention_dropout": 0.0, - "bos_token_id": 1, - "eos_token_id": 2, - "head_dim": 64, - "hidden_act": "silu", - "hidden_size": 2048, - "initializer_range": 0.02, - "intermediate_size": 6144, - "max_position_embeddings": 32768, - "model_type": "mistral", - "num_attention_heads": 16, - "num_hidden_layers": 4, - "num_key_value_heads": 4, - "rms_norm_eps": 1e-05, - "rope_theta": 100000000.0, - "sliding_window": null, - "tie_word_embeddings": true, - "torch_dtype": "bfloat16", - "use_cache": true, - "vocab_size": 131072 - } \ No newline at end of file diff --git a/draft/train_draft.py b/draft/train_draft.py deleted file mode 100644 index 525e767..0000000 --- a/draft/train_draft.py +++ /dev/null @@ -1,83 +0,0 @@ -import torch -from transformers import AutoTokenizer, DataCollatorForLanguageModeling, TrainingArguments, Trainer, AutoModelForCausalLM, AutoConfig -from datasets import load_dataset -import json -import argparse -parser = argparse.ArgumentParser() -parser.add_argument('--tokenizer', type=str, default="mistralai/Mistral-Small-24B-Instruct-2501",help='tokenizer') -parser.add_argument('--config', type=str, default="./config.json",help='model config') -parser.add_argument('--output_dir', type=str, default="mistral",help='output directory') -parser.add_argument('--bsz', type=int, default=4, help='generation length') -args = parser.parse_args() - -config = AutoConfig.from_pretrained(args.config) -model_name = args.tokenizer -tokenizer = AutoTokenizer.from_pretrained(model_name) -if tokenizer.pad_token is None: - tokenizer.pad_token = tokenizer.eos_token - -model = AutoModelForCausalLM.from_config(config) -total_params = sum(p.numel() for p in model.parameters()) -print(f"total_params: {total_params:,}") - -train_data_files= [ - "train/chunk1/example_train_10*.jsonl.zst", - "train/chunk1/example_train_11*.jsonl.zst", - "train/chunk1/example_train_12*.jsonl.zst", - "train/chunk1/example_train_13*.jsonl.zst", - "train/chunk1/example_train_14*.jsonl.zst" -] -train_raw_datasets = load_dataset("cerebras/SlimPajama-627B", data_files=train_data_files, split="train") - -eval_data_files= ["validation/chunk1/example_holdout_*.jsonl.zst"] -eval_raw_datasets = load_dataset("cerebras/SlimPajama-627B", data_files=eval_data_files, split="train") - -# 定义预处理函数:对句子对进行编码 -def preprocess_function(examples): - - - output = tokenizer( - examples["text"], - truncation=True, - max_length=1024, - padding="max_length" - ) - - return output - -train_tokenized_datasets = train_raw_datasets.map(preprocess_function, batched=True, num_proc=8) -eval_tokenized_datasets = eval_raw_datasets.map(preprocess_function, batched=True, num_proc=8) - -data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False) - -training_args = TrainingArguments( - output_dir=args.output_dir, - learning_rate=1e-4, - per_device_train_batch_size=args.bsz, - per_device_eval_batch_size=args.bsz, - weight_decay=0.01, - lr_scheduler_type="cosine", - load_best_model_at_end=True, - logging_dir='./logs', - logging_steps=10, - bf16=True, - save_only_model=True, - save_steps=5000, - save_total_limit=2, - eval_strategy="steps", - save_strategy="steps", - eval_steps=5000 -) - -# 初始化 Trainer -trainer = Trainer( - model=model, - args=training_args, - train_dataset=train_tokenized_datasets, - eval_dataset=eval_tokenized_datasets, - tokenizer=tokenizer, - data_collator=data_collator -) - -# 开始训练 -trainer.train() diff --git a/examples/generate.py b/examples/generate.py index f130030..6e6d86d 100644 --- a/examples/generate.py +++ b/examples/generate.py @@ -1,21 +1,26 @@ import os + os.environ["TOKENIZERS_PARALLELISM"] = "false" from umbrella.models.auto_model import AutoModelLM from umbrella.logging_config import setup_logger from umbrella.utils import TextColors + logger = setup_logger() import torch from umbrella.templates import Prompts, SysPrompts -from transformers import AutoTokenizer -from umbrella.speculation.speculation_utils import make_causal_mask, is_sentence_complete_regex, find_first_element_position +from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM +from umbrella.speculation.speculation_utils import make_causal_mask, is_sentence_complete_regex, \ + find_first_element_position import argparse import time + parser = argparse.ArgumentParser() -parser.add_argument('--model', type=str, default="meta-llama/Llama-3.1-8B-Instruct",help='model') -parser.add_argument('--template', type=str, default="meta-llama3",help='prompt template') +parser.add_argument('--model', type=str, default="meta-llama/Llama-3.1-8B-Instruct", help='model') +parser.add_argument('--template', type=str, default="meta-llama3", help='prompt template') parser.add_argument('--G', type=int, default=512, help='generation length') parser.add_argument('--offload', action='store_true', help="offload the model") parser.add_argument('--cuda_graph', action='store_true', help="whether use cuda graph") +parser.add_argument('--d', type=int, default=0, help="whether use debug mode") args = parser.parse_args() DEVICE = "cuda:0" DTYPE = torch.float16 @@ -32,6 +37,12 @@ tokenizer = AutoTokenizer.from_pretrained(args.model) tokens = tokenizer.encode(text=text, return_tensors="pt").to(DEVICE) +config = AutoConfig.from_pretrained(args.model) + +# testing mistral sliding window +# config.sliding_window = True +# config.window_size = 100 + llm = AutoModelLM.from_pretrained( model_name=args.model, offload=args.offload, @@ -39,12 +50,16 @@ batch_size=1, max_length=MAX_LEN, dtype=DTYPE, - device=DEVICE + device=DEVICE, + config=config ) +print('llm', llm.config) + eos_tokens = llm.config.eos_token_id if not isinstance(eos_tokens, list): eos_tokens = [eos_tokens] llm.alloc() + if args.cuda_graph: llm.initialize_cuda_graph([1]) attention_mask = make_causal_mask((MAX_LEN, MAX_LEN), DEVICE) @@ -52,43 +67,48 @@ position_ids = torch.arange(MAX_LEN, device=DEVICE).unsqueeze(0) prefix_len = tokens.shape[1] -logits = llm.graph_inference(input_ids=tokens, position_ids=position_ids[:,:prefix_len], - storage_ids=storage_ids[:prefix_len], attention_mask=attention_mask[:prefix_len])[0] +logits = llm.graph_inference(input_ids=tokens, position_ids=position_ids[:, :prefix_len], + storage_ids=storage_ids[:prefix_len], attention_mask=attention_mask[:prefix_len])[0] torch.cuda.synchronize() t1 = time.time() + generated_tokens = [] pos = 0 +print('----start generating answers.----') for i in range(GEN_LEN): next_token = logits[-1:].argmax(dim=-1, keepdim=True) generated_tokens.append(next_token.item()) - + generated_text = ( - tokenizer.decode( - generated_tokens, - skip_special_tokens=True, - clean_up_tokenization_spaces=True, - spaces_between_special_tokens=False, - ) - .strip() - .split(" ") - ) - - + tokenizer.decode( + generated_tokens, + skip_special_tokens=True, + clean_up_tokenization_spaces=True, + spaces_between_special_tokens=False, + ) + .strip() + .split(" ") + ) + now = len(generated_text) - 1 if now > pos: - print(" ".join(generated_text[pos:now]), end=" ", flush=True) - pos = now - - if (is_sentence_complete_regex(generated_text[-1]) and (i >= GEN_LEN - 32)) or (find_first_element_position(next_token, eos_tokens) >= 0): - break - - logits = llm.graph_inference(input_ids=next_token, position_ids=position_ids[:,prefix_len+i:prefix_len+i+1], - storage_ids=storage_ids[prefix_len+i : prefix_len+i+1], attention_mask=attention_mask[prefix_len+i:prefix_len+i+1])[0] + print(" ".join(generated_text[pos:now]), end=" ", flush=True) + pos = now + + if (is_sentence_complete_regex(generated_text[-1]) and (i >= GEN_LEN - 32)) or ( + find_first_element_position(next_token, eos_tokens) >= 0): + break + + logits = llm.graph_inference(input_ids=next_token, position_ids=position_ids[:, prefix_len + i:prefix_len + i + 1], + storage_ids=storage_ids[prefix_len + i: prefix_len + i + 1], + attention_mask=attention_mask[prefix_len + i:prefix_len + i + 1])[0] print(" ".join(generated_text[pos:]), flush=True) +print('----end generating answers.----') torch.cuda.synchronize() t2 = time.time() dec_len = len(generated_tokens) -logger.info(TextColors.colorize("Avg Accept Tokens {:.2f} | TPOT {:.2f} ms ".format(1, 1000 * (t2-t1)/dec_len), "magenta")) +logger.info( + TextColors.colorize("Avg Accept Tokens {:.2f} | TPOT {:.2f} ms ".format(1, 1000 * (t2 - t1) / dec_len), "magenta")) \ No newline at end of file diff --git a/examples/hf_generate.py b/examples/hf_generate.py deleted file mode 100644 index a2261d6..0000000 --- a/examples/hf_generate.py +++ /dev/null @@ -1,16 +0,0 @@ -# Load model directly -import torch -from transformers import AutoTokenizer, AutoModelForCausalLM -import argparse -parser = argparse.ArgumentParser() -parser.add_argument('--model', type=str, default="meta-llama/Llama-3.1-8B-Instruct",help='model') -parser.add_argument('--G', type=int, default=512, help='generation length') -args = parser.parse_args() - -tokenizer = AutoTokenizer.from_pretrained(args.model) -model = AutoModelForCausalLM.from_pretrained(args.model, torch_dtype=torch.float16, _attn_implementation="eager").to("cuda:0") -text = "Tell me what you know about Reinforcement Learning in 100 words." -input_ids = tokenizer.encode(text=text, return_tensors="pt").to("cuda:0") - -output = model.generate(input_ids, do_sample=False, max_new_tokens=args.G) -print(tokenizer.decode(output[0], skip_special_tokens=True)) \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 371e1a7..85e36d6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,10 +1,8 @@ torch==2.4.1 -transformers==4.48.0 +transformers==4.47.0 huggingface-hub==0.27.0 transformers-stream-generator==0.0.5 optimum==1.23.3 autoawq==0.2.7.post3 autoawq-kernels==0.0.8 -gradio -protobuf -sentencepiece \ No newline at end of file +gradio \ No newline at end of file diff --git a/umbrella/attn/cache.py b/umbrella/attn/cache.py index 81a0800..a21e83b 100644 --- a/umbrella/attn/cache.py +++ b/umbrella/attn/cache.py @@ -2,24 +2,75 @@ import torch import flashinfer import math -class KV_Cache: - def __init__(self, - config :AutoConfig, - batch_size :int = 1, - max_length :int = 256, - device :str = 'cuda:0', - dtype = torch.float16) -> None: + +def mha_flash(q, k, v, kv_layout, attn_mask): + return flashinfer.single_prefill_with_kv_cache( + q=q, + k=k, + v=v, + kv_layout=kv_layout, + custom_mask=attn_mask, + allow_fp16_qk_reduction=True + ) + +def mha(q, k, v, attn_mask): + """ + Args: + q (torch.Tensor): Query tensor of shape (query_len, q_head, head_dim) + k (torch.Tensor): Key tensor of shape (kv_len, kv_head, head_dim) + v (torch.Tensor): Value tensor of shape (kv_len, kv_head, head_dim) + attn_mask (torch.Tensor): Value tensor of shape (q_len, kv_len) + + Returns: + + """ + q_len, q_head, head_dim = q.shape + kv_len, kv_head, _ = k.shape + assert (q_head % kv_head == 0) + num_kv_groups = q_head // kv_head + + # Step 1: Reshape Q for GQA + # (q_len, q_head, head_dim) -> (kv_head, num_kv_groups, q_len, head_dim) + q = q.transpose(0,1).reshape(kv_head, num_kv_groups, q_len, head_dim) + k = k.transpose(0,1) + v = v.transpose(0,1) + # Step 2: Compute Attention Scores (kv_head, num_kv_groups, q_len, kv_len) + attn_scores = torch.einsum('hgld,hmd->hglm', q, k) / math.sqrt(head_dim) + + # Step 3: Apply Attention Mask + # (kv_head, num_groups, q_len, kv_len) + attn_mask = attn_mask.unsqueeze(0).unsqueeze(0).expand(kv_head, num_kv_groups, -1, -1) + attn_scores.masked_fill_(~attn_mask, torch.finfo(attn_scores.dtype).min) + + # Step 4: Compute Attention Weights + # (kv_head, num_kv_groups, q_len, kv_len) + attn_weights = torch.nn.functional.softmax(attn_scores, dim=-1, dtype=torch.float32).to(q.dtype) + + # Step 5: Compute Context Vector + # (kv_head, num_kv_groups, q_len, head_dim) + hidden_states = torch.einsum('hglm,hmd->hgld', attn_weights, v) + hidden_states = hidden_states.reshape(q_head, q_len, head_dim) + hidden_states = hidden_states.transpose(0, 1).contiguous() + return hidden_states + + +class KV_Cache: + def __init__(self, + config: AutoConfig, + batch_size: int = 1, + max_length: int = 256, + device: str = 'cuda:0', + dtype=torch.float16) -> None: self.config = config self.max_length = max_length self.device = device self.dtype = dtype - self.head_dim = getattr(config, 'head_dim', config.hidden_size // config.num_attention_heads) self.k_cache = torch.zeros( config.num_hidden_layers, max_length, config.num_key_value_heads, - self.head_dim, + config.hidden_size // config.num_attention_heads, device=self.device, dtype=self.dtype ) @@ -28,92 +79,72 @@ def __init__(self, config.num_hidden_layers, max_length, config.num_key_value_heads, - self.head_dim, + config.hidden_size // config.num_attention_heads, device=self.device, dtype=self.dtype ) self.num_layers = config.num_hidden_layers self.kv_offset = 0 - self.num_key_value_heads = config.num_key_value_heads - self.num_attention_heads = config.num_attention_heads - self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads - - def gather_kv_incremental(self, indices: torch.LongTensor, offset:int): - self.k_cache[:,offset:offset + len(indices), :,:] = self.k_cache[:,indices, :,:] - self.v_cache[:,offset:offset + len(indices), :,:] = self.v_cache[:,indices, :,:] + def gather_kv_incremental(self, indices: torch.LongTensor, offset: int): + self.k_cache[:, offset:offset + len(indices), :, :] = self.k_cache[:, indices, :, :] + self.v_cache[:, offset:offset + len(indices), :, :] = self.v_cache[:, indices, :, :] - self.k_cache[:,offset + len(indices):, :,:] = 0.0 - self.v_cache[:,offset + len(indices):, :,:] = 0.0 + self.k_cache[:, offset + len(indices):, :, :] = 0.0 + self.v_cache[:, offset + len(indices):, :, :] = 0.0 self.kv_offset = offset + len(indices) - - - def update_kv_cache(self, - new_k_cache :torch.Tensor, - new_v_cache :torch.Tensor, - layer_idx :int, - storage_ids: torch.LongTensor - ): - + def update_kv_cache(self, + new_k_cache: torch.Tensor, + new_v_cache: torch.Tensor, + layer_idx: int, + storage_ids: torch.LongTensor + ): new_kv_len = storage_ids.shape[0] if layer_idx == 0: self.kv_offset += new_kv_len self.k_cache[layer_idx][self.kv_offset - new_kv_len:self.kv_offset] = new_k_cache self.v_cache[layer_idx][self.kv_offset - new_kv_len:self.kv_offset] = new_v_cache return self.k_cache[layer_idx][:self.kv_offset], self.v_cache[layer_idx][:self.kv_offset] - - def compute_attention(self, - query_states :torch.Tensor, - key_states :torch.Tensor, - value_states :torch.Tensor, - layer_idx, - storage_ids :torch.Tensor, - attention_mask :torch.Tensor, - logits_soft_cap = 0): - + + def compute_attention(self, + query_states: torch.Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx, + storage_ids: torch.Tensor, + attention_mask: torch.Tensor): key_states, value_states = self.update_kv_cache(key_states[0], value_states[0], layer_idx, storage_ids) - hidden_states = flashinfer.single_prefill_with_kv_cache( - q = query_states[0], - k = key_states, - v = value_states, - kv_layout="NHD", - custom_mask=attention_mask[:,:self.kv_offset], - allow_fp16_qk_reduction=True, - logits_soft_cap = logits_soft_cap - ) - + hidden_states = mha_flash(query_states[0], key_states, value_states, "NHD", attention_mask[:, :self.kv_offset]) + return hidden_states - + def clear(self): self.k_cache.zero_() self.v_cache.zero_() self.kv_offset = 0 - - def set_kv_len(self, kv_len :int): - self.kv_offset = kv_len + + def set_kv_len(self, kv_len: int): + self.kv_offset = kv_len class StaticKV_Cache: - - def __init__(self, - config :AutoConfig, - batch_size :int = 1, - max_length :int = 256, - device :str = 'cuda:0', - dtype = torch.float16) -> None: + def __init__(self, + config: AutoConfig, + batch_size: int = 1, + max_length: int = 256, + device: str = 'cuda:0', + dtype=torch.float16) -> None: self.config = config self.max_length = max_length self.device = device self.dtype = dtype - self.head_dim = getattr(config, 'head_dim', config.hidden_size // config.num_attention_heads) - self.k_cache = torch.zeros( config.num_hidden_layers, config.num_key_value_heads, max_length, - self.head_dim, + config.hidden_size // config.num_attention_heads, device=self.device, dtype=self.dtype ) @@ -122,7 +153,7 @@ def __init__(self, config.num_hidden_layers, config.num_key_value_heads, max_length, - self.head_dim, + config.hidden_size // config.num_attention_heads, device=self.device, dtype=self.dtype ) @@ -130,11 +161,10 @@ def __init__(self, self.kv_offset = 0 self.num_key_value_heads = config.num_key_value_heads self.num_attention_heads = config.num_attention_heads + self.head_dim = config.hidden_size // config.num_attention_heads self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads - - def gather_kv_incremental(self, indices: list[int], offset:int): - + def gather_kv_incremental(self, indices: list[int], offset: int): self.k_cache[..., offset:offset + len(indices), :] = self.k_cache[..., indices, :] self.v_cache[..., offset:offset + len(indices), :] = self.v_cache[..., indices, :] @@ -143,50 +173,173 @@ def gather_kv_incremental(self, indices: list[int], offset:int): self.kv_offset = offset + len(indices) - - - def update_kv_cache(self, - new_k_cache :torch.Tensor, - new_v_cache :torch.Tensor, - layer_idx :int, - storage_ids: torch.LongTensor - ): - + def update_kv_cache(self, + new_k_cache: torch.Tensor, + new_v_cache: torch.Tensor, + layer_idx: int, + storage_ids: torch.LongTensor + ): self.k_cache[layer_idx].index_copy_(dim=-2, index=storage_ids, source=new_k_cache) self.v_cache[layer_idx].index_copy_(dim=-2, index=storage_ids, source=new_v_cache) - + return self.k_cache[layer_idx], self.v_cache[layer_idx] - def clear(self): self.k_cache.zero_() self.v_cache.zero_() self.kv_offset = 0 - - def set_kv_len(self, kv_len :int): - self.kv_offset = kv_len - - def compute_attention(self, - query_states :torch.Tensor, - key_states :torch.Tensor, - value_states :torch.Tensor, - layer_idx, - storage_ids :torch.Tensor, - attention_mask :torch.Tensor): + + def set_kv_len(self, kv_len: int): + self.kv_offset = kv_len + + def compute_attention(self, + query_states: torch.Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx, + storage_ids: torch.Tensor, + attention_mask: torch.Tensor): bsz, _, q_len, _ = query_states.shape - - key_states, value_states = self.update_kv_cache(key_states[0], value_states[0], layer_idx, storage_ids) + + key_states, value_states = self.update_kv_cache(key_states[0], value_states[0], layer_idx, storage_ids) query_states = query_states[0] - + query_states = query_states.reshape(self.num_key_value_heads, q_len * self.num_key_value_groups, self.head_dim) attn_weights = torch.matmul(query_states, key_states.transpose(1, 2)) / math.sqrt(self.head_dim) - mask = attention_mask[None,:,:].repeat(1, self.num_key_value_groups, 1) - + mask = attention_mask[None, :, :].repeat(1, self.num_key_value_groups, 1) + attn_weights.masked_fill_(~mask, torch.finfo(attn_weights.dtype).min) - + attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) hidden_states = torch.matmul(attn_weights, value_states) hidden_states = hidden_states.reshape(bsz, self.num_attention_heads, q_len, -1) hidden_states = hidden_states.transpose(1, 2).contiguous() - + + return hidden_states + + +class SlidingWindowKV_Cache: + def __init__(self, + config: AutoConfig, + batch_size: int = 1, + max_length: int = 256, + device: str = 'cuda:0', + dtype=torch.float16) -> None: + + self.config = config + self.max_length = max_length + self.window_size = self.config.window_size + self.device = device + self.dtype = dtype + self.num_key_value_heads = config.num_key_value_heads + self.num_attention_heads = config.num_attention_heads + self.head_dim = config.hidden_size // config.num_attention_heads + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + + # initializing Key-Value Cache + self.k_cache = torch.zeros( + config.num_hidden_layers, + max_length, + config.num_key_value_heads, + config.hidden_size // config.num_attention_heads, + device=self.device, + dtype=self.dtype + ) + + self.v_cache = torch.zeros( + config.num_hidden_layers, + max_length, + config.num_key_value_heads, + config.hidden_size // config.num_attention_heads, + device=self.device, + dtype=self.dtype + ) + self.kv_offset = 0 + + def update_kv_cache(self, + new_k_cache: torch.Tensor, + new_v_cache: torch.Tensor, + layer_idx: int, + storage_ids: torch.LongTensor + ): + + new_kv_len = storage_ids.shape[0] + # # calculating new offset + need_len = new_kv_len + min(self.kv_offset, self.window_size) + if need_len > self.max_length: + raise ValueError(f'now kv_offset {self.kv_offset} new_kv_len {new_kv_len} need_len {need_len} exceeds max_length {self.max_length}') + + # update KV Cache + if layer_idx == 0: + self.kv_offset += new_kv_len + if self.kv_offset >= self.max_length: + # print(f'layer idx need shift data, {self.kv_offset}') + self.k_cache[:, :self.window_size, :, :] = self.k_cache[:, self.kv_offset - self.window_size: self.kv_offset, :, :] + self.v_cache[:, :self.window_size, :, :] = self.v_cache[:, self.kv_offset - self.window_size: self.kv_offset, :, :] + self.kv_offset = self.window_size + new_kv_len + # print(f'layer_idx, add, {self.kv_offset}') + self.k_cache[layer_idx][self.kv_offset - new_kv_len:self.kv_offset] = new_k_cache + self.v_cache[layer_idx][self.kv_offset - new_kv_len:self.kv_offset] = new_v_cache + start_idx = max(0, self.kv_offset - (self.window_size + new_kv_len)) + return self.k_cache[layer_idx][start_idx:self.kv_offset], self.v_cache[layer_idx][start_idx:self.kv_offset] + + def clear(self): + self.k_cache.zero_() + self.v_cache.zero_() + self.kv_offset = 0 + + def set_kv_len(self, kv_len: int): + self.kv_offset = kv_len + + def gather_kv_incremental(self, indices: torch.LongTensor, offset: int): + self.k_cache[:, offset:offset + len(indices), :, :] = self.k_cache[:, indices, :, :] + self.v_cache[:, offset:offset + len(indices), :, :] = self.v_cache[:, indices, :, :] + + self.k_cache[:, offset + len(indices):, :, :] = 0.0 + self.v_cache[:, offset + len(indices):, :, :] = 0.0 + + self.kv_offset = offset + len(indices) + + def compute_attention(self, + query_states: torch.Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx, + storage_ids: torch.Tensor, + attention_mask: torch.Tensor): + """ + Computes the attention output using a sliding window mechanism. + + This function updates the KV cache, constructs the appropriate attention mask that considers: + - Sliding window attention (limits visible tokens to a fixed window size) + - Original `attention_mask` (to handle padding tokens) + - Causal Mask (ensures auto-regressive decoding) + + Args: + query_states (torch.Tensor): Query tensor of shape (batch, num_heads, query_len, head_dim) + key_states (torch.Tensor): Key tensor of shape (batch, num_heads, kv_len, head_dim) + value_states (torch.Tensor): Value tensor of shape (batch, num_heads, kv_len, head_dim) + layer_idx (int): Current layer index in the transformer model. + storage_ids (torch.Tensor): Index positions to store the new KV cache. + attention_mask (torch.Tensor): Casual mask. (q_len, squence_len) + + Returns: + hidden_states (torch.Tensor): Output hidden states after applying attention. + """ + + # Step 1: Update KV Cache (keep only the latest window_size tokens) + key_states, value_states = self.update_kv_cache(key_states[0], value_states[0], layer_idx, storage_ids) + q_len = storage_ids.shape[0] + kv_len = key_states.shape[0] + + # Step 2: Generate Sliding Window Attention Mask + # Create a 2D attention mask where each query position can attend only to the latest `window_size` tokens + query_indices = (kv_len - q_len) + torch.arange(q_len, device=self.device).unsqueeze(1) # Shape: (q_len, 1) + kv_indices = torch.arange(kv_len, device=self.device).unsqueeze(0) # Shape: (1, kv_offset) + diff = query_indices - kv_indices + # Compute boolean mask: True if kv_index is within `window_size` of the query_index + attn_mask = torch.logical_and(diff <= self.window_size, diff >= 0) # (q_len, kv_offset) + # Step 3: Compute Attention Both are ok. + # hidden_states = mha(query_states[0], key_states, value_states, attn_mask) + hidden_states = mha_flash(query_states[0], key_states, value_states, "NHD", attn_mask) return hidden_states diff --git a/umbrella/models/auto_model.py b/umbrella/models/auto_model.py index f7c4105..da2d2d3 100644 --- a/umbrella/models/auto_model.py +++ b/umbrella/models/auto_model.py @@ -1,7 +1,6 @@ from .llama import Llama, LlamaAwq, LlamaOffload, LlamaAwqOffload, LlamaCudagraph from .qwen import Qwen, QwenOffload, QwenAwq, QwenAwqOffload, QwenCudagraph -from .gemma import Gemma2 -from .mistral import Mistral, MistralAwqOffload, MistralOffload, MistralCudagraph, MistralAwq +from .mistral import Mistral, MistralOffload, MistralAwq, MistralAwqOffload, MistralCudagraph class AutoModelLM: """ 自动模型加载器,根据模型类型动态加载对应的类。 @@ -48,11 +47,8 @@ class AutoModelLM: "Qwen/Qwen2.5-32B-Instruct-AWQ": QwenAwqOffload, "Qwen/Qwen2.5-72B-Instruct-AWQ": QwenAwqOffload, "KirillR/QwQ-32B-Preview-AWQ": QwenAwqOffload, - "casperhansen/deepseek-r1-distill-qwen-32b-awq":QwenAwqOffload, - "mistralai/Mistral-7B-Instruct-v0.3": MistralOffload, - "solidrust/Mistral-7B-Instruct-v0.3-AWQ": MistralAwqOffload, - "mistralai/Mistral-Small-24B-Instruct-2501": MistralOffload, - "stelterlab/Mistral-Small-24B-Instruct-2501-AWQ": MistralAwqOffload + "casperhansen/deepseek-r1-distill-qwen-32b-awq":QwenAwqOffload + } _MODEL_MAPPING = { @@ -107,17 +103,9 @@ class AutoModelLM: "Qwen/Qwen2.5-32B-Instruct-AWQ": QwenAwq, "Qwen/Qwen2.5-72B-Instruct-AWQ": QwenAwq, "KirillR/QwQ-32B-Preview-AWQ": QwenAwq, - "casperhansen/deepseek-r1-distill-qwen-32b-awq":QwenAwq, - "google/gemma-2-2b-it": Gemma2, - "google/gemma-2-9b-it": Gemma2, - "google/gemma-2-27b-it": Gemma2, - "google/gemma-2-2b": Gemma2, - "mistralai/Mistral-7B-Instruct-v0.3": Mistral, - "solidrust/Mistral-7B-Instruct-v0.3-AWQ": MistralAwq, + "casperhansen/deepseek-r1-distill-qwen-32b-awq": QwenAwq, "mistralai/Mistral-Small-24B-Instruct-2501": Mistral, - "stelterlab/Mistral-Small-24B-Instruct-2501-AWQ": MistralAwq, - "PyrTools/Ministral-8B-Instruct-2410-AWQ": MistralAwq, - "mistralai/Ministral-8B-Instruct-2410": Mistral + "mistralai/Mistral-7B-Instruct-v0.3": Mistral } _CUDAGRAPH_MODEL_MAPPING = { @@ -148,9 +136,7 @@ class AutoModelLM: "Qwen/Qwen2.5-14B-Instruct": QwenCudagraph, "Qwen/Qwen2.5-32B-Instruct": QwenCudagraph, "Qwen/Qwen2.5-72B-Instruct": QwenCudagraph, - "Qwen/QwQ-32B-Preview": QwenCudagraph, - "mistralai/Mistral-7B-Instruct-v0.3": MistralCudagraph, - "mistralai/Ministral-8B-Instruct-2410": MistralCudagraph + "Qwen/QwQ-32B-Preview": QwenCudagraph } @classmethod @@ -179,4 +165,4 @@ def from_pretrained(cls, model_name, offload=False, cuda_graph=False, **kwargs): raise ValueError(f"Model type '{model_name}' is not supported (offload). " f"Supported (offload) types: {list(cls._OFFLOAD_MODEL_MAPPING.keys())}") model_class = cls._OFFLOAD_MODEL_MAPPING[model_name] - return model_class(model_name = model_name, **kwargs) + return model_class(model_name = model_name, **kwargs) \ No newline at end of file diff --git a/umbrella/models/gemma.py b/umbrella/models/gemma.py deleted file mode 100644 index 2b0714b..0000000 --- a/umbrella/models/gemma.py +++ /dev/null @@ -1,170 +0,0 @@ -from transformers import Gemma2ForCausalLM, Gemma2Config -from transformers.models.gemma2.modeling_gemma2 import Gemma2RotaryEmbedding -import torch -import torch.nn.functional as F -import gc -import flashinfer -from ..attn.cache import KV_Cache, StaticKV_Cache -from .gemma_layer import Gemma2Layer -from .base import LLMBase -from .model_utils import apply_rotary_pos_emb, layer_norm, capture_graph, layer_norm_gemma -from tqdm import tqdm - -class Gemma2(LLMBase): - def __init__(self, - model_name: str, - batch_size :int = 1, - max_length :int = 256, - device :str = 'cuda:0', - dtype = torch.float16) -> None: - super().__init__() - self.batch_size = batch_size - self.device = device - self.dtype = dtype - self.config = Gemma2Config.from_pretrained(model_name) - self.model_name = model_name - self.max_length = max_length - self.hidden_size = self.config.hidden_size - self.num_heads = self.config.num_attention_heads - self.head_dim = self.config.head_dim - self.num_key_value_heads = self.config.num_key_value_heads - self.num_key_value_groups = self.num_heads // self.num_key_value_heads - self.max_position_embeddings = self.config.max_position_embeddings - self.rope_theta = self.config.rope_theta - self.eos_tokens = self.config.eos_token_id if (isinstance(self.config.eos_token_id, list)) else [self.config.eos_token_id] - self.sliding_window = self.config.sliding_window - self.attn_logit_softcapping = self.config.attn_logit_softcapping - self.final_logit_softcapping = self.config.final_logit_softcapping - - def alloc(self, **kwargs): - self.kv_cache = KV_Cache(self.config, max_length=self.max_length, device=self.device, dtype=self.dtype, batch_size=self.batch_size) - hf_model = Gemma2ForCausalLM.from_pretrained(self.model_name, torch_dtype = self.dtype) - self.embed_tokens = hf_model.model.embed_tokens.weight.detach().to(self.device) - if self.config.tie_word_embeddings: - self.lm_head = self.embed_tokens - else: - self.lm_head = hf_model.lm_head.weight.detach().to(self.device) - - self.norm_weight = hf_model.model.norm.weight.detach().to(self.device) - self.norm_variance_epsilon = hf_model.model.norm.eps - - self.inv_freq = hf_model.model.rotary_emb.inv_freq.detach().to(self.device) - self.attention_scaling = hf_model.model.rotary_emb.attention_scaling - - position_ids = torch.arange(0, self.max_length).unsqueeze(0).to(self.device) - inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) - position_ids_expanded = position_ids[:, None, :].float() - freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) - emb = torch.cat((freqs, freqs), dim=-1) - self.cos_cache = emb.cos()[0] - self.sin_cache = emb.sin()[0] - self.cos_cache = self.cos_cache * self.attention_scaling - self.sin_cache = self.sin_cache * self.attention_scaling - self.cos_cache = self.cos_cache.to(self.dtype) - self.sin_cache = self.sin_cache.to(self.dtype) - - self.layers :list[Gemma2Layer] = [] - - for idx, hf_layer in enumerate(hf_model.model.layers): - layer = Gemma2Layer(idx) - layer.init_parameters(hf_layer=hf_layer) - layer.to(self.device) - self.layers.append(layer) - hf_model.model.layers[idx] = None - gc.collect() - - self.num_layers = len(self.layers) - - @torch.inference_mode() - def layer_compute(self, - buffer: Gemma2Layer, - layer_idx :int, - hidden_states: torch.FloatTensor, - position_ids: torch.LongTensor, - attention_mask: torch.FloatTensor, - storage_ids: torch.LongTensor): - - if buffer.is_sliding and attention_mask is not None: - min_dtype = torch.finfo(hidden_states.dtype).min - sliding_window_mask = torch.tril( - torch.ones_like(attention_mask, dtype=torch.bool), diagonal=-self.sliding_window - ) - attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask) - if attention_mask.shape[-1] <= 1: # when decoding - attention_mask = attention_mask[:, :, :, -self.sliding_window :] - - - residual = hidden_states - bsz, q_len, _ = hidden_states.size() - - hidden_states = layer_norm_gemma(hidden_states, buffer.input_layernorm_variance_epsilon, buffer.input_layernorm_weight) - bsz, q_len, _ = hidden_states.size() - query_states = F.linear(hidden_states, buffer.wq) - key_states = F.linear(hidden_states, buffer.wk) - value_states = F.linear(hidden_states, buffer.wv) - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim) - - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, self.cos_cache, self.sin_cache, position_ids) - hidden_states = self.kv_cache.compute_attention( - query_states, key_states, value_states, layer_idx, storage_ids, attention_mask, self.attn_logit_softcapping - ) - hidden_states = hidden_states.reshape(bsz, q_len, self.num_heads * self.head_dim) - - #logit soft_capping - - - hidden_states = F.linear(hidden_states, buffer.wo) - hidden_states = layer_norm_gemma(hidden_states, buffer.post_attention_layernorm_variance_epsilon, buffer.post_attention_layernorm_weight) - hidden_states = residual + hidden_states - - residual = hidden_states - - #MLP - hidden_states = layer_norm_gemma(hidden_states, buffer.pre_feedforward_layernorm_variance_epsilon, buffer.pre_feedforward_layernorm_weight) - up = F.linear(hidden_states, buffer.up_proj) - gate = F.linear(hidden_states, buffer.gate_proj) - gate = F.gelu(gate, approximate='tanh') #hidden activation is gelu (tanh approx.) - hidden_states = gate * up - hidden_states = F.linear(hidden_states, buffer.down_proj) - - hidden_states = layer_norm_gemma(hidden_states, buffer.post_feedforward_layernorm_variance_epsilon, buffer.post_feedforward_layernorm_weight) - hidden_states = residual + hidden_states - - return hidden_states - - @torch.inference_mode() - def inference(self, - input_ids: torch.LongTensor, - position_ids: torch.LongTensor, - attention_mask: torch.FloatTensor, - storage_ids: torch.LongTensor): - - hidden_states = F.embedding(input_ids, self.embed_tokens) - normalizer = torch.tensor(self.hidden_size**.5, dtype = hidden_states.dtype) - hidden_states = hidden_states * normalizer - - for idx in range(self.num_layers): - hidden_states = self.layer_compute(self.layers[idx], idx, hidden_states, position_ids, attention_mask, storage_ids) - - b, s, h = hidden_states.shape - hidden_states = hidden_states.reshape(b * s, h) - hidden_states = flashinfer.gemma_rmsnorm(hidden_states, self.norm_weight, self.norm_variance_epsilon) - hidden_states = hidden_states.reshape(b, s, h) - logits = F.linear(hidden_states, self.lm_head).float() - if(self.final_logit_softcapping is not None): - logits = logits / self.final_logit_softcapping - logits = F.tanh(logits) - logits = logits * self.final_logit_softcapping - - return logits - - def gather_kv_incremental(self, indices: torch.LongTensor, offset:int): - - self.kv_cache.gather_kv_incremental(indices=indices, offset=offset) - - def clear(self): - - self.kv_cache.clear() - diff --git a/umbrella/models/gemma_layer.py b/umbrella/models/gemma_layer.py deleted file mode 100644 index 08eb737..0000000 --- a/umbrella/models/gemma_layer.py +++ /dev/null @@ -1,113 +0,0 @@ -from __future__ import annotations -import torch -from transformers.models.gemma2.modeling_gemma2 import Gemma2DecoderLayer -from ..quantization.awq_utils import AwqLinear - -class Gemma2Layer: - def __init__(self, layer_idx, device = "cpu") -> None: - self.wq :torch.Tensor = None - self.wk :torch.Tensor = None - self.wv :torch.Tensor = None - self.wo :torch.Tensor = None - - self.gate_proj :torch.Tensor = None - self.up_proj :torch.Tensor = None - self.down_proj :torch.Tensor = None - - self.input_layernorm_weight :torch.Tensor = None - self.input_layernorm_variance_epsilon :float = 0.0 - - self.post_attention_layernorm_weight :torch.Tensor = None - self.post_attention_layernorm_variance_epsilon :float = 0.0 - - self.pre_feedforward_layernorm_weight :torch.Tensor = None - self.pre_feedforward_layernorm_variance_epsilon: float = 0.0 - - self.post_feedforward_layernorm_weight :torch.Tensor = None - self.post_feedforward_layernorm_variance_epsilon: float = 0.0 - - self.layer_idx = layer_idx - self.device = device - - self.is_sliding = False - self.sliding_window = 0 - - def init_parameters(self, hf_layer: Gemma2DecoderLayer): - - self.wq :torch.Tensor= hf_layer.self_attn.q_proj.weight.detach() - self.wk :torch.Tensor= hf_layer.self_attn.k_proj.weight.detach() - self.wv :torch.Tensor= hf_layer.self_attn.v_proj.weight.detach() - self.wo :torch.Tensor= hf_layer.self_attn.o_proj.weight.detach() - - self.gate_proj = hf_layer.mlp.gate_proj.weight.detach() - self.up_proj = hf_layer.mlp.up_proj.weight.detach() - self.down_proj = hf_layer.mlp.down_proj.weight.detach() - - self.input_layernorm_weight = hf_layer.input_layernorm.weight.detach() - self.input_layernorm_variance_epsilon = hf_layer.input_layernorm.eps - - self.post_attention_layernorm_weight = hf_layer.post_attention_layernorm.weight.detach() - self.post_attention_layernorm_variance_epsilon = hf_layer.post_attention_layernorm.eps - - self.pre_feedforward_layernorm_weight :torch.Tensor = hf_layer.pre_feedforward_layernorm.weight.detach() - self.pre_feedforward_layernorm_variance_epsilon: float = hf_layer.pre_feedforward_layernorm.eps - - self.post_feedforward_layernorm_weight :torch.Tensor = hf_layer.post_feedforward_layernorm.weight.detach() - self.post_feedforward_layernorm_variance_epsilon: float = hf_layer.post_feedforward_layernorm.eps - - self.is_sliding = not bool(self.layer_idx % 2) - self.sliding_window = hf_layer.sliding_window - - def to(self, device:str = 'cuda:0', non_blocking = True): - - self.device = device - self.input_layernorm_weight = self.input_layernorm_weight.to(device, non_blocking=non_blocking) - self.post_attention_layernorm_weight = self.post_attention_layernorm_weight.to(device, non_blocking=non_blocking) - self.pre_feedforward_layernorm_weight = self.pre_feedforward_layernorm_weight.to(device, non_blocking=non_blocking) - self.post_feedforward_layernorm_weight = self.post_feedforward_layernorm_weight.to(device, non_blocking=non_blocking) - - self.wq = self.wq.to(device, non_blocking=non_blocking) - self.wk = self.wk.to(device, non_blocking=non_blocking) - self.wv = self.wv.to(device, non_blocking=non_blocking) - self.wo = self.wo.to(device, non_blocking=non_blocking) - self.gate_proj = self.gate_proj.to(device, non_blocking=non_blocking) - self.up_proj = self.up_proj.to(device, non_blocking=non_blocking) - self.down_proj = self.down_proj.to(device, non_blocking=non_blocking) - - def copy(self, layer: Gemma2Layer): - self.wq.copy_(layer.wq, non_blocking=True) - self.wk.copy_(layer.wk, non_blocking=True) - self.wv.copy_(layer.wv, non_blocking=True) - self.wo.copy_(layer.wo, non_blocking=True) - self.gate_proj.copy_(layer.gate_proj, non_blocking=True) - self.up_proj.copy_(layer.up_proj, non_blocking=True) - self.down_proj.copy_(layer.down_proj, non_blocking=True) - - self.input_layernorm_weight.copy_(layer.input_layernorm_weight, non_blocking=True) - self.post_attention_layernorm_weight.copy_(layer.post_attention_layernorm_weight, non_blocking=True) - self.pre_feedforward_layernorm_weight.copy_(layer.pre_feedforward_layernorm_weight, non_blocking=True) - self.post_feedforward_layernorm_weight.copy_(layer.post_feedforward_layernorm_weight, non_blocking=True) - - self.input_layernorm_variance_epsilon= layer.input_layernorm_variance_epsilon - self.post_attention_layernorm_variance_epsilon = layer.post_attention_layernorm_variance_epsilon - self.pre_feedforward_layernorm_variance_epsilon = layer.pre_feedforward_layernorm_variance_epsilon - self.post_feedforward_layernorm_variance_epsilon = layer.post_feedforward_layernorm_variance_epsilon - - self.layer_idx = layer.layer_idx - self.is_sliding = layer.is_sliding - - def alloc_space(self, layer: Gemma2Layer, device): - self.device = device - self.wq = torch.zeros_like(layer.wq).to(device) - self.wk = torch.zeros_like(layer.wk).to(device) - self.wv = torch.zeros_like(layer.wv).to(device) - self.wo = torch.zeros_like(layer.wo).to(device) - - - self.gate_proj = torch.zeros_like(layer.gate_proj).to(device) - self.up_proj = torch.zeros_like(layer.up_proj).to(device) - self.down_proj = torch.zeros_like(layer.down_proj).to(device) - self.input_layernorm_weight = torch.zeros_like(layer.input_layernorm_weight).to(device) - self.post_attention_layernorm_weight = torch.zeros_like(layer.post_attention_layernorm_weight).to(device) - self.pre_feedforward_layernorm_weight = torch.zeros_like(layer.pre_feedforward_layernorm_weight).to(device) - self.post_feedforward_layernorm_weight = torch.zeros_like(layer.post_attention_layernorm_weight).to(device) diff --git a/umbrella/models/llama.py b/umbrella/models/llama.py index 3c18b46..5f6478c 100644 --- a/umbrella/models/llama.py +++ b/umbrella/models/llama.py @@ -8,6 +8,9 @@ from .base import LLMBase from .model_utils import apply_rotary_pos_emb, layer_norm, capture_graph from tqdm import tqdm +""" +Standard LLmMA +""" class Llama(LLMBase): def __init__(self, model_name: str, @@ -60,7 +63,8 @@ def alloc(self, **kwargs): self.sin_cache = self.sin_cache.to(self.dtype) self.layers :list[LlamaLayer] = [] - + + ## loop hf_model.model.layers, and transfer to format LlamaLayer, and store it to self.layers for idx, hf_layer in enumerate(hf_model.model.layers): layer = LlamaLayer(idx) layer.init_parameters(hf_layer=hf_layer) @@ -120,7 +124,6 @@ def inference(self, position_ids: torch.LongTensor, attention_mask: torch.FloatTensor, storage_ids: torch.LongTensor): - hidden_states = F.embedding(input_ids, self.embed_tokens) for idx in range(self.num_layers): hidden_states = self.layer_compute(self.layers[idx], idx, hidden_states, position_ids, attention_mask, storage_ids) @@ -141,7 +144,9 @@ def clear(self): self.kv_cache.clear() - +""" +Load weights of different layers to different devices. +""" class LlamaOffload(Llama): def __init__(self, model_name, batch_size = 1, max_length = 256, device = 'cuda:0', dtype=torch.float16): super().__init__(model_name, batch_size, max_length, device, dtype) @@ -218,7 +223,9 @@ def inference(self, logits = F.linear(hidden_states, self.lm_head).float() return logits - +""" +Acitivation-aware Quantization +""" class LlamaAwq(Llama): @@ -320,7 +327,8 @@ def inference(self, hidden_states = hidden_states.reshape(b, s, h) logits = F.linear(hidden_states, self.lm_head).float() return logits - + +"Combination of Offload and AWQ" class LlamaAwqOffload(LlamaOffload): def alloc(self, **kwargs): diff --git a/umbrella/models/mistral.py b/umbrella/models/mistral.py index 1b5dc10..a7837fc 100644 --- a/umbrella/models/mistral.py +++ b/umbrella/models/mistral.py @@ -1,41 +1,56 @@ -from transformers import MistralForCausalLM, MistralConfig, AutoModelForCausalLM +# from _typeshed import NoneType +from transformers import MistralForCausalLM, MistralConfig import torch import torch.nn.functional as F import gc import flashinfer -from ..attn.cache import KV_Cache, StaticKV_Cache -from .mistral_layer import MistralLayer, MistralAwqLayer, MistralPackedLayer +from ..attn.cache import KV_Cache, SlidingWindowKV_Cache +from .mistral_layer import MistralLayer, MistralPackedLayer from .base import LLMBase from .model_utils import apply_rotary_pos_emb, layer_norm, capture_graph from tqdm import tqdm +""" +Standard Mistral +Compared to Llama, Support sliding window + GQA +""" class Mistral(LLMBase): def __init__(self, - model_name: str, - batch_size: int=1, - max_length: int=256, - device: str = "cuda:0", - dtype = torch.float16) -> None: + model_name: str, + batch_size: int = 1, + max_length: int = 256, + device: str = 'cuda:0', + dtype=torch.float16, + config=None) -> None: super().__init__() self.batch_size = batch_size self.device = device - self.dtype = dtype - self.config = MistralConfig.from_pretrained(model_name) + self.dtype = dtype + if config: + self.config = config + else: + self.config = MistralConfig.from_pretrained(model_name) self.model_name = model_name self.max_length = max_length self.hidden_size = self.config.hidden_size self.num_heads = self.config.num_attention_heads - self.head_dim = getattr(self.config, 'head_dim', self.config.hidden_size // self.config.num_attention_heads) + self.head_dim = self.hidden_size // self.num_heads self.num_key_value_heads = self.config.num_key_value_heads - self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads # here > 1 in Mistral self.max_position_embeddings = self.config.max_position_embeddings self.rope_theta = self.config.rope_theta - self.eos_tokens = self.config.eos_token_id if (isinstance(self.config.eos_token_id, list)) else [self.config.eos_token_id] + self.eos_tokens = self.config.eos_token_id if (isinstance(self.config.eos_token_id, list)) else [ + self.config.eos_token_id] def alloc(self, **kwargs): - - self.kv_cache = KV_Cache(self.config, max_length=self.max_length, device=self.device, dtype=self.dtype, batch_size=self.batch_size) + # if using sliding winodw, use KV_static_Cache to update past key/value + if self.config.sliding_window is not None and self.config.sliding_window == True: + self.kv_cache = SlidingWindowKV_Cache(self.config, max_length=self.max_length, device=self.device, + dtype=self.dtype, batch_size=self.batch_size) + else: + self.kv_cache = KV_Cache(self.config, max_length=self.max_length, device=self.device, dtype=self.dtype, + batch_size=self.batch_size) hf_model = MistralForCausalLM.from_pretrained(self.model_name, torch_dtype=self.dtype) self.embed_tokens = hf_model.model.embed_tokens.weight.detach().to(self.device) if self.config.tie_word_embeddings: @@ -45,9 +60,17 @@ def alloc(self, **kwargs): self.norm_weight = hf_model.model.norm.weight.detach().to(self.device) self.norm_variance_epsilon = hf_model.model.norm.variance_epsilon - - self.inv_freq = hf_model.model.rotary_emb.inv_freq.detach().to(self.device) - self.attention_scaling = hf_model.model.rotary_emb.attention_scaling + + if hasattr(hf_model.model.layers[0].self_attn, "rotary_emb"): + rotary_emb = hf_model.model.layers[0].self_attn.rotary_emb + else: + rotary_emb = hf_model.model.rotary_emb + self.inv_freq = rotary_emb.inv_freq.detach().to(self.device) + if hasattr(rotary_emb, "attention_scaling"): + self.attention_scaling = rotary_emb.attention_scaling + else: + self.attention_scaling = 1.0 # 默认值 + position_ids = torch.arange(0, self.max_length).unsqueeze(0).to(self.device) inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) position_ids_expanded = position_ids[:, None, :].float() @@ -59,9 +82,10 @@ def alloc(self, **kwargs): self.sin_cache = self.sin_cache * self.attention_scaling self.cos_cache = self.cos_cache.to(self.dtype) self.sin_cache = self.sin_cache.to(self.dtype) - - self.layers :list[MistralLayer] = [] - + + self.layers: list[MistralLayer] = [] + + ## loop hf_model.model.layers, and transfer to format MistralLayer, and store it to self.layers for idx, hf_layer in enumerate(hf_model.model.layers): layer = MistralLayer(idx) layer.init_parameters(hf_layer=hf_layer) @@ -69,61 +93,66 @@ def alloc(self, **kwargs): self.layers.append(layer) hf_model.model.layers[idx] = None gc.collect() - + self.num_layers = len(self.layers) @torch.inference_mode() - def layer_compute(self, - buffer: MistralLayer, - layer_idx :int, - hidden_states: torch.FloatTensor, - position_ids: torch.LongTensor, - attention_mask: torch.FloatTensor, - storage_ids: torch.LongTensor): + def layer_compute(self, + buffer: MistralLayer, + layer_idx: int, + hidden_states: torch.FloatTensor, + position_ids: torch.LongTensor, + attention_mask: torch.FloatTensor, + storage_ids: torch.LongTensor): residual = hidden_states bsz, q_len, _ = hidden_states.size() - - hidden_states = layer_norm(hidden_states, buffer.input_layernorm_variance_epsilon, buffer.input_layernorm_weight) + + hidden_states = layer_norm(hidden_states, buffer.input_layernorm_variance_epsilon, + buffer.input_layernorm_weight) bsz, q_len, _ = hidden_states.size() + query_states = F.linear(hidden_states, buffer.wq) key_states = F.linear(hidden_states, buffer.wk) value_states = F.linear(hidden_states, buffer.wv) query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim) - - - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, self.cos_cache, self.sin_cache, position_ids) + + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, self.cos_cache, self.sin_cache, + position_ids) hidden_states = self.kv_cache.compute_attention( query_states, key_states, value_states, layer_idx, storage_ids, attention_mask ) - hidden_states = hidden_states.reshape(bsz, q_len, self.head_dim * self.num_heads) - + hidden_states = hidden_states.reshape(bsz, q_len, self.hidden_size) + hidden_states = F.linear(hidden_states, buffer.wo) hidden_states = residual + hidden_states residual = hidden_states - hidden_states = layer_norm(hidden_states, buffer.post_attention_layernorm_variance_epsilon, buffer.post_attention_layernorm_weight) + hidden_states = layer_norm(hidden_states, buffer.post_attention_layernorm_variance_epsilon, + buffer.post_attention_layernorm_weight) up = F.linear(hidden_states, buffer.up_proj) gate = F.linear(hidden_states, buffer.gate_proj) gate = F.silu(gate) hidden_states = gate * up hidden_states = F.linear(hidden_states, buffer.down_proj) hidden_states = residual + hidden_states - + return hidden_states - + @torch.inference_mode() def inference(self, - input_ids: torch.LongTensor, - position_ids: torch.LongTensor, - attention_mask: torch.FloatTensor, - storage_ids: torch.LongTensor): - - hidden_states = F.embedding(input_ids, self.embed_tokens) + input_ids: torch.LongTensor, + position_ids: torch.LongTensor, + attention_mask: torch.FloatTensor, + storage_ids: torch.LongTensor): + + hidden_states = F.embedding(input_ids, self.embed_tokens) + # print('ifnerence', attention_mask) for idx in range(self.num_layers): - hidden_states = self.layer_compute(self.layers[idx], idx, hidden_states, position_ids, attention_mask, storage_ids) - + hidden_states = self.layer_compute(self.layers[idx], idx, hidden_states, position_ids, attention_mask, + storage_ids) + b, s, h = hidden_states.shape hidden_states = hidden_states.reshape(b * s, h) @@ -132,25 +161,25 @@ def inference(self, logits = F.linear(hidden_states, self.lm_head).float() return logits - def gather_kv_incremental(self, indices: torch.LongTensor, offset:int): - + def gather_kv_incremental(self, indices: torch.LongTensor, offset: int): self.kv_cache.gather_kv_incremental(indices=indices, offset=offset) - + def clear(self): - - self.kv_cache.clear() + self.kv_cache.clear() class MistralOffload(Mistral): - def __init__(self, model_name, batch_size = 1, max_length = 256, device = 'cuda:0', dtype=torch.float16): + def __init__(self, model_name, batch_size=1, max_length=256, device='cuda:0', dtype=torch.float16): super().__init__(model_name, batch_size, max_length, device, dtype) self.load_stream = torch.cuda.Stream(device=device) - + def alloc(self, **kwargs): - - - self.num_cache_layers = kwargs["num_cache_layers"] if 'num_cache_layers' in kwargs else 0 - self.kv_cache = KV_Cache(self.config, max_length=self.max_length, device=self.device, dtype=self.dtype, batch_size=self.batch_size) + if self.config.sliding_window is not None: + self.kv_cache = SlidingWindowKV_Cache(self.config, max_length=self.max_length, device=self.device, + dtype=self.dtype, batch_size=self.batch_size) + else: + self.kv_cache = KV_Cache(self.config, max_length=self.max_length, device=self.device, dtype=self.dtype, + batch_size=self.batch_size) hf_model = MistralForCausalLM.from_pretrained(self.model_name, torch_dtype=self.dtype) self.embed_tokens = hf_model.model.embed_tokens.weight.detach().to(self.device) if self.config.tie_word_embeddings: @@ -160,9 +189,18 @@ def alloc(self, **kwargs): self.norm_weight = hf_model.model.norm.weight.detach().to(self.device) self.norm_variance_epsilon = hf_model.model.norm.variance_epsilon - - self.inv_freq = hf_model.model.rotary_emb.inv_freq.detach().to(self.device) - self.attention_scaling = hf_model.model.rotary_emb.attention_scaling + + print('hf_model', hf_model.model) + if hasattr(hf_model.model.layers[0].self_attn, "rotary_emb"): + rotary_emb = hf_model.model.layers[0].self_attn.rotary_emb + else: + rotary_emb = hf_model.model.rotary_emb + self.inv_freq = rotary_emb.inv_freq.detach().to(self.device) + if hasattr(rotary_emb, "attention_scaling"): + self.attention_scaling = rotary_emb.attention_scaling + else: + self.attention_scaling = 1.0 # 默认值 + position_ids = torch.arange(0, self.max_length).unsqueeze(0).to(self.device) inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) position_ids_expanded = position_ids[:, None, :].float() @@ -174,40 +212,42 @@ def alloc(self, **kwargs): self.sin_cache = self.sin_cache * self.attention_scaling self.cos_cache = self.cos_cache.to(self.dtype) self.sin_cache = self.sin_cache.to(self.dtype) - - self.layers :list[MistralLayer] = [] - - for idx, hf_layer in tqdm(enumerate(hf_model.model.layers), desc="initial offloaded model"): + + self.layers: list[MistralLayer] = [] + + ## loop hf_model.model.layers, and transfer to format MistralLayer, and store it to self.layers + for idx, hf_layer in enumerate(hf_model.model.layers): layer = MistralLayer(idx) layer.init_parameters(hf_layer=hf_layer) - if idx < self.num_cache_layers: - layer.to(self.device) + layer.to(self.device) self.layers.append(layer) hf_model.model.layers[idx] = None gc.collect() - + self.num_layers = len(self.layers) + assert self.num_layers % 2 == 0 self.buffer = [MistralLayer(-1, self.device) for _ in range(2)] self.buffer[0].alloc_space(self.layers[0], self.device) self.buffer[1].alloc_space(self.layers[0], self.device) - + @torch.inference_mode() def inference(self, - input_ids: torch.LongTensor, - position_ids: torch.LongTensor, - attention_mask: torch.FloatTensor, - storage_ids: torch.LongTensor): - + input_ids: torch.LongTensor, + position_ids: torch.LongTensor, + attention_mask: torch.FloatTensor, + storage_ids: torch.LongTensor): + hidden_states = F.embedding(input_ids, self.embed_tokens) if self.buffer[0].layer_idx != 0: self.buffer[0].copy(self.layers[0]) torch.cuda.synchronize() for idx in range(self.num_layers): with torch.cuda.stream(self.load_stream): - self.buffer[(idx + 1) % 2].copy(self.layers[(idx + 1)% self.num_layers]) - - hidden_states = self.layer_compute(self.buffer[idx % 2], idx, hidden_states, position_ids, attention_mask, storage_ids) + self.buffer[(idx + 1) % 2].copy(self.layers[(idx + 1) % self.num_layers]) + + hidden_states = self.layer_compute(self.buffer[idx % 2], idx, hidden_states, position_ids, attention_mask, + storage_ids) torch.cuda.synchronize() b, s, h = hidden_states.shape @@ -217,207 +257,30 @@ def inference(self, logits = F.linear(hidden_states, self.lm_head).float() return logits - class MistralAwq(Mistral): - def alloc(self, **kwargs): - - self.kv_cache = KV_Cache(self.config, max_length=self.max_length, device=self.device, dtype=self.dtype, batch_size=self.batch_size) - - hf_model = AutoModelForCausalLM.from_pretrained(self.model_name, torch_dtype=self.dtype) - self.embed_tokens = hf_model.model.embed_tokens.weight.detach().to(self.device) - if self.config.tie_word_embeddings: - self.lm_head = self.embed_tokens - else: - self.lm_head = hf_model.lm_head.weight.detach().to(self.device) - self.norm_weight = hf_model.model.norm.weight.detach().to(self.device) - self.norm_variance_epsilon = hf_model.model.norm.variance_epsilon - self.inv_freq = hf_model.model.rotary_emb.inv_freq.detach().to(self.device) - self.attention_scaling = hf_model.model.rotary_emb.attention_scaling - position_ids = torch.arange(0, self.max_length).unsqueeze(0).to(self.device) - inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) - position_ids_expanded = position_ids[:, None, :].float() - freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) - emb = torch.cat((freqs, freqs), dim=-1) - self.cos_cache = emb.cos()[0] - self.sin_cache = emb.sin()[0] - self.cos_cache = self.cos_cache * self.attention_scaling - self.sin_cache = self.sin_cache * self.attention_scaling - self.cos_cache = self.cos_cache.to(self.dtype) - self.sin_cache = self.sin_cache.to(self.dtype) - self.layers :list[MistralAwqLayer] = [] - - for idx, hf_layer in enumerate(hf_model.model.layers): - layer = MistralAwqLayer(idx) - layer.init_parameters(hf_layer=hf_layer) - layer.to(self.device) - self.layers.append(layer) - hf_model.model.layers[idx] = None - gc.collect() - self.num_layers = len(self.layers) - - - @torch.inference_mode() - def layer_compute(self, - buffer: MistralAwqLayer, - layer_idx :int, - hidden_states: torch.FloatTensor, - position_ids: torch.LongTensor, - attention_mask: torch.FloatTensor, - storage_ids: torch.LongTensor): - - residual = hidden_states - bsz, q_len, _ = hidden_states.size() - - hidden_states = layer_norm(hidden_states, buffer.input_layernorm_variance_epsilon, buffer.input_layernorm_weight) - bsz, q_len, _ = hidden_states.size() - query_states = buffer.wq.apply(hidden_states) - key_states = buffer.wk.apply(hidden_states) - value_states = buffer.wv.apply(hidden_states) - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim) - - - - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, self.cos_cache, self.sin_cache, position_ids) - - hidden_states = self.kv_cache.compute_attention( - query_states, key_states, value_states, layer_idx, storage_ids, attention_mask - ) - - hidden_states = hidden_states.reshape(bsz, q_len, self.head_dim * self.num_heads) - - hidden_states = buffer.wo.apply(hidden_states) - hidden_states = residual + hidden_states - residual = hidden_states - hidden_states = layer_norm(hidden_states, buffer.post_attention_layernorm_variance_epsilon, buffer.post_attention_layernorm_weight) - up = buffer.up_proj.apply(hidden_states) - gate = buffer.gate_proj.apply(hidden_states) - gate = F.silu(gate) - hidden_states = gate * up - hidden_states = buffer.down_proj.apply(hidden_states) - hidden_states = residual + hidden_states - - return hidden_states - - @torch.inference_mode() - def inference(self, - input_ids: torch.LongTensor, - position_ids: torch.LongTensor, - attention_mask: torch.FloatTensor, - storage_ids: torch.LongTensor): - - hidden_states = F.embedding(input_ids, self.embed_tokens) - - for idx in range(self.num_layers): - hidden_states = self.layer_compute(self.layers[idx], idx, hidden_states, position_ids, attention_mask, storage_ids) - - b, s, h = hidden_states.shape - hidden_states = hidden_states.reshape(b * s, h) - hidden_states = flashinfer.rmsnorm(hidden_states, self.norm_weight, self.norm_variance_epsilon) - hidden_states = hidden_states.reshape(b, s, h) - logits = F.linear(hidden_states, self.lm_head).float() - return logits - - -class MistralAwqOffload(MistralOffload): - def alloc(self, **kwargs): - - self.num_cache_layers = kwargs["num_cache_layers"] if 'num_cache_layers' in kwargs else 0 - self.kv_cache = KV_Cache(self.config, max_length=self.max_length, device=self.device, dtype=self.dtype, batch_size=self.batch_size) - - hf_model = AutoModelForCausalLM.from_pretrained(self.model_name, torch_dtype=self.dtype) - self.embed_tokens = hf_model.model.embed_tokens.weight.detach().to(self.device) - if self.config.tie_word_embeddings: - self.lm_head = self.embed_tokens - else: - self.lm_head = hf_model.lm_head.weight.detach().to(self.device) - self.norm_weight = hf_model.model.norm.weight.detach().to(self.device) - self.norm_variance_epsilon = hf_model.model.norm.variance_epsilon - self.inv_freq = hf_model.model.rotary_emb.inv_freq.detach().to(self.device) - self.attention_scaling = hf_model.model.rotary_emb.attention_scaling - position_ids = torch.arange(0, self.max_length).unsqueeze(0).to(self.device) - inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) - position_ids_expanded = position_ids[:, None, :].float() - freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) - emb = torch.cat((freqs, freqs), dim=-1) - self.cos_cache = emb.cos()[0] - self.sin_cache = emb.sin()[0] - self.cos_cache = self.cos_cache * self.attention_scaling - self.sin_cache = self.sin_cache * self.attention_scaling - self.cos_cache = self.cos_cache.to(self.dtype) - self.sin_cache = self.sin_cache.to(self.dtype) - self.layers :list[MistralAwqLayer] = [] - - for idx, hf_layer in tqdm(enumerate(hf_model.model.layers), desc="initial offloaded model"): - layer = MistralAwqLayer(idx) - layer.init_parameters(hf_layer=hf_layer) - if idx < self.num_cache_layers: - layer.to(self.device) - self.layers.append(layer) - hf_model.model.layers[idx] = None - gc.collect() - self.num_layers = len(self.layers) - assert self.num_layers % 2 == 0 - self.buffer = [MistralAwqLayer(-1, self.device) for _ in range(2)] - self.buffer[0].alloc_space(self.layers[0], self.device) - self.buffer[1].alloc_space(self.layers[0], self.device) - - @torch.inference_mode() - def layer_compute(self, - buffer: MistralAwqLayer, - layer_idx :int, - hidden_states: torch.FloatTensor, - position_ids: torch.LongTensor, - attention_mask: torch.FloatTensor, - storage_ids: torch.LongTensor): - - residual = hidden_states - bsz, q_len, _ = hidden_states.size() - - hidden_states = layer_norm(hidden_states, buffer.input_layernorm_variance_epsilon, buffer.input_layernorm_weight) - bsz, q_len, _ = hidden_states.size() - query_states = buffer.wq.apply(hidden_states) - key_states = buffer.wk.apply(hidden_states) - value_states = buffer.wv.apply(hidden_states) - - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim) - - - - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, self.cos_cache, self.sin_cache, position_ids) - hidden_states = self.kv_cache.compute_attention( - query_states, key_states, value_states, layer_idx, storage_ids, attention_mask - ) - hidden_states = hidden_states.reshape(bsz, q_len, self.head_dim * self.num_heads) - - hidden_states = buffer.wo.apply(hidden_states) - hidden_states = residual + hidden_states - residual = hidden_states - hidden_states = layer_norm(hidden_states, buffer.post_attention_layernorm_variance_epsilon, buffer.post_attention_layernorm_weight) - up = buffer.up_proj.apply(hidden_states) - gate = buffer.gate_proj.apply(hidden_states) - gate = F.silu(gate) - hidden_states = gate * up - hidden_states = buffer.down_proj.apply(hidden_states) - hidden_states = residual + hidden_states - - return hidden_states + pass +class MistralAwqOffload(Mistral): + pass class MistralCudagraph(Mistral): - def __init__(self, model_name, batch_size = 1, max_length = 256, device = 'cuda:0', dtype=torch.float16): + def __init__(self, model_name, batch_size=1, max_length=256, device='cuda:0', dtype=torch.float16): super().__init__(model_name, batch_size, max_length, device, dtype) - + self.callables = {} self.mempool = None + def alloc(self, **kwargs): - exit_layer = kwargs.pop("exit_layer", -1) - self.kv_cache = StaticKV_Cache(self.config, max_length=self.max_length, device=self.device, dtype=self.dtype, batch_size=self.batch_size) + + # if using sliding winodw, use KV_static_Cache to update past key/value + if self.config.sliding_window is not None: + self.kv_cache = SlidingWindowKV_Cache(self.config, max_length=self.max_length, device=self.device, + dtype=self.dtype, batch_size=self.batch_size) + else: + self.kv_cache = KV_Cache(self.config, max_length=self.max_length, device=self.device, dtype=self.dtype, + batch_size=self.batch_size) hf_model = MistralForCausalLM.from_pretrained(self.model_name, torch_dtype=self.dtype) self.embed_tokens = hf_model.model.embed_tokens.weight.detach().to(self.device) if self.config.tie_word_embeddings: @@ -427,9 +290,17 @@ def alloc(self, **kwargs): self.norm_weight = hf_model.model.norm.weight.detach().to(self.device) self.norm_variance_epsilon = hf_model.model.norm.variance_epsilon - - self.inv_freq = hf_model.model.rotary_emb.inv_freq.detach().to(self.device) - self.attention_scaling = hf_model.model.rotary_emb.attention_scaling + + if hasattr(hf_model.model.layers[0].self_attn, "rotary_emb"): + rotary_emb = hf_model.model.layers[0].self_attn.rotary_emb + else: + rotary_emb = hf_model.model.rotary_emb + self.inv_freq = rotary_emb.inv_freq.detach().to(self.device) + if hasattr(rotary_emb, "attention_scaling"): + self.attention_scaling = rotary_emb.attention_scaling + else: + self.attention_scaling = 1.0 # 默认值 + position_ids = torch.arange(0, self.max_length).unsqueeze(0).to(self.device) inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) position_ids_expanded = position_ids[:, None, :].float() @@ -441,9 +312,10 @@ def alloc(self, **kwargs): self.sin_cache = self.sin_cache * self.attention_scaling self.cos_cache = self.cos_cache.to(self.dtype) self.sin_cache = self.sin_cache.to(self.dtype) - - self.layers :list[MistralPackedLayer] = [] - + + self.layers: list[MistralLayer] = [] + + ## loop hf_model.model.layers, and transfer to format MistralLayer, and store it to self.layers for idx, hf_layer in enumerate(hf_model.model.layers): if exit_layer > 0 and idx >= exit_layer: break @@ -453,57 +325,58 @@ def alloc(self, **kwargs): self.layers.append(layer) hf_model.model.layers[idx] = None gc.collect() - + self.num_layers = len(self.layers) - + @torch.inference_mode() - def layer_compute(self, - buffer: MistralPackedLayer, - layer_idx :int, - hidden_states: torch.FloatTensor, - position_ids: torch.LongTensor, - attention_mask: torch.FloatTensor, - storage_ids: torch.LongTensor): + def layer_compute(self, + buffer: MistralPackedLayer, + layer_idx: int, + hidden_states: torch.FloatTensor, + position_ids: torch.LongTensor, + attention_mask: torch.FloatTensor, + storage_ids: torch.LongTensor): residual = hidden_states bsz, q_len, _ = hidden_states.size() - - hidden_states = layer_norm(hidden_states, buffer.input_layernorm_variance_epsilon, buffer.input_layernorm_weight) + + hidden_states = layer_norm(hidden_states, buffer.input_layernorm_variance_epsilon, + buffer.input_layernorm_weight) bsz, q_len, _ = hidden_states.size() qkv = F.linear(hidden_states, buffer.wqkv) - query_states = qkv[...,:self.hidden_size] - key_states = qkv[...,self.hidden_size:self.hidden_size + self.head_dim * self.num_key_value_heads] - value_states = qkv[...,self.hidden_size + self.head_dim * self.num_key_value_heads:] - - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1,2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1,2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1,2) - - - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, self.cos_cache, self.sin_cache, position_ids, unsqueeze_dim=1) + query_states = qkv[..., :self.hidden_size] + key_states = qkv[..., self.hidden_size:self.hidden_size + self.head_dim * self.num_key_value_heads] + value_states = qkv[..., self.hidden_size + self.head_dim * self.num_key_value_heads:] + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, self.cos_cache, self.sin_cache, + position_ids, unsqueeze_dim=1) hidden_states = self.kv_cache.compute_attention( query_states, key_states, value_states, layer_idx, storage_ids, attention_mask ) - - hidden_states = hidden_states.reshape(bsz, q_len, self.head_dim * self.num_heads) + + hidden_states = hidden_states.reshape(bsz, q_len, self.hidden_size) hidden_states = F.linear(hidden_states, buffer.wo) hidden_states = residual + hidden_states residual = hidden_states - hidden_states = layer_norm(hidden_states, buffer.post_attention_layernorm_variance_epsilon, buffer.post_attention_layernorm_weight) + hidden_states = layer_norm(hidden_states, buffer.post_attention_layernorm_variance_epsilon, + buffer.post_attention_layernorm_weight) up = F.linear(hidden_states, buffer.up_proj) gate = F.linear(hidden_states, buffer.gate_proj) gate = F.silu(gate) hidden_states = gate * up hidden_states = F.linear(hidden_states, buffer.down_proj) hidden_states = residual + hidden_states - + return hidden_states - - + @torch.inference_mode() - def initialize_cuda_graph(self, - decoding_seqlens :list[int], - n_warmups=12): + def initialize_cuda_graph(self, + decoding_seqlens: list[int], + n_warmups=12): gc.collect() self.mempool = torch.cuda.graphs.graph_pool_handle() for decoding_seqlen in decoding_seqlens: @@ -515,17 +388,17 @@ def initialize_cuda_graph(self, n_warmups=n_warmups ) self.clear() - + @torch.inference_mode() def graph_inference(self, - input_ids: torch.LongTensor, - storage_ids :torch.LongTensor, - position_ids = None, - attention_mask = None, - ): - dec_length = input_ids.shape[1] - if dec_length in self.callables.keys(): - logits = self.callables[dec_length](input_ids, storage_ids, position_ids, attention_mask) - else: - logits = self.inference(input_ids, position_ids, attention_mask, storage_ids) - return logits \ No newline at end of file + input_ids: torch.LongTensor, + storage_ids: torch.LongTensor, + position_ids=None, + attention_mask=None, + ): + dec_length = input_ids.shape[1] + if dec_length in self.callables.keys(): + logits = self.callables[dec_length](input_ids, storage_ids, position_ids, attention_mask) + else: + logits = self.inference(input_ids, position_ids, attention_mask, storage_ids) + return logits \ No newline at end of file diff --git a/umbrella/models/mistral_layer.py b/umbrella/models/mistral_layer.py index be05520..9a65db6 100644 --- a/umbrella/models/mistral_layer.py +++ b/umbrella/models/mistral_layer.py @@ -3,33 +3,32 @@ from transformers.models.mistral.modeling_mistral import MistralDecoderLayer from ..quantization.awq_utils import AwqLinear + class MistralLayer: - def __init__(self, layer_idx, device = "cpu") -> None: - - self.wq :torch.Tensor = None - self.wk :torch.Tensor = None - self.wv :torch.Tensor = None - self.wo :torch.Tensor = None + def __init__(self, layer_idx, device="cpu") -> None: + self.wq: torch.Tensor = None + self.wk: torch.Tensor = None + self.wv: torch.Tensor = None + self.wo: torch.Tensor = None - self.gate_proj :torch.Tensor = None - self.up_proj :torch.Tensor = None - self.down_proj :torch.Tensor = None + self.gate_proj: torch.Tensor = None + self.up_proj: torch.Tensor = None + self.down_proj: torch.Tensor = None - self.input_layernorm_weight :torch.Tensor = None - self.input_layernorm_variance_epsilon :float = 0.0 + self.input_layernorm_weight: torch.Tensor = None + self.input_layernorm_variance_epsilon: float = 0.0 - self.post_attention_layernorm_weight :torch.Tensor = None - self.post_attention_layernorm_variance_epsilon :float = 0.0 + self.post_attention_layernorm_weight: torch.Tensor = None + self.post_attention_layernorm_variance_epsilon: float = 0.0 self.layer_idx = layer_idx self.device = device def init_parameters(self, hf_layer: MistralDecoderLayer): - - self.wq :torch.Tensor= hf_layer.self_attn.q_proj.weight.detach() - self.wk :torch.Tensor= hf_layer.self_attn.k_proj.weight.detach() - self.wv :torch.Tensor= hf_layer.self_attn.v_proj.weight.detach() - self.wo :torch.Tensor= hf_layer.self_attn.o_proj.weight.detach() + self.wq: torch.Tensor = hf_layer.self_attn.q_proj.weight.detach() + self.wk: torch.Tensor = hf_layer.self_attn.k_proj.weight.detach() + self.wv: torch.Tensor = hf_layer.self_attn.v_proj.weight.detach() + self.wo: torch.Tensor = hf_layer.self_attn.o_proj.weight.detach() self.gate_proj = hf_layer.mlp.gate_proj.weight.detach() self.up_proj = hf_layer.mlp.up_proj.weight.detach() @@ -41,21 +40,20 @@ def init_parameters(self, hf_layer: MistralDecoderLayer): self.post_attention_layernorm_weight = hf_layer.post_attention_layernorm.weight.detach() self.post_attention_layernorm_variance_epsilon = hf_layer.post_attention_layernorm.variance_epsilon - def to(self, device:str = 'cuda:0', non_blocking = True): - + def to(self, device: str = 'cuda:0', non_blocking=True): self.device = device self.input_layernorm_weight = self.input_layernorm_weight.to(device, non_blocking=non_blocking) - self.post_attention_layernorm_weight = self.post_attention_layernorm_weight.to(device, non_blocking=non_blocking) + self.post_attention_layernorm_weight = self.post_attention_layernorm_weight.to(device, + non_blocking=non_blocking) self.wq = self.wq.to(device, non_blocking=non_blocking) self.wk = self.wk.to(device, non_blocking=non_blocking) self.wv = self.wv.to(device, non_blocking=non_blocking) self.wo = self.wo.to(device, non_blocking=non_blocking) self.gate_proj = self.gate_proj.to(device, non_blocking=non_blocking) self.up_proj = self.up_proj.to(device, non_blocking=non_blocking) - self.down_proj = self.down_proj.to(device, non_blocking=non_blocking) + self.down_proj = self.down_proj.to(device, non_blocking=non_blocking) def copy(self, layer: MistralLayer): - self.wq.copy_(layer.wq, non_blocking=True) self.wk.copy_(layer.wk, non_blocking=True) self.wv.copy_(layer.wv, non_blocking=True) @@ -63,22 +61,20 @@ def copy(self, layer: MistralLayer): self.gate_proj.copy_(layer.gate_proj, non_blocking=True) self.up_proj.copy_(layer.up_proj, non_blocking=True) self.down_proj.copy_(layer.down_proj, non_blocking=True) - + self.input_layernorm_weight.copy_(layer.input_layernorm_weight, non_blocking=True) self.post_attention_layernorm_weight.copy_(layer.post_attention_layernorm_weight, non_blocking=True) - self.input_layernorm_variance_epsilon= layer.input_layernorm_variance_epsilon + self.input_layernorm_variance_epsilon = layer.input_layernorm_variance_epsilon self.post_attention_layernorm_variance_epsilon = layer.post_attention_layernorm_variance_epsilon self.layer_idx = layer.layer_idx - - def alloc_space(self, layer: MistralLayer, device): + def alloc_space(self, layer: MistralLayer, device): self.device = device self.wq = torch.zeros_like(layer.wq).to(device) self.wk = torch.zeros_like(layer.wk).to(device) self.wv = torch.zeros_like(layer.wv).to(device) self.wo = torch.zeros_like(layer.wo).to(device) - self.gate_proj = torch.zeros_like(layer.gate_proj).to(device) self.up_proj = torch.zeros_like(layer.up_proj).to(device) self.down_proj = torch.zeros_like(layer.down_proj).to(device) @@ -87,34 +83,32 @@ def alloc_space(self, layer: MistralLayer, device): class MistralPackedLayer: - def __init__(self, layer_idx, device = "cpu") -> None: - - self.wqkv :torch.Tensor = None + def __init__(self, layer_idx, device="cpu") -> None: + self.wqkv: torch.Tensor = None - self.gate_proj :torch.Tensor = None - self.up_proj :torch.Tensor = None - self.down_proj :torch.Tensor = None + self.gate_proj: torch.Tensor = None + self.up_proj: torch.Tensor = None + self.down_proj: torch.Tensor = None - self.input_layernorm_weight :torch.Tensor = None - self.input_layernorm_variance_epsilon :float = 0.0 + self.input_layernorm_weight: torch.Tensor = None + self.input_layernorm_variance_epsilon: float = 0.0 - self.post_attention_layernorm_weight :torch.Tensor = None - self.post_attention_layernorm_variance_epsilon :float = 0.0 + self.post_attention_layernorm_weight: torch.Tensor = None + self.post_attention_layernorm_variance_epsilon: float = 0.0 self.layer_idx = layer_idx self.device = device def init_parameters(self, hf_layer: MistralDecoderLayer): - - self.wqkv :torch.Tensor= torch.cat( - [ - hf_layer.self_attn.q_proj.weight.detach(), - hf_layer.self_attn.k_proj.weight.detach(), - hf_layer.self_attn.v_proj.weight.detach(), - ], - dim=0 + self.wqkv: torch.Tensor = torch.cat( + [ + hf_layer.self_attn.q_proj.weight.detach(), + hf_layer.self_attn.k_proj.weight.detach(), + hf_layer.self_attn.v_proj.weight.detach(), + ], + dim=0 ) - self.wo :torch.Tensor= hf_layer.self_attn.o_proj.weight.detach() + self.wo: torch.Tensor = hf_layer.self_attn.o_proj.weight.detach() self.gate_proj = hf_layer.mlp.gate_proj.weight.detach() self.up_proj = hf_layer.mlp.up_proj.weight.detach() self.down_proj = hf_layer.mlp.down_proj.weight.detach() @@ -124,135 +118,38 @@ def init_parameters(self, hf_layer: MistralDecoderLayer): self.post_attention_layernorm_weight = hf_layer.post_attention_layernorm.weight.detach() self.post_attention_layernorm_variance_epsilon = hf_layer.post_attention_layernorm.variance_epsilon - - def to(self, device:str = 'cuda:0', non_blocking = True): + def to(self, device: str = 'cuda:0', non_blocking=True): self.device = device self.input_layernorm_weight = self.input_layernorm_weight.to(device, non_blocking=non_blocking) - self.post_attention_layernorm_weight = self.post_attention_layernorm_weight.to(device, non_blocking=non_blocking) + self.post_attention_layernorm_weight = self.post_attention_layernorm_weight.to(device, + non_blocking=non_blocking) self.wqkv = self.wqkv.to(device, non_blocking=non_blocking) self.wo = self.wo.to(device, non_blocking=non_blocking) self.gate_proj = self.gate_proj.to(device, non_blocking=non_blocking) self.up_proj = self.up_proj.to(device, non_blocking=non_blocking) - self.down_proj = self.down_proj.to(device, non_blocking=non_blocking) - - def copy(self, layer: MistralPackedLayer): + self.down_proj = self.down_proj.to(device, non_blocking=non_blocking) + def copy(self, layer: LlamaPackedLayer): self.wqkv.copy_(layer.wqkv, non_blocking=True) self.wo.copy_(layer.wo, non_blocking=True) self.gate_proj.copy_(layer.gate_proj, non_blocking=True) self.up_proj.copy_(layer.up_proj, non_blocking=True) self.down_proj.copy_(layer.down_proj, non_blocking=True) - + self.input_layernorm_weight.copy_(layer.input_layernorm_weight, non_blocking=True) self.post_attention_layernorm_weight.copy_(layer.post_attention_layernorm_weight, non_blocking=True) - self.input_layernorm_variance_epsilon= layer.input_layernorm_variance_epsilon + self.input_layernorm_variance_epsilon = layer.input_layernorm_variance_epsilon self.post_attention_layernorm_variance_epsilon = layer.post_attention_layernorm_variance_epsilon self.layer_idx = layer.layer_idx - - def alloc_space(self, layer: MistralPackedLayer, device): + def alloc_space(self, layer: MistralPackedLayer, device): self.device = device self.wqkv = torch.zeros_like(layer.wqkv).to(device) self.wo = torch.zeros_like(layer.wo).to(device) - self.gate_proj = torch.zeros_like(layer.gate_proj).to(device) self.up_proj = torch.zeros_like(layer.up_proj).to(device) self.down_proj = torch.zeros_like(layer.down_proj).to(device) self.input_layernorm_weight = torch.zeros_like(layer.input_layernorm_weight).to(device) self.post_attention_layernorm_weight = torch.zeros_like(layer.post_attention_layernorm_weight).to(device) - - -class MistralAwqLayer(): - def __init__(self, layer_idx, device="cpu"): - - self.wq = AwqLinear() - self.wk = AwqLinear() - self.wv = AwqLinear() - self.wo = AwqLinear() - - self.gate_proj = AwqLinear() - self.up_proj = AwqLinear() - self.down_proj = AwqLinear() - - - self.input_layernorm_weight :torch.Tensor = None - self.input_layernorm_variance_epsilon :float = 0.0 - - self.post_attention_layernorm_weight :torch.Tensor = None - self.post_attention_layernorm_variance_epsilon :float = 0.0 - - self.layer_idx = layer_idx - self.device = device - - def init_parameters(self, hf_layer): - - self.wq.init_parameters(hf_layer.self_attn.q_proj) - self.wk.init_parameters(hf_layer.self_attn.k_proj) - self.wv.init_parameters(hf_layer.self_attn.v_proj) - self.wo.init_parameters(hf_layer.self_attn.o_proj) - self.gate_proj.init_parameters(hf_layer.mlp.gate_proj) - self.up_proj.init_parameters(hf_layer.mlp.up_proj) - self.down_proj.init_parameters(hf_layer.mlp.down_proj) - - self.input_layernorm_weight = hf_layer.input_layernorm.weight.detach().pin_memory() - self.input_layernorm_variance_epsilon = hf_layer.input_layernorm.variance_epsilon - - self.post_attention_layernorm_weight = hf_layer.post_attention_layernorm.weight.detach().pin_memory() - self.post_attention_layernorm_variance_epsilon = hf_layer.post_attention_layernorm.variance_epsilon - - def to(self, device:str = 'cuda:0', non_blocking = True): - - self.device = device - self.input_layernorm_weight = self.input_layernorm_weight.to(device, non_blocking=non_blocking) - self.post_attention_layernorm_weight = self.post_attention_layernorm_weight.to(device, non_blocking=non_blocking) - - self.wq.to(device=device) - self.wk.to(device=device) - self.wv.to(device=device) - self.wo.to(device=device) - - self.gate_proj.to(device=device) - self.up_proj.to(device=device) - self.down_proj.to(device=device) - - def alloc_space(self, layer: MistralAwqLayer, device): - - self.device = device - self.wq.empty_like(layer.wq) - self.wk.empty_like(layer.wk) - self.wv.empty_like(layer.wv) - self.wo.empty_like(layer.wo) - - self.gate_proj.empty_like(layer.gate_proj) - self.up_proj.empty_like(layer.up_proj) - self.down_proj.empty_like(layer.down_proj) - - self.wq.to(device=device) - self.wk.to(device=device) - self.wv.to(device=device) - self.wo.to(device=device) - - self.gate_proj.to(device=device) - self.up_proj.to(device=device) - self.down_proj.to(device=device) - - self.input_layernorm_weight = torch.zeros_like(layer.input_layernorm_weight).to(device) - self.post_attention_layernorm_weight = torch.zeros_like(layer.post_attention_layernorm_weight).to(device) - - def copy(self, layer: MistralAwqLayer): - - self.wq.copy(layer.wq, non_blocking=True) - self.wk.copy(layer.wk, non_blocking=True) - self.wv.copy(layer.wv, non_blocking=True) - self.wo.copy(layer.wo, non_blocking=True) - self.gate_proj.copy(layer.gate_proj, non_blocking=True) - self.up_proj.copy(layer.up_proj, non_blocking=True) - self.down_proj.copy(layer.down_proj, non_blocking=True) - - self.input_layernorm_weight.copy_(layer.input_layernorm_weight, non_blocking=True) - self.post_attention_layernorm_weight.copy_(layer.post_attention_layernorm_weight, non_blocking=True) - self.input_layernorm_variance_epsilon= layer.input_layernorm_variance_epsilon - self.post_attention_layernorm_variance_epsilon = layer.post_attention_layernorm_variance_epsilon - self.layer_idx = layer.layer_idx \ No newline at end of file diff --git a/umbrella/models/model_utils.py b/umbrella/models/model_utils.py index 6c68e6e..50695f1 100644 --- a/umbrella/models/model_utils.py +++ b/umbrella/models/model_utils.py @@ -63,17 +63,6 @@ def layer_norm( hidden_states = hidden_states.reshape(b, s, h) return hidden_states -def layer_norm_gemma( - hidden_states: torch.Tensor, - layernorm_variance_epsilon: float, - layernorm_weight: torch.Tensor, -): - b, s, h = hidden_states.shape - - hidden_states = hidden_states.reshape(b * s, h) - hidden_states = flashinfer.gemma_rmsnorm(hidden_states, layernorm_weight, layernorm_variance_epsilon) - hidden_states = hidden_states.reshape(b, s, h) - return hidden_states def capture_graph( llm, decoding_seqlen :int =1, mempool=None, n_warmups :int=3 @@ -113,4 +102,4 @@ def run(input_ids, storage_ids, position_ids, attention_mask): graph.replay() return static_logits.clone() - return run + return run \ No newline at end of file diff --git a/umbrella/speculation/static_speculation_engine.py b/umbrella/speculation/static_speculation_engine.py index 1693d24..7f03aa7 100644 --- a/umbrella/speculation/static_speculation_engine.py +++ b/umbrella/speculation/static_speculation_engine.py @@ -86,17 +86,17 @@ def initialize(self): graph_capture_list.append(1) self.draft_model = AutoModelLM.from_pretrained( - model_name=self.draft_model_name, offload=False, cuda_graph=True, batch_size=1, + model_name=self.draft_model_name, offload=False, cuda_graph=True, batch_size=1, max_length=self.max_length, device=self.device, dtype=self.dtype) self.draft_model.alloc(**self.config) self.target_model = AutoModelLM.from_pretrained( - model_name=self.target_model_name, offload=False, batch_size=1, + model_name=self.target_model_name, offload=False, batch_size=1, max_length=self.max_length, device=self.device, dtype=self.dtype) - + self.target_model.alloc(**self.config) self.draft_model.initialize_cuda_graph(graph_capture_list) diff --git a/umbrella/templates.py b/umbrella/templates.py index ac30a80..b163a45 100644 --- a/umbrella/templates.py +++ b/umbrella/templates.py @@ -15,15 +15,8 @@ {}<|im_end|> <|im_start|>assistant """, - -'gemma2-it': """user -{} -model -""", - -'gemma2': "{}", -'mistral': "[INST] {} [/INST]" - + 'mistral-24b': """[INST] {} [/INST]""", + 'mistral-7b': """[INST] {} [/INST]""" } SysPrompts = { @@ -36,12 +29,10 @@ 'qwen': """<|im_start|>system You are a helpful assistant.<|im_end|> """, - 'gemma2': "", - 'gemma2-it': "", - 'mistral': "", - + 'mistral-24b': """[INST] You are a knowledgeable, efficient, and direct AI assistant. Provide concise answers and focus on key information. [/INST]""", + 'mistral-7b': """[INST] You are a helpful AI assistant. Provide concise and informative responses. [/INST]""" } ExtraPrompts = { 'llama3-code': """\nAlways try to wrap what you write in a function.""" -} +} \ No newline at end of file From aae6d8372d4df9dd7108ebb61b204d0731713a1d Mon Sep 17 00:00:00 2001 From: yulinw2 Date: Fri, 21 Feb 2025 05:15:38 -0500 Subject: [PATCH 19/23] debug --- .gitignore | 3 +- .idea/.gitignore | 8 -- .idea/UMbreLLa.iml | 12 --- .idea/inspectionProfiles/Project_Default.xml | 38 --------- .../inspectionProfiles/profiles_settings.xml | 6 -- .idea/misc.xml | 4 - .idea/modules.xml | 8 -- .idea/vcs.xml | 6 -- draft/config.json | 25 ++++++ draft/train_draft.py | 83 +++++++++++++++++++ 10 files changed, 110 insertions(+), 83 deletions(-) delete mode 100644 .idea/.gitignore delete mode 100644 .idea/UMbreLLa.iml delete mode 100644 .idea/inspectionProfiles/Project_Default.xml delete mode 100644 .idea/inspectionProfiles/profiles_settings.xml delete mode 100644 .idea/misc.xml delete mode 100644 .idea/modules.xml delete mode 100644 .idea/vcs.xml create mode 100644 draft/config.json create mode 100644 draft/train_draft.py diff --git a/.gitignore b/.gitignore index 54a5a08..c72b144 100644 --- a/.gitignore +++ b/.gitignore @@ -171,4 +171,5 @@ cython_debug/ .pypirc .vscode/ app/.gradio/ -test* \ No newline at end of file +test* +.idea/ \ No newline at end of file diff --git a/.idea/.gitignore b/.idea/.gitignore deleted file mode 100644 index 13566b8..0000000 --- a/.idea/.gitignore +++ /dev/null @@ -1,8 +0,0 @@ -# Default ignored files -/shelf/ -/workspace.xml -# Editor-based HTTP Client requests -/httpRequests/ -# Datasource local storage ignored files -/dataSources/ -/dataSources.local.xml diff --git a/.idea/UMbreLLa.iml b/.idea/UMbreLLa.iml deleted file mode 100644 index 8a05c6e..0000000 --- a/.idea/UMbreLLa.iml +++ /dev/null @@ -1,12 +0,0 @@ - - - - - - - - - - \ No newline at end of file diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml deleted file mode 100644 index f3032d8..0000000 --- a/.idea/inspectionProfiles/Project_Default.xml +++ /dev/null @@ -1,38 +0,0 @@ - - - - \ No newline at end of file diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml deleted file mode 100644 index 105ce2d..0000000 --- a/.idea/inspectionProfiles/profiles_settings.xml +++ /dev/null @@ -1,6 +0,0 @@ - - - - \ No newline at end of file diff --git a/.idea/misc.xml b/.idea/misc.xml deleted file mode 100644 index c95407c..0000000 --- a/.idea/misc.xml +++ /dev/null @@ -1,4 +0,0 @@ - - - - \ No newline at end of file diff --git a/.idea/modules.xml b/.idea/modules.xml deleted file mode 100644 index 3892f88..0000000 --- a/.idea/modules.xml +++ /dev/null @@ -1,8 +0,0 @@ - - - - - - - - \ No newline at end of file diff --git a/.idea/vcs.xml b/.idea/vcs.xml deleted file mode 100644 index 94a25f7..0000000 --- a/.idea/vcs.xml +++ /dev/null @@ -1,6 +0,0 @@ - - - - - - \ No newline at end of file diff --git a/draft/config.json b/draft/config.json new file mode 100644 index 0000000..55a9136 --- /dev/null +++ b/draft/config.json @@ -0,0 +1,25 @@ +{ + "architectures": [ + "MistralForCausalLM" + ], + "attention_dropout": 0.0, + "bos_token_id": 1, + "eos_token_id": 2, + "head_dim": 64, + "hidden_act": "silu", + "hidden_size": 2048, + "initializer_range": 0.02, + "intermediate_size": 6144, + "max_position_embeddings": 32768, + "model_type": "mistral", + "num_attention_heads": 16, + "num_hidden_layers": 4, + "num_key_value_heads": 4, + "rms_norm_eps": 1e-05, + "rope_theta": 100000000.0, + "sliding_window": null, + "tie_word_embeddings": true, + "torch_dtype": "bfloat16", + "use_cache": true, + "vocab_size": 131072 + } \ No newline at end of file diff --git a/draft/train_draft.py b/draft/train_draft.py new file mode 100644 index 0000000..525e767 --- /dev/null +++ b/draft/train_draft.py @@ -0,0 +1,83 @@ +import torch +from transformers import AutoTokenizer, DataCollatorForLanguageModeling, TrainingArguments, Trainer, AutoModelForCausalLM, AutoConfig +from datasets import load_dataset +import json +import argparse +parser = argparse.ArgumentParser() +parser.add_argument('--tokenizer', type=str, default="mistralai/Mistral-Small-24B-Instruct-2501",help='tokenizer') +parser.add_argument('--config', type=str, default="./config.json",help='model config') +parser.add_argument('--output_dir', type=str, default="mistral",help='output directory') +parser.add_argument('--bsz', type=int, default=4, help='generation length') +args = parser.parse_args() + +config = AutoConfig.from_pretrained(args.config) +model_name = args.tokenizer +tokenizer = AutoTokenizer.from_pretrained(model_name) +if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + +model = AutoModelForCausalLM.from_config(config) +total_params = sum(p.numel() for p in model.parameters()) +print(f"total_params: {total_params:,}") + +train_data_files= [ + "train/chunk1/example_train_10*.jsonl.zst", + "train/chunk1/example_train_11*.jsonl.zst", + "train/chunk1/example_train_12*.jsonl.zst", + "train/chunk1/example_train_13*.jsonl.zst", + "train/chunk1/example_train_14*.jsonl.zst" +] +train_raw_datasets = load_dataset("cerebras/SlimPajama-627B", data_files=train_data_files, split="train") + +eval_data_files= ["validation/chunk1/example_holdout_*.jsonl.zst"] +eval_raw_datasets = load_dataset("cerebras/SlimPajama-627B", data_files=eval_data_files, split="train") + +# 定义预处理函数:对句子对进行编码 +def preprocess_function(examples): + + + output = tokenizer( + examples["text"], + truncation=True, + max_length=1024, + padding="max_length" + ) + + return output + +train_tokenized_datasets = train_raw_datasets.map(preprocess_function, batched=True, num_proc=8) +eval_tokenized_datasets = eval_raw_datasets.map(preprocess_function, batched=True, num_proc=8) + +data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False) + +training_args = TrainingArguments( + output_dir=args.output_dir, + learning_rate=1e-4, + per_device_train_batch_size=args.bsz, + per_device_eval_batch_size=args.bsz, + weight_decay=0.01, + lr_scheduler_type="cosine", + load_best_model_at_end=True, + logging_dir='./logs', + logging_steps=10, + bf16=True, + save_only_model=True, + save_steps=5000, + save_total_limit=2, + eval_strategy="steps", + save_strategy="steps", + eval_steps=5000 +) + +# 初始化 Trainer +trainer = Trainer( + model=model, + args=training_args, + train_dataset=train_tokenized_datasets, + eval_dataset=eval_tokenized_datasets, + tokenizer=tokenizer, + data_collator=data_collator +) + +# 开始训练 +trainer.train() From 88322f5ae6628e02a800d41df5aed47be5d7c0d5 Mon Sep 17 00:00:00 2001 From: yulinw2 Date: Fri, 21 Feb 2025 05:16:07 -0500 Subject: [PATCH 20/23] debug --- draft/config.json | 25 ------------- draft/train_draft.py | 83 -------------------------------------------- 2 files changed, 108 deletions(-) delete mode 100644 draft/config.json delete mode 100644 draft/train_draft.py diff --git a/draft/config.json b/draft/config.json deleted file mode 100644 index 55a9136..0000000 --- a/draft/config.json +++ /dev/null @@ -1,25 +0,0 @@ -{ - "architectures": [ - "MistralForCausalLM" - ], - "attention_dropout": 0.0, - "bos_token_id": 1, - "eos_token_id": 2, - "head_dim": 64, - "hidden_act": "silu", - "hidden_size": 2048, - "initializer_range": 0.02, - "intermediate_size": 6144, - "max_position_embeddings": 32768, - "model_type": "mistral", - "num_attention_heads": 16, - "num_hidden_layers": 4, - "num_key_value_heads": 4, - "rms_norm_eps": 1e-05, - "rope_theta": 100000000.0, - "sliding_window": null, - "tie_word_embeddings": true, - "torch_dtype": "bfloat16", - "use_cache": true, - "vocab_size": 131072 - } \ No newline at end of file diff --git a/draft/train_draft.py b/draft/train_draft.py deleted file mode 100644 index 525e767..0000000 --- a/draft/train_draft.py +++ /dev/null @@ -1,83 +0,0 @@ -import torch -from transformers import AutoTokenizer, DataCollatorForLanguageModeling, TrainingArguments, Trainer, AutoModelForCausalLM, AutoConfig -from datasets import load_dataset -import json -import argparse -parser = argparse.ArgumentParser() -parser.add_argument('--tokenizer', type=str, default="mistralai/Mistral-Small-24B-Instruct-2501",help='tokenizer') -parser.add_argument('--config', type=str, default="./config.json",help='model config') -parser.add_argument('--output_dir', type=str, default="mistral",help='output directory') -parser.add_argument('--bsz', type=int, default=4, help='generation length') -args = parser.parse_args() - -config = AutoConfig.from_pretrained(args.config) -model_name = args.tokenizer -tokenizer = AutoTokenizer.from_pretrained(model_name) -if tokenizer.pad_token is None: - tokenizer.pad_token = tokenizer.eos_token - -model = AutoModelForCausalLM.from_config(config) -total_params = sum(p.numel() for p in model.parameters()) -print(f"total_params: {total_params:,}") - -train_data_files= [ - "train/chunk1/example_train_10*.jsonl.zst", - "train/chunk1/example_train_11*.jsonl.zst", - "train/chunk1/example_train_12*.jsonl.zst", - "train/chunk1/example_train_13*.jsonl.zst", - "train/chunk1/example_train_14*.jsonl.zst" -] -train_raw_datasets = load_dataset("cerebras/SlimPajama-627B", data_files=train_data_files, split="train") - -eval_data_files= ["validation/chunk1/example_holdout_*.jsonl.zst"] -eval_raw_datasets = load_dataset("cerebras/SlimPajama-627B", data_files=eval_data_files, split="train") - -# 定义预处理函数:对句子对进行编码 -def preprocess_function(examples): - - - output = tokenizer( - examples["text"], - truncation=True, - max_length=1024, - padding="max_length" - ) - - return output - -train_tokenized_datasets = train_raw_datasets.map(preprocess_function, batched=True, num_proc=8) -eval_tokenized_datasets = eval_raw_datasets.map(preprocess_function, batched=True, num_proc=8) - -data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False) - -training_args = TrainingArguments( - output_dir=args.output_dir, - learning_rate=1e-4, - per_device_train_batch_size=args.bsz, - per_device_eval_batch_size=args.bsz, - weight_decay=0.01, - lr_scheduler_type="cosine", - load_best_model_at_end=True, - logging_dir='./logs', - logging_steps=10, - bf16=True, - save_only_model=True, - save_steps=5000, - save_total_limit=2, - eval_strategy="steps", - save_strategy="steps", - eval_steps=5000 -) - -# 初始化 Trainer -trainer = Trainer( - model=model, - args=training_args, - train_dataset=train_tokenized_datasets, - eval_dataset=eval_tokenized_datasets, - tokenizer=tokenizer, - data_collator=data_collator -) - -# 开始训练 -trainer.train() From 25aa294abde581dc790e3a39c07e765fb18b9f1c Mon Sep 17 00:00:00 2001 From: yulinw2 Date: Fri, 21 Feb 2025 05:20:15 -0500 Subject: [PATCH 21/23] debug --- examples/generate.py | 2 -- umbrella/models/llama.py | 18 +++++------------- 2 files changed, 5 insertions(+), 15 deletions(-) diff --git a/examples/generate.py b/examples/generate.py index 6e6d86d..7895574 100644 --- a/examples/generate.py +++ b/examples/generate.py @@ -20,7 +20,6 @@ parser.add_argument('--G', type=int, default=512, help='generation length') parser.add_argument('--offload', action='store_true', help="offload the model") parser.add_argument('--cuda_graph', action='store_true', help="whether use cuda graph") -parser.add_argument('--d', type=int, default=0, help="whether use debug mode") args = parser.parse_args() DEVICE = "cuda:0" DTYPE = torch.float16 @@ -53,7 +52,6 @@ device=DEVICE, config=config ) -print('llm', llm.config) eos_tokens = llm.config.eos_token_id if not isinstance(eos_tokens, list): diff --git a/umbrella/models/llama.py b/umbrella/models/llama.py index 5f6478c..3c18b46 100644 --- a/umbrella/models/llama.py +++ b/umbrella/models/llama.py @@ -8,9 +8,6 @@ from .base import LLMBase from .model_utils import apply_rotary_pos_emb, layer_norm, capture_graph from tqdm import tqdm -""" -Standard LLmMA -""" class Llama(LLMBase): def __init__(self, model_name: str, @@ -63,8 +60,7 @@ def alloc(self, **kwargs): self.sin_cache = self.sin_cache.to(self.dtype) self.layers :list[LlamaLayer] = [] - - ## loop hf_model.model.layers, and transfer to format LlamaLayer, and store it to self.layers + for idx, hf_layer in enumerate(hf_model.model.layers): layer = LlamaLayer(idx) layer.init_parameters(hf_layer=hf_layer) @@ -124,6 +120,7 @@ def inference(self, position_ids: torch.LongTensor, attention_mask: torch.FloatTensor, storage_ids: torch.LongTensor): + hidden_states = F.embedding(input_ids, self.embed_tokens) for idx in range(self.num_layers): hidden_states = self.layer_compute(self.layers[idx], idx, hidden_states, position_ids, attention_mask, storage_ids) @@ -144,9 +141,7 @@ def clear(self): self.kv_cache.clear() -""" -Load weights of different layers to different devices. -""" + class LlamaOffload(Llama): def __init__(self, model_name, batch_size = 1, max_length = 256, device = 'cuda:0', dtype=torch.float16): super().__init__(model_name, batch_size, max_length, device, dtype) @@ -223,9 +218,7 @@ def inference(self, logits = F.linear(hidden_states, self.lm_head).float() return logits -""" -Acitivation-aware Quantization -""" + class LlamaAwq(Llama): @@ -327,8 +320,7 @@ def inference(self, hidden_states = hidden_states.reshape(b, s, h) logits = F.linear(hidden_states, self.lm_head).float() return logits - -"Combination of Offload and AWQ" + class LlamaAwqOffload(LlamaOffload): def alloc(self, **kwargs): From 5fe65f3770fa463ba655835fea60e43ebc7dcd8b Mon Sep 17 00:00:00 2001 From: yulinw2 Date: Fri, 21 Feb 2025 05:20:38 -0500 Subject: [PATCH 22/23] delete unncessary files --- .DS_Store | Bin 6148 -> 0 bytes 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 .DS_Store diff --git a/.DS_Store b/.DS_Store deleted file mode 100644 index b5d9f8165cd02056a6c76be9adf60eb4886b9264..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 6148 zcmeHKF;2ul47A~jNEBR2xnJN1(dj6t`2ZK73pxZki839}h8f$SivpAsXdKy-^?I_& z6!ET^`Q~tcX|^`A4oxkh=N4`~EUpOTu-8@#tb82O; z6N<;`$hRmr?-LcJfE1W3aGJw4`~MaEhx7lOq>&Vm0{=<@Uu|!;8+=mr*2TxM*EaYq o{L@ey Date: Fri, 21 Feb 2025 05:21:49 -0500 Subject: [PATCH 23/23] delete unncessary files --- umbrella/speculation/static_speculation_engine.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/umbrella/speculation/static_speculation_engine.py b/umbrella/speculation/static_speculation_engine.py index 7f03aa7..1693d24 100644 --- a/umbrella/speculation/static_speculation_engine.py +++ b/umbrella/speculation/static_speculation_engine.py @@ -86,17 +86,17 @@ def initialize(self): graph_capture_list.append(1) self.draft_model = AutoModelLM.from_pretrained( - model_name=self.draft_model_name, offload=False, cuda_graph=True, batch_size=1, + model_name=self.draft_model_name, offload=False, cuda_graph=True, batch_size=1, max_length=self.max_length, device=self.device, dtype=self.dtype) self.draft_model.alloc(**self.config) self.target_model = AutoModelLM.from_pretrained( - model_name=self.target_model_name, offload=False, batch_size=1, + model_name=self.target_model_name, offload=False, batch_size=1, max_length=self.max_length, device=self.device, dtype=self.dtype) - + self.target_model.alloc(**self.config) self.draft_model.initialize_cuda_graph(graph_capture_list)