Skip to content

Commit 55de363

Browse files
Refactor TrainingConfig to improve initialization and cache directory handling
1 parent 84cf93f commit 55de363

File tree

4 files changed

+22
-27
lines changed

4 files changed

+22
-27
lines changed

v4n1_Generator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
dataset_ranges = [10, 100, 1000, 5000, 10000, 17500, 25000]
4444

4545
for dr in dataset_ranges:
46-
dataset_path = f"{cfg.DATA_CACHE_DIR}/dataset_{dr}.pt"
46+
dataset_path = f"{cfg.DATASET_CACHE_DIR}/dataset_{dr}.pt"
4747

4848
# Skip if already exists
4949
if os.path.exists(dataset_path):
@@ -57,7 +57,7 @@
5757
smaller_existing.sort(reverse=True)
5858

5959
for sr in smaller_existing:
60-
candidate_path = f"{cfg.DATA_CACHE_DIR}/dataset_{sr}.pt"
60+
candidate_path = f"{cfg.DATASET_CACHE_DIR}/dataset_{sr}.pt"
6161
if os.path.exists(candidate_path):
6262
data = torch.load(candidate_path, map_location="cpu")
6363
base_texts, base_labels = data["texts"], data["labels"]

v4n1_Trainer.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def train(config: TrainingConfig, resources: dict):
3838
generate = DataGen(cfg=config)
3939

4040
# Generate dataset
41-
dataset_path = f"{config.DATA_CACHE_DIR}/dataset_{config.DATASET_SIZE}.pt"
41+
dataset_path = f"{config.DATASET_CACHE_DIR}/dataset_{config.DATASET_SIZE}.pt"
4242
if os.path.exists(dataset_path):
4343
data = torch.load(dataset_path)
4444
texts, labels = data["texts"], data["labels"]
@@ -86,9 +86,6 @@ def train(config: TrainingConfig, resources: dict):
8686
# ---------------- CONFIG ----------------
8787
cfg = TrainingConfig()
8888
cfg.update({
89-
# Model / caching / logging
90-
"MODEL_NAME": "Model_Sense.4n1", # Name of the model for identification and caching
91-
9289
# Training parameters
9390
"BATCH_SIZE": 32, # Number of samples per training batch
9491
"MAX_EPOCHS": 35, # Maximum number of training epochs
@@ -101,8 +98,6 @@ def train(config: TrainingConfig, resources: dict):
10198
"LR_DECAY": 0.9, # Factor to multiply learning rate after decay
10299
"AUTO_CONTINUE": False, # Whether to automatically continue training and ignore EARLY_STOPPING_PATIENCE
103100

104-
# Dataset / data generation
105-
"DATASET_SIZE": 25000,
106101
# Number of samples to generate for training (not the same as for the training rounds themselves)
107102
"TEXT_MAX_LEN": 128, # Maximum length of generated text samples
108103
"TEXT_MAX_LEN_JUMP_RANGE": 10, # Range for random variation in text length

vulnscan/config.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -4,21 +4,25 @@
44

55

66
class TrainingConfig:
7-
def __init__(self, model_name: str = "Model_Sense.4n1"):
7+
def __init__(self):
8+
"""
9+
Configuration class for training settings and hyperparameters.
10+
11+
You must call the update method and set MODEL_NAME.
12+
"""
13+
814
# Model / caching / logging
9-
self.MODEL_NAME = model_name
15+
self.MODEL_NAME = None
16+
self.writer = None
17+
self.LOG_FILE = None
18+
self.EMBED_CACHE_DIR = None
19+
1020
self.CACHE_DIR = os.path.join(os.getcwd(), "cache")
21+
self.DATASET_CACHE_DIR = f"{self.CACHE_DIR}/dataset"
1122

1223
existing_rounds = self.__get_existing_rounds(self.CACHE_DIR) # Auto-increment round based on existing folders
1324
self.MODEL_ROUND = max(existing_rounds) + 1 if existing_rounds else 1
1425

15-
self.LOG_FILE = f"{self.CACHE_DIR}/{self.MODEL_NAME}/training.log"
16-
self.EMBED_CACHE_DIR = f"{self.CACHE_DIR}/{self.MODEL_NAME}/round_{self.MODEL_ROUND}/embeddings"
17-
self.DATA_CACHE_DIR = f"{self.CACHE_DIR}/dataset"
18-
19-
# TensorBoard
20-
self.writer = SummaryWriter(log_dir=f"{self.CACHE_DIR}/{self.MODEL_NAME}/round_{self.MODEL_ROUND}/tensorboard_logs")
21-
2226
# Training parameters
2327
self.BATCH_SIZE: int = 16
2428
self.MAX_EPOCHS: int = 35
@@ -53,11 +57,6 @@ def __init__(self, model_name: str = "Model_Sense.4n1"):
5357
self.DEVICE: str = "cuda" if torch.cuda.is_available() else "cpu"
5458
self.RAM_THRESHOLD: float = 0.85
5559

56-
# Create necessary folders
57-
os.makedirs(self.CACHE_DIR, exist_ok=True)
58-
os.makedirs(self.EMBED_CACHE_DIR, exist_ok=True)
59-
os.makedirs(self.DATA_CACHE_DIR, exist_ok=True)
60-
6160
@staticmethod
6261
def __get_existing_rounds(cache_dir: str) -> list[int]:
6362
"""
@@ -95,8 +94,9 @@ def update(self, updates):
9594
if 'MODEL_NAME' in dict(items) or 'CACHE_DIR' in dict(items) or 'MODEL_ROUND' in dict(items):
9695
self.LOG_FILE = f"{self.CACHE_DIR}/{self.MODEL_NAME}/training.log"
9796
self.EMBED_CACHE_DIR = f"{self.CACHE_DIR}/{self.MODEL_NAME}/round_{self.MODEL_ROUND}/embeddings"
98-
self.writer = SummaryWriter(log_dir=f"{self.CACHE_DIR}/{self.MODEL_NAME}/round_{self.MODEL_ROUND}/tensorboard_logs")
97+
self.writer = SummaryWriter(
98+
log_dir=f"{self.CACHE_DIR}/{self.MODEL_NAME}/round_{self.MODEL_ROUND}/tensorboard_logs")
9999
os.makedirs(self.EMBED_CACHE_DIR, exist_ok=True)
100-
if 'CACHE_DIR' in dict(items):
101-
self.DATA_CACHE_DIR = f"{self.CACHE_DIR}/dataset"
102-
os.makedirs(self.DATA_CACHE_DIR, exist_ok=True)
100+
if 'DATASET_CACHE_DIR' in dict(items):
101+
self.DATASET_CACHE_DIR = f"{self.CACHE_DIR}/dataset"
102+
os.makedirs(self.DATASET_CACHE_DIR, exist_ok=True)

vulnscan/genData.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def dataset(self, gpt_tokenizer: PreTrainedTokenizerFast, gpt_model: PreTrainedM
8484
labels.append(int(sensitive))
8585
except KeyboardInterrupt:
8686
sys.exit(f"\nDataset generation interrupted by user early. Premature dataset exit.")
87-
torch.save({"texts": dataset, "labels": labels}, f"{self.cfg.DATA_CACHE_DIR}/dataset_{self.cfg.DATASET_SIZE}.pt")
87+
torch.save({"texts": dataset, "labels": labels}, f"{self.cfg.DATASET_CACHE_DIR}/dataset_{self.cfg.DATASET_SIZE}.pt")
8888
return dataset, labels
8989

9090
# ---------------- EMBEDDINGS ----------------

0 commit comments

Comments
 (0)