Skip to content

Commit 8149064

Browse files
committed
refactor(main): redesign training loop to log per-step and sample during evaluation
- Replaced the periodic block evaluation layout with standard, per-step logging metrics (`loss`, `ms`, and `tok/s`). - Shifted initial validation loss calculation out of the iteration cycle to establish a zero-state baseline. - Restructured token streaming so that generations are triggered conditionally inside the training loop post-evaluation windows. - Streamlined architecture parameter reporting and consolidated command-line configuration visual prints.
1 parent 5e05bec commit 8149064

1 file changed

Lines changed: 77 additions & 116 deletions

File tree

main.cpp

Lines changed: 77 additions & 116 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,22 @@ static std::string choose_output_path(const std::string &requested_path,
103103
return exe_relative;
104104
}
105105

106+
// sample N tokens from the model and print them
107+
static void sample_tokens(GPTLanguageModel &model,
108+
DataLoader &dl,
109+
int n_tokens)
110+
{
111+
std::vector<int> ctx = {0};
112+
for (int i = 0; i < n_tokens; ++i)
113+
{
114+
ctx = model.generate(ctx, 1);
115+
std::cout << dl.decode({ctx.back()}) << std::flush;
116+
if ((int)ctx.size() > BLOCK_SIZE)
117+
ctx = std::vector<int>(ctx.end() - BLOCK_SIZE, ctx.end());
118+
}
119+
std::cout << "\n";
120+
}
121+
106122
// estimate loss — no gradients, training=false
107123
static float estimate_loss(GPTLanguageModel &model,
108124
DataLoader &dl,
@@ -184,10 +200,7 @@ int main(int argc, char *argv[])
184200
std::signal(SIGINT, sig_handler);
185201

186202
// Banner
187-
std::cout << std::string(60, '=') << "\n";
188203
std::cout << " Quadtrix v1.0 (C++)\n";
189-
std::cout << std::string(60, '=') << "\n";
190-
std::cout << "\n[INFO] Starting at: " << now_str() << "\n";
191204

192205
std::string data_path = DEFAULT_CLEANED_PATH;
193206
const char *env_data_path = std::getenv(DATA_PATH_ENV_VAR.c_str());
@@ -219,17 +232,6 @@ int main(int argc, char *argv[])
219232
data_path = choose_existing_path(data_path, argv[0]);
220233
model_path = choose_output_path(model_path, argv[0]);
221234

222-
// Config print
223-
std::cout << "\n[CONFIG] Hyperparameters:\n";
224-
std::cout << " batch_size=" << BATCH_SIZE
225-
<< " block_size=" << BLOCK_SIZE << "\n";
226-
std::cout << " max_iters=" << MAX_ITERS
227-
<< " learning_rate=" << LEARNING_RATE << "\n";
228-
std::cout << " n_embd=" << N_EMBD
229-
<< " n_head=" << N_HEAD
230-
<< " n_layer=" << N_LAYER
231-
<< " dropout=" << DROPOUT << "\n";
232-
233235
// Data
234236
DataLoader dl;
235237
try
@@ -247,13 +249,12 @@ int main(int argc, char *argv[])
247249
GPTLanguageModel model(dl.vocab_size, N_EMBD, N_HEAD, N_LAYER, BLOCK_SIZE, SEED);
248250

249251
long n_params = model.num_params();
250-
std::cout << "[MODEL] Parameters : "
251-
<< std::fixed << std::setprecision(2)
252-
<< n_params / 1.0e6f << " M (" << n_params << " total)\n";
253-
std::cout << "[MODEL] Architecture: "
254-
<< N_LAYER << " layers x "
255-
<< N_HEAD << " heads x "
256-
<< N_EMBD << " embedding dim\n";
252+
std::cout << "max_seq_len: " << BLOCK_SIZE << "\n";
253+
std::cout << "vocab_size: " << dl.vocab_size << "\n";
254+
std::cout << "num_layers: " << N_LAYER << "\n";
255+
std::cout << "num_heads: " << N_HEAD << "\n";
256+
std::cout << "channels: " << N_EMBD << "\n";
257+
std::cout << "num_parameters: " << n_params << "\n";
257258

258259
// chat mode
259260
if (chat_mode)
@@ -268,9 +269,8 @@ int main(int argc, char *argv[])
268269
}
269270

270271
model.load(model_path);
271-
std::cout << "[CHAT] Weights loaded from " << model_path << "\n";
272-
std::cout << "[CHAT] Max tokens per reply: " << chat_tokens
273-
<< " (override with --chat-tokens N)\n";
272+
std::cout << "weights: " << model_path << "\n";
273+
std::cout << "max_tokens: " << chat_tokens << "\n";
274274

275275
run_chat(model, dl, chat_tokens);
276276
return 0;
@@ -289,10 +289,7 @@ int main(int argc, char *argv[])
289289
}
290290

291291
model.load(model_path);
292-
std::cout << "\n"
293-
<< std::string(60, '-') << "\n";
294-
std::cout << " Quadtrix OUTPUT (Ctrl+C to stop)\n";
295-
std::cout << std::string(60, '-') << "\n\n";
292+
std::cout << "\ngenerating:\n";
296293
std::vector<int> ctx = {0};
297294
while (!g_interrupted)
298295
{
@@ -301,7 +298,7 @@ int main(int argc, char *argv[])
301298
if ((int)ctx.size() > BLOCK_SIZE)
302299
ctx = std::vector<int>(ctx.end() - BLOCK_SIZE, ctx.end());
303300
}
304-
std::cout << "\n\n[Stopped by user]\n";
301+
std::cout << "\n";
305302
return 0;
306303
}
307304

@@ -312,114 +309,78 @@ int main(int argc, char *argv[])
312309
std::mt19937 rng(SEED);
313310

314311
// training loop
315-
std::cout << "\n"
316-
<< std::string(60, '-') << "\n";
317-
std::cout << " TRAINING ("
318-
<< MAX_ITERS << " iters, eval every "
319-
<< EVAL_INTERVAL << ")\n";
320-
std::cout << std::string(60, '-') << "\n";
321312

322313
float best_val_loss = 1e30f;
314+
float last_val_loss = 0.0f;
323315
double train_start = wall_secs();
324-
double last_eval_time = train_start; // ← tracks time of previous eval
325316

326-
for (int iter = 0; iter <= MAX_ITERS && !g_interrupted; ++iter)
317+
// compute initial val loss before training
327318
{
319+
std::mt19937 init_rng(SEED);
320+
last_val_loss = estimate_loss(model, dl, "val", init_rng);
321+
}
328322

329-
// Periodic eval checkpoint
330-
if (iter % EVAL_INTERVAL == 0 || iter == MAX_ITERS)
331-
{
332-
double now = wall_secs();
333-
double elapsed = now - train_start;
334-
335-
// ms per training step since the last eval window
336-
double window_secs = now - last_eval_time;
337-
int steps_in_win = (iter == 0) ? 1 : EVAL_INTERVAL;
338-
double ms_per_step = window_secs * 1000.0 / steps_in_win;
339-
340-
// tokens processed per second
341-
long toks_in_win = (long)BATCH_SIZE * BLOCK_SIZE * steps_in_win;
342-
int tok_per_sec = (window_secs > 0.0)
343-
? (int)(toks_in_win / window_secs)
344-
: 0;
345-
346-
last_eval_time = now; // reset window
347-
348-
float tl = estimate_loss(model, dl, "train", rng);
349-
float vl = estimate_loss(model, dl, "val", rng);
350-
351-
bool better = vl < best_val_loss;
352-
if (better)
353-
{
354-
best_val_loss = vl;
355-
model.save(model_path);
356-
}
357-
358-
// ── new log line ─────────────────────────────────────────────
359-
std::cout
360-
<< "step "
361-
<< std::setw(5) << iter << "/" << MAX_ITERS
362-
<< " | loss "
363-
<< std::fixed << std::setprecision(6) << tl
364-
<< " | val "
365-
<< std::fixed << std::setprecision(6) << vl
366-
<< " | lr "
367-
<< std::scientific << std::setprecision(2) << (float)LEARNING_RATE
368-
<< " | "
369-
<< std::fixed << std::setprecision(2) << ms_per_step << " ms"
370-
<< " | " << tok_per_sec << " tok/s"
371-
<< (better ? " *best*" : "")
372-
<< "\n";
373-
std::cout.flush();
374-
375-
if (iter == MAX_ITERS)
376-
break;
377-
}
323+
for (int iter = 1; iter <= MAX_ITERS && !g_interrupted; ++iter)
324+
{
325+
double step_start = wall_secs();
378326

379-
// Sample training batch
327+
// train step
380328
std::pair<std::vector<int>, std::vector<int>> batch =
381329
dl.get_batch("train", BATCH_SIZE, BLOCK_SIZE, rng);
382330

383-
// Forward — saves all intermediate activations
384331
SavedForward saved = forward_save(model,
385332
batch.first, BATCH_SIZE, BLOCK_SIZE,
386333
batch.second, /*training=*/true);
387334

388-
// Backward — exact analytical gradients
389-
Grads grads = backward(model, saved);
335+
float batch_loss = model.forward(batch.first, BATCH_SIZE, BLOCK_SIZE,
336+
batch.second, false)
337+
.second;
390338

391-
// AdamW parameter update
339+
Grads grads = backward(model, saved);
392340
apply_grads(model, grads, opt);
393-
}
394341

395-
double total = wall_secs() - train_start;
396-
std::cout << "\n[DONE] Training finished in "
397-
<< std::fixed << std::setprecision(1) << total << "s ("
398-
<< total / 60.0 << " min) | Best val loss: "
399-
<< std::setprecision(4) << best_val_loss << "\n";
400-
std::cout << "[SAVE] Best weights saved to " << model_path << "\n";
342+
double step_ms = (wall_secs() - step_start) * 1000.0;
343+
int tok_per_sec = (step_ms > 0.0)
344+
? (int)((long)BATCH_SIZE * BLOCK_SIZE / (step_ms / 1000.0))
345+
: 0;
401346

402-
// Continuous generation
403-
std::cout << "\n"
404-
<< std::string(60, '-') << "\n";
405-
std::cout << " MODEL OUTPUT (Ctrl+C to stop)\n";
406-
std::cout << std::string(60, '-') << "\n\n";
347+
// every EVAL_INTERVAL steps: compute val, save if best, sample
348+
bool better = false;
349+
if (iter % EVAL_INTERVAL == 0 || iter == MAX_ITERS)
350+
{
351+
last_val_loss = estimate_loss(model, dl, "val", rng);
352+
if (last_val_loss < best_val_loss)
353+
{
354+
best_val_loss = last_val_loss;
355+
model.save(model_path);
356+
better = true;
357+
}
358+
}
407359

408-
model.load(model_path);
409-
model.rng = std::mt19937(SEED + 42);
360+
// print every step
361+
std::cout
362+
<< "step"
363+
<< std::setw(5) << iter << "/" << MAX_ITERS
364+
<< " | loss "
365+
<< std::fixed << std::setprecision(6) << batch_loss
366+
<< " | val "
367+
<< std::fixed << std::setprecision(6) << last_val_loss
368+
<< " | lr "
369+
<< std::scientific << std::setprecision(2) << (float)LEARNING_RATE
370+
<< " | "
371+
<< std::fixed << std::setprecision(2) << step_ms << " ms"
372+
<< " | " << tok_per_sec << " tok/s"
373+
<< (better ? " *best*" : "")
374+
<< "\n";
375+
std::cout.flush();
410376

411-
std::vector<int> ctx = {0};
412-
while (!g_interrupted)
413-
{
414-
ctx = model.generate(ctx, 1);
415-
std::cout << dl.decode({ctx.back()}) << std::flush;
416-
if ((int)ctx.size() > BLOCK_SIZE)
417-
ctx = std::vector<int>(ctx.end() - BLOCK_SIZE, ctx.end());
377+
// sample after every eval window
378+
if (iter % EVAL_INTERVAL == 0 || iter == MAX_ITERS)
379+
{
380+
std::cout << "generating:\n";
381+
sample_tokens(model, dl, iter == MAX_ITERS ? 10000 : 150);
382+
}
418383
}
419384

420-
std::cout << "\n\n[Stopped by user]\n";
421-
std::cout << "[TOTAL] Wall-clock: "
422-
<< std::fixed << std::setprecision(1)
423-
<< (wall_secs() - train_start) << "s\n";
424385
return 0;
425386
}

0 commit comments

Comments
 (0)