Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -171,4 +171,5 @@ cython_debug/
.pypirc
.vscode/
app/.gradio/
test*
test*
.idea/
2 changes: 1 addition & 1 deletion examples/data/question.jsonl
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
{"question_id": 119, "category": "math", "turns": ["Benjamin went to a bookstore and purchased a variety of books. He bought 5 copies of a sci-fi novel, each priced at $20, 3 copies of a history book priced at $30 each, and 2 copies of a philosophy book for $45 each.\nWhat was the total cost of his purchases?", "Suppose Benjamin decides to sell each of these books at a 25% markup from the price he purchased them. What would be his total revenue if he sold all the books he bought?"], "reference": ["280", "350"]}
{"question_id": 120, "category": "math", "turns": ["Given that f(x) = 4x^3 - 9x - 14, find the value of f(2).", "Find x such that f(x) = 0."], "reference": ["f(2) = 0", "x = 2"]}
{"question_id": 121, "category": "coding", "turns": ["Develop a Python program that reads all the text files under a directory and returns top-5 words with the most number of occurrences.", "Can you parallelize it?"], "reference": ["Can be simple solutions like using Counter\n\nSample answer:\n```\nimport os\nimport re\nfrom collections import Counter\ndef get_files_in_directory(directory):\n return [os.path.join(directory, f) for f in os.listdir(directory) if os.path.isfile(os.path.join(directory, f)) and f.endswith('.txt')]\ndef read_file(file_path):\n with open(file_path, 'r', encoding='utf-8') as file:\n return file.read()\ndef count_words(text):\n words = re.findall(r'\\w+', text.lower())\n return Counter(words)\ndef main():\n directory = input(\"Enter the directory path: \")\n files = get_files_in_directory(directory)\n word_counts = Counter()\n for file in files:\n text = read_file(file)\n word_counts += count_words(text)\n top_5_words = word_counts.most_common(5)\n print(\"Top 5 words with the most number of occurrences:\")\n for word, count in top_5_words:\n print(f\"{word}: {count}\")\nif __name__ == \"__main__\":\n main()\n```", "You should carefully check whether the parallelization logic is correct and choose the faster implementation.\n\nSample answer:\n```\nimport os\nimport re\nfrom collections import Counter\nimport concurrent.futures\ndef get_files_in_directory(directory):\n return [os.path.join(directory, f) for f in os.listdir(directory) if os.path.isfile(os.path.join(directory, f)) and f.endswith('.txt')]\ndef read_file(file_path):\n with open(file_path, 'r', encoding='utf-8') as file:\n return file.read()\ndef count_words(text):\n words = re.findall(r'\\w+', text.lower())\n return Counter(words)\ndef process_file(file):\n text = read_file(file)\n return count_words(text)\ndef main():\n directory = input(\"Enter the directory path: \")\n files = get_files_in_directory(directory)\n word_counts = Counter()\n with concurrent.futures.ThreadPoolExecutor() as executor:\n future_word_counts = {executor.submit(process_file, file): file for file in files}\n for future in concurrent.futures.as_completed(future_word_counts):\n word_counts += future.result()\n top_5_words = word_counts.most_common(5)\n print(\"Top 5 words with the most number of occurrences:\")\n for word, count in top_5_words:\n print(f\"{word}: {count}\")\nif __name__ == \"__main__\":\n main()\n```"]}
{"question_id": 122, "category": "coding", "turns": ["Write a C++ program to find the nth Fibonacci number using recursion.", "Now we define a sequence of numbers in which each number is the sum of the three preceding ones. The first three numbers are 0, -1, -1. Write a program to find the nth number."], "reference": ["Straightforward\n\n```\nint fibonacci(int n) {\n if (n <= 1) {\n return n;\n } else {\n return fibonacci(n - 1) + fibonacci(n - 2);\n }\n}\n```", "You should carefully check the inital cases for n < 3\n\n```\nint find_nth_number(int n) {\n std::vector<int> sequence = {0, -1, -1};\n for (int i = 3; i <= n; ++i) {\n int next_number = sequence[i - 1] + sequence[i - 2] + sequence[i - 3];\n sequence.push_back(next_number);\n }\n return sequence[n];\n}\n```"]}
{"question_id": 122, "category": "coding", "turns": ["Write a C++ program to find the nth Fibonacci number using recursion.", "Now we define a sequence of numbers in which each number is the sum of the three preceding ones. The first three numbers are 0, -1, -1. Write a program to find the nth number."], "reference": ["Straightforward\n\n```\nint fibonacci(int n) {\n if (n <= 1) {\n return n;\n } else {\n return fibonacci(n - 1) + fibonacci(n - 2);\n }\n}\n```", "You should carefully check the initial cases for n < 3\n\n```\nint find_nth_number(int n) {\n std::vector<int> sequence = {0, -1, -1};\n for (int i = 3; i <= n; ++i) {\n int next_number = sequence[i - 1] + sequence[i - 2] + sequence[i - 3];\n sequence.push_back(next_number);\n }\n return sequence[n];\n}\n```"]}
{"question_id": 123, "category": "coding", "turns": ["Write a simple website in HTML. When a user clicks the button, it shows a random joke from a list of 4 jokes.", "How to use CSS to change the color of jokes to red?"]}
{"question_id": 124, "category": "coding", "turns": ["Here is a Python function to find the length of the longest common subsequence of two input strings. Can you identify any bug in this function?\n\n```\ndef longest_common_subsequence_length(str1, str2):\n m = len(str1)\n n = len(str2)\n\n dp = [[0] * (n + 1) for _ in range(m + 1)]\n\n for i in range(1, m + 1):\n for j in range(1, n + 1):\n if str1[i - 1] == str2[j - 1]:\n dp[i][j] = dp[i - 1][j - 1] + 1\n else:\n dp[i][j] = max(dp[i - 1][j], dp[i][j - 1])\n\n return dp[m][n]\n```", "what about this one?\n\n```\ndef longest_common_subsequence(X , Y): \n # Find lengths of two strings \n m = len(X) \n n = len(Y) \n \n # Create a table to store results of sub-problems \n dp = [[None]*(n+1) for i in range(m+1)] \n \n # Fill dp[][] in bottom up manner \n for i in range(1, m+1): \n for j in range(1, n+1): \n if X[i-1] == Y[j-1]: \n dp[i][j] = dp[i-1][j-1]+1\n else: \n dp[i][j] = max(dp[i-1][j], dp[i][j-1]) \n \n return dp[m][n]\n```"], "reference": ["There is no bug in this implementation", "There is a bug for the initialization of dp array. Should use 0 rather than None"]}
{"question_id": 125, "category": "coding", "turns": ["Write a function to find the highest common ancestor (not LCA) of two nodes in a binary tree.", "What if it is not a binary tree?"], "reference": ["Very simple. The function should just return the root of the tree.", "Same answer. It's still the root of the tree."]}
Expand Down
74 changes: 46 additions & 28 deletions examples/generate.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,22 @@
import os

os.environ["TOKENIZERS_PARALLELISM"] = "false"
from umbrella.models.auto_model import AutoModelLM
from umbrella.logging_config import setup_logger
from umbrella.utils import TextColors

logger = setup_logger()
import torch
from umbrella.templates import Prompts, SysPrompts
from transformers import AutoTokenizer
from umbrella.speculation.speculation_utils import make_causal_mask, is_sentence_complete_regex, find_first_element_position
from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM
from umbrella.speculation.speculation_utils import make_causal_mask, is_sentence_complete_regex, \
find_first_element_position
import argparse
import time

parser = argparse.ArgumentParser()
parser.add_argument('--model', type=str, default="meta-llama/Llama-3.1-8B-Instruct",help='model')
parser.add_argument('--template', type=str, default="meta-llama3",help='prompt template')
parser.add_argument('--model', type=str, default="meta-llama/Llama-3.1-8B-Instruct", help='model')
parser.add_argument('--template', type=str, default="meta-llama3", help='prompt template')
parser.add_argument('--G', type=int, default=512, help='generation length')
parser.add_argument('--offload', action='store_true', help="offload the model")
parser.add_argument('--cuda_graph', action='store_true', help="whether use cuda graph")
Expand All @@ -32,63 +36,77 @@
tokenizer = AutoTokenizer.from_pretrained(args.model)
tokens = tokenizer.encode(text=text, return_tensors="pt").to(DEVICE)

config = AutoConfig.from_pretrained(args.model)

# testing mistral sliding window
# config.sliding_window = True
# config.window_size = 100

llm = AutoModelLM.from_pretrained(
model_name=args.model,
offload=args.offload,
cuda_graph=args.cuda_graph,
batch_size=1,
max_length=MAX_LEN,
dtype=DTYPE,
device=DEVICE
device=DEVICE,
config=config
)

eos_tokens = llm.config.eos_token_id
if not isinstance(eos_tokens, list):
eos_tokens = [eos_tokens]
llm.alloc()

if args.cuda_graph:
llm.initialize_cuda_graph([1])
attention_mask = make_causal_mask((MAX_LEN, MAX_LEN), DEVICE)
storage_ids = torch.arange(MAX_LEN, device=DEVICE)
position_ids = torch.arange(MAX_LEN, device=DEVICE).unsqueeze(0)

prefix_len = tokens.shape[1]
logits = llm.graph_inference(input_ids=tokens, position_ids=position_ids[:,:prefix_len],
storage_ids=storage_ids[:prefix_len], attention_mask=attention_mask[:prefix_len])[0]
logits = llm.graph_inference(input_ids=tokens, position_ids=position_ids[:, :prefix_len],
storage_ids=storage_ids[:prefix_len], attention_mask=attention_mask[:prefix_len])[0]

torch.cuda.synchronize()
t1 = time.time()

generated_tokens = []
pos = 0
print('----start generating answers.----')
for i in range(GEN_LEN):
next_token = logits[-1:].argmax(dim=-1, keepdim=True)
generated_tokens.append(next_token.item())

generated_text = (
tokenizer.decode(
generated_tokens,
skip_special_tokens=True,
clean_up_tokenization_spaces=True,
spaces_between_special_tokens=False,
)
.strip()
.split(" ")
)


tokenizer.decode(
generated_tokens,
skip_special_tokens=True,
clean_up_tokenization_spaces=True,
spaces_between_special_tokens=False,
)
.strip()
.split(" ")
)

now = len(generated_text) - 1
if now > pos:
print(" ".join(generated_text[pos:now]), end=" ", flush=True)
pos = now

if (is_sentence_complete_regex(generated_text[-1]) and (i >= GEN_LEN - 32)) or (find_first_element_position(next_token, eos_tokens) >= 0):
break

logits = llm.graph_inference(input_ids=next_token, position_ids=position_ids[:,prefix_len+i:prefix_len+i+1],
storage_ids=storage_ids[prefix_len+i : prefix_len+i+1], attention_mask=attention_mask[prefix_len+i:prefix_len+i+1])[0]
print(" ".join(generated_text[pos:now]), end=" ", flush=True)
pos = now

if (is_sentence_complete_regex(generated_text[-1]) and (i >= GEN_LEN - 32)) or (
find_first_element_position(next_token, eos_tokens) >= 0):
break

logits = llm.graph_inference(input_ids=next_token, position_ids=position_ids[:, prefix_len + i:prefix_len + i + 1],
storage_ids=storage_ids[prefix_len + i: prefix_len + i + 1],
attention_mask=attention_mask[prefix_len + i:prefix_len + i + 1])[0]

print(" ".join(generated_text[pos:]), flush=True)
print('----end generating answers.----')
torch.cuda.synchronize()
t2 = time.time()

dec_len = len(generated_tokens)
logger.info(TextColors.colorize("Avg Accept Tokens {:.2f} | TPOT {:.2f} ms ".format(1, 1000 * (t2-t1)/dec_len), "magenta"))
logger.info(
TextColors.colorize("Avg Accept Tokens {:.2f} | TPOT {:.2f} ms ".format(1, 1000 * (t2 - t1) / dec_len), "magenta"))
Loading