Skip to content

Commit acc1d56

Browse files
update benchmarking script
1 parent db5b69e commit acc1d56

1 file changed

Lines changed: 34 additions & 15 deletions

File tree

src/maxtext/scratch_code/benchmark_gdn_optimization.py

Lines changed: 34 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -302,17 +302,26 @@ def run_comparison():
302302
# -------------------------------------------------------------------------
303303
# Helper: Pure Functional wrappers to avoid Flax/JAX Version mismatch issues
304304
# -------------------------------------------------------------------------
305-
def create_jitted_train_step(model):
305+
def create_jitted_train_step(model, input_shape):
306306
graphdef, params = nnx.split(model)
307307

308+
# 1. Create a static, deterministic projection tensor
309+
# We generate this outside the JIT so it is a frozen constant in the graph
310+
proj_key = jax.random.PRNGKey(99)
311+
projection = jax.random.normal(proj_key, input_shape)
312+
308313
@jax.jit
309314
def pure_train_step(params, x):
310315
m = nnx.merge(graphdef, params)
311316
def loss_fn(m_inner):
312317
y = m_inner(x)
313-
return jnp.mean(y)
314-
loss, grads = nnx.value_and_grad(loss_fn)(m)
315-
return loss, grads
318+
# 2. Position-aware loss
319+
# Every element is scaled by a unique, random value before averaging
320+
loss = jnp.mean(y * projection.astype(y.dtype))
321+
return loss, y
322+
323+
(loss, y), grads = nnx.value_and_grad(loss_fn, has_aux=True)(m)
324+
return loss, y, grads
316325

317326
return pure_train_step, params
318327

@@ -331,19 +340,26 @@ def pure_forward(params, x):
331340
# ==============================================================================
332341
print("\n--- Checking Logical Correctness ---")
333342

334-
# Create safe functional wrappers
335-
jit_train_base, params_base = create_jitted_train_step(baseline_model)
336-
jit_train_opt, params_opt = create_jitted_train_step(optimized_model)
343+
# Pass the input shape so the random projection matches the output dimensions
344+
jit_train_base, params_base = create_jitted_train_step(baseline_model, inputs.shape)
345+
jit_train_opt, params_opt = create_jitted_train_step(optimized_model, inputs.shape)
337346

338-
loss_base, grads_base = jit_train_base(params_base, inputs)
339-
jax.block_until_ready((loss_base, grads_base))
347+
# Unpack loss, the raw output tensor, and gradients
348+
loss_base, out_base, grads_base = jit_train_base(params_base, inputs)
349+
jax.block_until_ready((loss_base, out_base, grads_base))
340350

341-
loss_opt, grads_opt = jit_train_opt(params_opt, inputs)
342-
jax.block_until_ready((loss_opt, grads_opt))
351+
loss_opt, out_opt, grads_opt = jit_train_opt(params_opt, inputs)
352+
jax.block_until_ready((loss_opt, out_opt, grads_opt))
343353

354+
# 1. Compare the Forward Pass Output Tensors (Element-by-Element)
355+
max_out_diff = float(jnp.max(jnp.abs(out_base - out_opt)))
356+
print(f"Forward Pass Max Output Diff: {max_out_diff:.2e}")
357+
358+
# 2. Compare the Loss
344359
diff_loss = jnp.abs(loss_base - loss_opt)
345-
print(f"Forward Pass Loss Diff: {float(diff_loss):.2e}")
360+
print(f"Loss Scalar Diff: {float(diff_loss):.2e}")
346361

362+
# 3. Compare the Gradients (Element-by-Element)
347363
flat_grads_base, _ = jax.tree_util.tree_flatten(grads_base)
348364
flat_grads_opt, _ = jax.tree_util.tree_flatten(grads_opt)
349365

@@ -353,10 +369,13 @@ def pure_forward(params, x):
353369
d = jnp.max(jnp.abs(g1 - g2))
354370
max_grad_diff = max(max_grad_diff, float(d))
355371

356-
print(f"Backward Pass Grad Diff: {max_grad_diff:.2e}")
372+
print(f"Backward Pass Max Grad Diff: {max_grad_diff:.2e}")
373+
374+
# Define a strict tolerance
375+
TOLERANCE = 1e-3 if DTYPE == jnp.bfloat16 else 1e-5
357376

358-
if max_grad_diff > 1e-2:
359-
print("WARNING: Significant divergence in gradients!")
377+
if max_out_diff > TOLERANCE or max_grad_diff > TOLERANCE:
378+
print("WARNING: Significant divergence detected!")
360379
else:
361380
print("✅ Outputs & Gradients match within tolerance.")
362381

0 commit comments

Comments
 (0)