Skip to content

Commit 7b0dd09

Browse files
authored
Stop over-tensoring input lines
1 parent 2a3b5b0 commit 7b0dd09

1 file changed

Lines changed: 1 addition & 1 deletion

File tree

fms_fsdp/utils/dataloader_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def causal_lm(data_seq, prompt_len=1):
2424
Perform causal language modeling by right-shifting the input sequence.
2525
Sets first prompt_len tokens to be ignored by the loss.
2626
"""
27-
data_seq = torch.tensor(data_seq, dtype=torch.int)
27+
data_seq = data_seq.int()
2828
t = data_seq.clone()[1:]
2929
data_seq = data_seq[:-1]
3030
t[:prompt_len] = -100

0 commit comments

Comments
 (0)