Skip to content

Commit 2c2ec56

Browse files
Refactor v4n1_Trainer.py to initialize resources separately, streamline training function, and enhance logging
1 parent 9c4d6f0 commit 2c2ec56

2 files changed

Lines changed: 36 additions & 14 deletions

File tree

.idea/dictionaries/project.xml

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

v4n1_Trainer.py

Lines changed: 35 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -9,18 +9,32 @@
99
from vulnscan import log, Train, plot_training, SimpleNN, EmbeddingDataset, TrainingConfig, DataGen
1010

1111

12-
# ---------------- MAIN ----------------
13-
def train(config: TrainingConfig):
14-
log(message="Loading GPT-Neo model for text generation...", cfg=config)
12+
# ---------------- INIT ----------------
13+
def init(config: TrainingConfig) -> dict:
14+
"""Initialize static, config-free resources (only once)."""
15+
log("Loading GPT-Neo tokenizer/model (static init)...", cfg=config)
1516
gpt_tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neo-1.3B")
16-
gpt_model = AutoModelForCausalLM.from_pretrained("EleutherAI/gpt-neo-1.3B").to(config.DEVICE)
17+
gpt_model = AutoModelForCausalLM.from_pretrained("EleutherAI/gpt-neo-1.3B")
1718
if gpt_tokenizer.pad_token is None:
1819
gpt_tokenizer.pad_token = gpt_tokenizer.eos_token
1920

20-
log(message="Loading MiniLM for embeddings...", cfg=config)
21+
log("Loading MiniLM for embeddings (static init)...", cfg=config)
2122
embed_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
2223

23-
log(message="Starting advanced self-training sensitive data classifier...", cfg=config)
24+
return {
25+
"gpt_tokenizer": gpt_tokenizer,
26+
"gpt_model": gpt_model,
27+
"embed_model": embed_model,
28+
}
29+
30+
31+
# ---------------- TRAIN ----------------
32+
def train(config: TrainingConfig, resources: dict):
33+
gpt_tokenizer = resources["gpt_tokenizer"]
34+
gpt_model = resources["gpt_model"].to(config.DEVICE) # attach to device here
35+
embed_model = resources["embed_model"]
36+
37+
log("Init DataGen with config...", cfg=config, silent=True)
2438
generate = DataGen(cfg=config)
2539

2640
# Generate dataset
@@ -39,11 +53,11 @@ def train(config: TrainingConfig):
3953
val_texts, val_labels = texts[train_split:val_split], labels[train_split:val_split]
4054
test_texts, test_labels = texts[val_split:], labels[val_split:]
4155

42-
log(message="Generating test embeddings...", cfg=config)
56+
log("Generating test embeddings...", cfg=config)
4357
generate.embeddings(embed_model=embed_model, texts=test_texts, labels=test_labels, split="test")
44-
log(message="Generating train embeddings...", cfg=config)
58+
log("Generating train embeddings...", cfg=config)
4559
generate.embeddings(embed_model=embed_model, texts=train_texts, labels=train_labels, split="train")
46-
log(message="Generating validation embeddings...", cfg=config)
60+
log("Generating validation embeddings...", cfg=config)
4761
generate.embeddings(embed_model=embed_model, texts=val_texts, labels=val_labels, split="validation")
4862

4963
train_dataset = EmbeddingDataset(config.EMBED_CACHE_DIR)
@@ -59,14 +73,17 @@ def train(config: TrainingConfig):
5973
# Plot + save history for each loop
6074
for i, history in enumerate(history_loops):
6175
plot_training(cfg=config, history_loops=history_loops)
62-
with open(f"{config.CACHE_DIR}/{config.MODEL_NAME}/round_{config.MODEL_ROUND}/training_history_loop{i + 1}.json", "w") as f:
76+
with open(
77+
f"{config.CACHE_DIR}/{config.MODEL_NAME}/round_{config.MODEL_ROUND}/training_history_loop{i + 1}.json",
78+
"w") as f:
6379
json.dump(history, f)
6480

65-
log(message="Training complete. All data, plots, and model saved.", cfg=config)
81+
log("Training complete. All data, plots, and model saved.", cfg=config)
6682

6783

6884
if __name__ == "__main__":
6985
# noinspection DuplicatedCode
86+
# ---------------- CONFIG ----------------
7087
cfg = TrainingConfig()
7188
cfg.update({
7289
# Model / caching / logging
@@ -85,7 +102,8 @@ def train(config: TrainingConfig):
85102
"AUTO_CONTINUE": False, # Whether to automatically continue training and ignore EARLY_STOPPING_PATIENCE
86103

87104
# Dataset / data generation
88-
"DATASET_SIZE": 25000, # Number of samples to generate for training (not the same as for the training rounds themselves)
105+
"DATASET_SIZE": 25000,
106+
# Number of samples to generate for training (not the same as for the training rounds themselves)
89107
"TEXT_MAX_LEN": 128, # Maximum length of generated text samples
90108
"TEXT_MAX_LEN_JUMP_RANGE": 10, # Range for random variation in text length
91109
"VAL_SPLIT": 0.85, # Fraction of dataset used for training + validation (rest for testing)
@@ -102,7 +120,9 @@ def train(config: TrainingConfig):
102120
# Device / system
103121
"RAM_THRESHOLD": 0.85 # Maximum allowed fraction of RAM usage before halting generation and offloading
104122
})
123+
train_init = init(cfg)
105124

125+
# ----------------- RUN ------------------
106126
available_dataset = [10, 100, 1000, 5000, 10000, 17500, 25000]
107127
for dataset in available_dataset:
108128
if dataset <= 1000:
@@ -117,7 +137,8 @@ def train(config: TrainingConfig):
117137
cfg.update({
118138
# Model / caching / logging
119139
"MODEL_NAME": f"Model_{name}.4n1", # Name of the model for identification and caching
120-
"DATASET_SIZE": dataset, # Number of samples to generate for training (not the same as for the training rounds themselves)
140+
"DATASET_SIZE": dataset,
141+
# Number of samples to generate for training (not the same as for the training rounds themselves)
121142
})
122143
log(message=f"Training {name} with {dataset} dataset...", cfg=cfg)
123-
train(config=cfg)
144+
train(config=cfg, resources=train_init)

0 commit comments

Comments
 (0)