|
4 | 4 |
|
5 | 5 |
|
6 | 6 | class TrainingConfig: |
7 | | - def __init__(self, model_name: str = "Model_Sense.4n1"): |
| 7 | + def __init__(self): |
| 8 | + """ |
| 9 | + Configuration class for training settings and hyperparameters. |
| 10 | +
|
| 11 | + You must call the update method and set MODEL_NAME. |
| 12 | + """ |
| 13 | + |
8 | 14 | # Model / caching / logging |
9 | | - self.MODEL_NAME = model_name |
| 15 | + self.MODEL_NAME = None |
| 16 | + self.writer = None |
| 17 | + self.LOG_FILE = None |
| 18 | + self.EMBED_CACHE_DIR = None |
| 19 | + |
10 | 20 | self.CACHE_DIR = os.path.join(os.getcwd(), "cache") |
| 21 | + self.DATASET_CACHE_DIR = f"{self.CACHE_DIR}/dataset" |
11 | 22 |
|
12 | 23 | existing_rounds = self.__get_existing_rounds(self.CACHE_DIR) # Auto-increment round based on existing folders |
13 | 24 | self.MODEL_ROUND = max(existing_rounds) + 1 if existing_rounds else 1 |
14 | 25 |
|
15 | | - self.LOG_FILE = f"{self.CACHE_DIR}/{self.MODEL_NAME}/training.log" |
16 | | - self.EMBED_CACHE_DIR = f"{self.CACHE_DIR}/{self.MODEL_NAME}/round_{self.MODEL_ROUND}/embeddings" |
17 | | - self.DATA_CACHE_DIR = f"{self.CACHE_DIR}/dataset" |
18 | | - |
19 | | - # TensorBoard |
20 | | - self.writer = SummaryWriter(log_dir=f"{self.CACHE_DIR}/{self.MODEL_NAME}/round_{self.MODEL_ROUND}/tensorboard_logs") |
21 | | - |
22 | 26 | # Training parameters |
23 | 27 | self.BATCH_SIZE: int = 16 |
24 | 28 | self.MAX_EPOCHS: int = 35 |
@@ -53,11 +57,6 @@ def __init__(self, model_name: str = "Model_Sense.4n1"): |
53 | 57 | self.DEVICE: str = "cuda" if torch.cuda.is_available() else "cpu" |
54 | 58 | self.RAM_THRESHOLD: float = 0.85 |
55 | 59 |
|
56 | | - # Create necessary folders |
57 | | - os.makedirs(self.CACHE_DIR, exist_ok=True) |
58 | | - os.makedirs(self.EMBED_CACHE_DIR, exist_ok=True) |
59 | | - os.makedirs(self.DATA_CACHE_DIR, exist_ok=True) |
60 | | - |
61 | 60 | @staticmethod |
62 | 61 | def __get_existing_rounds(cache_dir: str) -> list[int]: |
63 | 62 | """ |
@@ -95,8 +94,9 @@ def update(self, updates): |
95 | 94 | if 'MODEL_NAME' in dict(items) or 'CACHE_DIR' in dict(items) or 'MODEL_ROUND' in dict(items): |
96 | 95 | self.LOG_FILE = f"{self.CACHE_DIR}/{self.MODEL_NAME}/training.log" |
97 | 96 | self.EMBED_CACHE_DIR = f"{self.CACHE_DIR}/{self.MODEL_NAME}/round_{self.MODEL_ROUND}/embeddings" |
98 | | - self.writer = SummaryWriter(log_dir=f"{self.CACHE_DIR}/{self.MODEL_NAME}/round_{self.MODEL_ROUND}/tensorboard_logs") |
| 97 | + self.writer = SummaryWriter( |
| 98 | + log_dir=f"{self.CACHE_DIR}/{self.MODEL_NAME}/round_{self.MODEL_ROUND}/tensorboard_logs") |
99 | 99 | os.makedirs(self.EMBED_CACHE_DIR, exist_ok=True) |
100 | | - if 'CACHE_DIR' in dict(items): |
101 | | - self.DATA_CACHE_DIR = f"{self.CACHE_DIR}/dataset" |
102 | | - os.makedirs(self.DATA_CACHE_DIR, exist_ok=True) |
| 100 | + if 'DATASET_CACHE_DIR' in dict(items): |
| 101 | + self.DATASET_CACHE_DIR = f"{self.CACHE_DIR}/dataset" |
| 102 | + os.makedirs(self.DATASET_CACHE_DIR, exist_ok=True) |
0 commit comments