Skip to content

Commit d9baee1

Browse files
Refactor config and trainer scripts to expose get_existing_rounds method and streamline model round initialization
1 parent 8d85cb0 commit d9baee1

2 files changed

Lines changed: 7 additions & 6 deletions

File tree

v4n1_Trainer.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -133,12 +133,13 @@ def train(config: TrainingConfig, resources: dict):
133133
name = "Sense"
134134
else:
135135
name = "SenseMacro"
136-
136+
existing_rounds = cfg.get_existing_rounds(cfg.CACHE_DIR) # Auto-increment round based on existing folders
137+
model_round = max(existing_rounds) + 1 if existing_rounds else 1
137138
cfg.update({
138139
# Model / caching / logging
139140
"MODEL_NAME": f"Model_{name}.4n1", # Name of the model for identification and caching
140-
"DATASET_SIZE": dataset,
141-
# Number of samples to generate for training (not the same as for the training rounds themselves)
141+
"DATASET_SIZE": dataset, # Number of samples to generate for training (not the same as for the training rounds themselves)
142+
"MODEL_ROUND": model_round # Current training round (auto-incremented)
142143
})
143-
log(message=f"Training {name} with {dataset} dataset...", cfg=cfg)
144+
log(message=f"Training 'Model_{name}.4n1/round_{model_round}/' with {dataset} dataset...", cfg=cfg)
144145
train(config=cfg, resources=train_init)

vulnscan/config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ def __init__(self, model_name: str = "Model_Sense.4n1"):
99
self.MODEL_NAME = model_name
1010
self.CACHE_DIR = os.path.join(os.getcwd(), "cache")
1111

12-
existing_rounds = self.__get_existing_rounds(self.CACHE_DIR) # Auto-increment round based on existing folders
12+
existing_rounds = self.get_existing_rounds(self.CACHE_DIR) # Auto-increment round based on existing folders
1313
self.MODEL_ROUND = max(existing_rounds) + 1 if existing_rounds else 1
1414

1515
self.LOG_FILE = f"{self.CACHE_DIR}/{self.MODEL_NAME}/training.log"
@@ -59,7 +59,7 @@ def __init__(self, model_name: str = "Model_Sense.4n1"):
5959
os.makedirs(self.DATA_CACHE_DIR, exist_ok=True)
6060

6161
@staticmethod
62-
def __get_existing_rounds(cache_dir: str) -> list[int]:
62+
def get_existing_rounds(cache_dir: str) -> list[int]:
6363
"""
6464
Returns a list of round numbers based on existing folders in the cache directory.
6565
"""

0 commit comments

Comments
 (0)