diff --git a/optimum/gptq/data.py b/optimum/gptq/data.py index 7e5fc0b43d..127e6676cd 100644 --- a/optimum/gptq/data.py +++ b/optimum/gptq/data.py @@ -122,9 +122,9 @@ def get_wikitext2(tokenizer: Any, seqlen: int, nsamples: int, split: str = "trai raise ImportError(DATASETS_IMPORT_ERROR.format("get_wikitext2")) if split == "train": - data = load_dataset("wikitext", "wikitext-2-raw-v1", split="train") + data = load_dataset("Salesforce/wikitext", "wikitext-2-raw-v1", split="train") elif split == "validation": - data = load_dataset("wikitext", "wikitext-2-raw-v1", split="test") + data = load_dataset("Salesforce/wikitext", "wikitext-2-raw-v1", split="test") # length of 288059 should be enough text = "".join([" \n" if s == "" else s for s in data["text"][:1000]]) diff --git a/optimum/gptq/eval.py b/optimum/gptq/eval.py index 3ae6e4d7bf..7b290ad2d9 100644 --- a/optimum/gptq/eval.py +++ b/optimum/gptq/eval.py @@ -9,7 +9,7 @@ def _perplexity(nlls, n_samples, seqlen): return torch.exp(torch.stack(nlls).sum() / (n_samples * seqlen)) # load and prepare dataset - data = load_dataset("wikitext", "wikitext-2-raw-v1", split="test") + data = load_dataset("Salesforce/wikitext", "wikitext-2-raw-v1", split="test") data = tokenizer("\n\n".join(data["text"]), return_tensors="pt") data = data.input_ids.to(model.device)