@@ -440,3 +440,100 @@ def test_dataset_idx_matches_sampler_intent(self, two_mmap_dirs):
440440 # All indices in ds_b range → dataset_idx=1
441441 for i in range (len (ds_a ), len (mix )):
442442 assert mix [i ]["dataset_idx" ] == 1
443+
444+
445+ # ---------------------------------------------------------------------------
446+ # Zero-weight validation
447+ # ---------------------------------------------------------------------------
448+
449+
450+ class TestMixtureSamplerZeroWeights :
451+ """Both normalization branches must reject all-zero weights up-front.
452+
453+ Previously, ``temperature == 1.0`` crashed with ZeroDivisionError at the
454+ ``/`` inside __init__/update_weights, and ``temperature != 1.0`` silently
455+ degraded to uniform sampling via the ``max(w, 1e-12)`` clamp. Both are
456+ config errors and should raise the same ValueError with a clear message.
457+ """
458+
459+ def test_init_rejects_all_zero_weights (self ):
460+ with pytest .raises (ValueError , match = "sum to > 0" ):
461+ MixtureSampler (
462+ cumulative_sizes = [0 , 10 , 20 ],
463+ weights = [0.0 , 0.0 ],
464+ num_replicas = 1 ,
465+ rank = 0 ,
466+ )
467+
468+ def test_init_rejects_empty_weights (self ):
469+ with pytest .raises (ValueError , match = "empty" ):
470+ MixtureSampler (
471+ cumulative_sizes = [0 ],
472+ weights = [],
473+ num_replicas = 1 ,
474+ rank = 0 ,
475+ )
476+
477+ def test_init_rejects_negative_weights (self ):
478+ with pytest .raises (ValueError , match = "non-negative" ):
479+ MixtureSampler (
480+ cumulative_sizes = [0 , 10 , 20 ],
481+ weights = [1.0 , - 0.5 ],
482+ num_replicas = 1 ,
483+ rank = 0 ,
484+ )
485+
486+ def test_init_rejects_all_zero_with_temperature (self ):
487+ """Regression guard: temperature branch previously silently produced uniform."""
488+ with pytest .raises (ValueError , match = "sum to > 0" ):
489+ MixtureSampler (
490+ cumulative_sizes = [0 , 10 , 20 ],
491+ weights = [0.0 , 0.0 ],
492+ num_replicas = 1 ,
493+ rank = 0 ,
494+ temperature = 2.0 ,
495+ )
496+
497+ def test_update_weights_rejects_all_zero (self ):
498+ sampler = MixtureSampler (
499+ cumulative_sizes = [0 , 10 , 20 ],
500+ weights = [0.5 , 0.5 ],
501+ num_replicas = 1 ,
502+ rank = 0 ,
503+ )
504+ with pytest .raises (ValueError , match = "sum to > 0" ):
505+ sampler .update_weights ([0.0 , 0.0 ])
506+
507+ def test_update_weights_rejects_all_zero_with_temperature (self ):
508+ """Regression guard: the temperature != 1.0 branch must also reject
509+ all-zero weights."""
510+ sampler = MixtureSampler (
511+ cumulative_sizes = [0 , 10 , 20 ],
512+ weights = [0.5 , 0.5 ],
513+ num_replicas = 1 ,
514+ rank = 0 ,
515+ temperature = 2.0 ,
516+ )
517+ with pytest .raises (ValueError , match = "sum to > 0" ):
518+ sampler .update_weights ([0.0 , 0.0 ], temperature = 2.0 )
519+
520+ def test_update_weights_accepts_single_zero_with_nonzero_companion (self ):
521+ """A single dataset can be zeroed out as long as at least one stays positive."""
522+ sampler = MixtureSampler (
523+ cumulative_sizes = [0 , 10 , 20 ],
524+ weights = [0.5 , 0.5 ],
525+ num_replicas = 1 ,
526+ rank = 0 ,
527+ )
528+ sampler .update_weights ([0.0 , 1.0 ]) # must not raise
529+ assert sampler ._probs == [0.0 , 1.0 ]
530+
531+ def test_update_weights_rejects_negative (self ):
532+ sampler = MixtureSampler (
533+ cumulative_sizes = [0 , 10 , 20 ],
534+ weights = [0.5 , 0.5 ],
535+ num_replicas = 1 ,
536+ rank = 0 ,
537+ )
538+ with pytest .raises (ValueError , match = "non-negative" ):
539+ sampler .update_weights ([1.0 , - 0.2 ])
0 commit comments