Skip to content

Commit d931f3f

Browse files
authored
Add namespace when loading wikitext dataset (#2446)
add namespace when loading wikitext dataset
1 parent 4035973 commit d931f3f

2 files changed

Lines changed: 3 additions & 3 deletions

File tree

optimum/gptq/data.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,9 +122,9 @@ def get_wikitext2(tokenizer: Any, seqlen: int, nsamples: int, split: str = "trai
122122
raise ImportError(DATASETS_IMPORT_ERROR.format("get_wikitext2"))
123123

124124
if split == "train":
125-
data = load_dataset("wikitext", "wikitext-2-raw-v1", split="train")
125+
data = load_dataset("Salesforce/wikitext", "wikitext-2-raw-v1", split="train")
126126
elif split == "validation":
127-
data = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
127+
data = load_dataset("Salesforce/wikitext", "wikitext-2-raw-v1", split="test")
128128
# length of 288059 should be enough
129129
text = "".join([" \n" if s == "" else s for s in data["text"][:1000]])
130130

optimum/gptq/eval.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ def _perplexity(nlls, n_samples, seqlen):
99
return torch.exp(torch.stack(nlls).sum() / (n_samples * seqlen))
1010

1111
# load and prepare dataset
12-
data = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
12+
data = load_dataset("Salesforce/wikitext", "wikitext-2-raw-v1", split="test")
1313
data = tokenizer("\n\n".join(data["text"]), return_tensors="pt")
1414
data = data.input_ids.to(model.device)
1515

0 commit comments

Comments
 (0)