Skip to content

Commit cd675d2

Browse files
committed
refactor(main.cpp): clean up code and add progress tracking removed redundant code and improved readability.
1 parent 9ba7a91 commit cd675d2

1 file changed

Lines changed: 39 additions & 47 deletions

File tree

main.cpp

Lines changed: 39 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -122,10 +122,6 @@ static float estimate_loss(GPTLanguageModel &model,
122122
}
123123

124124
// Chat window
125-
// Encodes a user prompt string into token indices using the
126-
// DataLoader's vocabulary, then feeds them into model.generate().
127-
// Only touches the public encode/decode/generate interface —
128-
// zero changes to math or training logic.
129125
static void run_chat(GPTLanguageModel &model,
130126
DataLoader &dl,
131127
int max_new_tokens)
@@ -139,19 +135,17 @@ static void run_chat(GPTLanguageModel &model,
139135

140136
while (!g_interrupted)
141137
{
142-
// input
143138
std::cout << "\033[1;32mYou>\033[0m ";
144139
std::cout.flush();
145140

146141
std::string prompt;
147142
if (!std::getline(std::cin, prompt))
148-
break; // EOF (piped input ended)
143+
break;
149144

150-
// Trim leading/trailing whitespace
151145
size_t s = prompt.find_first_not_of(" \t\r\n");
152146
size_t e = prompt.find_last_not_of(" \t\r\n");
153147
if (s == std::string::npos)
154-
continue; // blank line — ask again
148+
continue;
155149
prompt = prompt.substr(s, e - s + 1);
156150

157151
if (prompt == "quit" || prompt == "exit")
@@ -160,23 +154,15 @@ static void run_chat(GPTLanguageModel &model,
160154
break;
161155
}
162156

163-
// Encode prompt
164-
// dl.encode() maps the raw string through the same
165-
// char-level (or BPE) vocab built during data loading.
166157
std::vector<int> ctx = dl.encode(prompt);
167158
if (ctx.empty())
168159
{
169-
// If the vocab doesn't cover some characters,
170-
// fall back to the BOS token so generation can
171-
// still start rather than crashing.
172160
ctx = {0};
173161
}
174162

175-
// Clamp context to BLOCK_SIZE (model's max sequence length)
176163
if ((int)ctx.size() > BLOCK_SIZE)
177164
ctx = std::vector<int>(ctx.end() - BLOCK_SIZE, ctx.end());
178165

179-
// Stream model response token-by-token
180166
std::cout << "\033[1;36mQuadtrix>\033[0m ";
181167
std::cout.flush();
182168

@@ -185,7 +171,6 @@ static void run_chat(GPTLanguageModel &model,
185171
ctx = model.generate(ctx, 1);
186172
std::cout << dl.decode({ctx.back()}) << std::flush;
187173

188-
// Keep context within BLOCK_SIZE window
189174
if ((int)ctx.size() > BLOCK_SIZE)
190175
ctx = std::vector<int>(ctx.end() - BLOCK_SIZE, ctx.end());
191176
}
@@ -215,17 +200,17 @@ int main(int argc, char *argv[])
215200
model_path = env_model_path;
216201

217202
bool gen_mode = false;
218-
bool chat_mode = false; // ← NEW flag
219-
int chat_tokens = 200; // default tokens per reply
203+
bool chat_mode = false;
204+
int chat_tokens = 200;
220205

221206
for (int i = 1; i < argc; ++i)
222207
{
223208
std::string a = argv[i];
224209
if (a == "--generate")
225210
gen_mode = true;
226-
else if (a == "--chat") // ← NEW
211+
else if (a == "--chat")
227212
chat_mode = true;
228-
else if (a == "--chat-tokens" && i + 1 < argc) // ← NEW (optional)
213+
else if (a == "--chat-tokens" && i + 1 < argc)
229214
chat_tokens = std::atoi(argv[++i]);
230215
else
231216
data_path = a;
@@ -271,7 +256,7 @@ int main(int argc, char *argv[])
271256
<< N_EMBD << " embedding dim\n";
272257

273258
// chat mode
274-
if (chat_mode) // NEW block
259+
if (chat_mode)
275260
{
276261
if (!file_exists(model_path))
277262
{
@@ -336,48 +321,55 @@ int main(int argc, char *argv[])
336321

337322
float best_val_loss = 1e30f;
338323
double train_start = wall_secs();
324+
double last_eval_time = train_start; // ← tracks time of previous eval
339325

340326
for (int iter = 0; iter <= MAX_ITERS && !g_interrupted; ++iter)
341327
{
342328

343329
// Periodic eval checkpoint
344330
if (iter % EVAL_INTERVAL == 0 || iter == MAX_ITERS)
345331
{
346-
if (iter == 0)
347-
{
348-
std::cout << "[INFO] Running initial loss estimate (" << EVAL_ITERS
349-
<< " train batches + " << EVAL_ITERS
350-
<< " val batches). This can take a while on CPU...\n";
351-
}
352-
else
353-
{
354-
std::cout << "[INFO] Evaluating checkpoint at iter " << iter
355-
<< "/" << MAX_ITERS << "...\n";
356-
}
357-
std::cout.flush();
332+
double now = wall_secs();
333+
double elapsed = now - train_start;
334+
335+
// ms per training step since the last eval window
336+
double window_secs = now - last_eval_time;
337+
int steps_in_win = (iter == 0) ? 1 : EVAL_INTERVAL;
338+
double ms_per_step = window_secs * 1000.0 / steps_in_win;
339+
340+
// tokens processed per second
341+
long toks_in_win = (long)BATCH_SIZE * BLOCK_SIZE * steps_in_win;
342+
int tok_per_sec = (window_secs > 0.0)
343+
? (int)(toks_in_win / window_secs)
344+
: 0;
345+
346+
last_eval_time = now; // reset window
358347

359348
float tl = estimate_loss(model, dl, "train", rng);
360349
float vl = estimate_loss(model, dl, "val", rng);
361350

362-
double elapsed = wall_secs() - train_start;
363-
double eta = (iter > 0) ? elapsed / iter * (MAX_ITERS - iter) : 0.0;
364-
float pct = 100.0f * iter / MAX_ITERS;
365351
bool better = vl < best_val_loss;
366-
367352
if (better)
368353
{
369354
best_val_loss = vl;
370355
model.save(model_path);
371356
}
372357

373-
std::cout << "[" << std::setw(5) << iter << "/" << MAX_ITERS << "] "
374-
<< std::fixed << std::setprecision(1) << pct << "% "
375-
<< "train=" << std::setprecision(4) << tl
376-
<< " val=" << vl
377-
<< " elapsed=" << std::setprecision(0) << elapsed << "s"
378-
<< " ETA=" << eta << "s"
379-
<< (better ? " << best!" : "")
380-
<< "\n";
358+
// ── new log line ─────────────────────────────────────────────
359+
std::cout
360+
<< "step "
361+
<< std::setw(5) << iter << "/" << MAX_ITERS
362+
<< " | loss "
363+
<< std::fixed << std::setprecision(6) << tl
364+
<< " | val "
365+
<< std::fixed << std::setprecision(6) << vl
366+
<< " | lr "
367+
<< std::scientific << std::setprecision(2) << (float)LEARNING_RATE
368+
<< " | "
369+
<< std::fixed << std::setprecision(2) << ms_per_step << " ms"
370+
<< " | " << tok_per_sec << " tok/s"
371+
<< (better ? " *best*" : "")
372+
<< "\n";
381373
std::cout.flush();
382374

383375
if (iter == MAX_ITERS)
@@ -407,7 +399,7 @@ int main(int argc, char *argv[])
407399
<< std::setprecision(4) << best_val_loss << "\n";
408400
std::cout << "[SAVE] Best weights saved to " << model_path << "\n";
409401

410-
// Continuous generation (mirrors Python's while True loop)
402+
// Continuous generation
411403
std::cout << "\n"
412404
<< std::string(60, '-') << "\n";
413405
std::cout << " MODEL OUTPUT (Ctrl+C to stop)\n";

0 commit comments

Comments
 (0)