diff --git a/kempnerforge/data/sampler.py b/kempnerforge/data/sampler.py index 6b7281b..d8e6fc3 100644 --- a/kempnerforge/data/sampler.py +++ b/kempnerforge/data/sampler.py @@ -15,6 +15,27 @@ from torch.utils.data import Dataset, Sampler +def _validate_weights(weights: list[float], context: str) -> None: + """Fail fast on empty, negative, or all-zero weight lists. + + The two normalization branches in ``MixtureSampler`` disagree on all-zero + input: the ``temperature == 1.0`` branch divides by ``sum(weights)`` and + raises ``ZeroDivisionError``; the ``temperature != 1.0`` branch clamps via + ``max(w, 1e-12)`` and silently produces uniform sampling. Reject both + cases up-front with a clear error so misconfigured phase transitions + surface immediately instead of crashing mid-run or drifting silently. + """ + if not weights: + raise ValueError(f"{context}: weights list is empty") + if any(w < 0 for w in weights): + raise ValueError(f"{context}: weights must be non-negative, got {weights}") + if sum(weights) <= 0: + raise ValueError( + f"{context}: weights must sum to > 0 (at least one dataset must have " + f"weight > 0), got {weights}" + ) + + class DistributedSampler(Sampler[int]): """Deterministic distributed sampler with skip-ahead support. @@ -159,6 +180,8 @@ def __init__( self._dataset_sizes = [cumulative_sizes[i + 1] - cumulative_sizes[i] for i in range(n)] self._offsets = list(cumulative_sizes[:n]) + _validate_weights(weights, "MixtureSampler(weights=...)") + # Apply temperature scaling and normalize if temperature != 1.0: import math as _math @@ -285,6 +308,8 @@ def update_weights(self, weights: list[float], temperature: float = 1.0) -> None if len(weights) != n: raise ValueError(f"Expected {n} weights, got {len(weights)}") + _validate_weights(weights, "MixtureSampler.update_weights") + # Apply temperature scaling and normalize (same logic as __init__) if temperature != 1.0: import math as _math diff --git a/tests/unit/test_mixing.py b/tests/unit/test_mixing.py index 98fb677..b9e12f8 100644 --- a/tests/unit/test_mixing.py +++ b/tests/unit/test_mixing.py @@ -440,3 +440,100 @@ def test_dataset_idx_matches_sampler_intent(self, two_mmap_dirs): # All indices in ds_b range → dataset_idx=1 for i in range(len(ds_a), len(mix)): assert mix[i]["dataset_idx"] == 1 + + +# --------------------------------------------------------------------------- +# Zero-weight validation +# --------------------------------------------------------------------------- + + +class TestMixtureSamplerZeroWeights: + """Both normalization branches must reject all-zero weights up-front. + + Previously, ``temperature == 1.0`` crashed with ZeroDivisionError at the + ``/`` inside __init__/update_weights, and ``temperature != 1.0`` silently + degraded to uniform sampling via the ``max(w, 1e-12)`` clamp. Both are + config errors and should raise the same ValueError with a clear message. + """ + + def test_init_rejects_all_zero_weights(self): + with pytest.raises(ValueError, match="sum to > 0"): + MixtureSampler( + cumulative_sizes=[0, 10, 20], + weights=[0.0, 0.0], + num_replicas=1, + rank=0, + ) + + def test_init_rejects_empty_weights(self): + with pytest.raises(ValueError, match="empty"): + MixtureSampler( + cumulative_sizes=[0], + weights=[], + num_replicas=1, + rank=0, + ) + + def test_init_rejects_negative_weights(self): + with pytest.raises(ValueError, match="non-negative"): + MixtureSampler( + cumulative_sizes=[0, 10, 20], + weights=[1.0, -0.5], + num_replicas=1, + rank=0, + ) + + def test_init_rejects_all_zero_with_temperature(self): + """Regression guard: temperature branch previously silently produced uniform.""" + with pytest.raises(ValueError, match="sum to > 0"): + MixtureSampler( + cumulative_sizes=[0, 10, 20], + weights=[0.0, 0.0], + num_replicas=1, + rank=0, + temperature=2.0, + ) + + def test_update_weights_rejects_all_zero(self): + sampler = MixtureSampler( + cumulative_sizes=[0, 10, 20], + weights=[0.5, 0.5], + num_replicas=1, + rank=0, + ) + with pytest.raises(ValueError, match="sum to > 0"): + sampler.update_weights([0.0, 0.0]) + + def test_update_weights_rejects_all_zero_with_temperature(self): + """Regression guard: the temperature != 1.0 branch must also reject + all-zero weights.""" + sampler = MixtureSampler( + cumulative_sizes=[0, 10, 20], + weights=[0.5, 0.5], + num_replicas=1, + rank=0, + temperature=2.0, + ) + with pytest.raises(ValueError, match="sum to > 0"): + sampler.update_weights([0.0, 0.0], temperature=2.0) + + def test_update_weights_accepts_single_zero_with_nonzero_companion(self): + """A single dataset can be zeroed out as long as at least one stays positive.""" + sampler = MixtureSampler( + cumulative_sizes=[0, 10, 20], + weights=[0.5, 0.5], + num_replicas=1, + rank=0, + ) + sampler.update_weights([0.0, 1.0]) # must not raise + assert sampler._probs == [0.0, 1.0] + + def test_update_weights_rejects_negative(self): + sampler = MixtureSampler( + cumulative_sizes=[0, 10, 20], + weights=[0.5, 0.5], + num_replicas=1, + rank=0, + ) + with pytest.raises(ValueError, match="non-negative"): + sampler.update_weights([1.0, -0.2])