Skip to content

Commit c01a814

Browse files
Enhance error handling and logging during initialization and training processes
1 parent 5201571 commit c01a814

File tree

1 file changed

+115
-82
lines changed

1 file changed

+115
-82
lines changed

v4n1_Trainer.py

Lines changed: 115 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import json
22
import os
3+
import sys
34

45
import torch
56
from sentence_transformers import SentenceTransformer
@@ -12,73 +13,99 @@
1213
# ---------------- INIT ----------------
1314
def init(config: TrainingConfig) -> dict:
1415
"""Initialize static, config-free resources (only once)."""
15-
log("Loading GPT-Neo tokenizer/model (static init)...", cfg=config, only_console=True)
16-
gpt_tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neo-1.3B")
17-
gpt_model = AutoModelForCausalLM.from_pretrained("EleutherAI/gpt-neo-1.3B")
18-
if gpt_tokenizer.pad_token is None:
19-
gpt_tokenizer.pad_token = gpt_tokenizer.eos_token
20-
21-
log("Loading MiniLM for embeddings (static init)...", cfg=config, only_console=True)
22-
embed_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
23-
24-
return {
25-
"gpt_tokenizer": gpt_tokenizer,
26-
"gpt_model": gpt_model,
27-
"embed_model": embed_model,
28-
}
16+
try:
17+
log("Loading GPT-Neo tokenizer/model (static init)...", cfg=config, only_console=True)
18+
gpt_tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neo-1.3B")
19+
gpt_model = AutoModelForCausalLM.from_pretrained("EleutherAI/gpt-neo-1.3B")
20+
if gpt_tokenizer.pad_token is None:
21+
gpt_tokenizer.pad_token = gpt_tokenizer.eos_token
22+
23+
log("Loading MiniLM for embeddings (static init)...", cfg=config, only_console=True)
24+
embed_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
25+
26+
return {
27+
"gpt_tokenizer": gpt_tokenizer,
28+
"gpt_model": gpt_model,
29+
"embed_model": embed_model,
30+
}
31+
except KeyboardInterrupt:
32+
sys.exit("Interrupted by user in initialization.")
33+
except Exception as err:
34+
sys.exit(f"Error during initialization: {err}")
2935

3036

3137
# ---------------- TRAIN ----------------
3238
def train(config: TrainingConfig, resources: dict):
33-
gpt_tokenizer = resources["gpt_tokenizer"]
34-
gpt_model = resources["gpt_model"].to(config.DEVICE) # attach to device here
35-
embed_model = resources["embed_model"]
36-
37-
log("Init DataGen with config...", cfg=config, silent=True)
38-
generate = DataGen(cfg=config)
39-
40-
# Generate dataset
41-
dataset_path = f"{config.DATASET_CACHE_DIR}/dataset_{config.DATASET_SIZE}.pt"
42-
if os.path.exists(dataset_path):
43-
data = torch.load(dataset_path)
44-
texts, labels = data["texts"], data["labels"]
45-
else:
46-
texts, labels = generate.dataset(gpt_tokenizer=gpt_tokenizer, gpt_model=gpt_model)
47-
torch.save({"texts": texts, "labels": labels}, dataset_path)
48-
49-
train_split = int(len(texts) * config.TRAIN_VAL_SPLIT)
50-
val_split = int(len(texts) * config.VAL_SPLIT)
51-
52-
train_texts, train_labels = texts[:train_split], labels[:train_split]
53-
val_texts, val_labels = texts[train_split:val_split], labels[train_split:val_split]
54-
test_texts, test_labels = texts[val_split:], labels[val_split:]
55-
56-
log("Generating test embeddings...", cfg=config)
57-
generate.embeddings(embed_model=embed_model, texts=test_texts, labels=test_labels, split="test")
58-
log("Generating train embeddings...", cfg=config)
59-
generate.embeddings(embed_model=embed_model, texts=train_texts, labels=train_labels, split="train")
60-
log("Generating validation embeddings...", cfg=config)
61-
generate.embeddings(embed_model=embed_model, texts=val_texts, labels=val_labels, split="validation")
62-
63-
train_dataset = EmbeddingDataset(config.EMBED_CACHE_DIR)
64-
val_dataset = EmbeddingDataset(config.EMBED_CACHE_DIR)
65-
val_loader = DataLoader(dataset=val_dataset, batch_size=config.BATCH_SIZE, shuffle=False)
66-
67-
train_ = Train(cfg=config)
68-
model = SimpleNN(input_dim=384).to(config.DEVICE)
69-
70-
# Run training (handles TRAIN_LOOPS internally)
71-
history_loops = train_.model(model=model, train_dataset=train_dataset, val_loader=val_loader)
72-
73-
# Plot + save history for each loop
74-
for i, history in enumerate(history_loops):
75-
plot_training(cfg=config, history_loops=history_loops)
76-
with open(
77-
f"{config.CACHE_DIR}/{config.MODEL_NAME}/round_{config.MODEL_ROUND}/training_history_loop{i + 1}.json",
78-
"w") as f:
79-
json.dump(history, f)
80-
81-
log("Training complete. All data, plots, and model saved.", cfg=config)
39+
part = "???"
40+
try:
41+
# Load resources from init
42+
part = "init resources loading"
43+
gpt_tokenizer = resources["gpt_tokenizer"]
44+
gpt_model = resources["gpt_model"].to(config.DEVICE) # attach to device here
45+
embed_model = resources["embed_model"]
46+
47+
# Initialise DataGen
48+
part = "initialising DataGen"
49+
log("Initialising DataGen with config...", cfg=config, silent=True)
50+
generate = DataGen(cfg=config)
51+
52+
# Generate dataset
53+
part = "generating/loading the dataset"
54+
dataset_path = f"{config.DATASET_CACHE_DIR}/dataset_{config.DATASET_SIZE}.pt"
55+
if os.path.exists(dataset_path):
56+
log("Loading existing dataset...", cfg=config)
57+
data = torch.load(dataset_path)
58+
texts, labels = data["texts"], data["labels"]
59+
else:
60+
log("Dataset not found, generating", cfg=config)
61+
texts, labels = generate.dataset(gpt_tokenizer=gpt_tokenizer, gpt_model=gpt_model)
62+
torch.save({"texts": texts, "labels": labels}, dataset_path)
63+
64+
# Split dataset
65+
part = "splitting the dataset"
66+
train_split = int(len(texts) * config.TRAIN_VAL_SPLIT)
67+
val_split = int(len(texts) * config.VAL_SPLIT)
68+
69+
train_texts, train_labels = texts[:train_split], labels[:train_split]
70+
val_texts, val_labels = texts[train_split:val_split], labels[train_split:val_split]
71+
test_texts, test_labels = texts[val_split:], labels[val_split:]
72+
73+
# Generate embeddings for all splits
74+
part = "generating the embeddings"
75+
log("Generating test embeddings...", cfg=config)
76+
generate.embeddings(embed_model=embed_model, texts=test_texts, labels=test_labels, split="test")
77+
log("Generating train embeddings...", cfg=config)
78+
generate.embeddings(embed_model=embed_model, texts=train_texts, labels=train_labels, split="train")
79+
log("Generating validation embeddings...", cfg=config)
80+
generate.embeddings(embed_model=embed_model, texts=val_texts, labels=val_labels, split="validation")
81+
82+
# Prepare datasets and dataloaders
83+
part = "preparing datasets and dataloaders"
84+
train_dataset = EmbeddingDataset(config.EMBED_CACHE_DIR)
85+
val_dataset = EmbeddingDataset(config.EMBED_CACHE_DIR)
86+
val_loader = DataLoader(dataset=val_dataset, batch_size=config.BATCH_SIZE, shuffle=False)
87+
88+
train_ = Train(cfg=config)
89+
model = SimpleNN(input_dim=384).to(config.DEVICE)
90+
91+
# Run training (handles TRAIN_LOOPS internally)
92+
part = "training the model"
93+
history_loops = train_.model(model=model, train_dataset=train_dataset, val_loader=val_loader)
94+
95+
# Plot + save history for each loop
96+
part = "plotting and saving training history"
97+
for i, history in enumerate(history_loops):
98+
plot_training(cfg=config, history_loops=history_loops)
99+
with open(
100+
f"{config.CACHE_DIR}/{config.MODEL_NAME}/round_{config.MODEL_ROUND}/training_history_loop{i + 1}.json",
101+
"w") as f:
102+
json.dump(history, f)
103+
104+
log("Training complete. All data, plots, and model saved.", cfg=config)
105+
except KeyboardInterrupt:
106+
sys.exit("Interrupted by user during training.")
107+
except Exception as err:
108+
sys.exit(f"Error during '{part}': {err}")
82109

83110

84111
if __name__ == "__main__":
@@ -118,22 +145,28 @@ def train(config: TrainingConfig, resources: dict):
118145
train_init = init(cfg)
119146

120147
# ----------------- RUN ------------------
121-
available_dataset = [10, 100, 1000, 5000, 10000, 17500, 25000]
122-
for loop_idx, dataset in enumerate(available_dataset, start=1):
123-
if dataset <= 1000:
124-
name = "SenseNano"
125-
elif 1000 < dataset <= 5000:
126-
name = "SenseMini"
127-
elif 5000 < dataset <= 10000:
128-
name = "Sense"
129-
else:
130-
name = "SenseMacro"
131-
model_round = loop_idx
132-
cfg.update({
133-
# Model / caching / logging
134-
"MODEL_NAME": f"Model_{name}.4n1", # Name of the model for identification and caching
135-
"DATASET_SIZE": dataset, # Number of samples to generate for training (not the same as for the training rounds themselves)
136-
"MODEL_ROUND": model_round # Current training round (auto-incremented)
137-
})
138-
log(message=f"Training 'Model_{name}.4n1/round_{model_round}/' with {dataset} dataset...", cfg=cfg)
139-
train(config=cfg, resources=train_init)
148+
try:
149+
available_dataset = [10, 100, 1000, 5000, 10000, 17500, 25000]
150+
for loop_idx, dataset in enumerate(available_dataset, start=1):
151+
if dataset <= 1000:
152+
name = "SenseNano"
153+
elif 1000 < dataset <= 5000:
154+
name = "SenseMini"
155+
elif 5000 < dataset <= 10000:
156+
name = "Sense"
157+
else:
158+
name = "SenseMacro"
159+
model_round = loop_idx
160+
cfg.update({
161+
# Model / caching / logging
162+
"MODEL_NAME": f"Model_{name}.4n1", # Name of the model for identification and caching
163+
"DATASET_SIZE": dataset,
164+
# Number of samples to generate for training (not the same as for the training rounds themselves)
165+
"MODEL_ROUND": model_round # Current training round (auto-incremented)
166+
})
167+
log(message=f"Training 'Model_{name}.4n1/round_{model_round}/' with {dataset} dataset...", cfg=cfg)
168+
train(config=cfg, resources=train_init)
169+
except KeyboardInterrupt:
170+
sys.exit("Interrupted by user in main.")
171+
except Exception as e:
172+
sys.exit(f"Error during training: {e}")

0 commit comments

Comments
 (0)