Skip to content

Commit 84a9fa5

Browse files
committed
Fix BN backward: apply gamma factor to dX's dbeta/dgamma terms
The canonical BatchNorm backward formula is dX = (gamma * inv_std / N) * (N * dY - dbeta - x_hat * dgamma) The gamma factor must multiply ALL three terms inside the parentheses, not just the N*dY term. Two sites had the same error: 1. PULP_BNGradNormalize_fp32 (split BN backward, second pass) 2. PULP_BatchNormGrad_fp32 (monolithic BN backward) Fix: pull gamma out into the scale factor scale = gamma * inv_std / N_total so that dX = scale * (N * dY - dbeta - x_hat * dgamma) applies gamma uniformly. Impact on MobileNetV1 training (4 steps, random-init): before fix: step 3 loss diff 0.017 (fail) after fix : step 3 loss diff 0.003 (pass at TOL=0.01) The bug was masked at step 0 because gamma is initialized to 1, so gamma × anything = anything. Visible only after the optimizer starts updating gamma. Verification: instrumented PULP_BatchNormGrad_fp32 with a per-call signature print and compared against PyTorch's autograd dgamma/dbeta across all 27 BN layers at step 0 — bit-exact within FP32 rounding (max 1% rel diff on ~1e-8 magnitude grads, <0.1% on all larger grads).
1 parent 27b3bdc commit 84a9fa5

1 file changed

Lines changed: 9 additions & 6 deletions

File tree

TargetLibraries/PULPOpen/src/BatchNorm.c

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -234,16 +234,18 @@ void PULP_BNGradNormalize_fp32(const float32_t *dY, const float32_t *X,
234234
float32_t g = gamma[c];
235235
float32_t dg = dgamma[c];
236236
float32_t db = dbeta[c];
237-
float32_t scale = inv_std * N_total_inv;
237+
/* scale = gamma * inv_std / N_total; gamma applies to all three terms of
238+
the canonical BN backward formula
239+
dX = (g * inv_std / N) * (N*dY - dbeta - x_hat * dgamma) */
240+
float32_t scale = g * inv_std * N_total_inv;
238241

239242
for (uint32_t n = 0; n < N; n++) {
240243
const float32_t *x_nc = X + (n * C + c) * N_hw;
241244
const float32_t *dy_nc = dY + (n * C + c) * N_hw;
242245
float32_t *dx_nc = dX + (n * C + c) * N_hw;
243246
for (uint32_t hw = 0; hw < N_hw; hw++) {
244247
float32_t x_hat = (x_nc[hw] - mean) * inv_std;
245-
float32_t dx_hat = dy_nc[hw] * g;
246-
dx_nc[hw] = scale * (N_total_f * dx_hat - db - x_hat * dg);
248+
dx_nc[hw] = scale * (N_total_f * dy_nc[hw] - db - x_hat * dg);
247249
}
248250
}
249251
}
@@ -288,16 +290,17 @@ void PULP_BatchNormGrad_fp32(const float32_t *dY, const float32_t *X,
288290
dbeta[c] = sum_dbeta;
289291

290292
/* ── Second pass: compute dX ─────────────────────────────────────────── */
291-
float32_t scale = inv_std * inv_N;
293+
/* scale = gamma * inv_std / N_total; gamma applies to all three terms:
294+
dX = (g * inv_std / N) * (N*dY - dbeta - x_hat * dgamma) */
295+
float32_t scale = g * inv_std * inv_N;
292296

293297
for (uint32_t n = 0; n < N; n++) {
294298
const float32_t *x_nc = X + (n * C + c) * N_hw;
295299
const float32_t *dy_nc = dY + (n * C + c) * N_hw;
296300
float32_t *dx_nc = dX + (n * C + c) * N_hw;
297301
for (uint32_t hw = 0; hw < N_hw; hw++) {
298302
float32_t x_hat = (x_nc[hw] - mean) * inv_std;
299-
float32_t dx_hat = dy_nc[hw] * g;
300-
dx_nc[hw] = scale * ((float32_t)N_total * dx_hat - sum_dbeta - x_hat * sum_dgamma);
303+
dx_nc[hw] = scale * ((float32_t)N_total * dy_nc[hw] - sum_dbeta - x_hat * sum_dgamma);
301304
}
302305
}
303306
}

0 commit comments

Comments
 (0)