Skip to content

Commit 265556c

Browse files
committed
fix ./train-transformers.py
1 parent dcb28fb commit 265556c

File tree

1 file changed

+15
-10
lines changed

1 file changed

+15
-10
lines changed

llm/train-transformers.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,20 @@
33
print("initializing")
44

55
import os
6+
os.environ["WANDB_DISABLED"] = "true"
7+
os.environ["DS_ACCELERATOR"] = "cpu"
8+
os.environ["HF_DATASETS_DISABLE_CACHING"] = "1"
9+
610
from datasets import DatasetDict, disable_progress_bar
711
from datasets.arrow_dataset import Dataset
812
from transformers.utils import logging
913
from transformers import GPT2Config, GPT2LMHeadModel
1014
from transformers import AutoTokenizer
11-
# https://huggingface.co/docs/transformers/en/main_classes/data_collator
1215
from transformers import DataCollatorForLanguageModeling, Trainer, TrainingArguments
1316
from transformers import pipeline
1417
from difflib import SequenceMatcher
1518
import torch
1619

17-
os.environ["WANDB_DISABLED"] = "true"
18-
os.environ["DS_ACCELERATOR"] = "cpu"
1920
logging.set_verbosity(logging.ERROR)
2021

2122
length = 4
@@ -24,13 +25,16 @@
2425
input_list = input.split()
2526

2627
disable_progress_bar()
27-
ds = DatasetDict({ "train": Dataset.from_list([{"text":input}]),
28-
#Dataset.from_generator(gen),
29-
"valid": Dataset.from_list([{"text":input}])
30-
})
28+
import pyarrow as pa
29+
train_data = pa.Table.from_pydict({"text": [input]})
30+
valid_data = pa.Table.from_pydict({"text": [input]})
31+
ds = DatasetDict({
32+
"train": Dataset(train_data, fingerprint="train"),
33+
"valid": Dataset(valid_data, fingerprint="valid")
34+
})
3135

3236
tokenizer= AutoTokenizer.from_pretrained("gpt2")
33-
37+
3438
tokenizer.pad_token = tokenizer.eos_token
3539

3640
# Tokenize the ds
@@ -52,10 +56,11 @@ def tokenize_function(examples):
5256
# https://huggingface.co/docs/transformers/main_classes/trainer#transformers.TrainingArguments
5357
training_args = TrainingArguments(
5458
run_name="test",
55-
num_train_epochs=4,
59+
num_train_epochs=20,
5660
output_dir="./results",
5761
overwrite_output_dir=True,
5862
eval_strategy="steps",
63+
use_cpu=True,
5964
)
6065

6166
trainer = Trainer(
@@ -71,7 +76,7 @@ def tokenize_function(examples):
7176

7277
print("inference")
7378
gen = pipeline("text-generation", model=model, tokenizer=tokenizer)
74-
output = gen(input_list[0], max_length=length, num_return_sequences=1)[0]['generated_text']
79+
output = gen(input_list[0], max_new_tokens=length - 1, num_return_sequences=1)[0]['generated_text']
7580
print(SequenceMatcher(None, input, output).ratio(), output)
7681
tokenizer.save_pretrained('trained_model')
7782

0 commit comments

Comments
 (0)