4646def _validate_and_clamp_priors (noise_model : Any , dtype : str ) -> list [float ]:
4747 """Validate noise priors and clamp them into ``[eps, 1 - eps]``.
4848
49- The fused cross-entropy reduction in
50- :meth:`NMOptimizer.cross_entropy_loss` has no ``log`` guard, so a
51- prior of exactly ``0.0`` or ``1.0`` makes the contraction emit a
52- zero whose log is ``-inf`` and whose gradient is ``NaN``; training
53- silently diverges. Stim DEMs occasionally emit ``p=1.0``
54- (deterministic detectors) or ``p<1e-15`` ( underflow), so we
55- intercept here rather than force every caller to clamp.
49+ The cross-entropy reduction floors log inputs so roundoff-induced
50+ zero or negative values do not create non-finite losses. Priors at
51+ exactly ``0.0`` or ``1.0`` are still clamped because they can
52+ saturate loss terms and make gradients uninformative. Stim DEMs
53+ occasionally emit ``p=1.0`` (deterministic detectors) or ``p<1e-15 ``
54+ (underflow), so we intercept here rather than force every caller to
55+ clamp.
5656
5757 Behaviour mirrors :class:`torch.nn.BCELoss`-style stable wrappers:
5858
@@ -92,15 +92,20 @@ def _validate_and_clamp_priors(noise_model: Any, dtype: str) -> list[float]:
9292 warnings .warn (
9393 f"Clamped { int (out_of_range .sum ())} /{ len (arr )} NMOptimizer "
9494 f"priors into [{ eps } , { 1.0 - eps } ] for numerical stability; "
95- f"values at or outside the (0, 1) boundary produce -inf "
96- f"cross-entropy loss and NaN gradients in the fused codegen ." ,
95+ f"values at or outside the (0, 1) boundary can saturate "
96+ f"cross-entropy terms and make gradients uninformative ." ,
9797 UserWarning ,
9898 stacklevel = 3 ,
9999 )
100100 arr = np .clip (arr , eps , 1.0 - eps )
101101 return arr .tolist ()
102102
103103
104+ def _clamp_log_input (x : torch .Tensor ) -> torch .Tensor :
105+ """Floor log inputs after roundoff-induced non-positive values."""
106+ return x .clamp_min (torch .finfo (x .dtype ).tiny )
107+
108+
104109def remap_eq_to_ascii (eq : str ) -> str :
105110 """Rewrite an einsum equation so every label is in ``[a-zA-Z]``.
106111
@@ -158,8 +163,9 @@ class NMOptimizer(TensorNetworkDecoder):
158163
159164 Priors are clamped into ``[eps, 1 - eps]`` only at construction;
160165 an unconstrained optimiser step on :attr:`noise_params` can push
161- them past the boundary, after which :meth:`cross_entropy_loss`
162- returns ``NaN`` gradients. Prefer logit-space training via
166+ them outside the probability interval. The loss is floored for
167+ finiteness, but probability-space training can then saturate or
168+ optimise invalid probabilities. Prefer logit-space training via
163169 :func:`make_compiled_step` (shown below), or clamp the tensor
164170 under :func:`torch.no_grad` after each step.
165171
@@ -370,9 +376,10 @@ def noise_params(self) -> list[torch.Tensor]:
370376 """Trainable noise probabilities, ready for ``torch.optim``.
371377
372378 Clamped to ``[eps, 1 - eps]`` only at construction; an
373- unconstrained step can push past the boundary and produce
374- ``NaN`` gradients on the next :meth:`cross_entropy_loss`.
375- See the class warning for safe training patterns.
379+ unconstrained step can push outside the probability interval.
380+ The next :meth:`cross_entropy_loss` remains finite, but training
381+ can saturate or optimise invalid probabilities. See the class
382+ warning for safe training patterns.
376383 """
377384 return [self ._noise_probs ]
378385
@@ -665,24 +672,24 @@ def _build_loss_wrapped(self):
665672
666673 def _loss_from_probs (noise_probs , syndromes ):
667674 p = predict_fn (noise_probs , syndromes )
668- return (- torch .log (p [obs_t , 1 ]).sum () -
669- torch .log (p [obs_f , 0 ]).sum ())
675+ return (- torch .log (_clamp_log_input ( p [obs_t , 1 ]) ).sum () -
676+ torch .log (_clamp_log_input ( p [obs_f , 0 ]) ).sum ())
670677
671678 def _loss_from_logits (logits , syndromes ):
672679 p = predict_fn (torch .sigmoid (logits ), syndromes )
673- return (- torch .log (p [obs_t , 1 ]).sum () -
674- torch .log (p [obs_f , 0 ]).sum ())
680+ return (- torch .log (_clamp_log_input ( p [obs_t , 1 ]) ).sum () -
681+ torch .log (_clamp_log_input ( p [obs_f , 0 ]) ).sum ())
675682 else :
676683
677684 def _loss_from_probs (noise_probs , syndromes = ()):
678685 p = predict_fn (noise_probs , ())
679- return (- torch .log (p [obs_t , 1 ]).sum () -
680- torch .log (p [obs_f , 0 ]).sum ())
686+ return (- torch .log (_clamp_log_input ( p [obs_t , 1 ]) ).sum () -
687+ torch .log (_clamp_log_input ( p [obs_f , 0 ]) ).sum ())
681688
682689 def _loss_from_logits (logits , syndromes = ()):
683690 p = predict_fn (torch .sigmoid (logits ), ())
684- return (- torch .log (p [obs_t , 1 ]).sum () -
685- torch .log (p [obs_f , 0 ]).sum ())
691+ return (- torch .log (_clamp_log_input ( p [obs_t , 1 ]) ).sum () -
692+ torch .log (_clamp_log_input ( p [obs_f , 0 ]) ).sum ())
686693
687694 return _loss_from_logits , _loss_from_probs
688695
@@ -947,8 +954,10 @@ def _build_codegen_loss(cls,
947954 normed = final_value / final_value .sum (dim = 1 , keepdim = True )
948955 # Compute the loss eagerly; we can't fold it because
949956 # autograd needs a path back to noise_probs.
950- ce = (- torch .log (normed [obs_idx_true , 1 ]).sum () -
951- torch .log (normed [obs_idx_false , 0 ]).sum ())
957+ ce = (
958+ - torch .log (_clamp_log_input (normed [obs_idx_true , 1 ])).sum ()
959+ -
960+ torch .log (_clamp_log_input (normed [obs_idx_false , 0 ])).sum ())
952961 closure_vars ["_LOSS" ] = ce
953962 body .append (" return _LOSS + 0.0 * noise_probs.sum()" )
954963 runtime_lines = []
@@ -972,9 +981,11 @@ def _build_codegen_loss(cls,
972981 body .append (f" _out = { final_name } " )
973982 body .append (" _z0 = _out[:, 0]" )
974983 body .append (" _z1 = _out[:, 1]" )
975- body .append (" return (torch.log(_z0 + _z1).sum() "
976- "- torch.log(_z1[_OBS_T]).sum() "
977- "- torch.log(_z0[_OBS_F]).sum())" )
984+ body .append (" _eps = torch.finfo(_z0.dtype).tiny" )
985+ body .append (
986+ " return (torch.log((_z0 + _z1).clamp_min(_eps)).sum() "
987+ "- torch.log(_z1[_OBS_T].clamp_min(_eps)).sum() "
988+ "- torch.log(_z0[_OBS_F].clamp_min(_eps)).sum())" )
978989
979990 return cls ._compile_codegen_source (body , closure_vars , n_folded ,
980991 len (runtime_lines ), "loss" )
@@ -1002,9 +1013,9 @@ def cross_entropy_loss(self) -> torch.Tensor:
10021013 """Cross-entropy loss over the syndrome batch.
10031014
10041015 Returns a differentiable scalar; call ``.backward()`` to obtain
1005- gradients w.r.t. :attr:`noise_params`. The fused codegen omits
1006- the ``log`` guard, so a prior at ``0`` or ``1`` yields ``NaN``
1007- gradients — see :attr:`noise_params` for safe training patterns .
1016+ gradients w.r.t. :attr:`noise_params`. Log inputs are floored to
1017+ avoid non-finite values from roundoff; use the safe training
1018+ patterns in :attr:`noise_params` to keep probabilities in range .
10081019 """
10091020 return self ._compiled_loss_from_probs (self ._noise_probs ,
10101021 self ._syndrome_tuple )
0 commit comments