Skip to content

Commit cbdbe40

Browse files
committed
feat(main): add DirectML backend support enables execution on integrated GPUs (iGPUs)
1 parent 2f3dd15 commit cbdbe40

1 file changed

Lines changed: 298 additions & 0 deletions

File tree

engine/iGPU/inference.py

Lines changed: 298 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,298 @@
1+
import argparse
2+
from pathlib import Path
3+
import time
4+
5+
import torch
6+
import torch.nn as nn
7+
from torch.nn import functional as F
8+
import tiktoken
9+
10+
try:
11+
import torch_directml
12+
except ImportError:
13+
torch_directml = None
14+
15+
16+
W = 78
17+
DOUBLE = "=" * W
18+
SINGLE = "-" * W
19+
ARROW = "->"
20+
21+
block_size = 32
22+
n_embd = 64
23+
n_head = 4
24+
n_layer = 4
25+
dropout = 0.1
26+
27+
28+
def header(title, subtitle=""):
29+
print(f"\n{DOUBLE}")
30+
print(f" {title}")
31+
if subtitle:
32+
print(f" {subtitle}")
33+
print(DOUBLE)
34+
35+
36+
def row(label, value="", unit="", note=""):
37+
label_col = f" {label:<28}"
38+
value_col = f"{str(value):<20}"
39+
unit_col = f"{unit:<8}"
40+
note_col = f" {note}" if note else ""
41+
print(f"{label_col}{value_col}{unit_col}{note_col}")
42+
43+
44+
def rule():
45+
print(f" {SINGLE}")
46+
47+
48+
def blank():
49+
print()
50+
51+
52+
def get_device():
53+
if torch_directml is not None:
54+
return torch_directml.device()
55+
if torch.cuda.is_available():
56+
return torch.device("cuda")
57+
return torch.device("cpu")
58+
59+
60+
def get_tokenizer(encoding_name="gpt2"):
61+
tokenizer = tiktoken.get_encoding(encoding_name)
62+
return tokenizer, tokenizer.n_vocab
63+
64+
65+
def encode(text, tokenizer):
66+
return tokenizer.encode(text)
67+
68+
69+
def decode(tokens, tokenizer):
70+
return tokenizer.decode(tokens)
71+
72+
73+
tokenizer, vocab_size = get_tokenizer("gpt2")
74+
device = get_device()
75+
76+
77+
class Head(nn.Module):
78+
def __init__(self, head_size):
79+
super().__init__()
80+
self.key = nn.Linear(n_embd, head_size, bias=False)
81+
self.query = nn.Linear(n_embd, head_size, bias=False)
82+
self.value = nn.Linear(n_embd, head_size, bias=False)
83+
self.register_buffer("tril", torch.tril(torch.ones(block_size, block_size)))
84+
self.dropout = nn.Dropout(dropout)
85+
86+
def forward(self, x):
87+
_, T, _ = x.shape
88+
k = self.key(x)
89+
q = self.query(x)
90+
wei = q @ k.transpose(-2, -1) * k.shape[-1] ** -0.5
91+
wei = wei.masked_fill(self.tril[:T, :T] == 0, float("-inf"))
92+
wei = F.softmax(wei, dim=-1)
93+
wei = self.dropout(wei)
94+
return wei @ self.value(x)
95+
96+
97+
class MultiHeadAttention(nn.Module):
98+
def __init__(self, num_heads, head_size):
99+
super().__init__()
100+
self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
101+
self.proj = nn.Linear(head_size * num_heads, n_embd)
102+
self.dropout = nn.Dropout(dropout)
103+
104+
def forward(self, x):
105+
out = torch.cat([h(x) for h in self.heads], dim=-1)
106+
return self.dropout(self.proj(out))
107+
108+
109+
class FeedForward(nn.Module):
110+
def __init__(self, n_embd):
111+
super().__init__()
112+
self.net = nn.Sequential(
113+
nn.Linear(n_embd, 4 * n_embd),
114+
nn.ReLU(),
115+
nn.Linear(4 * n_embd, n_embd),
116+
nn.Dropout(dropout),
117+
)
118+
119+
def forward(self, x):
120+
return self.net(x)
121+
122+
123+
class Block(nn.Module):
124+
def __init__(self, n_embd, n_head):
125+
super().__init__()
126+
head_size = n_embd // n_head
127+
self.sa = MultiHeadAttention(n_head, head_size)
128+
self.ffwd = FeedForward(n_embd)
129+
self.ln1 = nn.LayerNorm(n_embd)
130+
self.ln2 = nn.LayerNorm(n_embd)
131+
132+
def forward(self, x):
133+
x = x + self.sa(self.ln1(x))
134+
x = x + self.ffwd(self.ln2(x))
135+
return x
136+
137+
138+
class GPTLanguageModel(nn.Module):
139+
def __init__(self):
140+
super().__init__()
141+
self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
142+
self.position_embedding_table = nn.Embedding(block_size, n_embd)
143+
self.blocks = nn.Sequential(*[Block(n_embd, n_head=n_head) for _ in range(n_layer)])
144+
self.ln_f = nn.LayerNorm(n_embd)
145+
self.lm_head = nn.Linear(n_embd, vocab_size)
146+
147+
def forward(self, idx, targets=None):
148+
B, T = idx.shape
149+
tok_emb = self.token_embedding_table(idx)
150+
pos_emb = self.position_embedding_table(torch.arange(T, device=device))
151+
x = tok_emb + pos_emb
152+
x = self.blocks(x)
153+
x = self.ln_f(x)
154+
logits = self.lm_head(x)
155+
156+
if targets is None:
157+
loss = None
158+
else:
159+
B, T, C = logits.shape
160+
logits = logits.view(B * T, C)
161+
targets = targets.view(B * T)
162+
loss = F.cross_entropy(logits, targets)
163+
return logits, loss
164+
165+
def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
166+
for _ in range(max_new_tokens):
167+
idx_cond = idx[:, -block_size:]
168+
logits, _ = self(idx_cond)
169+
logits = logits[:, -1, :] / max(temperature, 1e-6)
170+
171+
if top_k is not None:
172+
values, _ = torch.topk(logits, min(top_k, logits.size(-1)))
173+
logits[logits < values[:, [-1]]] = float("-inf")
174+
175+
probs = F.softmax(logits, dim=-1)
176+
idx_next = torch.multinomial(probs, num_samples=1)
177+
idx = torch.cat((idx, idx_next), dim=1)
178+
return idx
179+
180+
181+
def default_checkpoint_path():
182+
script_dir = Path(__file__).resolve().parent
183+
candidates = [
184+
script_dir / "best_model.pt",
185+
Path.cwd() / "best_model.pt",
186+
Path.cwd() / "iGPU" / "best_model.pt",
187+
]
188+
for candidate in candidates:
189+
if candidate.exists():
190+
return candidate
191+
return script_dir / "best_model.pt"
192+
193+
194+
def load_model(checkpoint_path):
195+
checkpoint_path = Path(checkpoint_path)
196+
if not checkpoint_path.exists():
197+
raise FileNotFoundError(
198+
f"Checkpoint not found: {checkpoint_path}\n"
199+
"Train first with iGPU/main.py, or pass --checkpoint path/to/best_model.pt"
200+
)
201+
202+
model = GPTLanguageModel().to(device)
203+
state_dict = torch.load(checkpoint_path, map_location=device)
204+
model.load_state_dict(state_dict)
205+
model.eval()
206+
return model
207+
208+
209+
def generate_response(model, prompt, max_new_tokens, temperature, top_k):
210+
encoded_prompt = encode(prompt, tokenizer)
211+
context = torch.tensor([encoded_prompt], dtype=torch.long, device=device)
212+
213+
with torch.no_grad():
214+
output_ids = model.generate(
215+
context,
216+
max_new_tokens=max_new_tokens,
217+
temperature=temperature,
218+
top_k=top_k,
219+
)
220+
221+
new_tokens = output_ids[0][len(encoded_prompt):].tolist()
222+
return decode(new_tokens, tokenizer).strip()
223+
224+
225+
def chat(model, args):
226+
header("INFERENCE", "quit / exit / q -> end session")
227+
blank()
228+
229+
while True:
230+
prompt = input(f" user {ARROW} ").strip()
231+
if prompt.lower() in ("quit", "exit", "q"):
232+
blank()
233+
print(" Session ended.")
234+
break
235+
if not prompt:
236+
continue
237+
238+
response = generate_response(
239+
model,
240+
prompt,
241+
args.max_new_tokens,
242+
args.temperature,
243+
args.top_k,
244+
)
245+
blank()
246+
print(f" Model {ARROW} {response}")
247+
blank()
248+
249+
250+
def parse_args():
251+
parser = argparse.ArgumentParser(description="Run inference from an iGPU trained .pt checkpoint.")
252+
parser.add_argument(
253+
"--checkpoint",
254+
type=Path,
255+
default=default_checkpoint_path(),
256+
help="Path to the .pt file generated by iGPU/main.py.",
257+
)
258+
parser.add_argument("--prompt", type=str, default=None, help="Generate once from this prompt.")
259+
parser.add_argument("--max-new-tokens", type=int, default=200)
260+
parser.add_argument("--temperature", type=float, default=1.0)
261+
parser.add_argument("--top-k", type=int, default=None)
262+
return parser.parse_args()
263+
264+
265+
def main():
266+
args = parse_args()
267+
start = time.time()
268+
269+
print(f"{'Quadtrix-v1.0':^{W}}")
270+
blank()
271+
row("Started", time.strftime("%Y-%m-%d %H:%M:%S"))
272+
row("Device", str(device))
273+
row("PyTorch", torch.__version__)
274+
row("Checkpoint", args.checkpoint)
275+
rule()
276+
277+
model = load_model(args.checkpoint)
278+
279+
if args.prompt:
280+
response = generate_response(
281+
model,
282+
args.prompt,
283+
args.max_new_tokens,
284+
args.temperature,
285+
args.top_k,
286+
)
287+
blank()
288+
print(response)
289+
else:
290+
chat(model, args)
291+
292+
blank()
293+
row("Total", f"{time.time() - start:.2f}s")
294+
print(DOUBLE)
295+
296+
297+
if __name__ == "__main__":
298+
main()

0 commit comments

Comments
 (0)