Skip to content

Commit 403e8d7

Browse files
committed
feat(main): add DirectML backend support enables execution on integrated iGPUs
1 parent 1e20c5a commit 403e8d7

1 file changed

Lines changed: 360 additions & 0 deletions

File tree

engine/iGPU/main.py

Lines changed: 360 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,360 @@
1+
import torch
2+
import torch_directml
3+
import torch.nn as nn
4+
from torch.nn import functional as F
5+
import time
6+
import sys
7+
from pathlib import Path
8+
import tiktoken
9+
10+
# LOGGING UTILITIES
11+
W = 78
12+
DOUBLE = "=" * W
13+
SINGLE = "-" * W
14+
TICK = "best"
15+
ARROW = ">"
16+
17+
LOG_DIR = Path(__file__).resolve().parent / "logs"
18+
LOG_DIR.mkdir(parents=True, exist_ok=True)
19+
LOG_PATH = LOG_DIR / f"run_{time.strftime('%Y%m%d_%H%M%S')}.txt"
20+
21+
def log(message=""):
22+
line = "" if message == "" else f"{time.strftime('%Y-%m-%d %H:%M:%S')} | {message}"
23+
print(line)
24+
with open(LOG_PATH, "a", encoding="utf-8") as f:
25+
f.write(f"{line}\n")
26+
27+
def header(title, subtitle=""):
28+
log()
29+
log(DOUBLE)
30+
log(f" {title}")
31+
if subtitle:
32+
log(f" {subtitle}")
33+
log(DOUBLE)
34+
35+
def row(label, value="", unit="", note=""):
36+
label_col = f" {label:<28}"
37+
value_col = f"{str(value):<20}"
38+
unit_col = f"{unit:<8}"
39+
note_col = f" {note}" if note else ""
40+
log(f"{label_col}{value_col}{unit_col}{note_col}")
41+
42+
def rule(): log(f" {SINGLE}")
43+
def blank(): log()
44+
def info(msg): log(f" {ARROW} {msg}")
45+
def success(msg): log(f" ok {msg}")
46+
47+
48+
# SESSION
49+
50+
51+
52+
log(f"{'Quadtrix-v1.0':^{W}}")
53+
blank()
54+
row("Started", time.strftime('%Y-%m-%d %H:%M:%S'))
55+
row("Device", str(torch_directml.device()))
56+
row("PyTorch", torch.__version__)
57+
row("Log file", str(LOG_PATH))
58+
59+
start = time.time()
60+
61+
# CONFIGURATION
62+
63+
64+
cleaned_path = "engine\data\cleaned.txt"
65+
train_split = 0.9
66+
seed = 1337
67+
68+
batch_size = 16
69+
block_size = 32
70+
max_iters = 3000
71+
eval_interval = 100
72+
learning_rate = 1e-3
73+
device = torch_directml.device()
74+
eval_iters = 20
75+
n_embd = 64
76+
n_head = 4
77+
n_layer = 4
78+
dropout = 0.1
79+
80+
torch.manual_seed(seed)
81+
82+
83+
# tokenizer
84+
85+
def get_tokenizer(encoding_name="gpt2"):
86+
tokenizer = tiktoken.get_encoding(encoding_name)
87+
vocab_size = tokenizer.n_vocab
88+
return tokenizer, vocab_size
89+
90+
def encode(text, tokenizer): return tokenizer.encode(text)
91+
def decode(tokens, tokenizer): return tokenizer.decode(tokens)
92+
93+
94+
95+
# DATA
96+
97+
with open(cleaned_path, 'r', encoding='utf-8') as f:
98+
text = f.read()
99+
100+
tokenizer, vocab_size = get_tokenizer("gpt2")
101+
encoded_data = encode(text, tokenizer)
102+
103+
data = torch.tensor(encoded_data, dtype=torch.long)
104+
n = int(train_split * len(data))
105+
train_data = data[:n]
106+
val_data = data[n:]
107+
108+
# Batch and LOSS
109+
110+
def get_batch(split):
111+
data_split = train_data if split == 'train' else val_data
112+
ix = torch.randint(len(data_split) - block_size, (batch_size,))
113+
x = torch.stack([data_split[i:i + block_size] for i in ix])
114+
y = torch.stack([data_split[i + 1:i + block_size + 1] for i in ix])
115+
x, y = x.to(device), y.to(device)
116+
return x, y
117+
118+
@torch.no_grad()
119+
def estimate_loss():
120+
out = {}
121+
model.eval()
122+
for split in ['train', 'val']:
123+
losses = torch.zeros(eval_iters)
124+
for k in range(eval_iters):
125+
X, Y = get_batch(split)
126+
_, loss = model(X, Y)
127+
losses[k] = loss.item()
128+
out[split] = losses.mean()
129+
model.train()
130+
return out
131+
132+
# model
133+
134+
class Head(nn.Module):
135+
def __init__(self, head_size):
136+
super().__init__()
137+
self.key = nn.Linear(n_embd, head_size, bias=False)
138+
self.query = nn.Linear(n_embd, head_size, bias=False)
139+
self.value = nn.Linear(n_embd, head_size, bias=False)
140+
self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))
141+
self.dropout = nn.Dropout(dropout)
142+
143+
def forward(self, x):
144+
B, T, C = x.shape
145+
k = self.key(x)
146+
q = self.query(x)
147+
wei = q @ k.transpose(-2, -1) * k.shape[-1] ** -0.5
148+
wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
149+
wei = F.softmax(wei, dim=-1)
150+
wei = self.dropout(wei)
151+
return wei @ self.value(x)
152+
153+
class MultiHeadAttention(nn.Module):
154+
def __init__(self, num_heads, head_size):
155+
super().__init__()
156+
self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
157+
self.proj = nn.Linear(head_size * num_heads, n_embd)
158+
self.dropout = nn.Dropout(dropout)
159+
160+
def forward(self, x):
161+
out = torch.cat([h(x) for h in self.heads], dim=-1)
162+
return self.dropout(self.proj(out))
163+
164+
class FeedFoward(nn.Module):
165+
def __init__(self, n_embd):
166+
super().__init__()
167+
self.net = nn.Sequential(
168+
nn.Linear(n_embd, 4 * n_embd),
169+
nn.ReLU(),
170+
nn.Linear(4 * n_embd, n_embd),
171+
nn.Dropout(dropout),
172+
)
173+
174+
def forward(self, x):
175+
return self.net(x)
176+
177+
class Block(nn.Module):
178+
def __init__(self, n_embd, n_head):
179+
super().__init__()
180+
head_size = n_embd // n_head
181+
self.sa = MultiHeadAttention(n_head, head_size)
182+
self.ffwd = FeedFoward(n_embd)
183+
self.ln1 = nn.LayerNorm(n_embd)
184+
self.ln2 = nn.LayerNorm(n_embd)
185+
186+
def forward(self, x):
187+
x = x + self.sa(self.ln1(x))
188+
x = x + self.ffwd(self.ln2(x))
189+
return x
190+
191+
class GPTLanguageModel(nn.Module):
192+
def __init__(self):
193+
super().__init__()
194+
self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
195+
self.position_embedding_table = nn.Embedding(block_size, n_embd)
196+
self.blocks = nn.Sequential(*[Block(n_embd, n_head=n_head) for _ in range(n_layer)])
197+
self.ln_f = nn.LayerNorm(n_embd)
198+
self.lm_head = nn.Linear(n_embd, vocab_size)
199+
self.apply(self._init_weights)
200+
201+
def _init_weights(self, module):
202+
if isinstance(module, nn.Linear):
203+
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
204+
if module.bias is not None:
205+
torch.nn.init.zeros_(module.bias)
206+
elif isinstance(module, nn.Embedding):
207+
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
208+
209+
def forward(self, idx, targets=None):
210+
B, T = idx.shape
211+
tok_emb = self.token_embedding_table(idx)
212+
pos_emb = self.position_embedding_table(torch.arange(T, device=device))
213+
x = tok_emb + pos_emb
214+
x = self.blocks(x)
215+
x = self.ln_f(x)
216+
logits = self.lm_head(x)
217+
218+
if targets is None:
219+
loss = None
220+
else:
221+
B, T, C = logits.shape
222+
logits = logits.view(B * T, C)
223+
targets = targets.view(B * T)
224+
loss = F.cross_entropy(logits, targets)
225+
return logits, loss
226+
227+
def generate(self, idx, max_new_tokens):
228+
for _ in range(max_new_tokens):
229+
idx_cond = idx[:, -block_size:]
230+
logits, _ = self(idx_cond)
231+
logits = logits[:, -1, :]
232+
probs = F.softmax(logits, dim=-1)
233+
idx_next = torch.multinomial(probs, num_samples=1)
234+
idx = torch.cat((idx, idx_next), dim=1)
235+
return idx
236+
237+
238+
239+
# INITIALISE
240+
241+
model = GPTLanguageModel().to(device)
242+
n_params = sum(p.numel() for p in model.parameters())
243+
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
244+
245+
header("CONFIG")
246+
row("Seed", seed)
247+
row("Batch size", batch_size)
248+
row("Block size", block_size)
249+
row("Learning rate", learning_rate)
250+
row("Layers", n_layer)
251+
row("Heads", n_head)
252+
row("Embedding dim", n_embd)
253+
row("Dropout", dropout)
254+
row("Parameters", f"{n_params:,}")
255+
row("Train tokens", f"{len(train_data):,}")
256+
row("Val tokens", f"{len(val_data):,}")
257+
258+
259+
#training
260+
header("TRAINING", f"{max_iters:,} steps | eval every {eval_interval} | checkpoint on improvement")
261+
blank()
262+
263+
log(" training loop started")
264+
265+
best_val_loss = float('inf')
266+
train_start = time.time()
267+
268+
for iter in range(max_iters):
269+
270+
if iter % eval_interval == 0 or iter == max_iters - 1:
271+
losses = estimate_loss()
272+
elapsed = time.time() - train_start
273+
pct = 100 * iter / max_iters
274+
eta_secs = (elapsed / (iter + 1)) * (max_iters - iter - 1) if iter > 0 else 0
275+
is_best = losses['val'] < best_val_loss
276+
status = f"{TICK} saved" if is_best else "-"
277+
elapsed_fmt = f"{int(elapsed // 60)}m {int(elapsed % 60):02d}s"
278+
eta_fmt = f"{int(eta_secs // 60)}m {int(eta_secs % 60):02d}s"
279+
280+
if is_best:
281+
best_val_loss = losses['val']
282+
torch.save(model.state_dict(), 'best_model.pt')
283+
log(f" ckpt path=best_model.pt val_loss={best_val_loss:.4f} step={iter}")
284+
285+
log(
286+
f" train step={iter}/{max_iters} pct={pct:.1f}% "
287+
f"loss_train={losses['train']:.4f} loss_val={losses['val']:.4f} "
288+
f"elapsed={elapsed_fmt} eta={eta_fmt} status={status}"
289+
)
290+
sys.stdout.flush()
291+
292+
xb, yb = get_batch('train')
293+
logits, loss = model(xb, yb)
294+
optimizer.zero_grad(set_to_none=True)
295+
loss.backward()
296+
optimizer.step()
297+
298+
total_time = time.time() - train_start
299+
blank()
300+
rule()
301+
row("Duration", f"{int(total_time // 60)}m {int(total_time % 60):02d}s")
302+
row("Best val loss", f"{best_val_loss:.4f}", "", TICK)
303+
row("Checkpoint", "best_model.pt", "", TICK)
304+
rule()
305+
306+
307+
308+
# RESTORE CHECKPOIN
309+
blank()
310+
model.load_state_dict(torch.load('best_model.pt', map_location=device))
311+
model.eval()
312+
success(f"Restored best_model.pt | val loss {best_val_loss:.4f}")
313+
314+
# INFERENCE
315+
316+
317+
header("INFERENCE", "quit / exit / q -> end session")
318+
blank()
319+
320+
try:
321+
while True:
322+
prompt = input(f" user {ARROW} ").strip()
323+
log(f" user {ARROW} {prompt}")
324+
325+
if prompt.lower() in ("quit", "exit", "q"):
326+
blank()
327+
success("Session ended.")
328+
break
329+
330+
if not prompt:
331+
continue
332+
333+
encoded_prompt = encode(prompt, tokenizer)
334+
context = torch.tensor([encoded_prompt], dtype=torch.long, device=device)
335+
336+
with torch.no_grad():
337+
output_ids = model.generate(context, max_new_tokens=200)
338+
339+
new_tokens = output_ids[0][len(encoded_prompt):].tolist()
340+
response = decode(new_tokens, tokenizer).strip()
341+
342+
blank()
343+
log(f" Model {ARROW} {response}")
344+
blank()
345+
346+
except KeyboardInterrupt:
347+
blank()
348+
success("Interrupted.")
349+
350+
351+
end = time.time()
352+
wall_clock = end - start
353+
354+
blank()
355+
rule()
356+
row("Training", f"{int(total_time // 60)}m {int(total_time % 60):02d}s")
357+
row("Total", f"{int(wall_clock // 60)}m {int(wall_clock % 60):02d}s", "", TICK)
358+
rule()
359+
blank()
360+
log(f"{DOUBLE}\n")

0 commit comments

Comments
 (0)