Skip to content

Commit a0f54ba

Browse files
authored
fix: To add clamp for torch.log(...) to cope with TF32 (#577)
This PR tries to fix [[B] 6250652](https://nvbugspro.nvidia.com/bug/6250652) . Thank you for looking at this PR. --------- Signed-off-by: Kaiqi Yan <kaiqiy@nvidia.com>
1 parent 1808cef commit a0f54ba

2 files changed

Lines changed: 70 additions & 30 deletions

File tree

libs/qec/python/cudaq_qec/plugins/decoders/tensor_network_utils/nm_optimizer.py

Lines changed: 41 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -46,13 +46,13 @@
4646
def _validate_and_clamp_priors(noise_model: Any, dtype: str) -> list[float]:
4747
"""Validate noise priors and clamp them into ``[eps, 1 - eps]``.
4848
49-
The fused cross-entropy reduction in
50-
:meth:`NMOptimizer.cross_entropy_loss` has no ``log`` guard, so a
51-
prior of exactly ``0.0`` or ``1.0`` makes the contraction emit a
52-
zero whose log is ``-inf`` and whose gradient is ``NaN``; training
53-
silently diverges. Stim DEMs occasionally emit ``p=1.0``
54-
(deterministic detectors) or ``p<1e-15`` (underflow), so we
55-
intercept here rather than force every caller to clamp.
49+
The cross-entropy reduction floors log inputs so roundoff-induced
50+
zero or negative values do not create non-finite losses. Priors at
51+
exactly ``0.0`` or ``1.0`` are still clamped because they can
52+
saturate loss terms and make gradients uninformative. Stim DEMs
53+
occasionally emit ``p=1.0`` (deterministic detectors) or ``p<1e-15``
54+
(underflow), so we intercept here rather than force every caller to
55+
clamp.
5656
5757
Behaviour mirrors :class:`torch.nn.BCELoss`-style stable wrappers:
5858
@@ -92,15 +92,20 @@ def _validate_and_clamp_priors(noise_model: Any, dtype: str) -> list[float]:
9292
warnings.warn(
9393
f"Clamped {int(out_of_range.sum())}/{len(arr)} NMOptimizer "
9494
f"priors into [{eps}, {1.0 - eps}] for numerical stability; "
95-
f"values at or outside the (0, 1) boundary produce -inf "
96-
f"cross-entropy loss and NaN gradients in the fused codegen.",
95+
f"values at or outside the (0, 1) boundary can saturate "
96+
f"cross-entropy terms and make gradients uninformative.",
9797
UserWarning,
9898
stacklevel=3,
9999
)
100100
arr = np.clip(arr, eps, 1.0 - eps)
101101
return arr.tolist()
102102

103103

104+
def _clamp_log_input(x: torch.Tensor) -> torch.Tensor:
105+
"""Floor log inputs after roundoff-induced non-positive values."""
106+
return x.clamp_min(torch.finfo(x.dtype).tiny)
107+
108+
104109
def remap_eq_to_ascii(eq: str) -> str:
105110
"""Rewrite an einsum equation so every label is in ``[a-zA-Z]``.
106111
@@ -158,8 +163,9 @@ class NMOptimizer(TensorNetworkDecoder):
158163
159164
Priors are clamped into ``[eps, 1 - eps]`` only at construction;
160165
an unconstrained optimiser step on :attr:`noise_params` can push
161-
them past the boundary, after which :meth:`cross_entropy_loss`
162-
returns ``NaN`` gradients. Prefer logit-space training via
166+
them outside the probability interval. The loss is floored for
167+
finiteness, but probability-space training can then saturate or
168+
optimise invalid probabilities. Prefer logit-space training via
163169
:func:`make_compiled_step` (shown below), or clamp the tensor
164170
under :func:`torch.no_grad` after each step.
165171
@@ -370,9 +376,10 @@ def noise_params(self) -> list[torch.Tensor]:
370376
"""Trainable noise probabilities, ready for ``torch.optim``.
371377
372378
Clamped to ``[eps, 1 - eps]`` only at construction; an
373-
unconstrained step can push past the boundary and produce
374-
``NaN`` gradients on the next :meth:`cross_entropy_loss`.
375-
See the class warning for safe training patterns.
379+
unconstrained step can push outside the probability interval.
380+
The next :meth:`cross_entropy_loss` remains finite, but training
381+
can saturate or optimise invalid probabilities. See the class
382+
warning for safe training patterns.
376383
"""
377384
return [self._noise_probs]
378385

@@ -665,24 +672,24 @@ def _build_loss_wrapped(self):
665672

666673
def _loss_from_probs(noise_probs, syndromes):
667674
p = predict_fn(noise_probs, syndromes)
668-
return (-torch.log(p[obs_t, 1]).sum() -
669-
torch.log(p[obs_f, 0]).sum())
675+
return (-torch.log(_clamp_log_input(p[obs_t, 1])).sum() -
676+
torch.log(_clamp_log_input(p[obs_f, 0])).sum())
670677

671678
def _loss_from_logits(logits, syndromes):
672679
p = predict_fn(torch.sigmoid(logits), syndromes)
673-
return (-torch.log(p[obs_t, 1]).sum() -
674-
torch.log(p[obs_f, 0]).sum())
680+
return (-torch.log(_clamp_log_input(p[obs_t, 1])).sum() -
681+
torch.log(_clamp_log_input(p[obs_f, 0])).sum())
675682
else:
676683

677684
def _loss_from_probs(noise_probs, syndromes=()):
678685
p = predict_fn(noise_probs, ())
679-
return (-torch.log(p[obs_t, 1]).sum() -
680-
torch.log(p[obs_f, 0]).sum())
686+
return (-torch.log(_clamp_log_input(p[obs_t, 1])).sum() -
687+
torch.log(_clamp_log_input(p[obs_f, 0])).sum())
681688

682689
def _loss_from_logits(logits, syndromes=()):
683690
p = predict_fn(torch.sigmoid(logits), ())
684-
return (-torch.log(p[obs_t, 1]).sum() -
685-
torch.log(p[obs_f, 0]).sum())
691+
return (-torch.log(_clamp_log_input(p[obs_t, 1])).sum() -
692+
torch.log(_clamp_log_input(p[obs_f, 0])).sum())
686693

687694
return _loss_from_logits, _loss_from_probs
688695

@@ -947,8 +954,10 @@ def _build_codegen_loss(cls,
947954
normed = final_value / final_value.sum(dim=1, keepdim=True)
948955
# Compute the loss eagerly; we can't fold it because
949956
# autograd needs a path back to noise_probs.
950-
ce = (-torch.log(normed[obs_idx_true, 1]).sum() -
951-
torch.log(normed[obs_idx_false, 0]).sum())
957+
ce = (
958+
-torch.log(_clamp_log_input(normed[obs_idx_true, 1])).sum()
959+
-
960+
torch.log(_clamp_log_input(normed[obs_idx_false, 0])).sum())
952961
closure_vars["_LOSS"] = ce
953962
body.append(" return _LOSS + 0.0 * noise_probs.sum()")
954963
runtime_lines = []
@@ -972,9 +981,11 @@ def _build_codegen_loss(cls,
972981
body.append(f" _out = {final_name}")
973982
body.append(" _z0 = _out[:, 0]")
974983
body.append(" _z1 = _out[:, 1]")
975-
body.append(" return (torch.log(_z0 + _z1).sum() "
976-
"- torch.log(_z1[_OBS_T]).sum() "
977-
"- torch.log(_z0[_OBS_F]).sum())")
984+
body.append(" _eps = torch.finfo(_z0.dtype).tiny")
985+
body.append(
986+
" return (torch.log((_z0 + _z1).clamp_min(_eps)).sum() "
987+
"- torch.log(_z1[_OBS_T].clamp_min(_eps)).sum() "
988+
"- torch.log(_z0[_OBS_F].clamp_min(_eps)).sum())")
978989

979990
return cls._compile_codegen_source(body, closure_vars, n_folded,
980991
len(runtime_lines), "loss")
@@ -1002,9 +1013,9 @@ def cross_entropy_loss(self) -> torch.Tensor:
10021013
"""Cross-entropy loss over the syndrome batch.
10031014
10041015
Returns a differentiable scalar; call ``.backward()`` to obtain
1005-
gradients w.r.t. :attr:`noise_params`. The fused codegen omits
1006-
the ``log`` guard, so a prior at ``0`` or ``1`` yields ``NaN``
1007-
gradients — see :attr:`noise_params` for safe training patterns.
1016+
gradients w.r.t. :attr:`noise_params`. Log inputs are floored to
1017+
avoid non-finite values from roundoff; use the safe training
1018+
patterns in :attr:`noise_params` to keep probabilities in range.
10081019
"""
10091020
return self._compiled_loss_from_probs(self._noise_probs,
10101021
self._syndrome_tuple)

libs/qec/python/tests/test_nm_optimizer.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -385,6 +385,35 @@ def test_boundary_priors_clamped_with_warning(device, dtype):
385385
assert torch.isfinite(loss)
386386

387387

388+
@pytest.mark.skipif(not _gpu_available(), reason="CUDA not available")
389+
def test_boundary_priors_finite_with_tf32_matmul_enabled():
390+
"""Regression: global TF32 matmul must not make boundary-prior loss NaN."""
391+
old_precision = torch.get_float32_matmul_precision()
392+
try:
393+
# Mirror solver imports that enable TF32 process-wide during full
394+
# pytest collection; the QEC test must be self-contained.
395+
torch.set_float32_matmul_precision("high")
396+
H, logical, _ = _simple_repetition_code()
397+
boundary_priors = [0.0, 0.5, 1.0]
398+
syn, flips = _sample_synthetic_dataset(H,
399+
logical, [0.1, 0.2, 0.3],
400+
num_shots=8,
401+
rng=np.random.default_rng(20))
402+
with pytest.warns(UserWarning, match=r"Clamped \d+/\d+"):
403+
opt = _make_opt(H,
404+
logical,
405+
boundary_priors,
406+
syn,
407+
flips,
408+
device="cuda",
409+
dtype="float32")
410+
411+
loss = opt.cross_entropy_loss()
412+
assert torch.isfinite(loss)
413+
finally:
414+
torch.set_float32_matmul_precision(old_precision)
415+
416+
388417
@pytest.mark.parametrize("device", _device_params())
389418
def test_non_finite_priors_raise(device):
390419
"""Non-finite priors are caller bugs, not stability concerns - raise."""

0 commit comments

Comments
 (0)