@@ -302,17 +302,26 @@ def run_comparison():
302302 # -------------------------------------------------------------------------
303303 # Helper: Pure Functional wrappers to avoid Flax/JAX Version mismatch issues
304304 # -------------------------------------------------------------------------
305- def create_jitted_train_step (model ):
305+ def create_jitted_train_step (model , input_shape ):
306306 graphdef , params = nnx .split (model )
307307
308+ # 1. Create a static, deterministic projection tensor
309+ # We generate this outside the JIT so it is a frozen constant in the graph
310+ proj_key = jax .random .PRNGKey (99 )
311+ projection = jax .random .normal (proj_key , input_shape )
312+
308313 @jax .jit
309314 def pure_train_step (params , x ):
310315 m = nnx .merge (graphdef , params )
311316 def loss_fn (m_inner ):
312317 y = m_inner (x )
313- return jnp .mean (y )
314- loss , grads = nnx .value_and_grad (loss_fn )(m )
315- return loss , grads
318+ # 2. Position-aware loss
319+ # Every element is scaled by a unique, random value before averaging
320+ loss = jnp .mean (y * projection .astype (y .dtype ))
321+ return loss , y
322+
323+ (loss , y ), grads = nnx .value_and_grad (loss_fn , has_aux = True )(m )
324+ return loss , y , grads
316325
317326 return pure_train_step , params
318327
@@ -331,19 +340,26 @@ def pure_forward(params, x):
331340 # ==============================================================================
332341 print ("\n --- Checking Logical Correctness ---" )
333342
334- # Create safe functional wrappers
335- jit_train_base , params_base = create_jitted_train_step (baseline_model )
336- jit_train_opt , params_opt = create_jitted_train_step (optimized_model )
343+ # Pass the input shape so the random projection matches the output dimensions
344+ jit_train_base , params_base = create_jitted_train_step (baseline_model , inputs . shape )
345+ jit_train_opt , params_opt = create_jitted_train_step (optimized_model , inputs . shape )
337346
338- loss_base , grads_base = jit_train_base (params_base , inputs )
339- jax .block_until_ready ((loss_base , grads_base ))
347+ # Unpack loss, the raw output tensor, and gradients
348+ loss_base , out_base , grads_base = jit_train_base (params_base , inputs )
349+ jax .block_until_ready ((loss_base , out_base , grads_base ))
340350
341- loss_opt , grads_opt = jit_train_opt (params_opt , inputs )
342- jax .block_until_ready ((loss_opt , grads_opt ))
351+ loss_opt , out_opt , grads_opt = jit_train_opt (params_opt , inputs )
352+ jax .block_until_ready ((loss_opt , out_opt , grads_opt ))
343353
354+ # 1. Compare the Forward Pass Output Tensors (Element-by-Element)
355+ max_out_diff = float (jnp .max (jnp .abs (out_base - out_opt )))
356+ print (f"Forward Pass Max Output Diff: { max_out_diff :.2e} " )
357+
358+ # 2. Compare the Loss
344359 diff_loss = jnp .abs (loss_base - loss_opt )
345- print (f"Forward Pass Loss Diff: { float (diff_loss ):.2e} " )
360+ print (f"Loss Scalar Diff: { float (diff_loss ):.2e} " )
346361
362+ # 3. Compare the Gradients (Element-by-Element)
347363 flat_grads_base , _ = jax .tree_util .tree_flatten (grads_base )
348364 flat_grads_opt , _ = jax .tree_util .tree_flatten (grads_opt )
349365
@@ -353,10 +369,13 @@ def pure_forward(params, x):
353369 d = jnp .max (jnp .abs (g1 - g2 ))
354370 max_grad_diff = max (max_grad_diff , float (d ))
355371
356- print (f"Backward Pass Grad Diff: { max_grad_diff :.2e} " )
372+ print (f"Backward Pass Max Grad Diff: { max_grad_diff :.2e} " )
373+
374+ # Define a strict tolerance
375+ TOLERANCE = 1e-3 if DTYPE == jnp .bfloat16 else 1e-5
357376
358- if max_grad_diff > 1e-2 :
359- print ("WARNING: Significant divergence in gradients !" )
377+ if max_out_diff > TOLERANCE or max_grad_diff > TOLERANCE :
378+ print ("❌ WARNING: Significant divergence detected !" )
360379 else :
361380 print ("✅ Outputs & Gradients match within tolerance." )
362381
0 commit comments