Skip to content

Commit f931000

Browse files
committed
Add inference profiling support
1 parent 3d8910c commit f931000

1 file changed

Lines changed: 35 additions & 1 deletion

File tree

infer.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
import argparse
2+
import time
3+
import torch
24

35
from transformers import AutoModelForCausalLM, AutoTokenizer
46

@@ -38,13 +40,30 @@
3840
default="cpu",
3941
help='Device to use for inference (e.g., "cuda", "cpu").',
4042
)
43+
parser.add_argument(
44+
"--num-warmups",
45+
type=int,
46+
default=0,
47+
help="For profiling. The number of warmup iterations to run before measuring performance.",
48+
)
49+
parser.add_argument(
50+
"--num-iterations",
51+
type=int,
52+
default=1,
53+
help="For profiling. The number of iterations to run for performance measurement.",
54+
)
4155

4256
args = parser.parse_args()
4357

4458
model_name_or_path = args.model
4559
prompts = args.prompts
4660
max_new_tokens = args.max_new_tokens
4761
device = args.device
62+
num_warmups = args.num_warmups
63+
num_iterations = args.num_iterations
64+
65+
if (num_iterations < 1) or (num_warmups < 0):
66+
raise ValueError("num_iterations must be >= 1 and num_warmups must be >= 0")
4867

4968
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
5069
model = AutoModelForCausalLM.from_pretrained(model_name_or_path).to(device)
@@ -58,7 +77,22 @@
5877
replace_module(model, SiLU)
5978

6079
inputs = tokenizer(prompts, padding=True, return_tensors="pt").to(device)
61-
outputs = model.generate(**inputs, max_new_tokens=max_new_tokens)
80+
81+
for _ in range(num_warmups):
82+
_ = model.generate(**inputs, max_new_tokens=max_new_tokens)
83+
if device == "cuda":
84+
torch.cuda.synchronize()
85+
86+
start_time = time.time()
87+
88+
for _ in range(num_iterations):
89+
outputs = model.generate(**inputs, max_new_tokens=max_new_tokens)
90+
if device == "cuda":
91+
torch.cuda.synchronize()
92+
93+
end_time = time.time()
94+
avg_time_ms = (end_time - start_time) * 1000 / num_iterations
6295
strings = tokenizer.batch_decode(outputs, skip_special_tokens=True)
6396

6497
print(strings)
98+
print(f"\nTotal Inference time: {avg_time_ms:.4f} ms")

0 commit comments

Comments
 (0)