diff --git a/src/speculators/train/data.py b/src/speculators/train/data.py index b1e05c91e..dd96d0d18 100644 --- a/src/speculators/train/data.py +++ b/src/speculators/train/data.py @@ -191,6 +191,7 @@ def __init__( hidden_states_dtype: The dtype of the hidden states. """ self.data = load_from_disk(datapath) + self.start_file_idx = 0 if split_ratio == 1.0: pass elif 1.0 > split_ratio > 0: diff --git a/tests/unit/train/test_data.py b/tests/unit/train/test_data.py index 4142568d1..c86bc49ca 100644 --- a/tests/unit/train/test_data.py +++ b/tests/unit/train/test_data.py @@ -4,9 +4,11 @@ from pathlib import Path import torch +from datasets import Dataset from speculators.models.eagle3.data import shift_batch from speculators.train.data import ( + ArrowDataset, SampleFileDataset, create_collate_fn, standardize_data_v1, @@ -434,3 +436,26 @@ def test_dataset_fallback_when_sample_lengths_json_malformed(tmp_path: Path): file_list = sorted([str(f) for f in tmp_path.glob("data_*.pt")]) dataset = SampleFileDataset(max_len=50, file_list=file_list) assert len(dataset.approx_lengths) == 2 + + +def test_arrow_dataset_default_split_ratio_does_not_crash(tmp_path: Path): + """ArrowDataset with default split_ratio=1.0 should support indexing.""" + ds = Dataset.from_dict( + { + "input_ids": [[1, 2, 3]], + "loss_mask": [[1, 1, 1]], + "seq_len": [3], + } + ) + ds.save_to_disk(str(tmp_path / "data")) + (tmp_path / "data" / "hidden_states").mkdir() + + arrow_ds = ArrowDataset( + max_len=128, + datapath=str(tmp_path / "data"), + on_missing="skip", + ) + + # Should not raise AttributeError + assert arrow_ds._map_to_file_idx(0) == 0 + assert arrow_ds._map_to_file_idx(5) == 5