Skip to content

Commit ee8ecec

Browse files
Ziminlivoltjia
authored andcommitted
Add inference profiling support
1 parent d18e73f commit ee8ecec

1 file changed

Lines changed: 38 additions & 1 deletion

File tree

infer.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import argparse
2+
import time
23

4+
import torch
35
from transformers import AutoModelForCausalLM, AutoTokenizer
46

57
from fused_rms_norm import RMSNorm
@@ -38,13 +40,30 @@
3840
default="cpu",
3941
help='Device to use for inference (e.g., "cuda", "cpu").',
4042
)
43+
parser.add_argument(
44+
"--num-warmup-iterations",
45+
type=int,
46+
default=0,
47+
help="For profiling. The number of warmup iterations to run before measuring performance.",
48+
)
49+
parser.add_argument(
50+
"--num-profiling-iterations",
51+
type=int,
52+
default=1,
53+
help="For profiling. The number of iterations to run for performance measurement.",
54+
)
4155

4256
args = parser.parse_args()
4357

4458
model_name_or_path = args.model
4559
prompts = args.prompts
4660
max_new_tokens = args.max_new_tokens
4761
device = args.device
62+
num_warmup_iterations = args.num_warmup_iterations
63+
num_profiling_iterations = args.num_profiling_iterations
64+
65+
assert num_profiling_iterations >= 1
66+
assert num_warmup_iterations >= 0
4867

4968
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
5069
model = AutoModelForCausalLM.from_pretrained(model_name_or_path).to(device)
@@ -58,7 +77,25 @@
5877
replace_module(model, SiLU)
5978

6079
inputs = tokenizer(prompts, padding=True, return_tensors="pt").to(device)
61-
outputs = model.generate(**inputs, max_new_tokens=max_new_tokens)
80+
81+
for _ in range(num_warmup_iterations):
82+
model.generate(**inputs, max_new_tokens=max_new_tokens)
83+
84+
if device == "cuda":
85+
torch.cuda.synchronize()
86+
87+
start_time = time.time()
88+
89+
for _ in range(num_profiling_iterations):
90+
outputs = model.generate(**inputs, max_new_tokens=max_new_tokens)
91+
92+
if device == "cuda":
93+
torch.cuda.synchronize()
94+
95+
end_time = time.time()
96+
avg_time_ms = (end_time - start_time) * 1000 / num_profiling_iterations
97+
6298
strings = tokenizer.batch_decode(outputs, skip_special_tokens=True)
6399

64100
print(strings)
101+
print(f"\nAverage inference time: {avg_time_ms:.4f} ms.")

0 commit comments

Comments
 (0)