Skip to content

Commit 922c2ea

Browse files
committed
test: add unit tests for OfflineSupervisedDataset and EagleOfflineDataCollator
Cover dataset loading, label shifting, collator truncation/padding, and multi-sample batching to improve code coverage for the new offline speculative decoding classes. Signed-off-by: Ye Yu <yey@nvidia.com> Signed-off-by: Ye Yu <yeyu@nvidia.com>
1 parent ce1b339 commit 922c2ea

1 file changed

Lines changed: 102 additions & 0 deletions

File tree

tests/unit/torch/speculative/plugins/test_hf_speculative_offline.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434

3535
import modelopt.torch.speculative as mtsp
3636
from modelopt.torch.speculative.eagle.default_config import default_eagle_config
37+
from modelopt.torch.speculative.eagle.utils import EagleOfflineDataCollator, OfflineSupervisedDataset
3738

3839
_mock_scripts = types.ModuleType("scripts")
3940
_mock_ar = types.ModuleType("scripts.ar_validate")
@@ -161,3 +162,104 @@ def test_get_dummy_inputs_offline(eagle_model):
161162
hidden_size = eagle_model.config.hidden_size
162163
assert dummy["base_model_outputs"]["base_model_hidden_states"].shape[-1] == hidden_size
163164
assert dummy["base_model_outputs"]["base_model_input_embeds"].shape[-1] == hidden_size
165+
166+
167+
# ---------------------------------------------------------------------------
168+
# OfflineSupervisedDataset tests
169+
# ---------------------------------------------------------------------------
170+
171+
SEQ_LEN = 16
172+
HIDDEN_SIZE = 8
173+
174+
175+
def _make_offline_pt(path, seq_len=SEQ_LEN, hidden_size=HIDDEN_SIZE):
176+
"""Write a realistic .pt file matching the format expected by OfflineSupervisedDataset."""
177+
data = {
178+
"input_ids": torch.randint(0, 100, (seq_len,)),
179+
"hidden_states": torch.randn(seq_len, hidden_size),
180+
"aux_hidden_states": torch.randn(seq_len, hidden_size),
181+
"base_model_input_embeds": torch.randn(seq_len, hidden_size),
182+
}
183+
torch.save(data, path)
184+
return data
185+
186+
187+
def test_offline_dataset_len_and_getitem(tmp_path):
188+
"""OfflineSupervisedDataset should load .pt files and return proper keys."""
189+
n = 3
190+
files = []
191+
for i in range(n):
192+
p = tmp_path / f"sample_{i}.pt"
193+
_make_offline_pt(p)
194+
files.append(str(p))
195+
196+
ds = OfflineSupervisedDataset(files)
197+
assert len(ds) == n
198+
199+
item = ds[0]
200+
assert set(item.keys()) == {
201+
"input_ids",
202+
"base_model_hidden_states",
203+
"aux_hidden_states",
204+
"attention_mask",
205+
"loss_mask",
206+
"labels",
207+
}
208+
assert item["input_ids"].shape == (SEQ_LEN,)
209+
assert item["attention_mask"].shape == (SEQ_LEN,)
210+
assert item["labels"].shape == (SEQ_LEN,)
211+
212+
213+
def test_offline_dataset_labels_shift(tmp_path):
214+
"""Labels should be input_ids shifted left by 1."""
215+
p = tmp_path / "sample.pt"
216+
orig = _make_offline_pt(p)
217+
ds = OfflineSupervisedDataset([str(p)])
218+
item = ds[0]
219+
# labels[:-1] should equal input_ids[1:]
220+
assert torch.equal(item["labels"][:-1], orig["input_ids"][1:])
221+
222+
223+
# ---------------------------------------------------------------------------
224+
# EagleOfflineDataCollator tests
225+
# ---------------------------------------------------------------------------
226+
227+
228+
def test_collator_truncates(tmp_path):
229+
"""Collator should truncate sequences longer than train_len."""
230+
train_len = 8
231+
p = tmp_path / "sample.pt"
232+
_make_offline_pt(p, seq_len=SEQ_LEN) # SEQ_LEN > train_len
233+
ds = OfflineSupervisedDataset([str(p)])
234+
collator = EagleOfflineDataCollator(train_len=train_len)
235+
batch = collator([ds[0]])
236+
assert batch["input_ids"].shape == (1, train_len)
237+
assert batch["base_model_outputs"]["base_model_hidden_states"].shape[1] == train_len
238+
239+
240+
def test_collator_pads(tmp_path):
241+
"""Collator should pad sequences shorter than train_len."""
242+
train_len = 32
243+
p = tmp_path / "sample.pt"
244+
_make_offline_pt(p, seq_len=SEQ_LEN) # SEQ_LEN < train_len
245+
ds = OfflineSupervisedDataset([str(p)])
246+
collator = EagleOfflineDataCollator(train_len=train_len)
247+
batch = collator([ds[0]])
248+
assert batch["input_ids"].shape == (1, train_len)
249+
# Padded region should be zeros
250+
assert (batch["input_ids"][0, SEQ_LEN:] == 0).all()
251+
252+
253+
def test_collator_batches_multiple(tmp_path):
254+
"""Collator should stack multiple samples into a batch."""
255+
train_len = SEQ_LEN
256+
files = []
257+
for i in range(4):
258+
p = tmp_path / f"sample_{i}.pt"
259+
_make_offline_pt(p)
260+
files.append(str(p))
261+
ds = OfflineSupervisedDataset(files)
262+
collator = EagleOfflineDataCollator(train_len=train_len)
263+
batch = collator([ds[i] for i in range(4)])
264+
assert batch["input_ids"].shape == (4, train_len)
265+
assert batch["base_model_outputs"]["base_model_hidden_states"].shape == (4, train_len, HIDDEN_SIZE)

0 commit comments

Comments
 (0)