Skip to content

Commit b84d70e

Browse files
authored
Merge pull request #54 from KempnerInstitute/mixture-sampler-zero-weights
Reject all-zero weights in MixtureSampler
2 parents 6b3df7f + 489cdd0 commit b84d70e

2 files changed

Lines changed: 122 additions & 0 deletions

File tree

kempnerforge/data/sampler.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,27 @@
1515
from torch.utils.data import Dataset, Sampler
1616

1717

18+
def _validate_weights(weights: list[float], context: str) -> None:
19+
"""Fail fast on empty, negative, or all-zero weight lists.
20+
21+
The two normalization branches in ``MixtureSampler`` disagree on all-zero
22+
input: the ``temperature == 1.0`` branch divides by ``sum(weights)`` and
23+
raises ``ZeroDivisionError``; the ``temperature != 1.0`` branch clamps via
24+
``max(w, 1e-12)`` and silently produces uniform sampling. Reject both
25+
cases up-front with a clear error so misconfigured phase transitions
26+
surface immediately instead of crashing mid-run or drifting silently.
27+
"""
28+
if not weights:
29+
raise ValueError(f"{context}: weights list is empty")
30+
if any(w < 0 for w in weights):
31+
raise ValueError(f"{context}: weights must be non-negative, got {weights}")
32+
if sum(weights) <= 0:
33+
raise ValueError(
34+
f"{context}: weights must sum to > 0 (at least one dataset must have "
35+
f"weight > 0), got {weights}"
36+
)
37+
38+
1839
class DistributedSampler(Sampler[int]):
1940
"""Deterministic distributed sampler with skip-ahead support.
2041
@@ -159,6 +180,8 @@ def __init__(
159180
self._dataset_sizes = [cumulative_sizes[i + 1] - cumulative_sizes[i] for i in range(n)]
160181
self._offsets = list(cumulative_sizes[:n])
161182

183+
_validate_weights(weights, "MixtureSampler(weights=...)")
184+
162185
# Apply temperature scaling and normalize
163186
if temperature != 1.0:
164187
import math as _math
@@ -285,6 +308,8 @@ def update_weights(self, weights: list[float], temperature: float = 1.0) -> None
285308
if len(weights) != n:
286309
raise ValueError(f"Expected {n} weights, got {len(weights)}")
287310

311+
_validate_weights(weights, "MixtureSampler.update_weights")
312+
288313
# Apply temperature scaling and normalize (same logic as __init__)
289314
if temperature != 1.0:
290315
import math as _math

tests/unit/test_mixing.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -440,3 +440,100 @@ def test_dataset_idx_matches_sampler_intent(self, two_mmap_dirs):
440440
# All indices in ds_b range → dataset_idx=1
441441
for i in range(len(ds_a), len(mix)):
442442
assert mix[i]["dataset_idx"] == 1
443+
444+
445+
# ---------------------------------------------------------------------------
446+
# Zero-weight validation
447+
# ---------------------------------------------------------------------------
448+
449+
450+
class TestMixtureSamplerZeroWeights:
451+
"""Both normalization branches must reject all-zero weights up-front.
452+
453+
Previously, ``temperature == 1.0`` crashed with ZeroDivisionError at the
454+
``/`` inside __init__/update_weights, and ``temperature != 1.0`` silently
455+
degraded to uniform sampling via the ``max(w, 1e-12)`` clamp. Both are
456+
config errors and should raise the same ValueError with a clear message.
457+
"""
458+
459+
def test_init_rejects_all_zero_weights(self):
460+
with pytest.raises(ValueError, match="sum to > 0"):
461+
MixtureSampler(
462+
cumulative_sizes=[0, 10, 20],
463+
weights=[0.0, 0.0],
464+
num_replicas=1,
465+
rank=0,
466+
)
467+
468+
def test_init_rejects_empty_weights(self):
469+
with pytest.raises(ValueError, match="empty"):
470+
MixtureSampler(
471+
cumulative_sizes=[0],
472+
weights=[],
473+
num_replicas=1,
474+
rank=0,
475+
)
476+
477+
def test_init_rejects_negative_weights(self):
478+
with pytest.raises(ValueError, match="non-negative"):
479+
MixtureSampler(
480+
cumulative_sizes=[0, 10, 20],
481+
weights=[1.0, -0.5],
482+
num_replicas=1,
483+
rank=0,
484+
)
485+
486+
def test_init_rejects_all_zero_with_temperature(self):
487+
"""Regression guard: temperature branch previously silently produced uniform."""
488+
with pytest.raises(ValueError, match="sum to > 0"):
489+
MixtureSampler(
490+
cumulative_sizes=[0, 10, 20],
491+
weights=[0.0, 0.0],
492+
num_replicas=1,
493+
rank=0,
494+
temperature=2.0,
495+
)
496+
497+
def test_update_weights_rejects_all_zero(self):
498+
sampler = MixtureSampler(
499+
cumulative_sizes=[0, 10, 20],
500+
weights=[0.5, 0.5],
501+
num_replicas=1,
502+
rank=0,
503+
)
504+
with pytest.raises(ValueError, match="sum to > 0"):
505+
sampler.update_weights([0.0, 0.0])
506+
507+
def test_update_weights_rejects_all_zero_with_temperature(self):
508+
"""Regression guard: the temperature != 1.0 branch must also reject
509+
all-zero weights."""
510+
sampler = MixtureSampler(
511+
cumulative_sizes=[0, 10, 20],
512+
weights=[0.5, 0.5],
513+
num_replicas=1,
514+
rank=0,
515+
temperature=2.0,
516+
)
517+
with pytest.raises(ValueError, match="sum to > 0"):
518+
sampler.update_weights([0.0, 0.0], temperature=2.0)
519+
520+
def test_update_weights_accepts_single_zero_with_nonzero_companion(self):
521+
"""A single dataset can be zeroed out as long as at least one stays positive."""
522+
sampler = MixtureSampler(
523+
cumulative_sizes=[0, 10, 20],
524+
weights=[0.5, 0.5],
525+
num_replicas=1,
526+
rank=0,
527+
)
528+
sampler.update_weights([0.0, 1.0]) # must not raise
529+
assert sampler._probs == [0.0, 1.0]
530+
531+
def test_update_weights_rejects_negative(self):
532+
sampler = MixtureSampler(
533+
cumulative_sizes=[0, 10, 20],
534+
weights=[0.5, 0.5],
535+
num_replicas=1,
536+
rank=0,
537+
)
538+
with pytest.raises(ValueError, match="non-negative"):
539+
sampler.update_weights([1.0, -0.2])

0 commit comments

Comments
 (0)