2323 DataCollatorWithFlattening ,
2424 MLMDataCollatorWithFlattening ,
2525 TokenPackingDataset ,
26- split_sample_by_num_tokens ,
26+ _split_sample_by_num_tokens ,
2727)
2828
2929
@@ -494,36 +494,36 @@ def __iter__(self):
494494 assert sum (len (sample ["input_ids" ]) for sample in batches [0 ]) == 90
495495
496496
497- def test_split_sample_by_num_tokens_basic ():
498- """Test split_sample_by_num_tokens with basic input_ids."""
497+ def test__split_sample_by_num_tokens_basic ():
498+ """Test _split_sample_by_num_tokens with basic input_ids."""
499499 sample = {"input_ids" : [0 , 5 , 6 , 7 , 8 , 9 , 2 ]}
500- first , remaining = split_sample_by_num_tokens (sample , 3 )
500+ first , remaining = _split_sample_by_num_tokens (sample , 3 )
501501
502502 assert first ["input_ids" ] == [0 , 5 , 6 ]
503503 assert remaining ["input_ids" ] == [7 , 8 , 9 , 2 ]
504504 assert len (first ["input_ids" ]) == 3
505505 assert len (remaining ["input_ids" ]) == 4
506506
507507
508- def test_split_sample_by_num_tokens_with_labels ():
509- """Test split_sample_by_num_tokens with input_ids and labels."""
508+ def test__split_sample_by_num_tokens_with_labels ():
509+ """Test _split_sample_by_num_tokens with input_ids and labels."""
510510 sample = {"input_ids" : [0 , 5 , 6 , 7 , 8 , 2 ], "labels" : [0 , 5 , 6 , 7 , 8 , 2 ]}
511- first , remaining = split_sample_by_num_tokens (sample , 3 )
511+ first , remaining = _split_sample_by_num_tokens (sample , 3 )
512512
513513 assert first ["input_ids" ] == [0 , 5 , 6 ]
514514 assert first ["labels" ] == [0 , 5 , 6 ]
515515 assert remaining ["input_ids" ] == [7 , 8 , 2 ]
516516 assert remaining ["labels" ] == [7 , 8 , 2 ]
517517
518518
519- def test_split_sample_by_num_tokens_with_attention_mask ():
520- """Test split_sample_by_num_tokens with input_ids, attention_mask, and labels."""
519+ def test__split_sample_by_num_tokens_with_attention_mask ():
520+ """Test _split_sample_by_num_tokens with input_ids, attention_mask, and labels."""
521521 sample = {
522522 "input_ids" : [0 , 5 , 6 , 7 , 8 , 2 ],
523523 "attention_mask" : [1 , 1 , 1 , 1 , 1 , 1 ],
524524 "labels" : [0 , 5 , 6 , 7 , 8 , 2 ],
525525 }
526- first , remaining = split_sample_by_num_tokens (sample , 4 )
526+ first , remaining = _split_sample_by_num_tokens (sample , 4 )
527527
528528 assert first ["input_ids" ] == [0 , 5 , 6 , 7 ]
529529 assert first ["attention_mask" ] == [1 , 1 , 1 , 1 ]
@@ -533,14 +533,14 @@ def test_split_sample_by_num_tokens_with_attention_mask():
533533 assert remaining ["labels" ] == [8 , 2 ]
534534
535535
536- def test_split_sample_by_num_tokens_with_token_type_ids ():
537- """Test split_sample_by_num_tokens with token_type_ids."""
536+ def test__split_sample_by_num_tokens_with_token_type_ids ():
537+ """Test _split_sample_by_num_tokens with token_type_ids."""
538538 sample = {
539539 "input_ids" : [0 , 5 , 6 , 7 , 8 , 2 ],
540540 "token_type_ids" : [0 , 0 , 0 , 1 , 1 , 1 ],
541541 "labels" : [0 , 5 , 6 , 7 , 8 , 2 ],
542542 }
543- first , remaining = split_sample_by_num_tokens (sample , 3 )
543+ first , remaining = _split_sample_by_num_tokens (sample , 3 )
544544
545545 assert first ["input_ids" ] == [0 , 5 , 6 ]
546546 assert first ["token_type_ids" ] == [0 , 0 , 0 ]
@@ -550,14 +550,14 @@ def test_split_sample_by_num_tokens_with_token_type_ids():
550550 assert remaining ["labels" ] == [7 , 8 , 2 ]
551551
552552
553- def test_split_sample_by_num_tokens_with_token_type ():
554- """Test split_sample_by_num_tokens with token_type (alternative name)."""
553+ def test__split_sample_by_num_tokens_with_token_type ():
554+ """Test _split_sample_by_num_tokens with token_type (alternative name)."""
555555 sample = {
556556 "input_ids" : [0 , 5 , 6 , 7 , 8 , 2 ],
557557 "token_type" : [0 , 0 , 0 , 1 , 1 , 1 ],
558558 "labels" : [0 , 5 , 6 , 7 , 8 , 2 ],
559559 }
560- first , remaining = split_sample_by_num_tokens (sample , 3 )
560+ first , remaining = _split_sample_by_num_tokens (sample , 3 )
561561
562562 assert first ["input_ids" ] == [0 , 5 , 6 ]
563563 assert first ["token_type" ] == [0 , 0 , 0 ]
@@ -567,14 +567,14 @@ def test_split_sample_by_num_tokens_with_token_type():
567567 assert remaining ["labels" ] == [7 , 8 , 2 ]
568568
569569
570- def test_split_sample_by_num_tokens_with_tensors ():
571- """Test split_sample_by_num_tokens with torch tensors."""
570+ def test__split_sample_by_num_tokens_with_tensors ():
571+ """Test _split_sample_by_num_tokens with torch tensors."""
572572 sample = {
573573 "input_ids" : torch .tensor ([0 , 5 , 6 , 7 , 8 , 2 ]),
574574 "attention_mask" : torch .tensor ([1 , 1 , 1 , 1 , 1 , 1 ]),
575575 "labels" : torch .tensor ([0 , 5 , 6 , 7 , 8 , 2 ]),
576576 }
577- first , remaining = split_sample_by_num_tokens (sample , 3 )
577+ first , remaining = _split_sample_by_num_tokens (sample , 3 )
578578
579579 assert torch .equal (first ["input_ids" ], torch .tensor ([0 , 5 , 6 ]))
580580 assert torch .equal (first ["attention_mask" ], torch .tensor ([1 , 1 , 1 ]))
@@ -584,14 +584,14 @@ def test_split_sample_by_num_tokens_with_tensors():
584584 assert torch .equal (remaining ["labels" ], torch .tensor ([7 , 8 , 2 ]))
585585
586586
587- def test_split_sample_by_num_tokens_with_metadata ():
588- """Test split_sample_by_num_tokens preserves non-sequence fields."""
587+ def test__split_sample_by_num_tokens_with_metadata ():
588+ """Test _split_sample_by_num_tokens preserves non-sequence fields."""
589589 sample = {
590590 "input_ids" : [0 , 5 , 6 , 7 , 8 , 2 ],
591591 "labels" : [0 , 5 , 6 , 7 , 8 , 2 ],
592592 "metadata" : {"id" : 123 , "source" : "test" },
593593 }
594- first , remaining = split_sample_by_num_tokens (sample , 3 )
594+ first , remaining = _split_sample_by_num_tokens (sample , 3 )
595595
596596 # Sequence fields should be split
597597 assert first ["input_ids" ] == [0 , 5 , 6 ]
@@ -602,23 +602,23 @@ def test_split_sample_by_num_tokens_with_metadata():
602602 assert remaining ["metadata" ] == {"id" : 123 , "source" : "test" }
603603
604604
605- def test_split_sample_by_num_tokens_errors ():
606- """Test split_sample_by_num_tokens raises errors for invalid inputs."""
605+ def test__split_sample_by_num_tokens_errors ():
606+ """Test _split_sample_by_num_tokens raises errors for invalid inputs."""
607607 sample = {"input_ids" : [0 , 5 , 6 , 7 , 2 ]}
608608
609609 # num_tokens >= sample_length should raise ValueError
610610 with pytest .raises (ValueError , match = "num_tokens.*must be less than sample length" ):
611- split_sample_by_num_tokens (sample , 5 )
611+ _split_sample_by_num_tokens (sample , 5 )
612612
613613 with pytest .raises (ValueError , match = "num_tokens.*must be less than sample length" ):
614- split_sample_by_num_tokens (sample , 10 )
614+ _split_sample_by_num_tokens (sample , 10 )
615615
616616 # num_tokens <= 0 should raise ValueError
617617 with pytest .raises (ValueError , match = "num_tokens.*must be positive" ):
618- split_sample_by_num_tokens (sample , 0 )
618+ _split_sample_by_num_tokens (sample , 0 )
619619
620620 with pytest .raises (ValueError , match = "num_tokens.*must be positive" ):
621- split_sample_by_num_tokens (sample , - 1 )
621+ _split_sample_by_num_tokens (sample , - 1 )
622622
623623
624624def test_token_packing_dataset_with_split_samples ():
0 commit comments