Skip to content

Commit b857105

Browse files
Refactor config and trainer scripts to rename get_existing_rounds method to __get_existing_rounds and update model round initialization logic
1 parent d9baee1 commit b857105

2 files changed

Lines changed: 4 additions & 5 deletions

File tree

v4n1_Trainer.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ def train(config: TrainingConfig, resources: dict):
124124

125125
# ----------------- RUN ------------------
126126
available_dataset = [10, 100, 1000, 5000, 10000, 17500, 25000]
127-
for dataset in available_dataset:
127+
for loop_idx, dataset in enumerate(available_dataset, start=1):
128128
if dataset <= 1000:
129129
name = "SenseNano"
130130
elif 1000 < dataset <= 5000:
@@ -133,8 +133,7 @@ def train(config: TrainingConfig, resources: dict):
133133
name = "Sense"
134134
else:
135135
name = "SenseMacro"
136-
existing_rounds = cfg.get_existing_rounds(cfg.CACHE_DIR) # Auto-increment round based on existing folders
137-
model_round = max(existing_rounds) + 1 if existing_rounds else 1
136+
model_round = loop_idx
138137
cfg.update({
139138
# Model / caching / logging
140139
"MODEL_NAME": f"Model_{name}.4n1", # Name of the model for identification and caching

vulnscan/config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ def __init__(self, model_name: str = "Model_Sense.4n1"):
99
self.MODEL_NAME = model_name
1010
self.CACHE_DIR = os.path.join(os.getcwd(), "cache")
1111

12-
existing_rounds = self.get_existing_rounds(self.CACHE_DIR) # Auto-increment round based on existing folders
12+
existing_rounds = self.__get_existing_rounds(self.CACHE_DIR) # Auto-increment round based on existing folders
1313
self.MODEL_ROUND = max(existing_rounds) + 1 if existing_rounds else 1
1414

1515
self.LOG_FILE = f"{self.CACHE_DIR}/{self.MODEL_NAME}/training.log"
@@ -59,7 +59,7 @@ def __init__(self, model_name: str = "Model_Sense.4n1"):
5959
os.makedirs(self.DATA_CACHE_DIR, exist_ok=True)
6060

6161
@staticmethod
62-
def get_existing_rounds(cache_dir: str) -> list[int]:
62+
def __get_existing_rounds(cache_dir: str) -> list[int]:
6363
"""
6464
Returns a list of round numbers based on existing folders in the cache directory.
6565
"""

0 commit comments

Comments
 (0)