diff --git a/.gitignore b/.gitignore index 54a5a08..c72b144 100644 --- a/.gitignore +++ b/.gitignore @@ -171,4 +171,5 @@ cython_debug/ .pypirc .vscode/ app/.gradio/ -test* \ No newline at end of file +test* +.idea/ \ No newline at end of file diff --git a/examples/data/question.jsonl b/examples/data/question.jsonl index 464e2c2..1b63e8a 100644 --- a/examples/data/question.jsonl +++ b/examples/data/question.jsonl @@ -39,7 +39,7 @@ {"question_id": 119, "category": "math", "turns": ["Benjamin went to a bookstore and purchased a variety of books. He bought 5 copies of a sci-fi novel, each priced at $20, 3 copies of a history book priced at $30 each, and 2 copies of a philosophy book for $45 each.\nWhat was the total cost of his purchases?", "Suppose Benjamin decides to sell each of these books at a 25% markup from the price he purchased them. What would be his total revenue if he sold all the books he bought?"], "reference": ["280", "350"]} {"question_id": 120, "category": "math", "turns": ["Given that f(x) = 4x^3 - 9x - 14, find the value of f(2).", "Find x such that f(x) = 0."], "reference": ["f(2) = 0", "x = 2"]} {"question_id": 121, "category": "coding", "turns": ["Develop a Python program that reads all the text files under a directory and returns top-5 words with the most number of occurrences.", "Can you parallelize it?"], "reference": ["Can be simple solutions like using Counter\n\nSample answer:\n```\nimport os\nimport re\nfrom collections import Counter\ndef get_files_in_directory(directory):\n return [os.path.join(directory, f) for f in os.listdir(directory) if os.path.isfile(os.path.join(directory, f)) and f.endswith('.txt')]\ndef read_file(file_path):\n with open(file_path, 'r', encoding='utf-8') as file:\n return file.read()\ndef count_words(text):\n words = re.findall(r'\\w+', text.lower())\n return Counter(words)\ndef main():\n directory = input(\"Enter the directory path: \")\n files = get_files_in_directory(directory)\n word_counts = Counter()\n for file in files:\n text = read_file(file)\n word_counts += count_words(text)\n top_5_words = word_counts.most_common(5)\n print(\"Top 5 words with the most number of occurrences:\")\n for word, count in top_5_words:\n print(f\"{word}: {count}\")\nif __name__ == \"__main__\":\n main()\n```", "You should carefully check whether the parallelization logic is correct and choose the faster implementation.\n\nSample answer:\n```\nimport os\nimport re\nfrom collections import Counter\nimport concurrent.futures\ndef get_files_in_directory(directory):\n return [os.path.join(directory, f) for f in os.listdir(directory) if os.path.isfile(os.path.join(directory, f)) and f.endswith('.txt')]\ndef read_file(file_path):\n with open(file_path, 'r', encoding='utf-8') as file:\n return file.read()\ndef count_words(text):\n words = re.findall(r'\\w+', text.lower())\n return Counter(words)\ndef process_file(file):\n text = read_file(file)\n return count_words(text)\ndef main():\n directory = input(\"Enter the directory path: \")\n files = get_files_in_directory(directory)\n word_counts = Counter()\n with concurrent.futures.ThreadPoolExecutor() as executor:\n future_word_counts = {executor.submit(process_file, file): file for file in files}\n for future in concurrent.futures.as_completed(future_word_counts):\n word_counts += future.result()\n top_5_words = word_counts.most_common(5)\n print(\"Top 5 words with the most number of occurrences:\")\n for word, count in top_5_words:\n print(f\"{word}: {count}\")\nif __name__ == \"__main__\":\n main()\n```"]} -{"question_id": 122, "category": "coding", "turns": ["Write a C++ program to find the nth Fibonacci number using recursion.", "Now we define a sequence of numbers in which each number is the sum of the three preceding ones. The first three numbers are 0, -1, -1. Write a program to find the nth number."], "reference": ["Straightforward\n\n```\nint fibonacci(int n) {\n if (n <= 1) {\n return n;\n } else {\n return fibonacci(n - 1) + fibonacci(n - 2);\n }\n}\n```", "You should carefully check the inital cases for n < 3\n\n```\nint find_nth_number(int n) {\n std::vector sequence = {0, -1, -1};\n for (int i = 3; i <= n; ++i) {\n int next_number = sequence[i - 1] + sequence[i - 2] + sequence[i - 3];\n sequence.push_back(next_number);\n }\n return sequence[n];\n}\n```"]} +{"question_id": 122, "category": "coding", "turns": ["Write a C++ program to find the nth Fibonacci number using recursion.", "Now we define a sequence of numbers in which each number is the sum of the three preceding ones. The first three numbers are 0, -1, -1. Write a program to find the nth number."], "reference": ["Straightforward\n\n```\nint fibonacci(int n) {\n if (n <= 1) {\n return n;\n } else {\n return fibonacci(n - 1) + fibonacci(n - 2);\n }\n}\n```", "You should carefully check the initial cases for n < 3\n\n```\nint find_nth_number(int n) {\n std::vector sequence = {0, -1, -1};\n for (int i = 3; i <= n; ++i) {\n int next_number = sequence[i - 1] + sequence[i - 2] + sequence[i - 3];\n sequence.push_back(next_number);\n }\n return sequence[n];\n}\n```"]} {"question_id": 123, "category": "coding", "turns": ["Write a simple website in HTML. When a user clicks the button, it shows a random joke from a list of 4 jokes.", "How to use CSS to change the color of jokes to red?"]} {"question_id": 124, "category": "coding", "turns": ["Here is a Python function to find the length of the longest common subsequence of two input strings. Can you identify any bug in this function?\n\n```\ndef longest_common_subsequence_length(str1, str2):\n m = len(str1)\n n = len(str2)\n\n dp = [[0] * (n + 1) for _ in range(m + 1)]\n\n for i in range(1, m + 1):\n for j in range(1, n + 1):\n if str1[i - 1] == str2[j - 1]:\n dp[i][j] = dp[i - 1][j - 1] + 1\n else:\n dp[i][j] = max(dp[i - 1][j], dp[i][j - 1])\n\n return dp[m][n]\n```", "what about this one?\n\n```\ndef longest_common_subsequence(X , Y): \n # Find lengths of two strings \n m = len(X) \n n = len(Y) \n \n # Create a table to store results of sub-problems \n dp = [[None]*(n+1) for i in range(m+1)] \n \n # Fill dp[][] in bottom up manner \n for i in range(1, m+1): \n for j in range(1, n+1): \n if X[i-1] == Y[j-1]: \n dp[i][j] = dp[i-1][j-1]+1\n else: \n dp[i][j] = max(dp[i-1][j], dp[i][j-1]) \n \n return dp[m][n]\n```"], "reference": ["There is no bug in this implementation", "There is a bug for the initialization of dp array. Should use 0 rather than None"]} {"question_id": 125, "category": "coding", "turns": ["Write a function to find the highest common ancestor (not LCA) of two nodes in a binary tree.", "What if it is not a binary tree?"], "reference": ["Very simple. The function should just return the root of the tree.", "Same answer. It's still the root of the tree."]} diff --git a/examples/generate.py b/examples/generate.py index 065bf31..7895574 100644 --- a/examples/generate.py +++ b/examples/generate.py @@ -1,18 +1,22 @@ import os + os.environ["TOKENIZERS_PARALLELISM"] = "false" from umbrella.models.auto_model import AutoModelLM from umbrella.logging_config import setup_logger from umbrella.utils import TextColors + logger = setup_logger() import torch from umbrella.templates import Prompts, SysPrompts -from transformers import AutoTokenizer -from umbrella.speculation.speculation_utils import make_causal_mask, is_sentence_complete_regex, find_first_element_position +from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM +from umbrella.speculation.speculation_utils import make_causal_mask, is_sentence_complete_regex, \ + find_first_element_position import argparse import time + parser = argparse.ArgumentParser() -parser.add_argument('--model', type=str, default="meta-llama/Llama-3.1-8B-Instruct",help='model') -parser.add_argument('--template', type=str, default="meta-llama3",help='prompt template') +parser.add_argument('--model', type=str, default="meta-llama/Llama-3.1-8B-Instruct", help='model') +parser.add_argument('--template', type=str, default="meta-llama3", help='prompt template') parser.add_argument('--G', type=int, default=512, help='generation length') parser.add_argument('--offload', action='store_true', help="offload the model") parser.add_argument('--cuda_graph', action='store_true', help="whether use cuda graph") @@ -32,6 +36,12 @@ tokenizer = AutoTokenizer.from_pretrained(args.model) tokens = tokenizer.encode(text=text, return_tensors="pt").to(DEVICE) +config = AutoConfig.from_pretrained(args.model) + +# testing mistral sliding window +# config.sliding_window = True +# config.window_size = 100 + llm = AutoModelLM.from_pretrained( model_name=args.model, offload=args.offload, @@ -39,12 +49,15 @@ batch_size=1, max_length=MAX_LEN, dtype=DTYPE, - device=DEVICE + device=DEVICE, + config=config ) + eos_tokens = llm.config.eos_token_id if not isinstance(eos_tokens, list): eos_tokens = [eos_tokens] llm.alloc() + if args.cuda_graph: llm.initialize_cuda_graph([1]) attention_mask = make_causal_mask((MAX_LEN, MAX_LEN), DEVICE) @@ -52,43 +65,48 @@ position_ids = torch.arange(MAX_LEN, device=DEVICE).unsqueeze(0) prefix_len = tokens.shape[1] -logits = llm.graph_inference(input_ids=tokens, position_ids=position_ids[:,:prefix_len], - storage_ids=storage_ids[:prefix_len], attention_mask=attention_mask[:prefix_len])[0] +logits = llm.graph_inference(input_ids=tokens, position_ids=position_ids[:, :prefix_len], + storage_ids=storage_ids[:prefix_len], attention_mask=attention_mask[:prefix_len])[0] torch.cuda.synchronize() t1 = time.time() + generated_tokens = [] pos = 0 +print('----start generating answers.----') for i in range(GEN_LEN): next_token = logits[-1:].argmax(dim=-1, keepdim=True) generated_tokens.append(next_token.item()) - + generated_text = ( - tokenizer.decode( - generated_tokens, - skip_special_tokens=True, - clean_up_tokenization_spaces=True, - spaces_between_special_tokens=False, - ) - .strip() - .split(" ") - ) - - + tokenizer.decode( + generated_tokens, + skip_special_tokens=True, + clean_up_tokenization_spaces=True, + spaces_between_special_tokens=False, + ) + .strip() + .split(" ") + ) + now = len(generated_text) - 1 if now > pos: - print(" ".join(generated_text[pos:now]), end=" ", flush=True) - pos = now - - if (is_sentence_complete_regex(generated_text[-1]) and (i >= GEN_LEN - 32)) or (find_first_element_position(next_token, eos_tokens) >= 0): - break - - logits = llm.graph_inference(input_ids=next_token, position_ids=position_ids[:,prefix_len+i:prefix_len+i+1], - storage_ids=storage_ids[prefix_len+i : prefix_len+i+1], attention_mask=attention_mask[prefix_len+i:prefix_len+i+1])[0] + print(" ".join(generated_text[pos:now]), end=" ", flush=True) + pos = now + + if (is_sentence_complete_regex(generated_text[-1]) and (i >= GEN_LEN - 32)) or ( + find_first_element_position(next_token, eos_tokens) >= 0): + break + + logits = llm.graph_inference(input_ids=next_token, position_ids=position_ids[:, prefix_len + i:prefix_len + i + 1], + storage_ids=storage_ids[prefix_len + i: prefix_len + i + 1], + attention_mask=attention_mask[prefix_len + i:prefix_len + i + 1])[0] print(" ".join(generated_text[pos:]), flush=True) +print('----end generating answers.----') torch.cuda.synchronize() t2 = time.time() dec_len = len(generated_tokens) -logger.info(TextColors.colorize("Avg Accept Tokens {:.2f} | TPOT {:.2f} ms ".format(1, 1000 * (t2-t1)/dec_len), "magenta")) \ No newline at end of file +logger.info( + TextColors.colorize("Avg Accept Tokens {:.2f} | TPOT {:.2f} ms ".format(1, 1000 * (t2 - t1) / dec_len), "magenta")) \ No newline at end of file diff --git a/umbrella/attn/cache.py b/umbrella/attn/cache.py index 533206d..a21e83b 100644 --- a/umbrella/attn/cache.py +++ b/umbrella/attn/cache.py @@ -2,14 +2,66 @@ import torch import flashinfer import math -class KV_Cache: - def __init__(self, - config :AutoConfig, - batch_size :int = 1, - max_length :int = 256, - device :str = 'cuda:0', - dtype = torch.float16) -> None: + +def mha_flash(q, k, v, kv_layout, attn_mask): + return flashinfer.single_prefill_with_kv_cache( + q=q, + k=k, + v=v, + kv_layout=kv_layout, + custom_mask=attn_mask, + allow_fp16_qk_reduction=True + ) + +def mha(q, k, v, attn_mask): + """ + Args: + q (torch.Tensor): Query tensor of shape (query_len, q_head, head_dim) + k (torch.Tensor): Key tensor of shape (kv_len, kv_head, head_dim) + v (torch.Tensor): Value tensor of shape (kv_len, kv_head, head_dim) + attn_mask (torch.Tensor): Value tensor of shape (q_len, kv_len) + + Returns: + + """ + q_len, q_head, head_dim = q.shape + kv_len, kv_head, _ = k.shape + assert (q_head % kv_head == 0) + num_kv_groups = q_head // kv_head + + # Step 1: Reshape Q for GQA + # (q_len, q_head, head_dim) -> (kv_head, num_kv_groups, q_len, head_dim) + q = q.transpose(0,1).reshape(kv_head, num_kv_groups, q_len, head_dim) + k = k.transpose(0,1) + v = v.transpose(0,1) + # Step 2: Compute Attention Scores (kv_head, num_kv_groups, q_len, kv_len) + attn_scores = torch.einsum('hgld,hmd->hglm', q, k) / math.sqrt(head_dim) + + # Step 3: Apply Attention Mask + # (kv_head, num_groups, q_len, kv_len) + attn_mask = attn_mask.unsqueeze(0).unsqueeze(0).expand(kv_head, num_kv_groups, -1, -1) + attn_scores.masked_fill_(~attn_mask, torch.finfo(attn_scores.dtype).min) + + # Step 4: Compute Attention Weights + # (kv_head, num_kv_groups, q_len, kv_len) + attn_weights = torch.nn.functional.softmax(attn_scores, dim=-1, dtype=torch.float32).to(q.dtype) + + # Step 5: Compute Context Vector + # (kv_head, num_kv_groups, q_len, head_dim) + hidden_states = torch.einsum('hglm,hmd->hgld', attn_weights, v) + hidden_states = hidden_states.reshape(q_head, q_len, head_dim) + hidden_states = hidden_states.transpose(0, 1).contiguous() + return hidden_states + + +class KV_Cache: + def __init__(self, + config: AutoConfig, + batch_size: int = 1, + max_length: int = 256, + device: str = 'cuda:0', + dtype=torch.float16) -> None: self.config = config self.max_length = max_length self.device = device @@ -33,74 +85,57 @@ def __init__(self, ) self.num_layers = config.num_hidden_layers self.kv_offset = 0 - self.num_key_value_heads = config.num_key_value_heads - self.num_attention_heads = config.num_attention_heads - self.head_dim = config.hidden_size // config.num_attention_heads - self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads - - def gather_kv_incremental(self, indices: torch.LongTensor, offset:int): - self.k_cache[:,offset:offset + len(indices), :,:] = self.k_cache[:,indices, :,:] - self.v_cache[:,offset:offset + len(indices), :,:] = self.v_cache[:,indices, :,:] + def gather_kv_incremental(self, indices: torch.LongTensor, offset: int): + self.k_cache[:, offset:offset + len(indices), :, :] = self.k_cache[:, indices, :, :] + self.v_cache[:, offset:offset + len(indices), :, :] = self.v_cache[:, indices, :, :] - self.k_cache[:,offset + len(indices):, :,:] = 0.0 - self.v_cache[:,offset + len(indices):, :,:] = 0.0 + self.k_cache[:, offset + len(indices):, :, :] = 0.0 + self.v_cache[:, offset + len(indices):, :, :] = 0.0 self.kv_offset = offset + len(indices) - - - def update_kv_cache(self, - new_k_cache :torch.Tensor, - new_v_cache :torch.Tensor, - layer_idx :int, - storage_ids: torch.LongTensor - ): - + def update_kv_cache(self, + new_k_cache: torch.Tensor, + new_v_cache: torch.Tensor, + layer_idx: int, + storage_ids: torch.LongTensor + ): new_kv_len = storage_ids.shape[0] if layer_idx == 0: self.kv_offset += new_kv_len self.k_cache[layer_idx][self.kv_offset - new_kv_len:self.kv_offset] = new_k_cache self.v_cache[layer_idx][self.kv_offset - new_kv_len:self.kv_offset] = new_v_cache return self.k_cache[layer_idx][:self.kv_offset], self.v_cache[layer_idx][:self.kv_offset] - - def compute_attention(self, - query_states :torch.Tensor, - key_states :torch.Tensor, - value_states :torch.Tensor, - layer_idx, - storage_ids :torch.Tensor, - attention_mask :torch.Tensor): - + + def compute_attention(self, + query_states: torch.Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx, + storage_ids: torch.Tensor, + attention_mask: torch.Tensor): key_states, value_states = self.update_kv_cache(key_states[0], value_states[0], layer_idx, storage_ids) - hidden_states = flashinfer.single_prefill_with_kv_cache( - q = query_states[0], - k = key_states, - v = value_states, - kv_layout="NHD", - custom_mask=attention_mask[:,:self.kv_offset], - allow_fp16_qk_reduction=True - ) - + hidden_states = mha_flash(query_states[0], key_states, value_states, "NHD", attention_mask[:, :self.kv_offset]) + return hidden_states - + def clear(self): self.k_cache.zero_() self.v_cache.zero_() self.kv_offset = 0 - - def set_kv_len(self, kv_len :int): - self.kv_offset = kv_len + + def set_kv_len(self, kv_len: int): + self.kv_offset = kv_len class StaticKV_Cache: - - def __init__(self, - config :AutoConfig, - batch_size :int = 1, - max_length :int = 256, - device :str = 'cuda:0', - dtype = torch.float16) -> None: + def __init__(self, + config: AutoConfig, + batch_size: int = 1, + max_length: int = 256, + device: str = 'cuda:0', + dtype=torch.float16) -> None: self.config = config self.max_length = max_length self.device = device @@ -129,9 +164,7 @@ def __init__(self, self.head_dim = config.hidden_size // config.num_attention_heads self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads - - def gather_kv_incremental(self, indices: list[int], offset:int): - + def gather_kv_incremental(self, indices: list[int], offset: int): self.k_cache[..., offset:offset + len(indices), :] = self.k_cache[..., indices, :] self.v_cache[..., offset:offset + len(indices), :] = self.v_cache[..., indices, :] @@ -140,50 +173,173 @@ def gather_kv_incremental(self, indices: list[int], offset:int): self.kv_offset = offset + len(indices) - - - def update_kv_cache(self, - new_k_cache :torch.Tensor, - new_v_cache :torch.Tensor, - layer_idx :int, - storage_ids: torch.LongTensor - ): - + def update_kv_cache(self, + new_k_cache: torch.Tensor, + new_v_cache: torch.Tensor, + layer_idx: int, + storage_ids: torch.LongTensor + ): self.k_cache[layer_idx].index_copy_(dim=-2, index=storage_ids, source=new_k_cache) self.v_cache[layer_idx].index_copy_(dim=-2, index=storage_ids, source=new_v_cache) - + return self.k_cache[layer_idx], self.v_cache[layer_idx] - def clear(self): self.k_cache.zero_() self.v_cache.zero_() self.kv_offset = 0 - - def set_kv_len(self, kv_len :int): - self.kv_offset = kv_len - - def compute_attention(self, - query_states :torch.Tensor, - key_states :torch.Tensor, - value_states :torch.Tensor, - layer_idx, - storage_ids :torch.Tensor, - attention_mask :torch.Tensor): + + def set_kv_len(self, kv_len: int): + self.kv_offset = kv_len + + def compute_attention(self, + query_states: torch.Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx, + storage_ids: torch.Tensor, + attention_mask: torch.Tensor): bsz, _, q_len, _ = query_states.shape - - key_states, value_states = self.update_kv_cache(key_states[0], value_states[0], layer_idx, storage_ids) + + key_states, value_states = self.update_kv_cache(key_states[0], value_states[0], layer_idx, storage_ids) query_states = query_states[0] - + query_states = query_states.reshape(self.num_key_value_heads, q_len * self.num_key_value_groups, self.head_dim) attn_weights = torch.matmul(query_states, key_states.transpose(1, 2)) / math.sqrt(self.head_dim) - mask = attention_mask[None,:,:].repeat(1, self.num_key_value_groups, 1) - + mask = attention_mask[None, :, :].repeat(1, self.num_key_value_groups, 1) + attn_weights.masked_fill_(~mask, torch.finfo(attn_weights.dtype).min) - + attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) hidden_states = torch.matmul(attn_weights, value_states) hidden_states = hidden_states.reshape(bsz, self.num_attention_heads, q_len, -1) hidden_states = hidden_states.transpose(1, 2).contiguous() - - return hidden_states \ No newline at end of file + + return hidden_states + + +class SlidingWindowKV_Cache: + def __init__(self, + config: AutoConfig, + batch_size: int = 1, + max_length: int = 256, + device: str = 'cuda:0', + dtype=torch.float16) -> None: + + self.config = config + self.max_length = max_length + self.window_size = self.config.window_size + self.device = device + self.dtype = dtype + self.num_key_value_heads = config.num_key_value_heads + self.num_attention_heads = config.num_attention_heads + self.head_dim = config.hidden_size // config.num_attention_heads + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + + # initializing Key-Value Cache + self.k_cache = torch.zeros( + config.num_hidden_layers, + max_length, + config.num_key_value_heads, + config.hidden_size // config.num_attention_heads, + device=self.device, + dtype=self.dtype + ) + + self.v_cache = torch.zeros( + config.num_hidden_layers, + max_length, + config.num_key_value_heads, + config.hidden_size // config.num_attention_heads, + device=self.device, + dtype=self.dtype + ) + self.kv_offset = 0 + + def update_kv_cache(self, + new_k_cache: torch.Tensor, + new_v_cache: torch.Tensor, + layer_idx: int, + storage_ids: torch.LongTensor + ): + + new_kv_len = storage_ids.shape[0] + # # calculating new offset + need_len = new_kv_len + min(self.kv_offset, self.window_size) + if need_len > self.max_length: + raise ValueError(f'now kv_offset {self.kv_offset} new_kv_len {new_kv_len} need_len {need_len} exceeds max_length {self.max_length}') + + # update KV Cache + if layer_idx == 0: + self.kv_offset += new_kv_len + if self.kv_offset >= self.max_length: + # print(f'layer idx need shift data, {self.kv_offset}') + self.k_cache[:, :self.window_size, :, :] = self.k_cache[:, self.kv_offset - self.window_size: self.kv_offset, :, :] + self.v_cache[:, :self.window_size, :, :] = self.v_cache[:, self.kv_offset - self.window_size: self.kv_offset, :, :] + self.kv_offset = self.window_size + new_kv_len + # print(f'layer_idx, add, {self.kv_offset}') + self.k_cache[layer_idx][self.kv_offset - new_kv_len:self.kv_offset] = new_k_cache + self.v_cache[layer_idx][self.kv_offset - new_kv_len:self.kv_offset] = new_v_cache + start_idx = max(0, self.kv_offset - (self.window_size + new_kv_len)) + return self.k_cache[layer_idx][start_idx:self.kv_offset], self.v_cache[layer_idx][start_idx:self.kv_offset] + + def clear(self): + self.k_cache.zero_() + self.v_cache.zero_() + self.kv_offset = 0 + + def set_kv_len(self, kv_len: int): + self.kv_offset = kv_len + + def gather_kv_incremental(self, indices: torch.LongTensor, offset: int): + self.k_cache[:, offset:offset + len(indices), :, :] = self.k_cache[:, indices, :, :] + self.v_cache[:, offset:offset + len(indices), :, :] = self.v_cache[:, indices, :, :] + + self.k_cache[:, offset + len(indices):, :, :] = 0.0 + self.v_cache[:, offset + len(indices):, :, :] = 0.0 + + self.kv_offset = offset + len(indices) + + def compute_attention(self, + query_states: torch.Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx, + storage_ids: torch.Tensor, + attention_mask: torch.Tensor): + """ + Computes the attention output using a sliding window mechanism. + + This function updates the KV cache, constructs the appropriate attention mask that considers: + - Sliding window attention (limits visible tokens to a fixed window size) + - Original `attention_mask` (to handle padding tokens) + - Causal Mask (ensures auto-regressive decoding) + + Args: + query_states (torch.Tensor): Query tensor of shape (batch, num_heads, query_len, head_dim) + key_states (torch.Tensor): Key tensor of shape (batch, num_heads, kv_len, head_dim) + value_states (torch.Tensor): Value tensor of shape (batch, num_heads, kv_len, head_dim) + layer_idx (int): Current layer index in the transformer model. + storage_ids (torch.Tensor): Index positions to store the new KV cache. + attention_mask (torch.Tensor): Casual mask. (q_len, squence_len) + + Returns: + hidden_states (torch.Tensor): Output hidden states after applying attention. + """ + + # Step 1: Update KV Cache (keep only the latest window_size tokens) + key_states, value_states = self.update_kv_cache(key_states[0], value_states[0], layer_idx, storage_ids) + q_len = storage_ids.shape[0] + kv_len = key_states.shape[0] + + # Step 2: Generate Sliding Window Attention Mask + # Create a 2D attention mask where each query position can attend only to the latest `window_size` tokens + query_indices = (kv_len - q_len) + torch.arange(q_len, device=self.device).unsqueeze(1) # Shape: (q_len, 1) + kv_indices = torch.arange(kv_len, device=self.device).unsqueeze(0) # Shape: (1, kv_offset) + diff = query_indices - kv_indices + # Compute boolean mask: True if kv_index is within `window_size` of the query_index + attn_mask = torch.logical_and(diff <= self.window_size, diff >= 0) # (q_len, kv_offset) + # Step 3: Compute Attention Both are ok. + # hidden_states = mha(query_states[0], key_states, value_states, attn_mask) + hidden_states = mha_flash(query_states[0], key_states, value_states, "NHD", attn_mask) + return hidden_states diff --git a/umbrella/models/auto_model.py b/umbrella/models/auto_model.py index 20f9483..da2d2d3 100644 --- a/umbrella/models/auto_model.py +++ b/umbrella/models/auto_model.py @@ -1,5 +1,6 @@ from .llama import Llama, LlamaAwq, LlamaOffload, LlamaAwqOffload, LlamaCudagraph from .qwen import Qwen, QwenOffload, QwenAwq, QwenAwqOffload, QwenCudagraph +from .mistral import Mistral, MistralOffload, MistralAwq, MistralAwqOffload, MistralCudagraph class AutoModelLM: """ 自动模型加载器,根据模型类型动态加载对应的类。 @@ -102,7 +103,9 @@ class AutoModelLM: "Qwen/Qwen2.5-32B-Instruct-AWQ": QwenAwq, "Qwen/Qwen2.5-72B-Instruct-AWQ": QwenAwq, "KirillR/QwQ-32B-Preview-AWQ": QwenAwq, - "casperhansen/deepseek-r1-distill-qwen-32b-awq":QwenAwq + "casperhansen/deepseek-r1-distill-qwen-32b-awq": QwenAwq, + "mistralai/Mistral-Small-24B-Instruct-2501": Mistral, + "mistralai/Mistral-7B-Instruct-v0.3": Mistral } _CUDAGRAPH_MODEL_MAPPING = { diff --git a/umbrella/models/mistral.py b/umbrella/models/mistral.py new file mode 100644 index 0000000..a7837fc --- /dev/null +++ b/umbrella/models/mistral.py @@ -0,0 +1,404 @@ +# from _typeshed import NoneType +from transformers import MistralForCausalLM, MistralConfig +import torch +import torch.nn.functional as F +import gc +import flashinfer +from ..attn.cache import KV_Cache, SlidingWindowKV_Cache +from .mistral_layer import MistralLayer, MistralPackedLayer +from .base import LLMBase +from .model_utils import apply_rotary_pos_emb, layer_norm, capture_graph +from tqdm import tqdm + +""" +Standard Mistral +Compared to Llama, Support sliding window + GQA +""" +class Mistral(LLMBase): + def __init__(self, + model_name: str, + batch_size: int = 1, + max_length: int = 256, + device: str = 'cuda:0', + dtype=torch.float16, + config=None) -> None: + + super().__init__() + self.batch_size = batch_size + self.device = device + self.dtype = dtype + if config: + self.config = config + else: + self.config = MistralConfig.from_pretrained(model_name) + self.model_name = model_name + self.max_length = max_length + self.hidden_size = self.config.hidden_size + self.num_heads = self.config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = self.config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads # here > 1 in Mistral + self.max_position_embeddings = self.config.max_position_embeddings + self.rope_theta = self.config.rope_theta + self.eos_tokens = self.config.eos_token_id if (isinstance(self.config.eos_token_id, list)) else [ + self.config.eos_token_id] + + def alloc(self, **kwargs): + # if using sliding winodw, use KV_static_Cache to update past key/value + if self.config.sliding_window is not None and self.config.sliding_window == True: + self.kv_cache = SlidingWindowKV_Cache(self.config, max_length=self.max_length, device=self.device, + dtype=self.dtype, batch_size=self.batch_size) + else: + self.kv_cache = KV_Cache(self.config, max_length=self.max_length, device=self.device, dtype=self.dtype, + batch_size=self.batch_size) + hf_model = MistralForCausalLM.from_pretrained(self.model_name, torch_dtype=self.dtype) + self.embed_tokens = hf_model.model.embed_tokens.weight.detach().to(self.device) + if self.config.tie_word_embeddings: + self.lm_head = self.embed_tokens + else: + self.lm_head = hf_model.lm_head.weight.detach().to(self.device) + + self.norm_weight = hf_model.model.norm.weight.detach().to(self.device) + self.norm_variance_epsilon = hf_model.model.norm.variance_epsilon + + if hasattr(hf_model.model.layers[0].self_attn, "rotary_emb"): + rotary_emb = hf_model.model.layers[0].self_attn.rotary_emb + else: + rotary_emb = hf_model.model.rotary_emb + self.inv_freq = rotary_emb.inv_freq.detach().to(self.device) + if hasattr(rotary_emb, "attention_scaling"): + self.attention_scaling = rotary_emb.attention_scaling + else: + self.attention_scaling = 1.0 # 默认值 + + position_ids = torch.arange(0, self.max_length).unsqueeze(0).to(self.device) + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + self.cos_cache = emb.cos()[0] + self.sin_cache = emb.sin()[0] + self.cos_cache = self.cos_cache * self.attention_scaling + self.sin_cache = self.sin_cache * self.attention_scaling + self.cos_cache = self.cos_cache.to(self.dtype) + self.sin_cache = self.sin_cache.to(self.dtype) + + self.layers: list[MistralLayer] = [] + + ## loop hf_model.model.layers, and transfer to format MistralLayer, and store it to self.layers + for idx, hf_layer in enumerate(hf_model.model.layers): + layer = MistralLayer(idx) + layer.init_parameters(hf_layer=hf_layer) + layer.to(self.device) + self.layers.append(layer) + hf_model.model.layers[idx] = None + gc.collect() + + self.num_layers = len(self.layers) + + @torch.inference_mode() + def layer_compute(self, + buffer: MistralLayer, + layer_idx: int, + hidden_states: torch.FloatTensor, + position_ids: torch.LongTensor, + attention_mask: torch.FloatTensor, + storage_ids: torch.LongTensor): + + residual = hidden_states + bsz, q_len, _ = hidden_states.size() + + hidden_states = layer_norm(hidden_states, buffer.input_layernorm_variance_epsilon, + buffer.input_layernorm_weight) + bsz, q_len, _ = hidden_states.size() + + query_states = F.linear(hidden_states, buffer.wq) + key_states = F.linear(hidden_states, buffer.wk) + value_states = F.linear(hidden_states, buffer.wv) + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim) + + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, self.cos_cache, self.sin_cache, + position_ids) + hidden_states = self.kv_cache.compute_attention( + query_states, key_states, value_states, layer_idx, storage_ids, attention_mask + ) + hidden_states = hidden_states.reshape(bsz, q_len, self.hidden_size) + + hidden_states = F.linear(hidden_states, buffer.wo) + hidden_states = residual + hidden_states + residual = hidden_states + hidden_states = layer_norm(hidden_states, buffer.post_attention_layernorm_variance_epsilon, + buffer.post_attention_layernorm_weight) + up = F.linear(hidden_states, buffer.up_proj) + gate = F.linear(hidden_states, buffer.gate_proj) + gate = F.silu(gate) + hidden_states = gate * up + hidden_states = F.linear(hidden_states, buffer.down_proj) + hidden_states = residual + hidden_states + + return hidden_states + + @torch.inference_mode() + def inference(self, + input_ids: torch.LongTensor, + position_ids: torch.LongTensor, + attention_mask: torch.FloatTensor, + storage_ids: torch.LongTensor): + + hidden_states = F.embedding(input_ids, self.embed_tokens) + # print('ifnerence', attention_mask) + for idx in range(self.num_layers): + hidden_states = self.layer_compute(self.layers[idx], idx, hidden_states, position_ids, attention_mask, + storage_ids) + + b, s, h = hidden_states.shape + + hidden_states = hidden_states.reshape(b * s, h) + hidden_states = flashinfer.rmsnorm(hidden_states, self.norm_weight, self.norm_variance_epsilon) + hidden_states = hidden_states.reshape(b, s, h) + logits = F.linear(hidden_states, self.lm_head).float() + return logits + + def gather_kv_incremental(self, indices: torch.LongTensor, offset: int): + self.kv_cache.gather_kv_incremental(indices=indices, offset=offset) + + def clear(self): + + self.kv_cache.clear() + +class MistralOffload(Mistral): + def __init__(self, model_name, batch_size=1, max_length=256, device='cuda:0', dtype=torch.float16): + super().__init__(model_name, batch_size, max_length, device, dtype) + self.load_stream = torch.cuda.Stream(device=device) + + def alloc(self, **kwargs): + if self.config.sliding_window is not None: + self.kv_cache = SlidingWindowKV_Cache(self.config, max_length=self.max_length, device=self.device, + dtype=self.dtype, batch_size=self.batch_size) + else: + self.kv_cache = KV_Cache(self.config, max_length=self.max_length, device=self.device, dtype=self.dtype, + batch_size=self.batch_size) + hf_model = MistralForCausalLM.from_pretrained(self.model_name, torch_dtype=self.dtype) + self.embed_tokens = hf_model.model.embed_tokens.weight.detach().to(self.device) + if self.config.tie_word_embeddings: + self.lm_head = self.embed_tokens + else: + self.lm_head = hf_model.lm_head.weight.detach().to(self.device) + + self.norm_weight = hf_model.model.norm.weight.detach().to(self.device) + self.norm_variance_epsilon = hf_model.model.norm.variance_epsilon + + print('hf_model', hf_model.model) + if hasattr(hf_model.model.layers[0].self_attn, "rotary_emb"): + rotary_emb = hf_model.model.layers[0].self_attn.rotary_emb + else: + rotary_emb = hf_model.model.rotary_emb + self.inv_freq = rotary_emb.inv_freq.detach().to(self.device) + if hasattr(rotary_emb, "attention_scaling"): + self.attention_scaling = rotary_emb.attention_scaling + else: + self.attention_scaling = 1.0 # 默认值 + + position_ids = torch.arange(0, self.max_length).unsqueeze(0).to(self.device) + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + self.cos_cache = emb.cos()[0] + self.sin_cache = emb.sin()[0] + self.cos_cache = self.cos_cache * self.attention_scaling + self.sin_cache = self.sin_cache * self.attention_scaling + self.cos_cache = self.cos_cache.to(self.dtype) + self.sin_cache = self.sin_cache.to(self.dtype) + + self.layers: list[MistralLayer] = [] + + ## loop hf_model.model.layers, and transfer to format MistralLayer, and store it to self.layers + for idx, hf_layer in enumerate(hf_model.model.layers): + layer = MistralLayer(idx) + layer.init_parameters(hf_layer=hf_layer) + layer.to(self.device) + self.layers.append(layer) + hf_model.model.layers[idx] = None + gc.collect() + + self.num_layers = len(self.layers) + + assert self.num_layers % 2 == 0 + self.buffer = [MistralLayer(-1, self.device) for _ in range(2)] + self.buffer[0].alloc_space(self.layers[0], self.device) + self.buffer[1].alloc_space(self.layers[0], self.device) + + @torch.inference_mode() + def inference(self, + input_ids: torch.LongTensor, + position_ids: torch.LongTensor, + attention_mask: torch.FloatTensor, + storage_ids: torch.LongTensor): + + hidden_states = F.embedding(input_ids, self.embed_tokens) + if self.buffer[0].layer_idx != 0: + self.buffer[0].copy(self.layers[0]) + torch.cuda.synchronize() + for idx in range(self.num_layers): + with torch.cuda.stream(self.load_stream): + self.buffer[(idx + 1) % 2].copy(self.layers[(idx + 1) % self.num_layers]) + + hidden_states = self.layer_compute(self.buffer[idx % 2], idx, hidden_states, position_ids, attention_mask, + storage_ids) + torch.cuda.synchronize() + b, s, h = hidden_states.shape + + hidden_states = hidden_states.reshape(b * s, h) + hidden_states = flashinfer.rmsnorm(hidden_states, self.norm_weight, self.norm_variance_epsilon) + hidden_states = hidden_states.reshape(b, s, h) + logits = F.linear(hidden_states, self.lm_head).float() + return logits + +class MistralAwq(Mistral): + pass + +class MistralAwqOffload(Mistral): + pass + +class MistralCudagraph(Mistral): + def __init__(self, model_name, batch_size=1, max_length=256, device='cuda:0', dtype=torch.float16): + super().__init__(model_name, batch_size, max_length, device, dtype) + + self.callables = {} + self.mempool = None + + + def alloc(self, **kwargs): + exit_layer = kwargs.pop("exit_layer", -1) + + # if using sliding winodw, use KV_static_Cache to update past key/value + if self.config.sliding_window is not None: + self.kv_cache = SlidingWindowKV_Cache(self.config, max_length=self.max_length, device=self.device, + dtype=self.dtype, batch_size=self.batch_size) + else: + self.kv_cache = KV_Cache(self.config, max_length=self.max_length, device=self.device, dtype=self.dtype, + batch_size=self.batch_size) + hf_model = MistralForCausalLM.from_pretrained(self.model_name, torch_dtype=self.dtype) + self.embed_tokens = hf_model.model.embed_tokens.weight.detach().to(self.device) + if self.config.tie_word_embeddings: + self.lm_head = self.embed_tokens + else: + self.lm_head = hf_model.lm_head.weight.detach().to(self.device) + + self.norm_weight = hf_model.model.norm.weight.detach().to(self.device) + self.norm_variance_epsilon = hf_model.model.norm.variance_epsilon + + if hasattr(hf_model.model.layers[0].self_attn, "rotary_emb"): + rotary_emb = hf_model.model.layers[0].self_attn.rotary_emb + else: + rotary_emb = hf_model.model.rotary_emb + self.inv_freq = rotary_emb.inv_freq.detach().to(self.device) + if hasattr(rotary_emb, "attention_scaling"): + self.attention_scaling = rotary_emb.attention_scaling + else: + self.attention_scaling = 1.0 # 默认值 + + position_ids = torch.arange(0, self.max_length).unsqueeze(0).to(self.device) + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + self.cos_cache = emb.cos()[0] + self.sin_cache = emb.sin()[0] + self.cos_cache = self.cos_cache * self.attention_scaling + self.sin_cache = self.sin_cache * self.attention_scaling + self.cos_cache = self.cos_cache.to(self.dtype) + self.sin_cache = self.sin_cache.to(self.dtype) + + self.layers: list[MistralLayer] = [] + + ## loop hf_model.model.layers, and transfer to format MistralLayer, and store it to self.layers + for idx, hf_layer in enumerate(hf_model.model.layers): + if exit_layer > 0 and idx >= exit_layer: + break + layer = MistralPackedLayer(idx) + layer.init_parameters(hf_layer=hf_layer) + layer.to(self.device) + self.layers.append(layer) + hf_model.model.layers[idx] = None + gc.collect() + + self.num_layers = len(self.layers) + + @torch.inference_mode() + def layer_compute(self, + buffer: MistralPackedLayer, + layer_idx: int, + hidden_states: torch.FloatTensor, + position_ids: torch.LongTensor, + attention_mask: torch.FloatTensor, + storage_ids: torch.LongTensor): + + residual = hidden_states + bsz, q_len, _ = hidden_states.size() + + hidden_states = layer_norm(hidden_states, buffer.input_layernorm_variance_epsilon, + buffer.input_layernorm_weight) + bsz, q_len, _ = hidden_states.size() + qkv = F.linear(hidden_states, buffer.wqkv) + query_states = qkv[..., :self.hidden_size] + key_states = qkv[..., self.hidden_size:self.hidden_size + self.head_dim * self.num_key_value_heads] + value_states = qkv[..., self.hidden_size + self.head_dim * self.num_key_value_heads:] + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, self.cos_cache, self.sin_cache, + position_ids, unsqueeze_dim=1) + hidden_states = self.kv_cache.compute_attention( + query_states, key_states, value_states, layer_idx, storage_ids, attention_mask + ) + + hidden_states = hidden_states.reshape(bsz, q_len, self.hidden_size) + hidden_states = F.linear(hidden_states, buffer.wo) + hidden_states = residual + hidden_states + residual = hidden_states + hidden_states = layer_norm(hidden_states, buffer.post_attention_layernorm_variance_epsilon, + buffer.post_attention_layernorm_weight) + up = F.linear(hidden_states, buffer.up_proj) + gate = F.linear(hidden_states, buffer.gate_proj) + gate = F.silu(gate) + hidden_states = gate * up + hidden_states = F.linear(hidden_states, buffer.down_proj) + hidden_states = residual + hidden_states + + return hidden_states + + @torch.inference_mode() + def initialize_cuda_graph(self, + decoding_seqlens: list[int], + n_warmups=12): + gc.collect() + self.mempool = torch.cuda.graphs.graph_pool_handle() + for decoding_seqlen in decoding_seqlens: + if decoding_seqlen not in self.callables: + self.callables[decoding_seqlen] = capture_graph( + llm=self, + decoding_seqlen=decoding_seqlen, + mempool=self.mempool, + n_warmups=n_warmups + ) + self.clear() + + @torch.inference_mode() + def graph_inference(self, + input_ids: torch.LongTensor, + storage_ids: torch.LongTensor, + position_ids=None, + attention_mask=None, + ): + dec_length = input_ids.shape[1] + if dec_length in self.callables.keys(): + logits = self.callables[dec_length](input_ids, storage_ids, position_ids, attention_mask) + else: + logits = self.inference(input_ids, position_ids, attention_mask, storage_ids) + return logits \ No newline at end of file diff --git a/umbrella/models/mistral_layer.py b/umbrella/models/mistral_layer.py new file mode 100644 index 0000000..9a65db6 --- /dev/null +++ b/umbrella/models/mistral_layer.py @@ -0,0 +1,155 @@ +from __future__ import annotations +import torch +from transformers.models.mistral.modeling_mistral import MistralDecoderLayer +from ..quantization.awq_utils import AwqLinear + + +class MistralLayer: + def __init__(self, layer_idx, device="cpu") -> None: + self.wq: torch.Tensor = None + self.wk: torch.Tensor = None + self.wv: torch.Tensor = None + self.wo: torch.Tensor = None + + self.gate_proj: torch.Tensor = None + self.up_proj: torch.Tensor = None + self.down_proj: torch.Tensor = None + + self.input_layernorm_weight: torch.Tensor = None + self.input_layernorm_variance_epsilon: float = 0.0 + + self.post_attention_layernorm_weight: torch.Tensor = None + self.post_attention_layernorm_variance_epsilon: float = 0.0 + + self.layer_idx = layer_idx + self.device = device + + def init_parameters(self, hf_layer: MistralDecoderLayer): + self.wq: torch.Tensor = hf_layer.self_attn.q_proj.weight.detach() + self.wk: torch.Tensor = hf_layer.self_attn.k_proj.weight.detach() + self.wv: torch.Tensor = hf_layer.self_attn.v_proj.weight.detach() + self.wo: torch.Tensor = hf_layer.self_attn.o_proj.weight.detach() + + self.gate_proj = hf_layer.mlp.gate_proj.weight.detach() + self.up_proj = hf_layer.mlp.up_proj.weight.detach() + self.down_proj = hf_layer.mlp.down_proj.weight.detach() + + self.input_layernorm_weight = hf_layer.input_layernorm.weight.detach() + self.input_layernorm_variance_epsilon = hf_layer.input_layernorm.variance_epsilon + + self.post_attention_layernorm_weight = hf_layer.post_attention_layernorm.weight.detach() + self.post_attention_layernorm_variance_epsilon = hf_layer.post_attention_layernorm.variance_epsilon + + def to(self, device: str = 'cuda:0', non_blocking=True): + self.device = device + self.input_layernorm_weight = self.input_layernorm_weight.to(device, non_blocking=non_blocking) + self.post_attention_layernorm_weight = self.post_attention_layernorm_weight.to(device, + non_blocking=non_blocking) + self.wq = self.wq.to(device, non_blocking=non_blocking) + self.wk = self.wk.to(device, non_blocking=non_blocking) + self.wv = self.wv.to(device, non_blocking=non_blocking) + self.wo = self.wo.to(device, non_blocking=non_blocking) + self.gate_proj = self.gate_proj.to(device, non_blocking=non_blocking) + self.up_proj = self.up_proj.to(device, non_blocking=non_blocking) + self.down_proj = self.down_proj.to(device, non_blocking=non_blocking) + + def copy(self, layer: MistralLayer): + self.wq.copy_(layer.wq, non_blocking=True) + self.wk.copy_(layer.wk, non_blocking=True) + self.wv.copy_(layer.wv, non_blocking=True) + self.wo.copy_(layer.wo, non_blocking=True) + self.gate_proj.copy_(layer.gate_proj, non_blocking=True) + self.up_proj.copy_(layer.up_proj, non_blocking=True) + self.down_proj.copy_(layer.down_proj, non_blocking=True) + + self.input_layernorm_weight.copy_(layer.input_layernorm_weight, non_blocking=True) + self.post_attention_layernorm_weight.copy_(layer.post_attention_layernorm_weight, non_blocking=True) + self.input_layernorm_variance_epsilon = layer.input_layernorm_variance_epsilon + self.post_attention_layernorm_variance_epsilon = layer.post_attention_layernorm_variance_epsilon + self.layer_idx = layer.layer_idx + + def alloc_space(self, layer: MistralLayer, device): + self.device = device + self.wq = torch.zeros_like(layer.wq).to(device) + self.wk = torch.zeros_like(layer.wk).to(device) + self.wv = torch.zeros_like(layer.wv).to(device) + self.wo = torch.zeros_like(layer.wo).to(device) + + self.gate_proj = torch.zeros_like(layer.gate_proj).to(device) + self.up_proj = torch.zeros_like(layer.up_proj).to(device) + self.down_proj = torch.zeros_like(layer.down_proj).to(device) + self.input_layernorm_weight = torch.zeros_like(layer.input_layernorm_weight).to(device) + self.post_attention_layernorm_weight = torch.zeros_like(layer.post_attention_layernorm_weight).to(device) + + +class MistralPackedLayer: + def __init__(self, layer_idx, device="cpu") -> None: + self.wqkv: torch.Tensor = None + + self.gate_proj: torch.Tensor = None + self.up_proj: torch.Tensor = None + self.down_proj: torch.Tensor = None + + self.input_layernorm_weight: torch.Tensor = None + self.input_layernorm_variance_epsilon: float = 0.0 + + self.post_attention_layernorm_weight: torch.Tensor = None + self.post_attention_layernorm_variance_epsilon: float = 0.0 + + self.layer_idx = layer_idx + self.device = device + + def init_parameters(self, hf_layer: MistralDecoderLayer): + self.wqkv: torch.Tensor = torch.cat( + [ + hf_layer.self_attn.q_proj.weight.detach(), + hf_layer.self_attn.k_proj.weight.detach(), + hf_layer.self_attn.v_proj.weight.detach(), + ], + dim=0 + ) + self.wo: torch.Tensor = hf_layer.self_attn.o_proj.weight.detach() + self.gate_proj = hf_layer.mlp.gate_proj.weight.detach() + self.up_proj = hf_layer.mlp.up_proj.weight.detach() + self.down_proj = hf_layer.mlp.down_proj.weight.detach() + + self.input_layernorm_weight = hf_layer.input_layernorm.weight.detach() + self.input_layernorm_variance_epsilon = hf_layer.input_layernorm.variance_epsilon + + self.post_attention_layernorm_weight = hf_layer.post_attention_layernorm.weight.detach() + self.post_attention_layernorm_variance_epsilon = hf_layer.post_attention_layernorm.variance_epsilon + + def to(self, device: str = 'cuda:0', non_blocking=True): + self.device = device + self.input_layernorm_weight = self.input_layernorm_weight.to(device, non_blocking=non_blocking) + self.post_attention_layernorm_weight = self.post_attention_layernorm_weight.to(device, + non_blocking=non_blocking) + self.wqkv = self.wqkv.to(device, non_blocking=non_blocking) + self.wo = self.wo.to(device, non_blocking=non_blocking) + self.gate_proj = self.gate_proj.to(device, non_blocking=non_blocking) + self.up_proj = self.up_proj.to(device, non_blocking=non_blocking) + self.down_proj = self.down_proj.to(device, non_blocking=non_blocking) + + def copy(self, layer: LlamaPackedLayer): + self.wqkv.copy_(layer.wqkv, non_blocking=True) + self.wo.copy_(layer.wo, non_blocking=True) + self.gate_proj.copy_(layer.gate_proj, non_blocking=True) + self.up_proj.copy_(layer.up_proj, non_blocking=True) + self.down_proj.copy_(layer.down_proj, non_blocking=True) + + self.input_layernorm_weight.copy_(layer.input_layernorm_weight, non_blocking=True) + self.post_attention_layernorm_weight.copy_(layer.post_attention_layernorm_weight, non_blocking=True) + self.input_layernorm_variance_epsilon = layer.input_layernorm_variance_epsilon + self.post_attention_layernorm_variance_epsilon = layer.post_attention_layernorm_variance_epsilon + self.layer_idx = layer.layer_idx + + def alloc_space(self, layer: MistralPackedLayer, device): + self.device = device + self.wqkv = torch.zeros_like(layer.wqkv).to(device) + self.wo = torch.zeros_like(layer.wo).to(device) + + self.gate_proj = torch.zeros_like(layer.gate_proj).to(device) + self.up_proj = torch.zeros_like(layer.up_proj).to(device) + self.down_proj = torch.zeros_like(layer.down_proj).to(device) + self.input_layernorm_weight = torch.zeros_like(layer.input_layernorm_weight).to(device) + self.post_attention_layernorm_weight = torch.zeros_like(layer.post_attention_layernorm_weight).to(device) diff --git a/umbrella/templates.py b/umbrella/templates.py index 5730deb..b163a45 100644 --- a/umbrella/templates.py +++ b/umbrella/templates.py @@ -14,7 +14,9 @@ 'qwen': """<|im_start|>user {}<|im_end|> <|im_start|>assistant -""" +""", + 'mistral-24b': """[INST] {} [/INST]""", + 'mistral-7b': """[INST] {} [/INST]""" } SysPrompts = { @@ -26,8 +28,9 @@ Environment: ipython<|eot_id|>""", 'qwen': """<|im_start|>system You are a helpful assistant.<|im_end|> -""" - +""", + 'mistral-24b': """[INST] You are a knowledgeable, efficient, and direct AI assistant. Provide concise answers and focus on key information. [/INST]""", + 'mistral-7b': """[INST] You are a helpful AI assistant. Provide concise and informative responses. [/INST]""" } ExtraPrompts = {