99from vulnscan import log , Train , plot_training , SimpleNN , EmbeddingDataset , TrainingConfig , DataGen
1010
1111
12- # ---------------- MAIN ----------------
13- def train (config : TrainingConfig ):
14- log (message = "Loading GPT-Neo model for text generation..." , cfg = config )
12+ # ---------------- INIT ----------------
13+ def init (config : TrainingConfig ) -> dict :
14+ """Initialize static, config-free resources (only once)."""
15+ log ("Loading GPT-Neo tokenizer/model (static init)..." , cfg = config )
1516 gpt_tokenizer = AutoTokenizer .from_pretrained ("EleutherAI/gpt-neo-1.3B" )
16- gpt_model = AutoModelForCausalLM .from_pretrained ("EleutherAI/gpt-neo-1.3B" ). to ( config . DEVICE )
17+ gpt_model = AutoModelForCausalLM .from_pretrained ("EleutherAI/gpt-neo-1.3B" )
1718 if gpt_tokenizer .pad_token is None :
1819 gpt_tokenizer .pad_token = gpt_tokenizer .eos_token
1920
20- log (message = "Loading MiniLM for embeddings..." , cfg = config )
21+ log ("Loading MiniLM for embeddings (static init) ..." , cfg = config )
2122 embed_model = SentenceTransformer ("sentence-transformers/all-MiniLM-L6-v2" )
2223
23- log (message = "Starting advanced self-training sensitive data classifier..." , cfg = config )
24+ return {
25+ "gpt_tokenizer" : gpt_tokenizer ,
26+ "gpt_model" : gpt_model ,
27+ "embed_model" : embed_model ,
28+ }
29+
30+
31+ # ---------------- TRAIN ----------------
32+ def train (config : TrainingConfig , resources : dict ):
33+ gpt_tokenizer = resources ["gpt_tokenizer" ]
34+ gpt_model = resources ["gpt_model" ].to (config .DEVICE ) # attach to device here
35+ embed_model = resources ["embed_model" ]
36+
37+ log ("Init DataGen with config..." , cfg = config , silent = True )
2438 generate = DataGen (cfg = config )
2539
2640 # Generate dataset
@@ -39,11 +53,11 @@ def train(config: TrainingConfig):
3953 val_texts , val_labels = texts [train_split :val_split ], labels [train_split :val_split ]
4054 test_texts , test_labels = texts [val_split :], labels [val_split :]
4155
42- log (message = "Generating test embeddings..." , cfg = config )
56+ log ("Generating test embeddings..." , cfg = config )
4357 generate .embeddings (embed_model = embed_model , texts = test_texts , labels = test_labels , split = "test" )
44- log (message = "Generating train embeddings..." , cfg = config )
58+ log ("Generating train embeddings..." , cfg = config )
4559 generate .embeddings (embed_model = embed_model , texts = train_texts , labels = train_labels , split = "train" )
46- log (message = "Generating validation embeddings..." , cfg = config )
60+ log ("Generating validation embeddings..." , cfg = config )
4761 generate .embeddings (embed_model = embed_model , texts = val_texts , labels = val_labels , split = "validation" )
4862
4963 train_dataset = EmbeddingDataset (config .EMBED_CACHE_DIR )
@@ -59,14 +73,17 @@ def train(config: TrainingConfig):
5973 # Plot + save history for each loop
6074 for i , history in enumerate (history_loops ):
6175 plot_training (cfg = config , history_loops = history_loops )
62- with open (f"{ config .CACHE_DIR } /{ config .MODEL_NAME } /round_{ config .MODEL_ROUND } /training_history_loop{ i + 1 } .json" , "w" ) as f :
76+ with open (
77+ f"{ config .CACHE_DIR } /{ config .MODEL_NAME } /round_{ config .MODEL_ROUND } /training_history_loop{ i + 1 } .json" ,
78+ "w" ) as f :
6379 json .dump (history , f )
6480
65- log (message = "Training complete. All data, plots, and model saved." , cfg = config )
81+ log ("Training complete. All data, plots, and model saved." , cfg = config )
6682
6783
6884if __name__ == "__main__" :
6985 # noinspection DuplicatedCode
86+ # ---------------- CONFIG ----------------
7087 cfg = TrainingConfig ()
7188 cfg .update ({
7289 # Model / caching / logging
@@ -85,7 +102,8 @@ def train(config: TrainingConfig):
85102 "AUTO_CONTINUE" : False , # Whether to automatically continue training and ignore EARLY_STOPPING_PATIENCE
86103
87104 # Dataset / data generation
88- "DATASET_SIZE" : 25000 , # Number of samples to generate for training (not the same as for the training rounds themselves)
105+ "DATASET_SIZE" : 25000 ,
106+ # Number of samples to generate for training (not the same as for the training rounds themselves)
89107 "TEXT_MAX_LEN" : 128 , # Maximum length of generated text samples
90108 "TEXT_MAX_LEN_JUMP_RANGE" : 10 , # Range for random variation in text length
91109 "VAL_SPLIT" : 0.85 , # Fraction of dataset used for training + validation (rest for testing)
@@ -102,7 +120,9 @@ def train(config: TrainingConfig):
102120 # Device / system
103121 "RAM_THRESHOLD" : 0.85 # Maximum allowed fraction of RAM usage before halting generation and offloading
104122 })
123+ train_init = init (cfg )
105124
125+ # ----------------- RUN ------------------
106126 available_dataset = [10 , 100 , 1000 , 5000 , 10000 , 17500 , 25000 ]
107127 for dataset in available_dataset :
108128 if dataset <= 1000 :
@@ -117,7 +137,8 @@ def train(config: TrainingConfig):
117137 cfg .update ({
118138 # Model / caching / logging
119139 "MODEL_NAME" : f"Model_{ name } .4n1" , # Name of the model for identification and caching
120- "DATASET_SIZE" : dataset , # Number of samples to generate for training (not the same as for the training rounds themselves)
140+ "DATASET_SIZE" : dataset ,
141+ # Number of samples to generate for training (not the same as for the training rounds themselves)
121142 })
122143 log (message = f"Training { name } with { dataset } dataset..." , cfg = cfg )
123- train (config = cfg )
144+ train (config = cfg , resources = train_init )
0 commit comments