Skip to content

Commit 65b36e7

Browse files
committed
refactor: implement dedicated chat interface for .pt model weights
Refactored model loading logic to utilize exported .pt files. Improved weight initialization and device mapping (CPU/CUDA).
1 parent 1dfbbd5 commit 65b36e7

1 file changed

Lines changed: 286 additions & 0 deletions

File tree

engine/inference.py

Lines changed: 286 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,286 @@
1+
import argparse
2+
from pathlib import Path
3+
import time
4+
5+
import torch
6+
import torch.nn as nn
7+
from torch.nn import functional as F
8+
import tiktoken
9+
10+
11+
W = 78
12+
DOUBLE = "=" * W
13+
SINGLE = "-" * W
14+
ARROW = "->"
15+
16+
block_size = 32
17+
n_embd = 64
18+
n_head = 4
19+
n_layer = 4
20+
dropout = 0.1
21+
22+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
23+
24+
25+
def header(title, subtitle=""):
26+
print(f"\n{DOUBLE}")
27+
print(f" {title}")
28+
if subtitle:
29+
print(f" {subtitle}")
30+
print(DOUBLE)
31+
32+
33+
def row(label, value="", unit="", note=""):
34+
label_col = f" {label:<28}"
35+
value_col = f"{str(value):<20}"
36+
unit_col = f"{unit:<8}"
37+
note_col = f" {note}" if note else ""
38+
print(f"{label_col}{value_col}{unit_col}{note_col}")
39+
40+
41+
def rule():
42+
print(f" {SINGLE}")
43+
44+
45+
def blank():
46+
print()
47+
48+
49+
def get_tokenizer(encoding_name="gpt2"):
50+
tokenizer = tiktoken.get_encoding(encoding_name)
51+
return tokenizer, tokenizer.n_vocab
52+
53+
54+
def encode(text, tokenizer):
55+
return tokenizer.encode(text)
56+
57+
58+
def decode(tokens, tokenizer):
59+
return tokenizer.decode(tokens)
60+
61+
62+
tokenizer, vocab_size = get_tokenizer("gpt2")
63+
64+
65+
class Head(nn.Module):
66+
def __init__(self, head_size):
67+
super().__init__()
68+
self.key = nn.Linear(n_embd, head_size, bias=False)
69+
self.query = nn.Linear(n_embd, head_size, bias=False)
70+
self.value = nn.Linear(n_embd, head_size, bias=False)
71+
self.register_buffer("tril", torch.tril(torch.ones(block_size, block_size)))
72+
self.dropout = nn.Dropout(dropout)
73+
74+
def forward(self, x):
75+
_, T, _ = x.shape
76+
k = self.key(x)
77+
q = self.query(x)
78+
wei = q @ k.transpose(-2, -1) * k.shape[-1] ** -0.5
79+
wei = wei.masked_fill(self.tril[:T, :T] == 0, float("-inf"))
80+
wei = F.softmax(wei, dim=-1)
81+
wei = self.dropout(wei)
82+
return wei @ self.value(x)
83+
84+
85+
class MultiHeadAttention(nn.Module):
86+
def __init__(self, num_heads, head_size):
87+
super().__init__()
88+
self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
89+
self.proj = nn.Linear(head_size * num_heads, n_embd)
90+
self.dropout = nn.Dropout(dropout)
91+
92+
def forward(self, x):
93+
out = torch.cat([h(x) for h in self.heads], dim=-1)
94+
return self.dropout(self.proj(out))
95+
96+
97+
class FeedForward(nn.Module):
98+
def __init__(self, n_embd):
99+
super().__init__()
100+
self.net = nn.Sequential(
101+
nn.Linear(n_embd, 4 * n_embd),
102+
nn.ReLU(),
103+
nn.Linear(4 * n_embd, n_embd),
104+
nn.Dropout(dropout),
105+
)
106+
107+
def forward(self, x):
108+
return self.net(x)
109+
110+
111+
class Block(nn.Module):
112+
def __init__(self, n_embd, n_head):
113+
super().__init__()
114+
head_size = n_embd // n_head
115+
self.sa = MultiHeadAttention(n_head, head_size)
116+
self.ffwd = FeedForward(n_embd)
117+
self.ln1 = nn.LayerNorm(n_embd)
118+
self.ln2 = nn.LayerNorm(n_embd)
119+
120+
def forward(self, x):
121+
x = x + self.sa(self.ln1(x))
122+
x = x + self.ffwd(self.ln2(x))
123+
return x
124+
125+
126+
class GPTLanguageModel(nn.Module):
127+
def __init__(self):
128+
super().__init__()
129+
self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
130+
self.position_embedding_table = nn.Embedding(block_size, n_embd)
131+
self.blocks = nn.Sequential(*[Block(n_embd, n_head=n_head) for _ in range(n_layer)])
132+
self.ln_f = nn.LayerNorm(n_embd)
133+
self.lm_head = nn.Linear(n_embd, vocab_size)
134+
135+
def forward(self, idx, targets=None):
136+
B, T = idx.shape
137+
tok_emb = self.token_embedding_table(idx)
138+
pos_emb = self.position_embedding_table(torch.arange(T, device=device))
139+
x = tok_emb + pos_emb
140+
x = self.blocks(x)
141+
x = self.ln_f(x)
142+
logits = self.lm_head(x)
143+
144+
if targets is None:
145+
loss = None
146+
else:
147+
B, T, C = logits.shape
148+
logits = logits.view(B * T, C)
149+
targets = targets.view(B * T)
150+
loss = F.cross_entropy(logits, targets)
151+
return logits, loss
152+
153+
def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
154+
for _ in range(max_new_tokens):
155+
idx_cond = idx[:, -block_size:]
156+
logits, _ = self(idx_cond)
157+
logits = logits[:, -1, :] / max(temperature, 1e-6)
158+
159+
if top_k is not None:
160+
values, _ = torch.topk(logits, min(top_k, logits.size(-1)))
161+
logits[logits < values[:, [-1]]] = float("-inf")
162+
163+
probs = F.softmax(logits, dim=-1)
164+
idx_next = torch.multinomial(probs, num_samples=1)
165+
idx = torch.cat((idx, idx_next), dim=1)
166+
return idx
167+
168+
169+
def default_checkpoint_path():
170+
script_dir = Path(__file__).resolve().parent
171+
candidates = [
172+
script_dir / "best_model.pt",
173+
Path.cwd() / "best_model.pt",
174+
Path.cwd() / "engine" / "best_model.pt",
175+
]
176+
for candidate in candidates:
177+
if candidate.exists():
178+
return candidate
179+
return script_dir / "best_model.pt"
180+
181+
182+
def load_model(checkpoint_path):
183+
checkpoint_path = Path(checkpoint_path)
184+
if not checkpoint_path.exists():
185+
raise FileNotFoundError(
186+
f"Checkpoint not found: {checkpoint_path}\n"
187+
"Train first with engine/main.py, or pass --checkpoint path/to/best_model.pt"
188+
)
189+
190+
model = GPTLanguageModel().to(device)
191+
state_dict = torch.load(checkpoint_path, map_location=device)
192+
model.load_state_dict(state_dict)
193+
model.eval()
194+
return model
195+
196+
197+
def generate_response(model, prompt, max_new_tokens, temperature, top_k):
198+
encoded_prompt = encode(prompt, tokenizer)
199+
context = torch.tensor([encoded_prompt], dtype=torch.long, device=device)
200+
201+
with torch.no_grad():
202+
output_ids = model.generate(
203+
context,
204+
max_new_tokens=max_new_tokens,
205+
temperature=temperature,
206+
top_k=top_k,
207+
)
208+
209+
new_tokens = output_ids[0][len(encoded_prompt):].tolist()
210+
return decode(new_tokens, tokenizer).strip()
211+
212+
213+
def chat(model, args):
214+
header("INFERENCE", "quit / exit / q -> end session")
215+
blank()
216+
217+
while True:
218+
prompt = input(f" user {ARROW} ").strip()
219+
if prompt.lower() in ("quit", "exit", "q"):
220+
blank()
221+
print(" Session ended.")
222+
break
223+
if not prompt:
224+
continue
225+
226+
response = generate_response(
227+
model,
228+
prompt,
229+
args.max_new_tokens,
230+
args.temperature,
231+
args.top_k,
232+
)
233+
blank()
234+
print(f" Model {ARROW} {response}")
235+
blank()
236+
237+
238+
def parse_args():
239+
parser = argparse.ArgumentParser(description="Run inference from an engine trained .pt checkpoint.")
240+
parser.add_argument(
241+
"--checkpoint",
242+
type=Path,
243+
default=default_checkpoint_path(),
244+
help="Path to the .pt file generated by engine/main.py.",
245+
)
246+
parser.add_argument("--prompt", type=str, default=None, help="Generate once from this prompt.")
247+
parser.add_argument("--max-new-tokens", type=int, default=200)
248+
parser.add_argument("--temperature", type=float, default=1.0)
249+
parser.add_argument("--top-k", type=int, default=None)
250+
return parser.parse_args()
251+
252+
253+
def main():
254+
args = parse_args()
255+
start = time.time()
256+
257+
print(f"{'Quadtrix-v1.0':^{W}}")
258+
blank()
259+
row("Started", time.strftime("%Y-%m-%d %H:%M:%S"))
260+
row("Device", str(device))
261+
row("PyTorch", torch.__version__)
262+
row("Checkpoint", args.checkpoint)
263+
rule()
264+
265+
model = load_model(args.checkpoint)
266+
267+
if args.prompt:
268+
response = generate_response(
269+
model,
270+
args.prompt,
271+
args.max_new_tokens,
272+
args.temperature,
273+
args.top_k,
274+
)
275+
blank()
276+
print(response)
277+
else:
278+
chat(model, args)
279+
280+
blank()
281+
row("Total", f"{time.time() - start:.2f}s")
282+
print(DOUBLE)
283+
284+
285+
if __name__ == "__main__":
286+
main()

0 commit comments

Comments
 (0)