Skip to content

MixtureSampler accepts all-zero weights inconsistently #53

@mmshad

Description

@mmshad

MixtureSampler disagrees with itself on all-zero input. The two
normalization branches (selected by temperature) behave differently:

  • temperature == 1.0: divides by sum(weights) and raises
    ZeroDivisionError mid-run.
  • temperature != 1.0: clamps each weight with max(w, 1e-12) and silently
    produces uniform sampling.

Both are reachable from the phase-annealing path in scripts/train.py.
TrainingPhase.dataset_weights validates >= 0, not > 0, to intentionally
allow dropping individual datasets. A user who zeroes out every dataset in
a phase transition gets a mid-run crash on one branch and silent
training-quality drift on the other.

Fix

Add _validate_weights(weights, context) in kempnerforge/data/sampler.py
and call it from both MixtureSampler.__init__ and
MixtureSampler.update_weights. The helper rejects:

  • empty weight lists,
  • any negative entry,
  • all-zero lists (sum <= 0),

with a ValueError that prints the offending list and the call site.

The TrainingPhase schema-level validator stays permissive because it
cannot know the effective mixture at config-parse time (phases only merge
at runtime). The sampler-level gate is where the hard check belongs.

Coverage

tests/unit/test_mixing.py::TestMixtureSamplerZeroWeights adds 8 tests
covering both constructors, both temperature branches, negative weights,
empty list, and a positive case where a single dataset is zeroed out but
at least one stays positive.

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions