We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 2a3b5b0 commit 7b0dd09Copy full SHA for 7b0dd09
1 file changed
fms_fsdp/utils/dataloader_utils.py
@@ -24,7 +24,7 @@ def causal_lm(data_seq, prompt_len=1):
24
Perform causal language modeling by right-shifting the input sequence.
25
Sets first prompt_len tokens to be ignored by the loss.
26
"""
27
- data_seq = torch.tensor(data_seq, dtype=torch.int)
+ data_seq = data_seq.int()
28
t = data_seq.clone()[1:]
29
data_seq = data_seq[:-1]
30
t[:prompt_len] = -100
0 commit comments