Skip to content

Commit 2007f9c

Browse files
committed
./simple_train.py
1 parent 265556c commit 2007f9c

File tree

1 file changed

+90
-0
lines changed

1 file changed

+90
-0
lines changed

llm/simple_train.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
#!/usr/bin/env python3
2+
3+
print("initializing")
4+
5+
import os
6+
os.environ["WANDB_DISABLED"] = "true"
7+
os.environ["DS_ACCELERATOR"] = "cpu"
8+
os.environ["HF_DATASETS_DISABLE_CACHING"] = "1"
9+
10+
from datasets import DatasetDict, disable_progress_bar
11+
from datasets.arrow_dataset import Dataset
12+
from transformers.utils import logging
13+
14+
logging.set_verbosity(logging.ERROR)
15+
16+
length = 4
17+
input = " ".join(str(i) for i in range(length))
18+
print("input:", input)
19+
input_list = input.split()
20+
21+
disable_progress_bar()
22+
import pyarrow as pa
23+
train_data = pa.Table.from_pydict({"text": [input]})
24+
valid_data = pa.Table.from_pydict({"text": [input]})
25+
ds = DatasetDict({
26+
"train": Dataset(train_data, fingerprint="train"),
27+
"valid": Dataset(valid_data, fingerprint="valid")
28+
})
29+
30+
ds.save_to_disk("test_dataset")
31+
from transformers import AutoTokenizer
32+
33+
tokenizer = AutoTokenizer.from_pretrained("gpt2")
34+
35+
tokenizer.pad_token = tokenizer.eos_token
36+
37+
# Tokenize the ds
38+
def tokenize_function(examples):
39+
return tokenizer(examples["text"], max_length=length)
40+
41+
tokenized_datasets = ds.map(tokenize_function, batched=True, remove_columns=["text"])
42+
43+
from transformers import GPT2Config, GPT2LMHeadModel
44+
45+
# https://huggingface.co/docs/transformers/model_doc/gpt2#transformers.GPT2Config
46+
config = GPT2Config( vocab_size=tokenizer.vocab_size, n_positions=128, n_ctx=128,
47+
n_embd=256, n_layer=2 * length, n_head=length)
48+
49+
print("loading model")
50+
51+
model = GPT2LMHeadModel(config)
52+
53+
from transformers import DataCollatorForLanguageModeling, Trainer, TrainingArguments
54+
55+
dc = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
56+
57+
# https://huggingface.co/docs/transformers/main_classes/trainer#transformers.TrainingArguments
58+
training_args = TrainingArguments(
59+
run_name="test",
60+
num_train_epochs=20,
61+
use_cpu=True,
62+
output_dir="./results",
63+
overwrite_output_dir=True,
64+
eval_strategy="steps",
65+
)
66+
67+
from transformers import pipeline
68+
69+
trainer = Trainer(
70+
model=model,
71+
args=training_args,
72+
train_dataset=tokenized_datasets["train"],
73+
eval_dataset=tokenized_datasets["valid"],
74+
data_collator=dc,
75+
)
76+
print("training")
77+
t = trainer.train()
78+
print(t)
79+
#eval_results = trainer.evaluate()
80+
#print(f"Perplexity: {eval_results['eval_loss']}")
81+
82+
print("inference")
83+
gen = pipeline("text-generation", model=model, tokenizer=tokenizer)
84+
output = gen(input_list[0], max_new_tokens=length - 1, num_return_sequences=1)[0]['generated_text']
85+
print("output:", output)
86+
if input == output:
87+
print('pass')
88+
else:
89+
print('fail')
90+
exit(1)

0 commit comments

Comments
 (0)