Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions kempnerforge/data/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,27 @@
from torch.utils.data import Dataset, Sampler


def _validate_weights(weights: list[float], context: str) -> None:
"""Fail fast on empty, negative, or all-zero weight lists.

The two normalization branches in ``MixtureSampler`` disagree on all-zero
input: the ``temperature == 1.0`` branch divides by ``sum(weights)`` and
raises ``ZeroDivisionError``; the ``temperature != 1.0`` branch clamps via
``max(w, 1e-12)`` and silently produces uniform sampling. Reject both
cases up-front with a clear error so misconfigured phase transitions
surface immediately instead of crashing mid-run or drifting silently.
"""
if not weights:
raise ValueError(f"{context}: weights list is empty")
if any(w < 0 for w in weights):
raise ValueError(f"{context}: weights must be non-negative, got {weights}")
if sum(weights) <= 0:
raise ValueError(
f"{context}: weights must sum to > 0 (at least one dataset must have "
f"weight > 0), got {weights}"
)


class DistributedSampler(Sampler[int]):
"""Deterministic distributed sampler with skip-ahead support.

Expand Down Expand Up @@ -159,6 +180,8 @@ def __init__(
self._dataset_sizes = [cumulative_sizes[i + 1] - cumulative_sizes[i] for i in range(n)]
self._offsets = list(cumulative_sizes[:n])

_validate_weights(weights, "MixtureSampler(weights=...)")

# Apply temperature scaling and normalize
if temperature != 1.0:
import math as _math
Expand Down Expand Up @@ -285,6 +308,8 @@ def update_weights(self, weights: list[float], temperature: float = 1.0) -> None
if len(weights) != n:
raise ValueError(f"Expected {n} weights, got {len(weights)}")

_validate_weights(weights, "MixtureSampler.update_weights")

# Apply temperature scaling and normalize (same logic as __init__)
if temperature != 1.0:
import math as _math
Expand Down
97 changes: 97 additions & 0 deletions tests/unit/test_mixing.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,3 +440,100 @@ def test_dataset_idx_matches_sampler_intent(self, two_mmap_dirs):
# All indices in ds_b range → dataset_idx=1
for i in range(len(ds_a), len(mix)):
assert mix[i]["dataset_idx"] == 1


# ---------------------------------------------------------------------------
# Zero-weight validation
# ---------------------------------------------------------------------------


class TestMixtureSamplerZeroWeights:
"""Both normalization branches must reject all-zero weights up-front.

Previously, ``temperature == 1.0`` crashed with ZeroDivisionError at the
``/`` inside __init__/update_weights, and ``temperature != 1.0`` silently
degraded to uniform sampling via the ``max(w, 1e-12)`` clamp. Both are
config errors and should raise the same ValueError with a clear message.
"""

def test_init_rejects_all_zero_weights(self):
with pytest.raises(ValueError, match="sum to > 0"):
MixtureSampler(
cumulative_sizes=[0, 10, 20],
weights=[0.0, 0.0],
num_replicas=1,
rank=0,
)

def test_init_rejects_empty_weights(self):
with pytest.raises(ValueError, match="empty"):
MixtureSampler(
cumulative_sizes=[0],
weights=[],
num_replicas=1,
rank=0,
)

def test_init_rejects_negative_weights(self):
with pytest.raises(ValueError, match="non-negative"):
MixtureSampler(
cumulative_sizes=[0, 10, 20],
weights=[1.0, -0.5],
num_replicas=1,
rank=0,
)

def test_init_rejects_all_zero_with_temperature(self):
"""Regression guard: temperature branch previously silently produced uniform."""
with pytest.raises(ValueError, match="sum to > 0"):
MixtureSampler(
cumulative_sizes=[0, 10, 20],
weights=[0.0, 0.0],
num_replicas=1,
rank=0,
temperature=2.0,
)

def test_update_weights_rejects_all_zero(self):
sampler = MixtureSampler(
cumulative_sizes=[0, 10, 20],
weights=[0.5, 0.5],
num_replicas=1,
rank=0,
)
with pytest.raises(ValueError, match="sum to > 0"):
sampler.update_weights([0.0, 0.0])

def test_update_weights_rejects_all_zero_with_temperature(self):
"""Regression guard: the temperature != 1.0 branch must also reject
all-zero weights."""
sampler = MixtureSampler(
cumulative_sizes=[0, 10, 20],
weights=[0.5, 0.5],
num_replicas=1,
rank=0,
temperature=2.0,
)
with pytest.raises(ValueError, match="sum to > 0"):
sampler.update_weights([0.0, 0.0], temperature=2.0)

def test_update_weights_accepts_single_zero_with_nonzero_companion(self):
"""A single dataset can be zeroed out as long as at least one stays positive."""
sampler = MixtureSampler(
cumulative_sizes=[0, 10, 20],
weights=[0.5, 0.5],
num_replicas=1,
rank=0,
)
sampler.update_weights([0.0, 1.0]) # must not raise
assert sampler._probs == [0.0, 1.0]

def test_update_weights_rejects_negative(self):
sampler = MixtureSampler(
cumulative_sizes=[0, 10, 20],
weights=[0.5, 0.5],
num_replicas=1,
rank=0,
)
with pytest.raises(ValueError, match="non-negative"):
sampler.update_weights([1.0, -0.2])
Loading