@@ -122,10 +122,6 @@ static float estimate_loss(GPTLanguageModel &model,
122122}
123123
124124// Chat window
125- // Encodes a user prompt string into token indices using the
126- // DataLoader's vocabulary, then feeds them into model.generate().
127- // Only touches the public encode/decode/generate interface —
128- // zero changes to math or training logic.
129125static void run_chat (GPTLanguageModel &model,
130126 DataLoader &dl,
131127 int max_new_tokens)
@@ -139,19 +135,17 @@ static void run_chat(GPTLanguageModel &model,
139135
140136 while (!g_interrupted)
141137 {
142- // input
143138 std::cout << " \033 [1;32mYou>\033 [0m " ;
144139 std::cout.flush ();
145140
146141 std::string prompt;
147142 if (!std::getline (std::cin, prompt))
148- break ; // EOF (piped input ended)
143+ break ;
149144
150- // Trim leading/trailing whitespace
151145 size_t s = prompt.find_first_not_of (" \t\r\n " );
152146 size_t e = prompt.find_last_not_of (" \t\r\n " );
153147 if (s == std::string::npos)
154- continue ; // blank line — ask again
148+ continue ;
155149 prompt = prompt.substr (s, e - s + 1 );
156150
157151 if (prompt == " quit" || prompt == " exit" )
@@ -160,23 +154,15 @@ static void run_chat(GPTLanguageModel &model,
160154 break ;
161155 }
162156
163- // Encode prompt
164- // dl.encode() maps the raw string through the same
165- // char-level (or BPE) vocab built during data loading.
166157 std::vector<int > ctx = dl.encode (prompt);
167158 if (ctx.empty ())
168159 {
169- // If the vocab doesn't cover some characters,
170- // fall back to the BOS token so generation can
171- // still start rather than crashing.
172160 ctx = {0 };
173161 }
174162
175- // Clamp context to BLOCK_SIZE (model's max sequence length)
176163 if ((int )ctx.size () > BLOCK_SIZE )
177164 ctx = std::vector<int >(ctx.end () - BLOCK_SIZE , ctx.end ());
178165
179- // Stream model response token-by-token
180166 std::cout << " \033 [1;36mQuadtrix>\033 [0m " ;
181167 std::cout.flush ();
182168
@@ -185,7 +171,6 @@ static void run_chat(GPTLanguageModel &model,
185171 ctx = model.generate (ctx, 1 );
186172 std::cout << dl.decode ({ctx.back ()}) << std::flush;
187173
188- // Keep context within BLOCK_SIZE window
189174 if ((int )ctx.size () > BLOCK_SIZE )
190175 ctx = std::vector<int >(ctx.end () - BLOCK_SIZE , ctx.end ());
191176 }
@@ -215,17 +200,17 @@ int main(int argc, char *argv[])
215200 model_path = env_model_path;
216201
217202 bool gen_mode = false ;
218- bool chat_mode = false ; // ← NEW flag
219- int chat_tokens = 200 ; // default tokens per reply
203+ bool chat_mode = false ;
204+ int chat_tokens = 200 ;
220205
221206 for (int i = 1 ; i < argc; ++i)
222207 {
223208 std::string a = argv[i];
224209 if (a == " --generate" )
225210 gen_mode = true ;
226- else if (a == " --chat" ) // ← NEW
211+ else if (a == " --chat" )
227212 chat_mode = true ;
228- else if (a == " --chat-tokens" && i + 1 < argc) // ← NEW (optional)
213+ else if (a == " --chat-tokens" && i + 1 < argc)
229214 chat_tokens = std::atoi (argv[++i]);
230215 else
231216 data_path = a;
@@ -271,7 +256,7 @@ int main(int argc, char *argv[])
271256 << N_EMBD << " embedding dim\n " ;
272257
273258 // chat mode
274- if (chat_mode) // NEW block
259+ if (chat_mode)
275260 {
276261 if (!file_exists (model_path))
277262 {
@@ -336,48 +321,55 @@ int main(int argc, char *argv[])
336321
337322 float best_val_loss = 1e30f;
338323 double train_start = wall_secs ();
324+ double last_eval_time = train_start; // ← tracks time of previous eval
339325
340326 for (int iter = 0 ; iter <= MAX_ITERS && !g_interrupted; ++iter)
341327 {
342328
343329 // Periodic eval checkpoint
344330 if (iter % EVAL_INTERVAL == 0 || iter == MAX_ITERS )
345331 {
346- if (iter == 0 )
347- {
348- std::cout << " [INFO] Running initial loss estimate (" << EVAL_ITERS
349- << " train batches + " << EVAL_ITERS
350- << " val batches). This can take a while on CPU...\n " ;
351- }
352- else
353- {
354- std::cout << " [INFO] Evaluating checkpoint at iter " << iter
355- << " /" << MAX_ITERS << " ...\n " ;
356- }
357- std::cout.flush ();
332+ double now = wall_secs ();
333+ double elapsed = now - train_start;
334+
335+ // ms per training step since the last eval window
336+ double window_secs = now - last_eval_time;
337+ int steps_in_win = (iter == 0 ) ? 1 : EVAL_INTERVAL ;
338+ double ms_per_step = window_secs * 1000.0 / steps_in_win;
339+
340+ // tokens processed per second
341+ long toks_in_win = (long )BATCH_SIZE * BLOCK_SIZE * steps_in_win;
342+ int tok_per_sec = (window_secs > 0.0 )
343+ ? (int )(toks_in_win / window_secs)
344+ : 0 ;
345+
346+ last_eval_time = now; // reset window
358347
359348 float tl = estimate_loss (model, dl, " train" , rng);
360349 float vl = estimate_loss (model, dl, " val" , rng);
361350
362- double elapsed = wall_secs () - train_start;
363- double eta = (iter > 0 ) ? elapsed / iter * (MAX_ITERS - iter) : 0.0 ;
364- float pct = 100 .0f * iter / MAX_ITERS ;
365351 bool better = vl < best_val_loss;
366-
367352 if (better)
368353 {
369354 best_val_loss = vl;
370355 model.save (model_path);
371356 }
372357
373- std::cout << " [" << std::setw (5 ) << iter << " /" << MAX_ITERS << " ] "
374- << std::fixed << std::setprecision (1 ) << pct << " % "
375- << " train=" << std::setprecision (4 ) << tl
376- << " val=" << vl
377- << " elapsed=" << std::setprecision (0 ) << elapsed << " s"
378- << " ETA=" << eta << " s"
379- << (better ? " << best!" : " " )
380- << " \n " ;
358+ // ── new log line ─────────────────────────────────────────────
359+ std::cout
360+ << " step "
361+ << std::setw (5 ) << iter << " /" << MAX_ITERS
362+ << " | loss "
363+ << std::fixed << std::setprecision (6 ) << tl
364+ << " | val "
365+ << std::fixed << std::setprecision (6 ) << vl
366+ << " | lr "
367+ << std::scientific << std::setprecision (2 ) << (float )LEARNING_RATE
368+ << " | "
369+ << std::fixed << std::setprecision (2 ) << ms_per_step << " ms"
370+ << " | " << tok_per_sec << " tok/s"
371+ << (better ? " *best*" : " " )
372+ << " \n " ;
381373 std::cout.flush ();
382374
383375 if (iter == MAX_ITERS )
@@ -407,7 +399,7 @@ int main(int argc, char *argv[])
407399 << std::setprecision (4 ) << best_val_loss << " \n " ;
408400 std::cout << " [SAVE] Best weights saved to " << model_path << " \n " ;
409401
410- // Continuous generation (mirrors Python's while True loop)
402+ // Continuous generation
411403 std::cout << " \n "
412404 << std::string (60 , ' -' ) << " \n " ;
413405 std::cout << " MODEL OUTPUT (Ctrl+C to stop)\n " ;
0 commit comments