33print ("initializing" )
44
55import os
6+ os .environ ["WANDB_DISABLED" ] = "true"
7+ os .environ ["DS_ACCELERATOR" ] = "cpu"
8+ os .environ ["HF_DATASETS_DISABLE_CACHING" ] = "1"
9+
610from datasets import DatasetDict , disable_progress_bar
711from datasets .arrow_dataset import Dataset
812from transformers .utils import logging
913from transformers import GPT2Config , GPT2LMHeadModel
1014from transformers import AutoTokenizer
11- # https://huggingface.co/docs/transformers/en/main_classes/data_collator
1215from transformers import DataCollatorForLanguageModeling , Trainer , TrainingArguments
1316from transformers import pipeline
1417from difflib import SequenceMatcher
1518import torch
1619
17- os .environ ["WANDB_DISABLED" ] = "true"
18- os .environ ["DS_ACCELERATOR" ] = "cpu"
1920logging .set_verbosity (logging .ERROR )
2021
2122length = 4
2425input_list = input .split ()
2526
2627disable_progress_bar ()
27- ds = DatasetDict ({ "train" : Dataset .from_list ([{"text" :input }]),
28- #Dataset.from_generator(gen),
29- "valid" : Dataset .from_list ([{"text" :input }])
30- })
28+ import pyarrow as pa
29+ train_data = pa .Table .from_pydict ({"text" : [input ]})
30+ valid_data = pa .Table .from_pydict ({"text" : [input ]})
31+ ds = DatasetDict ({
32+ "train" : Dataset (train_data , fingerprint = "train" ),
33+ "valid" : Dataset (valid_data , fingerprint = "valid" )
34+ })
3135
3236tokenizer = AutoTokenizer .from_pretrained ("gpt2" )
33-
37+
3438tokenizer .pad_token = tokenizer .eos_token
3539
3640# Tokenize the ds
@@ -52,10 +56,11 @@ def tokenize_function(examples):
5256# https://huggingface.co/docs/transformers/main_classes/trainer#transformers.TrainingArguments
5357training_args = TrainingArguments (
5458 run_name = "test" ,
55- num_train_epochs = 4 ,
59+ num_train_epochs = 20 ,
5660 output_dir = "./results" ,
5761 overwrite_output_dir = True ,
5862 eval_strategy = "steps" ,
63+ use_cpu = True ,
5964)
6065
6166trainer = Trainer (
@@ -71,7 +76,7 @@ def tokenize_function(examples):
7176
7277print ("inference" )
7378gen = pipeline ("text-generation" , model = model , tokenizer = tokenizer )
74- output = gen (input_list [0 ], max_length = length , num_return_sequences = 1 )[0 ]['generated_text' ]
79+ output = gen (input_list [0 ], max_new_tokens = length - 1 , num_return_sequences = 1 )[0 ]['generated_text' ]
7580print (SequenceMatcher (None , input , output ).ratio (), output )
7681tokenizer .save_pretrained ('trained_model' )
7782
0 commit comments